From 07ac0b9f4667237d7e30a5510397ea5ac2ca4009 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Fri, 7 Oct 2022 18:37:15 +0100
Subject: [PATCH] feat(chainer): add channels and aliases

---
 src/agora/io/signal.py       |  14 ++-
 src/postprocessor/chainer.py | 187 ++++++++++++++++++++++++++++-------
 2 files changed, 163 insertions(+), 38 deletions(-)

diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index 539a71d1..489a5256 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -33,6 +33,10 @@ class Signal(BridgeH5):
             "mother_label",
         )
 
+        equivalences = {
+            "m5m": ("extraction/GFP/max/max5px", "extraction/GFP/max/median")
+        }
+
     def __getitem__(self, dsets: t.Union[str, t.Collection]):
 
         if isinstance(dsets, str) and dsets.endswith("imBackground"):
@@ -88,10 +92,16 @@ class Signal(BridgeH5):
     def get_retained(df, cutoff):
         return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]
 
-    @lru_cache(30)
+    @property
+    def channels(self):
+        with h5py.File(self.filename, "r") as f:
+            return f.attrs["channels/channel"]
+
+    @_first_arg_str_to_df
     def retained(self, signal, cutoff=0.8):
 
-        df = self[signal]
+        df = signal
+        # df = self[signal]
         if isinstance(df, pd.DataFrame):
             return self.get_retained(df, cutoff)
 
diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py
index 56123a64..8276f541 100644
--- a/src/postprocessor/chainer.py
+++ b/src/postprocessor/chainer.py
@@ -11,6 +11,8 @@ from postprocessor.core.abc import get_parameters, get_process
 from postprocessor.core.lineageprocess import LineageProcessParameters
 from agora.utils.association import validate_association
 
+import re
+
 
 class Chainer(Signal):
     """
@@ -23,17 +25,54 @@ class Chainer(Signal):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        channel = [ch for ch in self.channels if re.match("GFP", ch)][0]
+        if (
+            channel == "GFPFast" and "mCherry" in self.channels
+        ):  # Use mCherry for Batman
+            channel = "mCherry"
+
+        equivalences = {
+            "m5m": (
+                f"extraction/{channel}/max/max5px",
+                f"extraction/{channel}/max/median",
+            )
+        }
+
+        def replace_url(url: str, bgsub: str = ""):
+            # return pattern with bgsub
+            channel = url.split("/")[1]
+            if "bgsub" in bgsub:
+                url = re.sub(channel, f"{channel}_bgsub", url)
+            return url
+
+        self.common_chains = {
+            alias
+            + bgsub: lambda **kwargs: self.get(
+                replace_url(denominator, alias + bgsub), **kwargs
+            )
+            / self.get(replace_url(numerator, alias + bgsub), **kwargs)
+            for alias, (denominator, numerator) in equivalences.items()
+            for bgsub in ("", "_bgsub")
+        }
 
     def get(
         self,
         dataset: str,
         chain: t.Collection[str] = ("standard", "interpolate", "savgol"),
         in_minutes: bool = True,
+        retain: t.Optional[float] = None,
         **kwargs,
     ):
-        data = self.get_raw(dataset, in_minutes=in_minutes)
-        if chain:
-            data = self.apply_chain(data, chain, **kwargs)
+        if dataset in self.common_chains:  # Produce dataset on the fly
+            data = self.common_chains[dataset](**kwargs)
+        else:
+            data = self.get_raw(dataset, in_minutes=in_minutes)
+            if chain:
+                data = self.apply_chain(data, chain, **kwargs)
+
+        if retain:
+            data = data.loc[data.notna().sum(axis=1) > data.shape[1] * retain]
+
         return data
 
     def apply_chain(
@@ -65,7 +104,7 @@ class Chainer(Signal):
         self._intermediate_steps = []
         for process in chain:
             if process == "standard":
-                result = standard(result, self.lineage())
+                result = standard_filtering(result, self.lineage())
             else:
                 params = kwargs.get(process, {})
                 process_cls = get_process(process)
@@ -80,46 +119,122 @@ class Chainer(Signal):
         return result
 
 
-def standard(
+# def standard(
+#     raw: pd.DataFrame,
+#     lin: np.ndarray,
+#     presence_filter_min: int = 7,
+#     presence_filter_mothers: float = 0.8,
+# ):
+#     """
+#     This requires a double-check that mothers-that-are-daughters still are accounted for after
+#     filtering daughters by the minimal threshold.
+#     """
+#     raw = raw.loc[raw.notna().sum(axis=1) > presence_filter_min].sort_index()
+#     indices = np.array(raw.index.to_list())
+#     # Get remaining full families
+#     valid_lineages, valid_indices = validate_association(lin, indices)
+
+#     daughters = lin[valid_lineages][:, [0, 2]]
+#     mothers = lin[valid_lineages][:, :2]
+#     in_lineage = raw.loc[valid_indices].copy()
+#     mother_label = np.repeat(0, in_lineage.shape[0])
+
+#     daughter_ids = (
+#         (
+#             np.array(in_lineage.index.to_list())
+#             == np.unique(daughters, axis=0)[:, None]
+#         )
+#         .all(axis=2)
+#         .any(axis=0)
+#     )
+#     mother_label[daughter_ids] = mothers[:, 1]
+#     # Filter mothers by presence
+#     in_lineage["mother_label"] = mother_label
+#     present = in_lineage.loc[
+#         (
+#             in_lineage.iloc[:, :-1].notna().sum(axis=1)
+#             > ((in_lineage.shape[1] - 1) * presence_filter_mothers)
+#         )
+#         | mother_label
+#     ]
+#     present.set_index("mother_label", append=True, inplace=True)
+
+#     # Finally, check full families again
+#     final_indices = np.array(present.index.to_list())
+#     _, final_mask = validate_association(
+#         np.array([tuple(x) for x in present.index.swaplevel(1, 2)]),
+#         final_indices[:, :2],
+#     )
+#     return present.loc[final_mask]
+
+#     # In the end, we get the mothers present for more than {presence_lineage1}% of the experiment
+#     # and their tracklets present for more than {presence_lineage2} time-points
+#     return present
+
+
+def standard_filtering(
     raw: pd.DataFrame,
     lin: np.ndarray,
-    presence_filter_min: int = 7,
-    presence_filter_mothers: float = 0.8,
+    presence_high: float = 0.8,
+    presence_low: int = 7,
 ):
-    """
-    This requires a double-check that mothers-that-are-daughters still are accounted for after
-    filtering daughters by the minimal threshold.
-    """
     # Get all mothers
-    raw = raw.loc[raw.notna().sum(axis=1) > presence_filter_min].sort_index()
-    indices = np.array(raw.index.to_list())
-    valid_lineages, valid_indices = validate_association(lin, indices)
-
-    daughters = lin[valid_lineages][:, [0, 2]]
-    mothers = lin[valid_lineages][:, :2]
-    in_lineage = raw.loc[valid_indices].copy()
-    mother_label = np.repeat(0, in_lineage.shape[0])
-
-    daughter_ids = (
-        (
-            np.array(in_lineage.index.to_list())
-            == np.unique(daughters, axis=0)[:, None]
-        )
-        .all(axis=2)
-        .any(axis=0)
+    _, valid_indices = validate_association(
+        lin, np.array(raw.index.to_list()), match_column=0
     )
-    mother_label[daughter_ids] = mothers[:, 1]
+    in_lineage = raw.loc[valid_indices]
+
     # Filter mothers by presence
-    in_lineage["mother_label"] = mother_label
     present = in_lineage.loc[
-        (
-            in_lineage.iloc[:, :-2].notna().sum(axis=1)
-            > (in_lineage.shape[1] * presence_filter_mothers)
-        )
-        | mother_label
+        in_lineage.notna().sum(axis=1) > (in_lineage.shape[1] * presence_high)
     ]
-    present.set_index("mother_label", append=True, inplace=True)
+
+    # Get indices
+    indices = np.array(present.index.to_list())
+    to_cast = np.stack((lin[:, :2], lin[:, [0, 2]]), axis=1)
+    ndin = to_cast[..., None] == indices.T[None, ...]
+
+    # use indices to fetch all daughters
+    valid_association = ndin.all(axis=2)[:, 0].any(axis=-1)
+
+    # Remove repeats
+    mothers, daughters = np.split(to_cast[valid_association], 2, axis=1)
+    mothers = mothers[:, 0]
+    daughters = daughters[:, 0]
+    d_m_dict = {tuple(d): m[-1] for m, d in zip(mothers, daughters)}
+
+    # assuming unique sorts
+    raw_mothers = raw.loc[_as_tuples(mothers)]
+    raw_mothers["mother_label"] = 0
+    raw_daughters = raw.loc[_as_tuples(daughters)]
+    raw_daughters["mother_label"] = d_m_dict.values()
+    concat = pd.concat((raw_mothers, raw_daughters)).sort_index()
+    concat.set_index("mother_label", append=True, inplace=True)
+
+    # Last filter to remove tracklets that are too short
+    removed_buds = concat.notna().sum(axis=1) <= presence_low
+    filt = concat.loc[~removed_buds]
+
+    # We check that no mothers are left child-less
+    m_d_dict = {tuple(m): [] for m in mothers}
+    for (trap, d), m in d_m_dict.items():
+        m_d_dict[(trap, m)].append(d)
+
+    for trap, daughter, mother in concat.index[removed_buds]:
+        idx_to_delete = m_d_dict[(trap, mother)].index(daughter)
+        del m_d_dict[(trap, mother)][idx_to_delete]
+
+    bud_free = []
+    for m, d in m_d_dict.items():
+        if not d:
+            bud_free.append(m)
+
+    final_result = filt.drop(bud_free)
 
     # In the end, we get the mothers present for more than {presence_lineage1}% of the experiment
     # and their tracklets present for more than {presence_lineage2} time-points
-    return present
+    return final_result
+
+
+def _as_tuples(array: t.Collection) -> t.List[t.Tuple[int, int]]:
+    return [tuple(x) for x in np.unique(array, axis=0)]
-- 
GitLab