diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index 117b5d4b1159e85bddfd8710bfc7fa28a58059c8..4d8a6b2daac2ab621bc80d1c4ccd9d5c20d657d2 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -1,19 +1,23 @@ -from abc import ABC, abstractmethod +# from abc import ABC, abstractmethod # from copy import copy -from itertools import groupby -from typing import List, Tuple, Union +# from itertools import groupby +# from typing import List, Tuple, Union +import typing as t +from typing import Union -import igraph as ig +# import igraph as ig import numpy as np import pandas as pd + from agora.abc import ParametersABC from agora.io.cells import Cells -from utils_find_1st import cmp_equal, find_1st -from postprocessor.core.lineageprocess import LineageProcess -from postprocessor.core.functions.tracks import max_nonstop_ntps, max_ntps +# from postprocessor.core.functions.tracks import max_nonstop_ntps, max_ntps from agora.utils.association import validate_association +from postprocessor.core.lineageprocess import LineageProcess + +# from utils_find_1st import cmp_equal, find_1st class pickerParameters(ParametersABC): @@ -47,11 +51,18 @@ class picker(LineageProcess): self.cells = cells - def pick_by_lineage(self, signal, how): + def pick_by_lineage( + self, + signal: pd.DataFrame, + how: str, + mothers_daughters: t.Optional[np.ndarray] = None, + ): self.orig_signals = signal idx = np.array(signal.index.to_list()) - mothers_daughters = self.cells.mothers_daughters + + if mothers_daughters is None: + mothers_daughters = self.cells.mothers_daughters valid_indices, valid_lineage = [slice(None)] * 2 if how == "mothers": @@ -60,11 +71,11 @@ class picker(LineageProcess): ) elif how == "daughters": valid_lineage, valid_indices = validate_association( - mothers_daughters, idx, match_column=0 + mothers_daughters, idx, match_column=1 ) elif how == "families": # Mothers and daughters that are still present valid_lineage, valid_indices = validate_association( - mothers_daughters, idx, match_column=0 + mothers_daughters, idx ) idx = idx[valid_indices] @@ -72,9 +83,11 @@ class picker(LineageProcess): return mothers_daughters, idx - def loc_lineage(self, signals: pd.DataFrame, how: str): - _, valid_indices = self.pick_by_lineage(signals, how) - return signals.loc[valid_indices] + def loc_lineage(self, kymo: pd.DataFrame, how: str, lineage=None): + _, valid_indices = self.pick_by_lineage( + kymo, how, mothers_daughters=lineage + ) + return kymo.loc[[tuple(x) for x in valid_indices]] def pick_by_condition(self, signals, condition, thresh): idx = self.switch_case(signals, condition, thresh)