From 2747ac375fd75ac6b3917c703804800586e718f8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Sun, 10 Oct 2021 21:51:56 +0100
Subject: [PATCH] bugfixes

Former-commit-id: 6e32d26b4f402947eff408c04e2f937d8fa758d3
---
 core/processes/picker.py | 97 ++++++++++++++++++++++++++++++++++------
 1 file changed, 84 insertions(+), 13 deletions(-)

diff --git a/core/processes/picker.py b/core/processes/picker.py
index 375e6db1..ec0814f3 100644
--- a/core/processes/picker.py
+++ b/core/processes/picker.py
@@ -28,7 +28,8 @@ class pickerParameters(ParametersABC):
                     # ("lineage", "intersection", "families"),
                     ("condition", "intersection", "any_present", 0.8),
                     ("condition", "intersection", "growing", 50),
-                    ("condition", "intersection", "present", 10),
+                    ("condition", "intersection", "present", 5),
+                    ("condition", "intersection", "mother_buds", 5, 0.5),
                     # ("lineage", "full_families", "intersection"),
                 ],
             }
@@ -161,6 +162,9 @@ class picker(ProcessABC):
 
         return idx
 
+    def pick_by_custom(self, signals, condition, thresh):
+        pass
+
     def pick_by_condition(self, signals, condition, thresh):
         idx = self.switch_case(signals, condition, thresh)
         return idx
@@ -199,7 +203,7 @@ class picker(ProcessABC):
                 param1 = params[0]
                 new_indices = getattr(self, "pick_by_" + alg)(signals, param1)
             else:
-                param1, param2 = params
+                param1, *param2 = params
                 new_indices = getattr(self, "pick_by_" + alg)(signals, param1, param2)
 
             if op is "union":
@@ -216,23 +220,20 @@ class picker(ProcessABC):
         condition: str,
         threshold: Union[float, int, list],
     ):
-        threshold_asint = _as_int(threshold, signals.shape[1])
-        if isinstance(threshold, list):
-            thresh_presence = threshold[0]
+        if len(threshold) == 1:
+            threshold = [_as_int(*threshold, signals.shape[1])]
         case_mgr = {
-            "any_present": lambda s, thresh: any_present(s, threshold_asint),
-            "present": lambda s, thresh: signals.notna().sum(axis=1) > threshold_asint,
+            "any_present": lambda s, thresh: any_present(s, thresh),
+            "present": lambda s, thresh: signals.notna().sum(axis=1) > thresh,
             "nonstoply_present": lambda s, thresh: signals.apply(
                 max_nonstop_ntps, axis=1
             )
-            > threshold_asint,
-            "growing": lambda s, thresh: signals.diff(axis=1).sum(axis=1) > threshold,
+            > thresh_asint,
+            "growing": lambda s, thresh: signals.diff(axis=1).sum(axis=1) > thresh,
+            "mother_buds": lambda s, p1, p2: mother_buds_wrap(s, p1, p2)
             # "quantile": [np.quantile(signals.values[signals.notna()], threshold)],
         }
-        return set(signals.index[case_mgr[condition](signals, threshold)])
-
-
-from copy import copy
+        return set(signals.index[case_mgr[condition](signals, *threshold)])
 
 
 def any_present(signals, threshold):
@@ -255,6 +256,76 @@ def any_present(signals, threshold):
     return any_present
 
 
+from copy import copy
+
+
+def mother_buds(df, min_budgrowth_t, min_mobud_ratio):
+    """
+    Parameters
+    ----------
+    signals : pd.DataFrame
+    min_budgrowth_t: Minimal number of timepoints we lock reassignment after assigning bud
+    min_initial_size: Minimal mother-bud ratio at the assignment
+    #TODO incorporate bud-assignment data?
+
+    # If more than one bud start in the same time point pick the smallest one
+    """
+
+    ntps = df.notna().sum(axis=1)
+    mother_id = df.index[ntps.argmax()]
+    nomother = df.drop(mother_id)
+    if not len(nomother):
+        return []
+    nomother = (  # Clean short-lived cells outside our mother cell's timepoints
+        nomother.loc[
+            nomother.apply(
+                lambda x: x.first_valid_index() >= df.loc[mother_id].first_valid_index()
+                and x.last_valid_index() <= df.loc[mother_id].last_valid_index(),
+                axis=1,
+            )
+        ]
+    )
+
+    start = nomother.apply(pd.Series.first_valid_index, axis=1)
+
+    # clean duplicates
+    duplicates = start.duplicated(False)
+    if duplicates.any():
+        dup_tps = np.unique(start[duplicates])
+        idx, tps = zip(
+            *[(nomother.loc[start == tp, tp].idxmin(), tp) for tp in dup_tps]
+        )
+        start = start[~duplicates]
+        start = pd.concat(
+            (start, pd.Series(tps, index=idx, dtype="int", name="cell_label"))
+        )
+        nomother = nomother.loc[start.index]
+        nomother.index = nomother.index.astype("int")
+
+    d_to_mother = nomother[start] - df.loc[mother_id, start] * min_mobud_ratio
+    size_filter = d_to_mother[
+        d_to_mother.apply(lambda x: x.dropna().iloc[0], axis=1) < 0
+    ]
+    cols_sorted = size_filter.sort_index(axis=1)
+    if not len(cols_sorted):
+        return []
+    bud_candidates = start[[True, *(np.diff(cols_sorted.columns) > min_budgrowth_t)]]
+
+    return [mother_id] + [int(i) for i in bud_candidates.index.tolist()]
+
+
+def mother_buds_wrap(signals, *args):
+    ids = []
+    for trap in signals.index.unique(level="trap"):
+        df = signals.loc[trap]
+        selected_ids = mother_buds(df, *args)
+        ids += [(trap, i) for i in selected_ids]
+
+    idx_srs = pd.Series(False, signals.index).astype(bool)
+    idx_srs.loc[ids] = True
+    return idx_srs
+
+
 def _as_int(threshold: Union[float, int], ntps: int):
     if type(threshold) is float:
         threshold = ntps * threshold
-- 
GitLab