From f31d1b103b3b42f62881a84cb0ec88d2e909553d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Sun, 23 Oct 2022 17:58:23 +0100 Subject: [PATCH] feat(grouper): add mode "families" and "daughter" --- src/agora/io/signal.py | 8 ++++--- src/agora/utils/kymograph.py | 42 ++++++++++++++++++++++++++++++++++++ src/postprocessor/grouper.py | 29 +++++++++++++++++++++---- 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 8349f201..797b8d99 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -41,11 +41,13 @@ class Signal(BridgeH5): def __getitem__(self, dsets: t.Union[str, t.Collection]): - if isinstance(dsets, str) and dsets.endswith("imBackground"): + if isinstance( + dsets, str + ): # or isinstance(Dsets,dsets.endswith("imBackground"): df = self.get_raw(dsets) - elif isinstance(dsets, str): - df = self.apply_prepost(dsets) + # elif isinstance(dsets, str): + # df = self.apply_prepost(dsets) elif isinstance(dsets, list): is_bgd = [dset.endswith("imBackground") for dset in dsets] diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py index 8d563de5..f19b77fa 100644 --- a/src/agora/utils/kymograph.py +++ b/src/agora/utils/kymograph.py @@ -2,8 +2,11 @@ import typing as t from copy import copy +import numpy as np import pandas as pd +index_row = t.Tuple[str, str, int, int] + def add_index_levels( df: pd.DataFrame, additional_ids: t.Dict[str, pd.Series] = {} @@ -16,3 +19,42 @@ def add_index_levels( new_df[k] = srs new_df.set_index(k, inplace=True, append=True) return new_df + + +def drop_level( + df: pd.DataFrame, name: str = "mother_label", as_list: bool = True +) -> t.Union[t.List[index_row], pd.Index]: + """Drop index level + + Parameters + ---------- + df : pd.DataFrame + dataframe whose multiindex we will drop + name : str + name of index level to drop + as_list : bool + Whether to return as a list instead of an index + + Examples + -------- + FIXME: Add docs. + + """ + short_index = df.index.droplevel(name) + if as_list: + short_index = short_index.to_list() + return short_index + + +def intersection_matrix( + index1: pd.MultiIndex, index2: pd.MultiIndex +) -> np.ndarray: + """ + Use casting to obtain the boolean mask of the intersection of two multiindices + """ + if not isinstance(index1, np.ndarray): + index1 = np.array(index1.to_list()) + if not isinstance(index2, np.ndarray): + index2 = np.array(index2.to_list()) + + return (index1[..., None] == index2.T).all(axis=1) diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index ebfd8f63..b4aac80f 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -13,9 +13,11 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns -from postprocessor.chainer import Chainer from pathos.multiprocessing import Pool +from agora.utils.kymograph import drop_level, intersection_matrix +from postprocessor.chainer import Chainer + class Grouper(ABC): """Base grouper class.""" @@ -83,7 +85,7 @@ class Grouper(ABC): standard: t.Optional[bool] = False, **kwargs, ): - """Concate + """Concatenate multiple signals Parameters ---------- @@ -390,12 +392,31 @@ def concat_signal_ind( raise (NotImplementedError) elif mode == "raw": combined = chainer.get_raw(path, **kwargs) + elif mode == "daughters": + combined = chainer.get_raw(path, **kwargs) + combined = combined.loc[ + combined.index.get_level_values("mother_label") > 0 + ] elif mode == "families": - combined = chainer[path] + combined = chainer.get_raw(path, **kwargs) + daughter_ids = combined.index[ + combined.index.get_level_values("mother_label") > 0 + ] + mother_id_mask = intersection_matrix( + daughter_ids.droplevel("cell_label"), + drop_level(combined, "mother_label", as_list=False), + ).any(axis=0) + combined = combined.loc[ + combined.index[mother_id_mask].union(daughter_ids) + ] + combined["position"] = position combined["group"] = group combined.set_index(["group", "position"], inplace=True, append=True) - combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) + # combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) + combined.index = combined.index.reorder_levels( + ("group", "position", "trap", "cell_label", "mother_label") + ) return combined -- GitLab