From 13e49faa53e6c4bc03a56f9ec9ff27404652eb09 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Sun, 23 Oct 2022 19:28:26 +0100
Subject: [PATCH] refactor(chainer): Integrate brfilter

---
 src/agora/utils/kymograph.py |  63 +++++++++++++++++
 src/postprocessor/chainer.py | 130 ++---------------------------------
 2 files changed, 68 insertions(+), 125 deletions(-)

diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py
index f19b77fa..eb5844fa 100644
--- a/src/agora/utils/kymograph.py
+++ b/src/agora/utils/kymograph.py
@@ -58,3 +58,66 @@ def intersection_matrix(
         index2 = np.array(index2.to_list())
 
     return (index1[..., None] == index2.T).all(axis=1)
+
+
+def get_mother_ilocs_from_daughters(df: pd.DataFrame) -> np.ndarray:
+    """
+    Fetch mother locations in the index of df for all daughters in df.
+    """
+    daughter_ids = df.index[df.index.get_level_values("mother_label") > 0]
+    mother_ilocs = intersection_matrix(
+        daughter_ids.droplevel("cell_label"),
+        drop_level(df, "mother_label", as_list=False),
+    ).any(axis=0)
+    return mother_ilocs
+
+
+def get_mothers_from_another_df(whole_df: pd.DataFrame, da_df: pd.DataFrame):
+    daughter_ids = da_df.index[
+        da_df.index.get_level_values("mother_label") > 0
+    ]
+    mother_ilocs = intersection_matrix(
+        daughter_ids.droplevel("cell_label"),
+        drop_level(whole_df, "mother_label", as_list=False),
+    ).any(axis=0)
+    return mother_ilocs
+
+
+def bidirectional_retainment_filter(
+    df: pd.DataFrame, mothers_thresh: float = 0.8, daughters_thresh: int = 7
+):
+    """
+    Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points.
+    """
+    all_daughters = df.loc[df.index.get_level_values("mother_label") > 0]
+
+    # Filter daughters
+    retained_daughters = all_daughters.loc[
+        all_daughters.notna().sum(axis=1) > daughters_thresh
+    ]
+
+    # Fectch mother using existing daughters
+    mothers = df.loc[get_mothers_from_another_df(df, retained_daughters)]
+
+    # Get mothers
+    retained_mothers = mothers.loc[
+        mothers.notna().sum(axis=1) > mothers.shape[1] * mothers_thresh
+    ]
+
+    # Filter-out daughters with no valid mothers
+    final_da_mask = intersection_matrix(
+        drop_level(retained_daughters, "cell_label", as_list=False),
+        drop_level(retained_mothers, "mother_label", as_list=False),
+    )
+
+    final_daughters = retained_daughters.loc[final_da_mask.any(axis=1)]
+
+    # Join mothers and daughters and sort index
+    #
+    return pd.concat((final_daughters, retained_mothers), axis=0).sort_index()
+
+
+def melt_reset(df: pd.DataFrame, additional_ids: t.Dict[str, pd.Series] = {}):
+    new_df = add_index_levels(df, additional_ids)
+
+    return new_df.melt(ignore_index=False).reset_index()
diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py
index d174a4b2..d53b8b3f 100644
--- a/src/postprocessor/chainer.py
+++ b/src/postprocessor/chainer.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env jupyter
 
+import re
 import typing as t
 from copy import copy
 
@@ -7,11 +8,10 @@ import numpy as np
 import pandas as pd
 
 from agora.io.signal import Signal
+from agora.utils.association import validate_association
+from agora.utils.kymograph import bidirectional_retainment_filter
 from postprocessor.core.abc import get_parameters, get_process
 from postprocessor.core.lineageprocess import LineageProcessParameters
-from agora.utils.association import validate_association
-
-import re
 
 
 class Chainer(Signal):
@@ -119,7 +119,7 @@ class Chainer(Signal):
         self._intermediate_steps = []
         for process in chain:
             if process == "standard":
-                result = standard_filtering(result, self.lineage())
+                result = bidirectional_retainment_filter(result)
             else:
                 params = kwargs.get(process, {})
                 process_cls = get_process(process)
@@ -127,129 +127,9 @@ class Chainer(Signal):
                 process_type = process_cls.__module__.split(".")[-2]
                 if process_type == "reshapers":
                     if process == "merger":
+                        raise (NotImplementedError)
                         merges = process.as_function(result, **params)
                         result = self.apply_merges(result, merges)
 
             self._intermediate_steps.append(result)
         return result
-
-
-# def standard(
-#     raw: pd.DataFrame,
-#     lin: np.ndarray,
-#     presence_filter_min: int = 7,
-#     presence_filter_mothers: float = 0.8,
-# ):
-#     """
-#     This requires a double-check that mothers-that-are-daughters still are accounted for after
-#     filtering daughters by the minimal threshold.
-#     """
-#     raw = raw.loc[raw.notna().sum(axis=1) > presence_filter_min].sort_index()
-#     indices = np.array(raw.index.to_list())
-#     # Get remaining full families
-#     valid_lineages, valid_indices = validate_association(lin, indices)
-
-#     daughters = lin[valid_lineages][:, [0, 2]]
-#     mothers = lin[valid_lineages][:, :2]
-#     in_lineage = raw.loc[valid_indices].copy()
-#     mother_label = np.repeat(0, in_lineage.shape[0])
-
-#     daughter_ids = (
-#         (
-#             np.array(in_lineage.index.to_list())
-#             == np.unique(daughters, axis=0)[:, None]
-#         )
-#         .all(axis=2)
-#         .any(axis=0)
-#     )
-#     mother_label[daughter_ids] = mothers[:, 1]
-#     # Filter mothers by presence
-#     in_lineage["mother_label"] = mother_label
-#     present = in_lineage.loc[
-#         (
-#             in_lineage.iloc[:, :-1].notna().sum(axis=1)
-#             > ((in_lineage.shape[1] - 1) * presence_filter_mothers)
-#         )
-#         | mother_label
-#     ]
-#     present.set_index("mother_label", append=True, inplace=True)
-
-#     # Finally, check full families again
-#     final_indices = np.array(present.index.to_list())
-#     _, final_mask = validate_association(
-#         np.array([tuple(x) for x in present.index.swaplevel(1, 2)]),
-#         final_indices[:, :2],
-#     )
-#     return present.loc[final_mask]
-
-#     # In the end, we get the mothers present for more than {presence_lineage1}% of the experiment
-#     # and their tracklets present for more than {presence_lineage2} time-points
-#     return present
-
-
-def standard_filtering(
-    raw: pd.DataFrame,
-    lin: np.ndarray,
-    presence_high: float = 0.8,
-    presence_low: int = 7,
-):
-    # Get all mothers
-    _, valid_indices = validate_association(
-        lin, np.array(raw.index.to_list()), match_column=0
-    )
-    in_lineage = raw.loc[valid_indices]
-
-    # Filter mothers by presence
-    present = in_lineage.loc[
-        in_lineage.notna().sum(axis=1) > (in_lineage.shape[1] * presence_high)
-    ]
-
-    # Get indices
-    indices = np.array(present.index.to_list())
-    to_cast = np.stack((lin[:, :2], lin[:, [0, 2]]), axis=1)
-    ndin = to_cast[..., None] == indices.T[None, ...]
-
-    # use indices to fetch all daughters
-    valid_association = ndin.all(axis=2)[:, 0].any(axis=-1)
-
-    # Remove repeats
-    mothers, daughters = np.split(to_cast[valid_association], 2, axis=1)
-    mothers = mothers[:, 0]
-    daughters = daughters[:, 0]
-    d_m_dict = {tuple(d): m[-1] for m, d in zip(mothers, daughters)}
-
-    # assuming unique sorts
-    raw_mothers = raw.loc[_as_tuples(mothers)]
-    raw_mothers["mother_label"] = 0
-    raw_daughters = raw.loc[_as_tuples(daughters)]
-    raw_daughters["mother_label"] = d_m_dict.values()
-    concat = pd.concat((raw_mothers, raw_daughters)).sort_index()
-    concat.set_index("mother_label", append=True, inplace=True)
-
-    # Last filter to remove tracklets that are too short
-    removed_buds = concat.notna().sum(axis=1) <= presence_low
-    filt = concat.loc[~removed_buds]
-
-    # We check that no mothers are left child-less
-    m_d_dict = {tuple(m): [] for m in mothers}
-    for (trap, d), m in d_m_dict.items():
-        m_d_dict[(trap, m)].append(d)
-
-    for trap, daughter, mother in concat.index[removed_buds]:
-        idx_to_delete = m_d_dict[(trap, mother)].index(daughter)
-        del m_d_dict[(trap, mother)][idx_to_delete]
-
-    bud_free = []
-    for m, d in m_d_dict.items():
-        if not d:
-            bud_free.append(m)
-
-    final_result = filt.drop(bud_free)
-
-    # In the end, we get the mothers present for more than {presence_lineage1}% of the experiment
-    # and their tracklets present for more than {presence_lineage2} time-points
-    return final_result
-
-
-def _as_tuples(array: t.Collection) -> t.List[t.Tuple[int, int]]:
-    return [tuple(x) for x in np.unique(array, axis=0)]
-- 
GitLab