From 8ef3726cc4f980d8ac3c2ab3b67a10f905551245 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Mon, 3 Apr 2023 16:38:20 +0100
Subject: [PATCH] fix(pp): refresh bud_metric and merging

---
 src/agora/utils/merge.py                      | 20 ++++++----
 src/postprocessor/core/processor.py           |  4 +-
 .../core/reshapers/bud_metric.py              | 37 +++++++++++++------
 3 files changed, 40 insertions(+), 21 deletions(-)

diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py
index e74a9a03..93669912 100644
--- a/src/agora/utils/merge.py
+++ b/src/agora/utils/merge.py
@@ -31,8 +31,11 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
 
     """
 
+    indices = data.index
+    if "mother_label" in indices.names:
+        indices = indices.droplevel("mother_label")
     valid_merges, indices = validate_association(
-        merges, np.array(list(data.index))
+        merges, np.array(list(indices))
     )
 
     # Assign non-merged
@@ -129,14 +132,15 @@ def merge_association(
 
     valid_indices = comparison_mat.any(axis=0)
 
-    replacement_d = {}
-    for dataset in grouped_merges:
-        for k in dataset:
-            replacement_d[tuple(k[0])] = dataset[-1][1]
+    if valid_indices.any():  # Where valid, perform transformation
+        replacement_d = {}
+        for dataset in grouped_merges:
+            for k in dataset:
+                replacement_d[tuple(k[0])] = dataset[-1][1]
 
-    flat_indices[valid_indices] = [
-        replacement_d[tuple(i)] for i in flat_indices[valid_indices]
-    ]
+        flat_indices[valid_indices] = [
+            replacement_d[tuple(i)] for i in flat_indices[valid_indices]
+        ]
 
     merged_indices = flat_indices.reshape(-1, 2, 2)
     return merged_indices
diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py
index e0adac14..72bbf3f7 100644
--- a/src/postprocessor/core/processor.py
+++ b/src/postprocessor/core/processor.py
@@ -204,7 +204,9 @@ class PostProcessor(ProcessABC):
             self._writer.write(
                 "modifiers/picks",
                 data=pd.MultiIndex.from_arrays(
-                    picked_indices, names=["trap", "cell_label"]
+                    picked_indices.T,
+                    # names=["trap", "cell_label", "mother_label"],
+                    names=["trap", "cell_label"],
                 ),
                 overwrite="overwrite",
             )
diff --git a/src/postprocessor/core/reshapers/bud_metric.py b/src/postprocessor/core/reshapers/bud_metric.py
index d23ea115..bf813739 100644
--- a/src/postprocessor/core/reshapers/bud_metric.py
+++ b/src/postprocessor/core/reshapers/bud_metric.py
@@ -65,23 +65,36 @@ class BudMetric(LineageProcess):
         mothers_mat = np.zeros((len(md), signal.shape[1]))
         cells_were_dropped = 0  # Flag determines if mothers (1), daughters (2) or both were missing (3)
 
-        md_index = signal.index.droplevel("mother_label")
-        if md is not None:  # Get intersection of lineage and current index
+        md_index = signal.index
+        if (
+            "mother_label" not in md_index.names
+        ):  # Generate mother label from md dict if unavailable
+            d = {v: k for k, values in md.items() for v in values}
+            signal["mother_label"] = list(
+                map(lambda x: d.get(x, [0])[-1], signal.index)
+            )
+            signal.set_index("mother_label", append=True, inplace=True)
             related_items = set(
                 [*md.keys(), *[y for x in md.values() for y in x]]
             )
-            intersection = md_index.intersection(related_items)
-            if len(intersection) < len(signal):
-                print("Dropped cells before bud_metric")  # TODO log
-
-            signal = (
-                signal.reset_index("mother_label")
-                .loc(axis=0)[intersection]
-                .set_index("mother_label", append=True)
-            )
+            md_index = md_index.intersection(related_items)
+        elif "mother_label" in md_index.names:
+            md_index = md_index.droplevel("mother_label")
+        else:
+            raise ("Unavailable relationship information")
+
+        if len(md_index) < len(signal):
+            print("Dropped cells before bud_metric")  # TODO log
+
+        signal = (
+            signal.reset_index("mother_label")
+            .loc(axis=0)[md_index]
+            .set_index("mother_label", append=True)
+        )
 
         names = list(signal.index.names)
         del names[-2]
+
         output_df = (
             signal.loc[signal.index.get_level_values("mother_label") > 0]
             .groupby(names)
@@ -101,7 +114,7 @@ def _combine_daughter_tracks(tracks: t.Collection[pd.Series]):
     prioritising the most recent entity.
     """
     sorted_da_ids = tracks.sort_index(level="cell_label")
-    tp_fvt = sorted_da_ids.apply(lambda x: x.last_valid_index(), axis=0)
+    tp_fvt = sorted_da_ids.apply(lambda x: x.first_valid_index(), axis=1)
     tp_fvt = sorted_da_ids.index.get_indexer(tp_fvt)
     tp_fvt[tp_fvt < 0] = sorted_da_ids.shape[0] - 1
 
-- 
GitLab