diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py index 71411e1dab85442fd986dd9c925a11f50785e3c2..f33c1c1dc1294e9cdf74aa4a2f6ad086ff2e0e7c 100644 --- a/src/agora/utils/kymograph.py +++ b/src/agora/utils/kymograph.py @@ -163,3 +163,15 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]: for start, end in zip(cumsum[:-1], cumsum[1:]) ] return slices + + +def drop_mother_label(index: pd.MultiIndex) -> np.ndarray: + no_mother_label = index + if "mother_label" in index.names: + no_mother_label = index.droplevel("mother_label") + return np.array(no_mother_label.tolist()) + + +def get_index_as_np(signal: pd.DataFrame): + # Get mother labels from multiindex dataframe + return np.array(signal.index.to_list()) diff --git a/src/postprocessor/core/lineageprocess.py b/src/postprocessor/core/lineageprocess.py index f10d5b3e3fef47e23724115ab0f99669a5b6ad94..1c875020bbb9e1a92062ddc7d398b8b079b06437 100644 --- a/src/postprocessor/core/lineageprocess.py +++ b/src/postprocessor/core/lineageprocess.py @@ -51,9 +51,11 @@ class LineageProcess(PostProcessABC): data, lineage=lineage, *extra_data ) - def load_lineage(self, lineage): - """ - Reshape the lineage information if needed - """ - # TODO does this need to be a function? - self.lineage = lineage + def get_lineage_information(self, signal): + if "mother_label" in signal.index.names: + lineage = get_index_as_np(signal) + elif self.cells is not None: + lineage = self.cells.mothers_daughters + else: + raise Exception("No linage information found") + return lineage diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index ec3b74abe1d8295a53d0a1f7210659852f136c50..b010efd16c8fdbddea83a8a8cf7a8a560bafbd05 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -286,7 +286,8 @@ class PostProcessor(ProcessABC): # self.parameters.lineage_location ) loaded_process = self.classfun[process](parameters) - loaded_process.load_lineage(lineage) + loaded_process.lineage = lineage + else: loaded_process = self.classfun[process](parameters) diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index c331c842030ae11ed5ecf0a2af639df6c700b7b5..4132cf8dbd742ebcb274efa44b97c3ab9c2476c4 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -6,7 +6,8 @@ import pandas as pd from agora.abc import ParametersABC from agora.io.cells import Cells -from agora.utils.association import validate_association, last_col_as_rows +from agora.utils.association import validate_association +from agora.utils.kymograph import drop_mother_label, get_index_as_np from postprocessor.core.lineageprocess import LineageProcess @@ -24,14 +25,14 @@ class Picker(LineageProcess): :cells: Cell object passed to the constructor :condition: Tuple with condition and associated parameter(s), conditions can be "present", "nonstoply_present" or "quantile". - Determines the thersholds or fractions of signals/signals to use. + Determines the thersholds or fractions of signals to use. :lineage: str {"mothers", "daughters", "families" (mothers AND daughters), "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families. """ def __init__( self, parameters: PickerParameters, - cells: Cells, + cells: Cells or None = None, ): super().__init__(parameters=parameters) @@ -46,8 +47,7 @@ class Picker(LineageProcess): cells_present = drop_mother_label(signal.index) - if mothers_daughters is None: - mothers_daughters = self.cells.mothers_daughters + mothers_daughters = self.get_lineage_information(signal) valid_indices = slice(None) @@ -66,15 +66,17 @@ class Picker(LineageProcess): return signal.index[valid_indices] - def pick_by_condition(self, signals, condition, thresh): - idx = self.switch_case(signals, condition, thresh) + def pick_by_condition(self, signal, condition, thresh): + idx = self.switch_case(signal, condition, thresh) return idx - def run(self, signals): - self.orig_signals = signals - indices = set(signals.index) - lineage = self.cells.mothers_daughters - if lineage.any(): + def run(self, signal): + self.orig_signal = signal + indices = set(signal.index) + + lineage = self.get_lineage_information(signal) + + if len(lineage): self.mothers = lineage[:, :2] self.daughters = lineage[:, [0, 2]] @@ -84,12 +86,12 @@ class Picker(LineageProcess): if alg == "lineage": param1 = params[0] new_indices = getattr(self, "pick_by_" + alg)( - signals.loc[list(indices)], param1 + signal.loc[list(indices)], param1 ) else: param1, *param2 = params new_indices = getattr(self, "pick_by_" + alg)( - signals.loc[list(indices)], param1, param2 + signal.loc[list(indices)], param1, param2 ) new_indices = [tuple(x) for x in new_indices] @@ -102,12 +104,12 @@ class Picker(LineageProcess): def switch_case( self, - signals: pd.DataFrame, + signal: pd.DataFrame, condition: str, threshold: t.Union[float, int, list], ): if len(threshold) == 1: - threshold = [_as_int(*threshold, signals.shape[1])] + threshold = [_as_int(*threshold, signal.shape[1])] case_mgr = { "any_present": lambda s, thresh: any_present(s, thresh), "present": lambda s, thresh: s.notna().sum(axis=1) > thresh, @@ -115,7 +117,7 @@ class Picker(LineageProcess): > thresh, "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh, } - return set(signals.index[case_mgr[condition](signals, *threshold)]) + return set(signal.index[case_mgr[condition](signal, *threshold)]) def _as_int(threshold: t.Union[float, int], ntps: int): @@ -124,28 +126,21 @@ def _as_int(threshold: t.Union[float, int], ntps: int): return threshold -def any_present(signals, threshold): +def any_present(signal, threshold): """ Returns a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints. """ any_present = pd.Series( np.sum( [ - np.isin([x[0] for x in signals.index], i) & v - for i, v in (signals.notna().sum(axis=1) > threshold) + np.isin([x[0] for x in signal.index], i) & v + for i, v in (signal.notna().sum(axis=1) > threshold) .groupby("trap") .any() .items() ], axis=0, ).astype(bool), - index=signals.index, + index=signal.index, ) return any_present - - -def drop_mother_label(index: pd.MultiIndex) -> np.ndarray: - no_mother_label = index - if "mother_label" in index.names: - no_mother_label = index.droplevel("mother_label") - return np.array(no_mother_label.tolist())