Source code for spateo.tools.labels

import copy
from typing import List, Tuple, Union

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import numpy as np
import scipy
from anndata import AnnData

from ..logging import logger_manager as lm

# --------------------------------------- Normalizing sparse arrays --------------------------------------- #


[docs]def row_normalize( graph: scipy.sparse.csr_matrix, copy: bool = False, verbose: bool = True, ) -> scipy.sparse.csr_matrix: """Normalize a compressed sparse row (CSR) matrix by row- written for sparse pairwise distance arrays, but can be applied to any sparse matrix. Args: graph: Sparse array of shape [n_samples, n_features]. If pairwise distance array, shape [n_samples, n_samples]. copy: If True, create a copy of the graph before computations so that the original is preserved. verbose: If True, prints number of nonzero entries. Returns: graph: Input array (or the copy of the input array) post-normalization. """ logger = lm.get_main_logger() if copy: logger.info( "Deep copying AnnData object and working on the new copy. Original AnnData object will not be modified.", indent_level=1, ) graph = graph.copy() data = graph.data for start_ptr, end_ptr in zip(graph.indptr[:-1], graph.indptr[1:]): row_sum = data[start_ptr:end_ptr].sum() if row_sum != 0: data[start_ptr:end_ptr] /= row_sum if verbose: logger.info( f"Computed normalized sum from ptr {start_ptr} to {end_ptr}. " f"Total entries: {end_ptr - start_ptr}, sum: {np.sum(graph.data[start_ptr:end_ptr])}" ) return graph
# --------------------------------------- Label class --------------------------------------- #
[docs]class Label(object): """Given categorizations for a set of points, wrap into a Label class. labels_dense: Numerical labels. str_map: Optional mapping of numerical labels (keys) to strings (values). verbose: whether to print running info of row_normalize. """ def __init__( self, labels_dense: Union[np.ndarray, list], str_map: Union[None, dict] = None, verbose: bool = False, ) -> None: logger = lm.get_main_logger() # Check type, dimensions, ensure all elements non-negative if isinstance(labels_dense, list): labels_dense = np.asarray(labels_dense, dtype=np.int32) elif isinstance(labels_dense, np.ndarray): pass else: logger.error(f"Labels provided are of type {type(labels_dense)}. Should be list or 1-dimensional ndarray.") raise TypeError( f"Labels provided are of type {type(labels_dense)}. " f"Should be list or 1-dimensional numpy ndarray.\n" ) if labels_dense.ndim != 1: logger.error(f"Label array has {labels_dense.ndim} dimensions, should be 1-dimensional.") raise ValueError(f"Label array has {labels_dense.ndim} dimensions, " f"should be 1-dimensional.") if not np.issubdtype(labels_dense.dtype, np.integer): logger.error(f"Label array data type is {labels_dense.dtype}, should be integer.") raise TypeError(f"Label array data type is {labels_dense.dtype}, " f"should be integer.") if np.amin(labels_dense) < 0: logger.error(f"Some of the labels have negative values. All labels must be 0 or positive integers.") raise ValueError( f"Some of the labels have negative values.\n" f"All labels must be 0 or positive integers.\n" ) # Initialize attributes
[docs] self.dense = labels_dense
# Total number of data-points with label (e.g. number of cells)
[docs] self.num_samples = len(labels_dense)
# Number of instances of each integer up to maximum label id
[docs] self.bins = np.bincount(self.dense)
# Unique labels (all non-negative integers)
[docs] self.ids = np.nonzero(self.bins)[0]
# Counts per label (same order as self.ids)
[docs] self.counts = self.bins[self.ids]
# Highest integer id for a label
[docs] self.max_id = np.amax(self.ids)
# Total number of labels
[docs] self.num_labels = len(self.ids)
# Verbose
[docs] self.verbose = verbose
# Mapping from numerical labels to strings if str_map is not None: self.str_map = str_map self.str_labels = list(map(self.str_map.get, labels_dense)) self.str_ids = list(map(self.str_map.get, self.ids))
[docs] self.onehot = None
[docs] self.normalized_onehot = None
[docs] def __repr__(self) -> str: return f"{self.num_labels} labels, {self.num_samples} samples, " f"ids: {self.ids}, counts: {self.counts}"
[docs] def __str__(self) -> str: return ( f"Label object:\n" f"Number of labels: {self.num_labels}, " f"number of samples: {self.num_samples}\n" f"ids: {self.ids}, counts: {self.counts},\n" )
[docs] def get_onehot(self) -> scipy.sparse.csr_matrix: """return one-hot sparse array of labels. If not already computed, generate the sparse array from dense label array """ if self.onehot is None: self.onehot = self.generate_onehot() return self.onehot
[docs] def get_normalized_onehot(self) -> scipy.sparse.csr_matrix: """Return normalized one-hot sparse array of labels.""" if self.normalized_onehot is None: self.normalized_onehot = self.generate_normalized_onehot() return self.normalized_onehot
[docs] def generate_normalized_onehot(self) -> scipy.sparse.csr_matrix: """Generate a normalized onehot matrix where each row is normalized by the count of that label e.g. a row [0 1 1 0 0] will be converted to [0 0.5 0.5 0 0] """ return row_normalize(self.get_onehot().astype(np.float64), verbose=self.verbose, copy=True)
[docs] def generate_onehot(self) -> scipy.sparse.csr_matrix: """Convert an array of labels to a num_labels x num_samples sparse one-hot matrix Labels MUST be integers starting from 0, but can have gaps in between e.g. [0,1,5,9] """ logger = lm.get_main_logger() # Initialize the fields of the CSR indptr = np.zeros((self.num_labels + 1,), dtype=np.int32) indices = np.zeros((self.num_samples,), dtype=np.int32) data = np.ones_like(indices, dtype=np.int32) logger.info( f"\n--- {self.num_labels} labels, " f"{self.num_samples} samples ---\n" f"initialized {indptr.shape} index ptr: {indptr}\n" f"initialized {indices.shape} indices: {indices}\n" f"initialized {data.shape} data: {data}\n" ) # Update index pointer and indices row by row for n, label in enumerate(self.ids): label_indices = np.nonzero(self.dense == label)[0] label_count = len(label_indices) previous_ptr = indptr[n] current_ptr = previous_ptr + label_count indptr[n + 1] = current_ptr if self.verbose: logger.info( f"indices for label {label}: {label_indices}\n" f"previous pointer: {previous_ptr}, " f"current pointer: {current_ptr}\n" ) if current_ptr > previous_ptr: indices[previous_ptr:current_ptr] = label_indices return scipy.sparse.csr_matrix((data, indices, indptr), shape=(self.num_labels, self.num_samples))
# --------------------------------------- Label Curation and Label Processing --------------------------------------- #
[docs]def _rand_binary_array(array_length, num_onbits): array = np.zeros(array_length, dtype=np.int32) array[:num_onbits] = 1 np.random.shuffle(array) return array
[docs]def expand_labels( label: Label, max_label_id: int, sort_labels: bool = False, ) -> Label: """Spread out label IDs such that they range evenly from 0 to max_label_id, e.g. [0 1 2] -> [0 5 10] Useful if you need to be consistent with other label sets with many more label IDs. This spreads labels out along the color spectrum/map so that the colors are not too similar to each other. Use sort_labels if the list of IDs are not already sorted (although IDs are typically already sorted) """ logger = lm.get_main_logger() logger.info(f"Expanding labels with ids: {label.ids} so that ids range from 0 to {max_label_id}") if sort_labels: ids = np.sort(copy.copy(label.ids)) else: ids = copy.copy(label.ids) # Make sure smallest label ID is zero ids_zeroed = ids - np.amin(label.ids) num_extra_labels = max_label_id - np.amax(ids_zeroed) multiple, remainder = np.divmod(num_extra_labels, label.num_labels - 1) # Insert regular spaces between each id inserted = np.arange(label.num_labels) * multiple # Insert remaining spaces so that max label id equals given max_id extra = _rand_binary_array(label.num_labels - 1, remainder) expanded_ids = ids_zeroed + inserted expanded_ids[1:] += np.cumsum(extra) # only add to 2nd label and above logger.info( f"Label ids zerod: {ids_zeroed}.\n" f"{multiple} to be inserted between each id: {inserted}\n" f"{remainder} extra rows to be randomly inserted: {extra}\n" f"New ids: {expanded_ids}" ) expanded_dense = (expanded_ids @ label.get_onehot()).astype(np.int32) return Label(expanded_dense)
[docs]def match_labels( labels_1: Label, labels_2: Label, extra_labels_assignment: str = "random", verbose: bool = False, ) -> Label: """Match second set of labels to first, returning a new Label object Uses scipy's version of the Hungarian algorithm (linear_sum_assigment) """ logger = lm.get_main_logger() max_id = max(labels_1.max_id, labels_2.max_id) num_extra_labels = labels_2.num_labels - labels_1.num_labels logger.info( f"Matching {labels_2.num_labels} labels against {labels_1.num_labels} labels.\n" f"highest label ID in both is {max_id}.\n" ) onehot_1, onehot_2 = labels_1.get_onehot(), labels_2.get_onehot() cost_matrix = (onehot_1 @ onehot_2.T).toarray() labels_match_1, labels_match_2 = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True) logger.info("\nMatches:\n", list(zip(labels_match_1, labels_match_2))) # Temporary list keeping track of which labels are still available for use available_labels = list(range(max_id + 1)) # List to be filled with new label ids relabeled_ids = -1 * np.ones((labels_2.num_labels,), dtype=np.int32) # Reassign labels for index_1, index_2 in zip(labels_match_1, labels_match_2): label_1 = labels_1.ids[index_1] label_2 = labels_2.ids[index_2] if verbose: logger.info( f"Assigning first set's {label_1} to " f"second set's {label_2}.\n" f"labels_left: {available_labels}" ) relabeled_ids[index_2] = label_1 available_labels.remove(label_1) # Assign remaining labels (if 2nd has more labels than 1st) if num_extra_labels > 0: unmatched_indices = np.nonzero(relabeled_ids == -1)[0] assert num_extra_labels == len(unmatched_indices), ( f"number of unmatched label IDs {len(unmatched_indices)} does not match mumber of " f"extra labels in second set {num_extra_labels}.\n" ) if extra_labels_assignment == "random": relabeled_ids[unmatched_indices] = np.random.choice(available_labels, size=num_extra_labels, replace=False) elif extra_labels_assignment == "greedy": def _insert_label( array: np.ndarray, max_length: int, added_labels: list = [], ) -> Tuple[np.ndarray, int, list]: """ Insert a label in the middle of the largest interval Assumes array is already sorted! """ if len(array) >= max_length: return array, max_length, added_labels else: intervals = array[1:] - array[:-1] max_interval_index = np.argmax(intervals) increment = intervals[max_interval_index] // 2 label_to_add = array[max_interval_index] + increment inserted_array = np.insert( array, max_interval_index + 1, label_to_add, ) added_labels.append(label_to_add) return _insert_label(inserted_array, max_length, added_labels) sorted_matched = np.sort(relabeled_ids[relabeled_ids != -1]) logger.info(f"already matched ids (sorted): {sorted_matched}") _, _, added_labels = _insert_label(sorted_matched, labels_2.num_labels) relabeled_ids[unmatched_indices] = np.random.choice(added_labels, size=num_extra_labels, replace=False) else: logger.error(f"Extra labels assignment method not recognised, should be random or greedy.") logger.info(f"\nRelabeled labels: {relabeled_ids}\n") relabeled_dense = (relabeled_ids @ onehot_2).astype(np.int32) return Label(relabeled_dense)
[docs]def match_label_series( label_list: List[Label], least_labels_first: bool = True, extra_labels_assignment: str = "greedy", ) -> Tuple[List[Label], int]: """Match a list of labels to each other, one after another in order of increasing (if least_labels_first is true) or decreasing (least_labels_first set to false) number of label ids. Returns the relabeled list in original order. """ logger = lm.get_main_logger() num_label_list = [label.num_labels for label in label_list] max_num_labels = max(num_label_list) sort_indices = np.argsort(num_label_list) logger.info( f"\nMaximum number of labels across all datasets = {max_num_labels}\n" f"Indices of sorted list: {sort_indices}\n" ) ordered_relabels = [] if least_labels_first: ordered_relabels.append(expand_labels(label_list[sort_indices[0]], max_num_labels - 1)) logger.info(f"First label, expanded label ids: {ordered_relabels[0]}") else: # Argsort is in ascending order, reverse it sort_indices = sort_indices[:, :, -1] # Already has max number of labels, no need to expand ordered_relabels.append(label_list[sort_indices[0]]) for index in sort_indices[1:]: current_label = label_list[index] previous_label = ordered_relabels[-1] logger.info(f"\nRelabeling:\n{current_label}\n" f"with reference to\n{previous_label}\n" + "-" * 70 + "\n") relabeled = match_labels(previous_label, current_label, extra_labels_assignment=extra_labels_assignment) ordered_relabels.append(relabeled) sort_indices_list = list(sort_indices) original_order_relabels = [ordered_relabels[sort_indices_list.index(n)] for n in range(len(label_list))] return original_order_relabels, max_num_labels
[docs]def interlabel_connections( label: Label, weights_matrix: Union[scipy.sparse.csr_matrix, np.ndarray], ) -> np.ndarray: """Compute connections strength between labels (based on pairwise distances), normalized by counts of each label Args: class: Instance of class 'Label', with one-hot dense label matrix in "dense", list of unique labels in "ids", counts per label in "counts", etc. weights_matrix: Pairwise adjacency matrix, weighted by e.g. spatial distance between points. Returns: connections: Pairwise connection strength array, shape [n_labels, n_labels]. """ logger = lm.get_main_logger() if weights_matrix.ndim != 2: logger.error(f"Weights matrix has {weights_matrix.ndim} dimensions, should be 2.") if weights_matrix.shape[0] != weights_matrix.shape[1] != label.num_samples: logger.error(f"Weights matrix dimensions do not match number of samples.") normalized_onehot = label.generate_normalized_onehot() logger.info( f"Matrix multiplying labels x weights x labels-transpose, shape {normalized_onehot.shape} x " f"{weights_matrix.shape} x {normalized_onehot.T.shape}." ) connections = normalized_onehot @ weights_matrix @ normalized_onehot.T if scipy.sparse.issparse(connections): connections = connections.toarray() return connections
[docs]def create_label_class( adata: AnnData, cat_key: Union[str, List[str]], ) -> Union[Label, List[Label]]: """Wraps categorical labels into custom Label class for downstream processing. Args: adata: An anndata object. cat_key: Keys in .obs containing categorical labels. This function and the Label class provide the most utility when this is used in conjunction with the results of multiple different runs of the Louvain algorithm. Returns: label: Either an object of Label class or a list where each element is an object of Label class. Will return a list if given multiple arguments to 'cat_key'. """ # Convert categorical labels to numerical and save mapping to have both numerical and categorical labels: if isinstance(cat_key, str): str_cat = np.unique(adata.obs[cat_key].values) num_cat = range(len(str_cat)) map_dict = dict(zip(num_cat, str_cat)) all_num_labels = adata.obs[cat_key].replace(str_cat, num_cat) label = Label(all_num_labels.to_numpy(), str_map=map_dict) return label else: all_labels = [] for key in cat_key: str_cat = np.unique(adata.obs[key].values) num_cat = range(len(str_cat)) map_dict = dict(zip(num_cat, str_cat)) all_num_labels = adata.obs[key].replace(str_cat, num_cat) label = Label(all_num_labels.to_numpy(), str_map=map_dict) all_labels.append(label) return all_labels