diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 8349f201254a4a8f73b41fc78eb92651f02773f7..797b8d994c8ec99961af325a8fbb1db3506791e6 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -41,11 +41,13 @@ class Signal(BridgeH5): def __getitem__(self, dsets: t.Union[str, t.Collection]): - if isinstance(dsets, str) and dsets.endswith("imBackground"): + if isinstance( + dsets, str + ): # or isinstance(Dsets,dsets.endswith("imBackground"): df = self.get_raw(dsets) - elif isinstance(dsets, str): - df = self.apply_prepost(dsets) + # elif isinstance(dsets, str): + # df = self.apply_prepost(dsets) elif isinstance(dsets, list): is_bgd = [dset.endswith("imBackground") for dset in dsets] diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py index 8d563de523c05144caaa98b477e4738c843904c3..f19b77fa05c70e1bb41b1de9979f8e230a42eb2b 100644 --- a/src/agora/utils/kymograph.py +++ b/src/agora/utils/kymograph.py @@ -2,8 +2,11 @@ import typing as t from copy import copy +import numpy as np import pandas as pd +index_row = t.Tuple[str, str, int, int] + def add_index_levels( df: pd.DataFrame, additional_ids: t.Dict[str, pd.Series] = {} @@ -16,3 +19,42 @@ def add_index_levels( new_df[k] = srs new_df.set_index(k, inplace=True, append=True) return new_df + + +def drop_level( + df: pd.DataFrame, name: str = "mother_label", as_list: bool = True +) -> t.Union[t.List[index_row], pd.Index]: + """Drop index level + + Parameters + ---------- + df : pd.DataFrame + dataframe whose multiindex we will drop + name : str + name of index level to drop + as_list : bool + Whether to return as a list instead of an index + + Examples + -------- + FIXME: Add docs. + + """ + short_index = df.index.droplevel(name) + if as_list: + short_index = short_index.to_list() + return short_index + + +def intersection_matrix( + index1: pd.MultiIndex, index2: pd.MultiIndex +) -> np.ndarray: + """ + Use casting to obtain the boolean mask of the intersection of two multiindices + """ + if not isinstance(index1, np.ndarray): + index1 = np.array(index1.to_list()) + if not isinstance(index2, np.ndarray): + index2 = np.array(index2.to_list()) + + return (index1[..., None] == index2.T).all(axis=1) diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index ebfd8f637a73a6a592cb51888344d504d49fd61b..b4aac80f4097bec9785e510085859b1d47292f12 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -13,9 +13,11 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns -from postprocessor.chainer import Chainer from pathos.multiprocessing import Pool +from agora.utils.kymograph import drop_level, intersection_matrix +from postprocessor.chainer import Chainer + class Grouper(ABC): """Base grouper class.""" @@ -83,7 +85,7 @@ class Grouper(ABC): standard: t.Optional[bool] = False, **kwargs, ): - """Concate + """Concatenate multiple signals Parameters ---------- @@ -390,12 +392,31 @@ def concat_signal_ind( raise (NotImplementedError) elif mode == "raw": combined = chainer.get_raw(path, **kwargs) + elif mode == "daughters": + combined = chainer.get_raw(path, **kwargs) + combined = combined.loc[ + combined.index.get_level_values("mother_label") > 0 + ] elif mode == "families": - combined = chainer[path] + combined = chainer.get_raw(path, **kwargs) + daughter_ids = combined.index[ + combined.index.get_level_values("mother_label") > 0 + ] + mother_id_mask = intersection_matrix( + daughter_ids.droplevel("cell_label"), + drop_level(combined, "mother_label", as_list=False), + ).any(axis=0) + combined = combined.loc[ + combined.index[mother_id_mask].union(daughter_ids) + ] + combined["position"] = position combined["group"] = group combined.set_index(["group", "position"], inplace=True, append=True) - combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) + # combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) + combined.index = combined.index.reorder_levels( + ("group", "position", "trap", "cell_label", "mother_label") + ) return combined