diff --git a/src/postprocessor/core/processes/lineageprocess.py b/src/postprocessor/core/processes/lineageprocess.py index 2c0cc6a0e6fa9687f530d96b6ca2f995de16bd6f..38565e18240b8a2b1d8f6f937e7be7a6ea0beb14 100644 --- a/src/postprocessor/core/processes/lineageprocess.py +++ b/src/postprocessor/core/processes/lineageprocess.py @@ -1,7 +1,11 @@ +import typing as t +from abc import abstractmethod + import numpy as np import pandas as pd -from agora.abc import ParametersABC +from agora.abc import ParametersABC +from agora.utils.lineage import group_matrix from postprocessor.core.abc import PostProcessABC @@ -21,11 +25,6 @@ class LineageProcess(PostProcessABC): def __init__(self, parameters: LineageProcessParameters): super().__init__(parameters) - def run( - self, - ): - pass - def filter_signal_cells(self, signal: pd.DataFrame): """ Use casting to filter cell ids in signal and lineage @@ -44,3 +43,33 @@ class LineageProcess(PostProcessABC): ) return self.lineage[mo_av & da_av] + + @abstractmethod + def run( + self, + data: pd.DataFrame, + mother_bud_ids: t.Dict[t.Tuple[int], t.Collection[int]], + *args, + ): + pass + + @classmethod + def as_function( + cls, + data: pd.DataFrame, + lineage: t.Union[t.Dict[t.Tuple[int], t.List[int]]], + *extra_data, + **kwargs, + ): + """ + Overrides PostProcess.as_function classmethod. + Lineage functions require lineage information to be passed if run as function. + """ + if isinstance(lineage, np.ndarray): + lineage = group_matrix(lineage, n_keys=2) + + parameters = cls.default_parameters(**kwargs) + return cls(parameters=parameters).run( + data, mother_bud_ids=lineage, *extra_data + ) + # super().as_function(data, *extra_data, lineage=lineage, **kwargs)