import numpy as np import pandas as pd # data type to link together trap and cell ids i_dtype = {"names": ["trap_id", "cell_id"], "formats": [np.int64, np.int64]} def validate_lineage( lineage: np.ndarray, indices: np.ndarray, how: str = "families", ): """ Identify mother-bud pairs both in lineage and a Signal's indices. We expect the lineage information to be unique: a bud should not have two mothers. Lineage is returned with buds assigned only to their first mother if they have multiple. Parameters ---------- lineage : np.ndarray 2D array of lineage associations where columns are (trap, mother, daughter) or a 3D array, which is an array of 2 X 2 arrays comprising [[trap_id, mother_label], [trap_id, daughter_label]]. indices : np.ndarray A 2D array of cell indices from a Signal, (trap_id, cell_label). This array should not include mother_label. how: str If "mothers", matches indicate mothers from mother-bud pairs; If "daughters", matches indicate daughters from mother-bud pairs; If "families", matches indicate mothers and daughters in mother-bud pairs. Returns ------- valid_lineage: boolean np.ndarray 1D array indicating matched elements in lineage. valid_indices: boolean np.ndarray 1D array indicating matched elements in indices. lineage: np.ndarray Any bud already having a mother that is assigned to another has that second assignment discarded. Examples -------- >>> import numpy as np >>> from agora.utils.indexing import validate_lineage >>> lineage = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ]) >>> indices = np.array([ [0, 1], [0, 2], [0, 3]]) >>> valid_lineage, valid_indices, lineage = validate_lineage(lineage, indices) >>> print(valid_lineage) array([ True, False, False, False]) >>> print(valid_indices) array([ True, False, True]) and >>> lineage = np.array([[[0,3], [0,1]], [[0,2], [0,4]]]) >>> indices = np.array([[0,1], [0,2], [0,3]]) >>> valid_lineage, valid_indices, lineage = validate_lineage(lineage, indices) >>> print(valid_lineage) array([ True, False]) >>> print(valid_indices) array([ True, False, True]) """ if lineage.ndim == 2: # [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]] lineage = assoc_indices_to_3d(lineage) invert_lineage = True if how == "mothers": c_index = 0 elif how == "daughters": c_index = 1 # if buds have two mothers, pick the first one lineage = lineage[ ~pd.DataFrame(lineage[:, 1, :]).duplicated().values, :, : ] # find valid lineage valid_lineages = index_isin(lineage, indices) if how == "families": # both mother and bud must be in indices valid_lineage = valid_lineages.all(axis=1) else: valid_lineage = valid_lineages[:, c_index, :] flat_valid_lineage = valid_lineage.flatten() # find valid indices selected_lineages = lineage[flat_valid_lineage, ...] if how == "families": # select only pairs of mother and bud indices valid_indices = index_isin(indices, selected_lineages) else: valid_indices = index_isin(indices, selected_lineages[:, c_index, :]) flat_valid_indices = valid_indices.flatten() # put the corrected lineage in the right format if invert_lineage: lineage = assoc_indices_to_2d(lineage) return flat_valid_lineage, flat_valid_indices, lineage def index_isin(x: np.ndarray, y: np.ndarray) -> np.ndarray: """ Find those elements of x that are in y. Both arrays must be arrays of integer indices, such as (trap_id, cell_id). """ x = np.ascontiguousarray(x, dtype=np.int64) y = np.ascontiguousarray(y, dtype=np.int64) xv = x.view(i_dtype) inboth = np.intersect1d(xv, y.view(i_dtype)) x_bool = np.isin(xv, inboth) return x_bool def assoc_indices_to_3d(ndarray: np.ndarray): """ Convert the last column to a new row and repeat first column's values. For example: [trap, mother, daughter] becomes [[trap, mother], [trap, daughter]]. Assumes the input array has shape (N,3). """ result = ndarray if len(ndarray) and ndarray.ndim > 1: # faster indexing for single positions if ndarray.shape[1] == 3: result = np.transpose( np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2), axes=[0, 2, 1], ) else: # 20% slower but more general indexing columns = np.arange(ndarray.shape[1]) result = np.stack( ( ndarray[:, np.delete(columns, -1)], ndarray[:, np.delete(columns, -2)], ), axis=1, ) return result def assoc_indices_to_2d(array: np.ndarray): """Convert indices to 2d.""" result = array if len(array): result = np.concatenate( (array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1 ) return result