From e7f52582c64ecd713ca7c5e4f33bdc8d095da3e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Sun, 10 Oct 2021 15:41:37 +0100 Subject: [PATCH] refactor picker Former-commit-id: 9b1ca612239fce8b05af7d6b223cd30aa0911632 --- core/functions/tracks.py | 1 + core/processes/picker.py | 117 ++++++++++++++++++++++++++++----------- 2 files changed, 85 insertions(+), 33 deletions(-) diff --git a/core/functions/tracks.py b/core/functions/tracks.py index 928fc2b1..e47e6023 100644 --- a/core/functions/tracks.py +++ b/core/functions/tracks.py @@ -388,6 +388,7 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list: mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist()) mins_d.index = mins_d.index - 1 # make indices equal + # TODO add support for skipping time points maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist()) common = sorted(set(mins_d.index).intersection(maxes_d.index), reverse=True) diff --git a/core/processes/picker.py b/core/processes/picker.py index 7649ab59..2a26f2eb 100644 --- a/core/processes/picker.py +++ b/core/processes/picker.py @@ -87,42 +87,64 @@ class picker(ProcessABC): return nested_massign def pick_by_lineage(self, signals, how): + idx = signals.index if how: - ma = self._cells["mother_assign_dynamic"] - trap = self._cells["trap"] - label = self._cells["cell_label"] - nested_massign = self.mother_assign_from_dynamic( - ma, label, trap, self._cells.ntraps - ) - # mother_bud_mat = self.mother_assign_to_mb_matrix(nested_massign) - - idx = set( - [ - (tid, i + 1) - for tid, x in enumerate(nested_massign) - for i in range(len(x)) - ] - ) - mothers, daughters = zip( - *[ - ((tid, m), (tid, d)) - for tid, trapcells in enumerate(nested_massign) - for d, m in enumerate(trapcells, 1) - if m - ] - ) - self.mothers = mothers - self.daughters = daughters - - mothers = set(mothers) - daughters = set(daughters) + mothers = set(self.mothers) + daughters = set(self.daughters) # daughters, mothers = np.where(mother_bud_mat) + + search = lambda a, b: np.where( + np.in1d( + np.ravel_multi_index(np.array(a).T, np.array(a).max(0) + 1), + np.ravel_multi_index(np.array(b).T, np.array(a).max(0) + 1), + ) + ) if how == "mothers": idx = mothers elif how == "daughters": idx = daughters + elif how == "daughters_w_mothers": + present_mothers = idx.intersection(mothers) + idx = set( + [ + tuple(x) + for m in present_mothers + for x in np.array(self.daughters)[search(self.mothers, m)] + ] + ) + elif how == "mothers_w_daughters": + present_daughters = idx.intersection(daughters) + idx = set( + [ + tuple(x) + for d in present_daughters + for x in np.array(self.mothers)[search(self.daughters, d)] + ] + ) + elif how == "full_families": + present_mothers = idx.intersection(mothers) + dwm_idx = set( + [ + tuple(x) + for m in present_mothers + for x in np.array(self.daughters)[ + search(np.array(self.mothers), m) + ] + ] + ) + present_daughters = idx.intersection(daughters) + mwd_idx = set( + [ + tuple(x) + for d in present_daughters + for x in np.array(self.mothers)[ + search(np.array(self.daughters), d) + ] + ] + ) + idx = mwd_idx.union(dwm_idx) elif how == "families" or how == "orphans": families = mothers.union(daughters) if how == "families": @@ -138,9 +160,35 @@ class picker(ProcessABC): idx = self.switch_case(signals, condition, thresh) return idx + def get_mothers_daughters(self): + ma = self._cells["mother_assign_dynamic"] + trap = self._cells["trap"] + label = self._cells["cell_label"] + nested_massign = self.mother_assign_from_dynamic( + ma, label, trap, self._cells.ntraps + ) + # mother_bud_mat = self.mother_assign_to_mb_matrix(nested_massign) + + idx = set( + [ + (tid, i + 1) + for tid, x in enumerate(nested_massign) + for i in range(len(x)) + ] + ) + mothers, daughters = zip( + *[ + ((tid, m), (tid, d)) + for tid, trapcells in enumerate(nested_massign) + for d, m in enumerate(trapcells, 1) + if m + ] + ) + return mothers, daughters + def run(self, signals): indices = set(signals.index) - daughters, mothers = (None, None) + self.mothers, self.daughters = self.get_mothers_daughters() for alg, param1, param2 in self.sequence: if alg is "lineage": new_indices = getattr(self, "pick_by_" + alg)(signals, param1) @@ -149,22 +197,25 @@ class picker(ProcessABC): indices = indices.union(set(new_indices)) else: new_indices = getattr(self, "pick_by_" + alg)(signals, param1, param2) - indices = indices.intersection(new_indices) + indices = indices.intersection(new_indices) - daughters, mothers = self.daughters, self.mothers - return np.array(daughters), np.array(mothers), np.array(list(indices)) + mothers, daughters = self.mothers, self.daughters + return np.array(mothers), np.array(daughters), np.array(list(indices)) @staticmethod def switch_case( signals: pd.DataFrame, condition: str, - threshold: Union[float, int], + threshold: Union[float, int, list], ): threshold_asint = _as_int(threshold, signals.shape[1]) + if isinstance(threshold, list): + thresh_presence = threshold[0] case_mgr = { "present": signals.notna().sum(axis=1) > threshold_asint, "nonstoply_present": signals.apply(max_nonstop_ntps, axis=1) > threshold_asint, + "growing": signals.diff(axis=1).sum(axis=1) > threshold, # "quantile": [np.quantile(signals.values[signals.notna()], threshold)], } return set(signals.index[case_mgr[condition]]) -- GitLab