From 05843ad28af4f487c183403487b3f982f07d2ea4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Sun, 10 Oct 2021 23:48:34 +0100
Subject: [PATCH] adjust parameters and bugfixes

Former-commit-id: 0669c9f294dcb76451ce6700f3f946cfcfae3d33
---
 core/processes/picker.py | 43 +++++++++++++++++++++++-----------------
 1 file changed, 25 insertions(+), 18 deletions(-)

diff --git a/core/processes/picker.py b/core/processes/picker.py
index ec0814f3..23be1f70 100644
--- a/core/processes/picker.py
+++ b/core/processes/picker.py
@@ -1,3 +1,6 @@
+import seaborn as sns
+from matplotlib import pyplot as plt  # TODO DELETE THIS
+
 from typing import Tuple, Union, List
 from abc import ABC, abstractmethod
 
@@ -27,9 +30,9 @@ class pickerParameters(ParametersABC):
                 "sequence": [
                     # ("lineage", "intersection", "families"),
                     ("condition", "intersection", "any_present", 0.8),
-                    ("condition", "intersection", "growing", 50),
-                    ("condition", "intersection", "present", 5),
-                    ("condition", "intersection", "mother_buds", 5, 0.5),
+                    ("condition", "intersection", "growing", 40),
+                    ("condition", "intersection", "present", 8),
+                    ("condition", "intersection", "mother_buds", 5, 0.8),
                     # ("lineage", "full_families", "intersection"),
                 ],
             }
@@ -162,9 +165,6 @@ class picker(ProcessABC):
 
         return idx
 
-    def pick_by_custom(self, signals, condition, thresh):
-        pass
-
     def pick_by_condition(self, signals, condition, thresh):
         idx = self.switch_case(signals, condition, thresh)
         return idx
@@ -201,14 +201,18 @@ class picker(ProcessABC):
         for alg, op, *params in self.sequence:
             if alg is "lineage":
                 param1 = params[0]
-                new_indices = getattr(self, "pick_by_" + alg)(signals, param1)
+                new_indices = getattr(self, "pick_by_" + alg)(
+                    signals.loc[list(indices)], param1
+                )
             else:
                 param1, *param2 = params
-                new_indices = getattr(self, "pick_by_" + alg)(signals, param1, param2)
+                new_indices = getattr(self, "pick_by_" + alg)(
+                    signals.loc[list(indices)], param1, param2
+                )
 
             if op is "union":
-                new_indices = new_indices.intersection(set(signals.index))
-                new_indices = indices.union(set(new_indices))
+                # new_indices = new_indices.intersection(set(signals.index))
+                new_indices = indices.union(new_indices)
 
             indices = indices.intersection(new_indices)
 
@@ -224,12 +228,9 @@ class picker(ProcessABC):
             threshold = [_as_int(*threshold, signals.shape[1])]
         case_mgr = {
             "any_present": lambda s, thresh: any_present(s, thresh),
-            "present": lambda s, thresh: signals.notna().sum(axis=1) > thresh,
-            "nonstoply_present": lambda s, thresh: signals.apply(
-                max_nonstop_ntps, axis=1
-            )
-            > thresh_asint,
-            "growing": lambda s, thresh: signals.diff(axis=1).sum(axis=1) > thresh,
+            "present": lambda s, thresh: s.notna().sum(axis=1) > thresh,
+            "nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1) > thresh,
+            "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh,
             "mother_buds": lambda s, p1, p2: mother_buds_wrap(s, p1, p2)
             # "quantile": [np.quantile(signals.values[signals.notna()], threshold)],
         }
@@ -306,10 +307,16 @@ def mother_buds(df, min_budgrowth_t, min_mobud_ratio):
     size_filter = d_to_mother[
         d_to_mother.apply(lambda x: x.dropna().iloc[0], axis=1) < 0
     ]
-    cols_sorted = size_filter.sort_index(axis=1)
+    cols_sorted = (
+        size_filter.sort_index(axis=1)
+        .apply(pd.Series.first_valid_index, axis=1)
+        .sort_values()
+    )
     if not len(cols_sorted):
         return []
-    bud_candidates = start[[True, *(np.diff(cols_sorted.columns) > min_budgrowth_t)]]
+    bud_candidates = cols_sorted.loc[
+        [True, *(np.diff(cols_sorted.values) > min_budgrowth_t)]
+    ]
 
     return [mother_id] + [int(i) for i in bud_candidates.index.tolist()]
 
-- 
GitLab