From 148d92c32ed250ff70def7430a03e94d7f9c8338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Wed, 5 Oct 2022 17:40:48 +0100 Subject: [PATCH] refactor(picker): use shared function for indices --- src/postprocessor/core/reshapers/picker.py | 41 ++++++++++++++-------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index 117b5d4b..4d8a6b2d 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -1,19 +1,23 @@ -from abc import ABC, abstractmethod +# from abc import ABC, abstractmethod # from copy import copy -from itertools import groupby -from typing import List, Tuple, Union +# from itertools import groupby +# from typing import List, Tuple, Union +import typing as t +from typing import Union -import igraph as ig +# import igraph as ig import numpy as np import pandas as pd + from agora.abc import ParametersABC from agora.io.cells import Cells -from utils_find_1st import cmp_equal, find_1st -from postprocessor.core.lineageprocess import LineageProcess -from postprocessor.core.functions.tracks import max_nonstop_ntps, max_ntps +# from postprocessor.core.functions.tracks import max_nonstop_ntps, max_ntps from agora.utils.association import validate_association +from postprocessor.core.lineageprocess import LineageProcess + +# from utils_find_1st import cmp_equal, find_1st class pickerParameters(ParametersABC): @@ -47,11 +51,18 @@ class picker(LineageProcess): self.cells = cells - def pick_by_lineage(self, signal, how): + def pick_by_lineage( + self, + signal: pd.DataFrame, + how: str, + mothers_daughters: t.Optional[np.ndarray] = None, + ): self.orig_signals = signal idx = np.array(signal.index.to_list()) - mothers_daughters = self.cells.mothers_daughters + + if mothers_daughters is None: + mothers_daughters = self.cells.mothers_daughters valid_indices, valid_lineage = [slice(None)] * 2 if how == "mothers": @@ -60,11 +71,11 @@ class picker(LineageProcess): ) elif how == "daughters": valid_lineage, valid_indices = validate_association( - mothers_daughters, idx, match_column=0 + mothers_daughters, idx, match_column=1 ) elif how == "families": # Mothers and daughters that are still present valid_lineage, valid_indices = validate_association( - mothers_daughters, idx, match_column=0 + mothers_daughters, idx ) idx = idx[valid_indices] @@ -72,9 +83,11 @@ class picker(LineageProcess): return mothers_daughters, idx - def loc_lineage(self, signals: pd.DataFrame, how: str): - _, valid_indices = self.pick_by_lineage(signals, how) - return signals.loc[valid_indices] + def loc_lineage(self, kymo: pd.DataFrame, how: str, lineage=None): + _, valid_indices = self.pick_by_lineage( + kymo, how, mothers_daughters=lineage + ) + return kymo.loc[[tuple(x) for x in valid_indices]] def pick_by_condition(self, signals, condition, thresh): idx = self.switch_case(signals, condition, thresh) -- GitLab