From 13e49faa53e6c4bc03a56f9ec9ff27404652eb09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Sun, 23 Oct 2022 19:28:26 +0100 Subject: [PATCH] refactor(chainer): Integrate brfilter --- src/agora/utils/kymograph.py | 63 +++++++++++++++++ src/postprocessor/chainer.py | 130 ++--------------------------------- 2 files changed, 68 insertions(+), 125 deletions(-) diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py index f19b77fa..eb5844fa 100644 --- a/src/agora/utils/kymograph.py +++ b/src/agora/utils/kymograph.py @@ -58,3 +58,66 @@ def intersection_matrix( index2 = np.array(index2.to_list()) return (index1[..., None] == index2.T).all(axis=1) + + +def get_mother_ilocs_from_daughters(df: pd.DataFrame) -> np.ndarray: + """ + Fetch mother locations in the index of df for all daughters in df. + """ + daughter_ids = df.index[df.index.get_level_values("mother_label") > 0] + mother_ilocs = intersection_matrix( + daughter_ids.droplevel("cell_label"), + drop_level(df, "mother_label", as_list=False), + ).any(axis=0) + return mother_ilocs + + +def get_mothers_from_another_df(whole_df: pd.DataFrame, da_df: pd.DataFrame): + daughter_ids = da_df.index[ + da_df.index.get_level_values("mother_label") > 0 + ] + mother_ilocs = intersection_matrix( + daughter_ids.droplevel("cell_label"), + drop_level(whole_df, "mother_label", as_list=False), + ).any(axis=0) + return mother_ilocs + + +def bidirectional_retainment_filter( + df: pd.DataFrame, mothers_thresh: float = 0.8, daughters_thresh: int = 7 +): + """ + Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points. + """ + all_daughters = df.loc[df.index.get_level_values("mother_label") > 0] + + # Filter daughters + retained_daughters = all_daughters.loc[ + all_daughters.notna().sum(axis=1) > daughters_thresh + ] + + # Fectch mother using existing daughters + mothers = df.loc[get_mothers_from_another_df(df, retained_daughters)] + + # Get mothers + retained_mothers = mothers.loc[ + mothers.notna().sum(axis=1) > mothers.shape[1] * mothers_thresh + ] + + # Filter-out daughters with no valid mothers + final_da_mask = intersection_matrix( + drop_level(retained_daughters, "cell_label", as_list=False), + drop_level(retained_mothers, "mother_label", as_list=False), + ) + + final_daughters = retained_daughters.loc[final_da_mask.any(axis=1)] + + # Join mothers and daughters and sort index + # + return pd.concat((final_daughters, retained_mothers), axis=0).sort_index() + + +def melt_reset(df: pd.DataFrame, additional_ids: t.Dict[str, pd.Series] = {}): + new_df = add_index_levels(df, additional_ids) + + return new_df.melt(ignore_index=False).reset_index() diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py index d174a4b2..d53b8b3f 100644 --- a/src/postprocessor/chainer.py +++ b/src/postprocessor/chainer.py @@ -1,5 +1,6 @@ #!/usr/bin/env jupyter +import re import typing as t from copy import copy @@ -7,11 +8,10 @@ import numpy as np import pandas as pd from agora.io.signal import Signal +from agora.utils.association import validate_association +from agora.utils.kymograph import bidirectional_retainment_filter from postprocessor.core.abc import get_parameters, get_process from postprocessor.core.lineageprocess import LineageProcessParameters -from agora.utils.association import validate_association - -import re class Chainer(Signal): @@ -119,7 +119,7 @@ class Chainer(Signal): self._intermediate_steps = [] for process in chain: if process == "standard": - result = standard_filtering(result, self.lineage()) + result = bidirectional_retainment_filter(result) else: params = kwargs.get(process, {}) process_cls = get_process(process) @@ -127,129 +127,9 @@ class Chainer(Signal): process_type = process_cls.__module__.split(".")[-2] if process_type == "reshapers": if process == "merger": + raise (NotImplementedError) merges = process.as_function(result, **params) result = self.apply_merges(result, merges) self._intermediate_steps.append(result) return result - - -# def standard( -# raw: pd.DataFrame, -# lin: np.ndarray, -# presence_filter_min: int = 7, -# presence_filter_mothers: float = 0.8, -# ): -# """ -# This requires a double-check that mothers-that-are-daughters still are accounted for after -# filtering daughters by the minimal threshold. -# """ -# raw = raw.loc[raw.notna().sum(axis=1) > presence_filter_min].sort_index() -# indices = np.array(raw.index.to_list()) -# # Get remaining full families -# valid_lineages, valid_indices = validate_association(lin, indices) - -# daughters = lin[valid_lineages][:, [0, 2]] -# mothers = lin[valid_lineages][:, :2] -# in_lineage = raw.loc[valid_indices].copy() -# mother_label = np.repeat(0, in_lineage.shape[0]) - -# daughter_ids = ( -# ( -# np.array(in_lineage.index.to_list()) -# == np.unique(daughters, axis=0)[:, None] -# ) -# .all(axis=2) -# .any(axis=0) -# ) -# mother_label[daughter_ids] = mothers[:, 1] -# # Filter mothers by presence -# in_lineage["mother_label"] = mother_label -# present = in_lineage.loc[ -# ( -# in_lineage.iloc[:, :-1].notna().sum(axis=1) -# > ((in_lineage.shape[1] - 1) * presence_filter_mothers) -# ) -# | mother_label -# ] -# present.set_index("mother_label", append=True, inplace=True) - -# # Finally, check full families again -# final_indices = np.array(present.index.to_list()) -# _, final_mask = validate_association( -# np.array([tuple(x) for x in present.index.swaplevel(1, 2)]), -# final_indices[:, :2], -# ) -# return present.loc[final_mask] - -# # In the end, we get the mothers present for more than {presence_lineage1}% of the experiment -# # and their tracklets present for more than {presence_lineage2} time-points -# return present - - -def standard_filtering( - raw: pd.DataFrame, - lin: np.ndarray, - presence_high: float = 0.8, - presence_low: int = 7, -): - # Get all mothers - _, valid_indices = validate_association( - lin, np.array(raw.index.to_list()), match_column=0 - ) - in_lineage = raw.loc[valid_indices] - - # Filter mothers by presence - present = in_lineage.loc[ - in_lineage.notna().sum(axis=1) > (in_lineage.shape[1] * presence_high) - ] - - # Get indices - indices = np.array(present.index.to_list()) - to_cast = np.stack((lin[:, :2], lin[:, [0, 2]]), axis=1) - ndin = to_cast[..., None] == indices.T[None, ...] - - # use indices to fetch all daughters - valid_association = ndin.all(axis=2)[:, 0].any(axis=-1) - - # Remove repeats - mothers, daughters = np.split(to_cast[valid_association], 2, axis=1) - mothers = mothers[:, 0] - daughters = daughters[:, 0] - d_m_dict = {tuple(d): m[-1] for m, d in zip(mothers, daughters)} - - # assuming unique sorts - raw_mothers = raw.loc[_as_tuples(mothers)] - raw_mothers["mother_label"] = 0 - raw_daughters = raw.loc[_as_tuples(daughters)] - raw_daughters["mother_label"] = d_m_dict.values() - concat = pd.concat((raw_mothers, raw_daughters)).sort_index() - concat.set_index("mother_label", append=True, inplace=True) - - # Last filter to remove tracklets that are too short - removed_buds = concat.notna().sum(axis=1) <= presence_low - filt = concat.loc[~removed_buds] - - # We check that no mothers are left child-less - m_d_dict = {tuple(m): [] for m in mothers} - for (trap, d), m in d_m_dict.items(): - m_d_dict[(trap, m)].append(d) - - for trap, daughter, mother in concat.index[removed_buds]: - idx_to_delete = m_d_dict[(trap, mother)].index(daughter) - del m_d_dict[(trap, mother)][idx_to_delete] - - bud_free = [] - for m, d in m_d_dict.items(): - if not d: - bud_free.append(m) - - final_result = filt.drop(bud_free) - - # In the end, we get the mothers present for more than {presence_lineage1}% of the experiment - # and their tracklets present for more than {presence_lineage2} time-points - return final_result - - -def _as_tuples(array: t.Collection) -> t.List[t.Tuple[int, int]]: - return [tuple(x) for x in np.unique(array, axis=0)] -- GitLab