Source code for adata_query._core._locator


# -- set typing: ---------------------------------------------------------------
from typing import List, Optional


# -- import packages: ----------------------------------------------------------
import ABCParse
import anndata
import numpy as np


# -- operational class: --------------------------------------------------------
[docs] class AnnDataLocator(ABCParse.ABCParse): """AnnDataLocator cls""" def __init__(self, searchable: Optional[List[str]] = None, *args, **kwargs) -> None: """Query available key values of AnnData. Operational class powering the ``locate`` function. Args: searchable (``Optional[List[str]]``): baseline query terms. - **Default**: ``None`` Returns: ``None``: Initializes class object. """ self._ATTRS = {} self._searchable = ["X"] if not searchable is None: self._searchable += searchable def _stash(self, attr: str, attr_val: np.ndarray) -> None: """TBD: Description. Args: attr (``str``): attribute name to stash. attr_val (``np.ndarray``): array value, linked to attribute to be stashed. Returns: None, updates `self._ATTRS` and sets the (attr, attr_val) key, value pair. """ self._ATTRS[attr] = attr_val setattr(self, attr, attr_val) def _intake(self, adata: anndata.AnnData) -> None: """TBD: Description. Args: adata (``anndata.AnnData``): param description. Returns: ``None``: Sets class attributes. """ for attr in adata.__dir__(): if "key" in attr: attr_val = getattr(adata, attr)() self._stash(attr, attr_val) if attr in ["layers", "obsp"]: attr_val = list(getattr(adata, attr)) self._stash(attr, attr_val) if attr in self._searchable: self._stash(attr, attr) def _cross_reference(self, passed_key: str) -> List[str]: """Description. Args: passed_key (``str``): param description. Returns: ``List[str]``: param description. """ return [key for key, val in self._ATTRS.items() if passed_key in val] def _query_str_vals(self, query_result: List[str]) -> str: """Description. Args: query_result (``List[str]``): param description. Returns: ``str``: param description. """ return ", ".join(query_result) def _format_error_msg(self, key: str, query_result: List[str]) -> str: """Description. Args key (``str``): problematic key. query_result (``List[str]``): returned list of keys. Returns: ``str``: formatted_message. """ if len(query_result) > 1: return f"Found more than one match: [{self._query_str_vals(query_result)}]" return f"{key} NOT FOUND" def _format_output_str(self, query_result: List[str]) -> str: """Description. Args: query_result (``List[str]``): Returns: ``str``: Chosen attribute of ``adata`` containing the passed key. """ return query_result[0].split("_keys")[0] def _forward(self, adata: anndata.AnnData, key: str) -> str: """Description. Args: adata (``anndata.AnnData``): The [annotated] single-cell data matrix of shape: ``[n_obs × n_vars]``. Rows correspond to cells and columns to genes. [1]. key (``str``): Key to access a matrix in adata. For example, if you wanted to access ``adata.obsm['X_pca']``, you would pass: ``"X_pca"``. Returns: ``str``: Attribute of ``adata`` containing the passed key. """ self._intake(adata) query_result = self._cross_reference(passed_key=key) if len(query_result) != 1: raise KeyError(self._format_error_msg(key, query_result)) return self._format_output_str(query_result) def __call__(self, adata: anndata.AnnData, key: str) -> str: """Operator that mediates the retrieval of a matrix from ``adata``. Args: adata (``anndata.AnnData``): The [annotated] single-cell data matrix of shape: ``[n_obs × n_vars]``. Rows correspond to cells and columns to genes. [1]. key (``str``): Key to access a matrix in adata. For example, if you wanted to access ``adata.obsm['X_pca']``, you would pass: ``"X_pca"``. Returns: ``str``: Attribute of adata containing the passed key. """ return self._forward(adata, key)
# -- API-facing function: -----------------------------------------------------
[docs] def locate(adata: anndata.AnnData, key: str) -> str: """Given ``adata`` and a key that points to a specific matrix stored in ``adata``, return the data, formatted either as ``np.ndarray`` or ``torch.Tensor``. If formatted as ``torch.Tensor``, device may be specified based on available devices. Args: adata (``anndata.AnnData``): The [annotated] single-cell data matrix of shape: ``[n_obs × n_vars]``. Rows correspond to cells and columns to genes. [1]. key (``str``): Key to access a matrix in adata. For example, if you wanted to access ``adata.obsm['X_pca']``, you would pass: ``"X_pca"``. Returns: ``str``: Attribute of adata containing the passed key. """ locator = AnnDataLocator() return locator(adata=adata, key=key)