Source code for adata_query._core._formatter


# -- import packages: ----------------------------------------------------------
import ABCParse
import autodevice
import anndata
import numpy as np
import torch as _torch


# -- set typing: ---------------------------------------------------------------
from typing import Union


# -- operational class: --------------------------------------------------------
[docs] class DataFormatter(ABCParse.ABCParse): """DataFormatter cls""" def __init__(self, data: Union[_torch.Tensor, np.ndarray], *args, **kwargs) -> None: """Format data to interface with numpy or torch, on a specified device. Args: data (``Union[np.ndarray, torch.Tensor]``): Input ``data`` to be formatted. Typically an ``np.ndarray``, ``torch.Tensor``, or ``ArrayView``. Returns: None """ self.__parse__(locals()) @property def device_type(self) -> str: """Returns device type """ if hasattr(self._data, "device"): return self._data.device.type return "cpu" @property def is_ArrayView(self) -> bool: """Checks if device is of type ArrayView""" return isinstance(self._data, anndata._core.views.ArrayView) @property def is_numpy_array(self) -> bool: """Checks if device is of type np.ndarray""" return isinstance(self._data, np.ndarray) @property def is_torch_Tensor(self) -> bool: """Checks if device is of type torch.Tensor""" return isinstance(self._data, _torch.Tensor) @property def on_cpu(self) -> bool: """Checks if device is on cuda or mps""" return self.device_type == "cpu" @property def on_gpu(self) -> bool: """Checks if device is on cuda or mps""" return self.device_type in ["cuda", "mps"]
[docs] def to_numpy(self) -> np.ndarray: """Sends data to np.ndarray""" if self.is_torch_Tensor: if self.on_gpu: return self._data.detach().cpu().numpy() return self._data.numpy() elif self.is_ArrayView: return self._data.toarray() return self._data
[docs] def to_torch(self, device: _torch.device = autodevice.AutoDevice()) -> _torch.Tensor: """Description of function. Args: device (``Optional[torch.device]``): If ``torch==True``, the device (e.g.: ``"cpu"``, ``"cuda:0"``, ``"mps:0"``) may be set. The default value, ``autodevice.AutoDevice()`` will indicate the use of GPU, if available. - **Default**: ``autodevice.AutoDevice()`` Returns: ``Union[np.ndarray, torch.Tensor]``: ``formatted_data`` """ self.__update__(locals()) if self.is_torch_Tensor: return self._data.to(self._device) elif self.is_ArrayView: self._data = self._data.toarray() return _torch.Tensor(self._data).to(self._device)
# -- functional wrap: ----------------------------------------------------------
[docs] def format_data( data: Union[np.ndarray, _torch.Tensor], torch: bool = False, device: _torch.device = autodevice.AutoDevice(), ) -> Union[np.ndarray, _torch.Tensor]: """Format data to interface with numpy or torch, on a specified device. Args: data (``Union[np.ndarray, torch.Tensor]``): Input ``data`` to be formatted. Typically an ``np.ndarray``, ``torch.Tensor``, or ``ArrayView``. torch (``Optional[bool]``): Toggle whether data should be formatted as ``torch.Tensor`` or ``np.ndarray``. - **Default**: ``False`` device (``Optional[torch.device]``): If ``torch==True``, the device (e.g.: ``"cpu"``, ``"cuda:0"``, ``"mps:0"``) may be set. The default value, ``autodevice.AutoDevice()`` will indicate the use of GPU, if available. - **Default**: ``autodevice.AutoDevice()`` Returns: ``Union[np.ndarray, torch.Tensor]``: ``formatted_data`` """ formatter = DataFormatter(data=data) if torch: return formatter.to_torch(device=device) return formatter.to_numpy()