diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index da4e9c744b66f478b066fd0f671a57730b6e06ee..dbdd597cd8cbcafce10eededf4d0759502d45279 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -47,20 +47,25 @@ class Signal(BridgeH5): def __getitem__(self, dsets: t.Union[str, t.Collection]): """Get and potentially pre-process data from h5 file and return as a dataframe.""" if isinstance(dsets, str): # no pre-processing - df = self.apply_prepost(dsets) - return self.add_name(df, dsets) + return self.get(dsets) elif isinstance(dsets, list): # pre-processing is_bgd = [dset.endswith("imBackground") for dset in dsets] # Check we are not comparing tile-indexed and cell-indexed data assert sum(is_bgd) == 0 or sum(is_bgd) == len( dsets ), "Tile data and cell data can't be mixed" - return [ - self.add_name(self.apply_prepost(dset), dset) for dset in dsets - ] + return [self.get(dset) for dset in dsets] else: raise Exception(f"Invalid type {type(dsets)} to get datasets") + def get(self, dsets: t.Union[str, t.Collection], **kwargs): + """Get and potentially pre-process data from h5 file and return as a dataframe.""" + if isinstance(dsets, str): # no pre-processing + df = get_raw(dsets, **kwargs) + prepost_applied = self.apply_prepost(dsets, **kwargs) + + return self.add_name(prepost_applied, dsets) + @staticmethod def add_name(df, name): """Add column of identical strings to a dataframe.""" @@ -129,18 +134,24 @@ class Signal(BridgeH5): Returns an array with three columns: the tile id, the mother label, and the daughter label. """ if lineage_location is None: - lineage_location = "postprocessing/lineage_merged" + lineage_location = "modifiers/lineage_merged" with h5py.File(self.filename, "r") as f: + # if lineage_location not in f: + # lineage_location = lineage_location.split("_")[0] if lineage_location not in f: - lineage_location = f[lineage_location.split("_")[0]] - tile_mo_da = f[lineage_location.split("_")[0]] - lineage = np.array( - ( - tile_mo_da["trap"], - tile_mo_da["mother_label"], - tile_mo_da["daughter_label"], - ) - ).T + lineage_location = "postprocessor/lineage" + tile_mo_da = f[lineage_location] + + if isinstance(tile_mo_da, h5py.Dataset): + lineage = tile_mo_da[()] + else: + lineage = np.array( + ( + tile_mo_da["trap"], + tile_mo_da["mother_label"], + tile_mo_da["daughter_label"], + ) + ).T return lineage @_first_arg_str_to_df diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py index f33c1c1dc1294e9cdf74aa4a2f6ad086ff2e0e7c..fe7aef17a0ab7f5c7c3372e59942ecb1f85cf5d8 100644 --- a/src/agora/utils/kymograph.py +++ b/src/agora/utils/kymograph.py @@ -5,6 +5,7 @@ from copy import copy import numpy as np import pandas as pd from sklearn.cluster import KMeans +from agora.utils.indexing import validate_association index_row = t.Tuple[str, str, int, int] @@ -175,3 +176,67 @@ def drop_mother_label(index: pd.MultiIndex) -> np.ndarray: def get_index_as_np(signal: pd.DataFrame): # Get mother labels from multiindex dataframe return np.array(signal.index.to_list()) + + +def standard_filtering( + raw: pd.DataFrame, + lin: np.ndarray, + presence_high: float = 0.8, + presence_low: int = 7, +): + # Get all mothers + _, valid_indices = validate_association( + lin, np.array(raw.index.to_list()), match_column=0 + ) + in_lineage = raw.loc[valid_indices] + + # Filter mothers by presence + present = in_lineage.loc[ + in_lineage.notna().sum(axis=1) > (in_lineage.shape[1] * presence_high) + ] + + # Get indices + indices = np.array(present.index.to_list()) + to_cast = np.stack((lin[:, :2], lin[:, [0, 2]]), axis=1) + ndin = to_cast[..., None] == indices.T[None, ...] + + # use indices to fetch all daughters + valid_association = ndin.all(axis=2)[:, 0].any(axis=-1) + + # Remove repeats + mothers, daughters = np.split(to_cast[valid_association], 2, axis=1) + mothers = mothers[:, 0] + daughters = daughters[:, 0] + d_m_dict = {tuple(d): m[-1] for m, d in zip(mothers, daughters)} + + # assuming unique sorts + raw_mothers = raw.loc[_as_tuples(mothers)] + raw_mothers["mother_label"] = 0 + raw_daughters = raw.loc[_as_tuples(daughters)] + raw_daughters["mother_label"] = d_m_dict.values() + concat = pd.concat((raw_mothers, raw_daughters)).sort_index() + concat.set_index("mother_label", append=True, inplace=True) + + # Last filter to remove tracklets that are too short + removed_buds = concat.notna().sum(axis=1) <= presence_low + filt = concat.loc[~removed_buds] + + # We check that no mothers are left child-less + m_d_dict = {tuple(m): [] for m in mothers} + for (trap, d), m in d_m_dict.items(): + m_d_dict[(trap, m)].append(d) + + for trap, daughter, mother in concat.index[removed_buds]: + idx_to_delete = m_d_dict[(trap, mother)].index(daughter) + del m_d_dict[(trap, mother)][idx_to_delete] + + bud_free = [] + for m, d in m_d_dict.items(): + if not d: + bud_free.append(m) + + final_result = filt.drop(bud_free) + + # In the end, we get the mothers present for more than {presence_lineage1}% of the experiment + # and their tracklets present for more than {presence_lineage2} time-points + return final_result