From f328b45fcf3fef85a15c1b2b820fd17362abd248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Wed, 15 Mar 2023 14:41:53 +0000 Subject: [PATCH] refactor(lineage): move functions to parent/utils --- src/agora/utils/kymograph.py | 12 +++++ src/postprocessor/core/lineageprocess.py | 14 +++--- src/postprocessor/core/processor.py | 3 +- src/postprocessor/core/reshapers/picker.py | 51 ++++++++++------------ 4 files changed, 45 insertions(+), 35 deletions(-) diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py index 71411e1d..f33c1c1d 100644 --- a/src/agora/utils/kymograph.py +++ b/src/agora/utils/kymograph.py @@ -163,3 +163,15 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]: for start, end in zip(cumsum[:-1], cumsum[1:]) ] return slices + + +def drop_mother_label(index: pd.MultiIndex) -> np.ndarray: + no_mother_label = index + if "mother_label" in index.names: + no_mother_label = index.droplevel("mother_label") + return np.array(no_mother_label.tolist()) + + +def get_index_as_np(signal: pd.DataFrame): + # Get mother labels from multiindex dataframe + return np.array(signal.index.to_list()) diff --git a/src/postprocessor/core/lineageprocess.py b/src/postprocessor/core/lineageprocess.py index f10d5b3e..1c875020 100644 --- a/src/postprocessor/core/lineageprocess.py +++ b/src/postprocessor/core/lineageprocess.py @@ -51,9 +51,11 @@ class LineageProcess(PostProcessABC): data, lineage=lineage, *extra_data ) - def load_lineage(self, lineage): - """ - Reshape the lineage information if needed - """ - # TODO does this need to be a function? - self.lineage = lineage + def get_lineage_information(self, signal): + if "mother_label" in signal.index.names: + lineage = get_index_as_np(signal) + elif self.cells is not None: + lineage = self.cells.mothers_daughters + else: + raise Exception("No linage information found") + return lineage diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index ec3b74ab..b010efd1 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -286,7 +286,8 @@ class PostProcessor(ProcessABC): # self.parameters.lineage_location ) loaded_process = self.classfun[process](parameters) - loaded_process.load_lineage(lineage) + loaded_process.lineage = lineage + else: loaded_process = self.classfun[process](parameters) diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index c331c842..4132cf8d 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -6,7 +6,8 @@ import pandas as pd from agora.abc import ParametersABC from agora.io.cells import Cells -from agora.utils.association import validate_association, last_col_as_rows +from agora.utils.association import validate_association +from agora.utils.kymograph import drop_mother_label, get_index_as_np from postprocessor.core.lineageprocess import LineageProcess @@ -24,14 +25,14 @@ class Picker(LineageProcess): :cells: Cell object passed to the constructor :condition: Tuple with condition and associated parameter(s), conditions can be "present", "nonstoply_present" or "quantile". - Determines the thersholds or fractions of signals/signals to use. + Determines the thersholds or fractions of signals to use. :lineage: str {"mothers", "daughters", "families" (mothers AND daughters), "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families. """ def __init__( self, parameters: PickerParameters, - cells: Cells, + cells: Cells or None = None, ): super().__init__(parameters=parameters) @@ -46,8 +47,7 @@ class Picker(LineageProcess): cells_present = drop_mother_label(signal.index) - if mothers_daughters is None: - mothers_daughters = self.cells.mothers_daughters + mothers_daughters = self.get_lineage_information(signal) valid_indices = slice(None) @@ -66,15 +66,17 @@ class Picker(LineageProcess): return signal.index[valid_indices] - def pick_by_condition(self, signals, condition, thresh): - idx = self.switch_case(signals, condition, thresh) + def pick_by_condition(self, signal, condition, thresh): + idx = self.switch_case(signal, condition, thresh) return idx - def run(self, signals): - self.orig_signals = signals - indices = set(signals.index) - lineage = self.cells.mothers_daughters - if lineage.any(): + def run(self, signal): + self.orig_signal = signal + indices = set(signal.index) + + lineage = self.get_lineage_information(signal) + + if len(lineage): self.mothers = lineage[:, :2] self.daughters = lineage[:, [0, 2]] @@ -84,12 +86,12 @@ class Picker(LineageProcess): if alg == "lineage": param1 = params[0] new_indices = getattr(self, "pick_by_" + alg)( - signals.loc[list(indices)], param1 + signal.loc[list(indices)], param1 ) else: param1, *param2 = params new_indices = getattr(self, "pick_by_" + alg)( - signals.loc[list(indices)], param1, param2 + signal.loc[list(indices)], param1, param2 ) new_indices = [tuple(x) for x in new_indices] @@ -102,12 +104,12 @@ class Picker(LineageProcess): def switch_case( self, - signals: pd.DataFrame, + signal: pd.DataFrame, condition: str, threshold: t.Union[float, int, list], ): if len(threshold) == 1: - threshold = [_as_int(*threshold, signals.shape[1])] + threshold = [_as_int(*threshold, signal.shape[1])] case_mgr = { "any_present": lambda s, thresh: any_present(s, thresh), "present": lambda s, thresh: s.notna().sum(axis=1) > thresh, @@ -115,7 +117,7 @@ class Picker(LineageProcess): > thresh, "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh, } - return set(signals.index[case_mgr[condition](signals, *threshold)]) + return set(signal.index[case_mgr[condition](signal, *threshold)]) def _as_int(threshold: t.Union[float, int], ntps: int): @@ -124,28 +126,21 @@ def _as_int(threshold: t.Union[float, int], ntps: int): return threshold -def any_present(signals, threshold): +def any_present(signal, threshold): """ Returns a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints. """ any_present = pd.Series( np.sum( [ - np.isin([x[0] for x in signals.index], i) & v - for i, v in (signals.notna().sum(axis=1) > threshold) + np.isin([x[0] for x in signal.index], i) & v + for i, v in (signal.notna().sum(axis=1) > threshold) .groupby("trap") .any() .items() ], axis=0, ).astype(bool), - index=signals.index, + index=signal.index, ) return any_present - - -def drop_mother_label(index: pd.MultiIndex) -> np.ndarray: - no_mother_label = index - if "mother_label" in index.names: - no_mother_label = index.droplevel("mother_label") - return np.array(no_mother_label.tolist()) -- GitLab