diff --git a/core/processes/picker.py b/core/processes/picker.py index 54c658bc61160500a3b6ab153ee0e4d2d043469a..7649ab5995683a5e1ad14b3980e78504403704e0 100644 --- a/core/processes/picker.py +++ b/core/processes/picker.py @@ -15,24 +15,19 @@ from postprocessor.core.functions.tracks import max_ntps, max_nonstop_ntps class pickerParameters(ParametersABC): def __init__( self, - condition: Tuple[str, Union[float, int]] = None, - lineage: str = None, - lineage_conditional: str = None, sequence: List[str] = ["lineage", "condition"], ): - self.condition = condition - self.lineage = lineage - self.lineage_conditional = lineage_conditional self.sequence = sequence @classmethod def default(cls): return cls.from_dict( { - "condition": ["present", 0.8], - "lineage": "families", - "lineage_conditional": "include", - "sequence": ["condition", "lineage"], + "sequence": [ + ("condition", "present", 0.8), + ("lineage", "families", "union"), + ("condition", "present", 10), + ], } ) @@ -91,10 +86,10 @@ class picker(ProcessABC): return nested_massign - def pick_by_lineage(self, signals): + def pick_by_lineage(self, signals, how): idx = signals.index - if self.lineage: + if how: ma = self._cells["mother_assign_dynamic"] trap = self._cells["trap"] label = self._cells["cell_label"] @@ -124,38 +119,45 @@ class picker(ProcessABC): mothers = set(mothers) daughters = set(daughters) # daughters, mothers = np.where(mother_bud_mat) - if self.lineage == "mothers": + if how == "mothers": idx = mothers - elif self.lineage == "daughters": + elif how == "daughters": idx = daughters - elif self.lineage == "families" or self.lineage == "orphans": + elif how == "families" or how == "orphans": families = mothers.union(daughters) - if self.lineage == "families": + if how == "families": idx = families - elif self.lineage == "orphans": # orphans + elif how == "orphans": idx = idx.diference(families) idx = idx.intersection(signals.index) return idx - def pick_by_condition(self, signals): - idx = self.switch_case(self.condition[0], signals, self.condition[1]) + def pick_by_condition(self, signals, condition, thresh): + idx = self.switch_case(signals, condition, thresh) return idx def run(self, signals): indices = set(signals.index) daughters, mothers = (None, None) - for alg in self.sequence: - indices = getattr(self, "pick_by_" + alg)(signals) + for alg, param1, param2 in self.sequence: + if alg is "lineage": + new_indices = getattr(self, "pick_by_" + alg)(signals, param1) + if param2 is "union": + new_indices = new_indices.intersection(set(signals.index)) + indices = indices.union(set(new_indices)) + else: + new_indices = getattr(self, "pick_by_" + alg)(signals, param1, param2) + indices = indices.intersection(new_indices) daughters, mothers = self.daughters, self.mothers return np.array(daughters), np.array(mothers), np.array(list(indices)) @staticmethod def switch_case( - condition: str, signals: pd.DataFrame, + condition: str, threshold: Union[float, int], ): threshold_asint = _as_int(threshold, signals.shape[1]) @@ -163,9 +165,9 @@ class picker(ProcessABC): "present": signals.notna().sum(axis=1) > threshold_asint, "nonstoply_present": signals.apply(max_nonstop_ntps, axis=1) > threshold_asint, - "quantile": [np.quantile(signals.values[signals.notna()], threshold)], + # "quantile": [np.quantile(signals.values[signals.notna()], threshold)], } - return set(case_mgr[condition].index) + return set(signals.index[case_mgr[condition]]) def _as_int(threshold: Union[float, int], ntps: int): diff --git a/core/processor.py b/core/processor.py index 09caa6a14203da47ff0054ebc5c7a22bd7358fee..4e60a4fa61e2ffad784fcb6b97186fa556e3df12 100644 --- a/core/processor.py +++ b/core/processor.py @@ -143,11 +143,26 @@ class PostProcessor: merged_moda = set([tuple(x) for x in merge_events[:, 0, :]]).intersection( set([*moset, *daset, *picked_set]) ) - for source, target in merge_events: - if tuple(source) in merged_moda: - mothers[np.isin(mothers, source).all(axis=1)] = target - daughters[np.isin(daughters, source).all(axis=1)] = target - indices[np.isin(indices, source).all(axis=1)] = target + search = lambda a, b: np.where( + np.in1d( + np.ravel_multi_index(a.T, a.max(0) + 1), + np.ravel_multi_index(b.T, a.max(0) + 1), + ) + ) + + for target, source in merge_events: + if ( + tuple(source) in moset + ): # update mother to lowest positive index among the two + mother_ids = search(mothers, source) + mothers[mother_ids] = ( + target[0], + self.pick_mother(mothers[mother_ids][0][1], target[1]), + ) + if tuple(source) in daset: + daughters[search(daughters, source)] = target + if tuple(source) in picked_set: + indices[search(indices, source)] = target self._writer.write( "postprocessing/lineage_merged", @@ -167,6 +182,17 @@ class PostProcessor: overwrite="overwrite", ) + @staticmethod + def pick_mother(a, b): + """Update the mother id following this priorities: + + The mother has a lower id + """ + x = max(a, b) + if min([a, b]): + x = [a, b][np.argmin([a, b])] + return x + def run(self): self.run_prepost()