diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py index 0e50be02012bf7ca0198dbdaf07e28d984c976d8..a86fce665cc8aa610fc9a96458d76f2f4109e991 100644 --- a/src/agora/utils/indexing.py +++ b/src/agora/utils/indexing.py @@ -21,6 +21,8 @@ def validate_association( We use broadcasting for speed. + Both mother and bud in association must be in indices. + Parameters ---------- association : np.ndarray @@ -49,12 +51,22 @@ def validate_association( >>> indices = np.array([ [0, 1], [0, 2], [0, 3]]) >>> print(indices.T) - >>> valid_associations, valid_indices = validate_association(association, indices) + >>> valid_association, valid_indices = validate_association(association, indices) + + >>> print(valid_association) + array([ True, False, False, False]) + >>> print(valid_indices) + array([ True, False, True]) + + and - >>> print(valid_associations) - array([ True, False, False, False]) + >>> association = np.array([[[0,3], [0,1]], [[0,2], [0,4]]]) + >>> indices = np.array([[0,1], [0,2], [0,3]]) + >>> valid_association, valid_indices = validate_association(association, indices) + >>> print(valid_association) + array([ True, False]) >>> print(valid_indices) - array([ True, False, True]) + array([ True, False, True]) """ if association.ndim == 2: # reshape into 3D array for broadcasting @@ -69,33 +81,29 @@ def validate_association( valid_ndassociation = ( association[..., np.newaxis] == indicesT[np.newaxis, ...] ) + # find matches in association + ### # make True comparisons have both trap_ids and cell labels matching valid_cell_ids = valid_ndassociation.all(axis=2) if match_column is None: - # 1. find matches in association # make True comparisons match at least one row in indices va_intermediate = valid_cell_ids.any(axis=2) - # make True comparisons have both mother and daughter matching rows in indices + # make True comparisons have both mother and bud matching rows in indices valid_association = va_intermediate.all(axis=1) - # 2. find matches in indices - # make True comparisons match for at least one mother or daughter in association - ind_intermediate = valid_cell_ids.any(axis=1) - # make True comparisons match for at least one row in association - valid_indices = ind_intermediate.any(axis=0) - # OLD - # valid_indices = ( - # valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1)) - # ) else: # match_column selects mothers if 0 and daughters if 1 # make True match at least one row in indices valid_association = valid_cell_ids[:, match_column].any(axis=1) - # make True match at least one row in association - valid_indices = valid_cell_ids[:, match_column].any(axis=0) - # OLD - # valid_association = ( - # valid_cell_ids[:, match_column] & valid_indices - # ).any(axis=1) + # find matches in indices + ### + # make True comparisons have a validated association for both the mother and bud + # make True comparisons have both trap_ids and cell labels matching + valid_cell_ids_va = valid_ndassociation[valid_association].all(axis=2) + if match_column is None: + # make True comparisons match at least one mother or bud in association + valid_indices = valid_cell_ids_va.any(axis=1)[0] + else: + valid_indices = valid_cell_ids_va[:, match_column][0] return valid_association, valid_indices @@ -132,7 +140,7 @@ def _assoc_indices_to_3d(ndarray: np.ndarray): def _3d_index_to_2d(array: np.ndarray): - """Perform opposite switch to _assoc_indices_to_3d.""" + """Revert switch from _assoc_indices_to_3d.""" result = array if len(array): result = np.concatenate( @@ -143,7 +151,8 @@ def _3d_index_to_2d(array: np.ndarray): def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray: """ - Fetch two 2-D indices and return a binary 2-D matrix - where a True value links two cells where all cells are the same. + Compare two 2D arrays using broadcasting. + + Return a binary array where a True value links two cells where all cells are the same. """ - return (x[..., None] == y.T[None, ...]).all(axis=1) + return (x[..., np.newaxis] == y.T[np.newaxis, ...]).all(axis=1) diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index e6d8b5982445b7419923b849e95c21584c2f31e1..168b022af5bccb85139cfd2e4c385da840f56b45 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -12,6 +12,14 @@ from postprocessor.core.lineageprocess import LineageProcess class PickerParameters(ParametersABC): + """ + A dictionary specifying the sequence of picks in order. + + "lineage" is further specified by "mothers", "daughters", "families" (mother-bud pairs), and "orphans", where orphans picks cells that are not in families. + + "condition" is further specified by "present", "continuously_present", "any_present", or "growing" and a threshold, either a number of time points or a fraction of the total duration of the experiment. + """ + _defaults = { "sequence": [ ["lineage", "families"], @@ -22,18 +30,7 @@ class PickerParameters(ParametersABC): class Picker(LineageProcess): """ - - Parameters - ---------- - cells: Cell object - Passed to the class's constructor. - condition: tuple - Tuple with condition and associated parameter(s), conditions can be - "present", "nonstoply_present" or "quantile". - Determine the thresholds or fractions of signals to use. - lineage: str - Either "mothers", "daughters", "families" (mothers AND daughters), "orphans". Mothers or daughters picks cells with those tags, families picks the union of - both, and orphans picks the difference between the total and families. + Picker selects cells from a signal using lineage information and by how long and by how they are retained in the data set. """ def __init__( @@ -51,10 +48,10 @@ class Picker(LineageProcess): how: str, mothers_daughters: t.Optional[np.ndarray] = None, ) -> pd.MultiIndex: - """""" + """Return rows of a signal corresponding to either mothers, daughters, or mother-daughter pairs using lineage information.""" cells_present = drop_mother_label(signal.index) mothers_daughters = self.get_lineage_information(signal) - valid_indices = slice(None) + #: might be better if match_column defined as a string to make everything one line if how == "mothers": _, valid_indices = validate_association( mothers_daughters, cells_present, match_column=0 @@ -64,18 +61,20 @@ class Picker(LineageProcess): mothers_daughters, cells_present, match_column=1 ) elif how == "families": - # mothers and daughters that are still present + # mother-daughter pairs _, valid_indices = validate_association( mothers_daughters, cells_present ) + else: + valid_indices = slice(None) return signal.index[valid_indices] - def pick_by_condition(self, signal, condition, thresh): - idx = self.switch_case(signal, condition, thresh) - return idx - def run(self, signal): - """Pick indices from the index of a signal's dataframe and return as an array.""" + """ + Pick indices from the index of a signal's dataframe and return as an array. + + Typically, we first pick by lineage, then by condition. + """ self.orig_signal = signal indices = set(signal.index) lineage = self.get_lineage_information(signal) @@ -83,63 +82,74 @@ class Picker(LineageProcess): self.mothers = lineage[:, [0, 1]] self.daughters = lineage[:, [0, 2]] for alg, *params in self.sequence: - new_indices = tuple() if indices: - # pick new indices if alg == "lineage": + # pick mothers, buds, or mother-bud pairs param1 = params[0] new_indices = getattr(self, "pick_by_" + alg)( signal.loc[list(indices)], param1 ) else: + # pick by condition param1, *param2 = params new_indices = getattr(self, "pick_by_" + alg)( signal.loc[list(indices)], param1, param2 ) - new_indices = [tuple(x) for x in new_indices] + else: + new_indices = tuple() + # number of indices reduces for each iteration of the loop indices = indices.intersection(new_indices) else: self._log("No lineage assignment") indices = np.array([]) + # convert to array indices_arr = np.array([tuple(map(_str_to_int, x)) for x in indices]) return indices_arr - def switch_case( + # def pick_by_condition(self, signal, condition, thresh): + # idx = self.switch_case(signal, condition, thresh) + # return idx + + def pick_by_condition( self, signal: pd.DataFrame, condition: str, threshold: t.Union[float, int, list], ): + """Pick indices from signal by any_present, present, continuously_present, and growing.""" if len(threshold) == 1: threshold = [_as_int(*threshold, signal.shape[1])] + #: is this correct for "growing"? case_mgr = { "any_present": lambda s, thresh: any_present(s, thresh), "present": lambda s, thresh: s.notna().sum(axis=1) > thresh, + #: continuously_present looks incorrect "continuously_present": lambda s, thresh: s.apply(thresh, axis=1) > thresh, "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh, } - return set(signal.index[case_mgr[condition](signal, *threshold)]) + # apply condition + idx = set(signal.index[case_mgr[condition](signal, *threshold)]) + new_indices = [tuple(x) for x in idx] + return new_indices def _as_int(threshold: t.Union[float, int], ntps: int): + """Convert a fraction of the total experiment duration into a number of time points.""" if type(threshold) is float: threshold = ntps * threshold return threshold def any_present(signal, threshold): - """ - Return a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints. - """ + """Return pd.Series for cells where True indicates that cell was present for more than threshold time points.""" + #: isn't full_traps all we need? + full_traps = (signal.notna().sum(axis=1) > threshold).groupby("trap") any_present = pd.Series( np.sum( [ - np.isin([x[0] for x in signal.index], i) & v - for i, v in (signal.notna().sum(axis=1) > threshold) - .groupby("trap") - .any() - .items() + np.isin([x[0] for x in signal.index], i) & full + for i, full in full_traps.any().items() ], axis=0, ).astype(bool),