diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index 2fc1ece75e5d078dbe05a309bf3a8e30aed06f07..117b5d4b1159e85bddfd8710bfc7fa28a58059c8 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -11,8 +11,9 @@ from agora.abc import ParametersABC from agora.io.cells import Cells from utils_find_1st import cmp_equal, find_1st -from postprocessor.core.abc import PostProcessABC +from postprocessor.core.lineageprocess import LineageProcess from postprocessor.core.functions.tracks import max_nonstop_ntps, max_ntps +from agora.utils.association import validate_association class pickerParameters(ParametersABC): @@ -28,7 +29,7 @@ class pickerParameters(ParametersABC): } -class picker(PostProcessABC): +class picker(LineageProcess): """ :cells: Cell object passed to the constructor :condition: Tuple with condition and associated parameter(s), conditions can be @@ -44,124 +45,45 @@ class picker(PostProcessABC): ): super().__init__(parameters=parameters) - self._cells = cells + self.cells = cells - def pick_by_lineage(self, signals, how): - self.orig_signals = signals + def pick_by_lineage(self, signal, how): + self.orig_signals = signal - idx = signals.index - - if how: - mothers = set(self.mothers) - daughters = set(self.daughters) - # daughters, mothers = np.where(mother_bud_mat) - - def search(a, b): - return np.where( - np.in1d( - np.ravel_multi_index( - np.array(a).T, np.array(a).max(0) + 1 - ), - np.ravel_multi_index( - np.array(b).T, np.array(a).max(0) + 1 - ), - ) - ) - - if how == "mothers": - idx = mothers - elif how == "daughters": - idx = daughters - elif how == "daughters_w_mothers": - present_mothers = idx.intersection(mothers) - idx = set( - [ - tuple(x) - for m in present_mothers - for x in np.array(self.daughters)[ - search(self.mothers, m) - ] - ] - ) - - print("associated daughters: ", idx) - elif how == "mothers_w_daughters": - present_daughters = idx.intersection(daughters) - idx = set( - [ - tuple(x) - for d in present_daughters - for x in np.array(self.mothers)[ - search(self.daughters, d) - ] - ] - ) - elif how == "full_families": - present_mothers = idx.intersection(mothers) - dwm_idx = set( - [ - tuple(x) - for m in present_mothers - for x in np.array(self.daughters)[ - search(np.array(self.mothers), m) - ] - ] - ) - present_daughters = idx.intersection(daughters) - mwd_idx = set( - [ - tuple(x) - for d in present_daughters - for x in np.array(self.mothers)[ - search(np.array(self.daughters), d) - ] - ] - ) - idx = mwd_idx.union(dwm_idx) - elif how == "families" or how == "orphans": - families = mothers.union(daughters) - if how == "families": - idx = families - elif how == "orphans": - idx = idx.diference(families) - - idx = idx.intersection(signals.index) + idx = np.array(signal.index.to_list()) + mothers_daughters = self.cells.mothers_daughters + valid_indices, valid_lineage = [slice(None)] * 2 - return idx + if how == "mothers": + valid_lineage, valid_indices = validate_association( + mothers_daughters, idx, match_column=0 + ) + elif how == "daughters": + valid_lineage, valid_indices = validate_association( + mothers_daughters, idx, match_column=0 + ) + elif how == "families": # Mothers and daughters that are still present + valid_lineage, valid_indices = validate_association( + mothers_daughters, idx, match_column=0 + ) + + idx = idx[valid_indices] + mothers_daughters = mothers_daughters[valid_lineage] + + return mothers_daughters, idx + + def loc_lineage(self, signals: pd.DataFrame, how: str): + _, valid_indices = self.pick_by_lineage(signals, how) + return signals.loc[valid_indices] def pick_by_condition(self, signals, condition, thresh): idx = self.switch_case(signals, condition, thresh) return idx - def get_mothers_daughters(self): - ma = self._cells["mother_assign_dynamic"] - trap = self._cells["trap"] - label = self._cells["cell_label"] - nested_massign = self._cells.mother_assign_from_dynamic( - ma, label, trap, self._cells.ntraps - ) - # mother_bud_mat = self.mother_assign_to_mb_matrix(nested_massign) - - if sum([x for y in nested_massign for x in y]): - - mothers, daughters = zip( - *[ - ((tid, m), (tid, d)) - for tid, trapcells in enumerate(nested_massign) - for d, m in enumerate(trapcells, 1) - if m - ] - ) - else: - mothers, daughters = ([], []) - print("Warning:Picker: No mother-daughters assigned") - - return mothers, daughters - def run(self, signals): self.orig_signals = signals indices = set(signals.index) - self.mothers, self.daughters = self.get_mothers_daughters() + self.mothers, self.daughters = self.cells.mothers_daughters for alg, op, *params in self.sequence: new_indices = tuple() if indices: