From 9f9d2ba0ecf8a1e9dd61f8b354ed9af7b521d3db Mon Sep 17 00:00:00 2001
From: Swainlab <peter.swain@ed.ac.uk>
Date: Sun, 2 Jul 2023 18:08:43 +0100
Subject: [PATCH] docs for budding; re-write of budmetric

---
 src/agora/io/signal.py                        |   2 +-
 src/agora/utils/lineage.py                    |   5 -
 .../core/reshapers/bud_metric.py              | 100 +++++++++++-------
 3 files changed, 61 insertions(+), 46 deletions(-)

diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index 20b6d8ba..7a7f871f 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -70,7 +70,7 @@ class Signal(BridgeH5):
 
     @staticmethod
     def add_name(df, name):
-        """Add column of identical strings to a dataframe."""
+        """TODO"""
         df.name = name
         return df
 
diff --git a/src/agora/utils/lineage.py b/src/agora/utils/lineage.py
index 52fb552b..05c161a5 100644
--- a/src/agora/utils/lineage.py
+++ b/src/agora/utils/lineage.py
@@ -1,12 +1,8 @@
 #!/usr/bin/env python3
-import re
-import typing as t
 
 import numpy as np
-import pandas as pd
 
 from agora.io.bridge import groupsort
-from itertools import groupby
 
 
 def mb_array_to_dict(mb_array: np.ndarray):
@@ -19,4 +15,3 @@ def mb_array_to_dict(mb_array: np.ndarray):
         for trap, mo_da in groupsort(mb_array).items()
         for mo, daughters in groupsort(mo_da).items()
     }
-
diff --git a/src/postprocessor/core/reshapers/bud_metric.py b/src/postprocessor/core/reshapers/bud_metric.py
index b8952288..ee239221 100644
--- a/src/postprocessor/core/reshapers/bud_metric.py
+++ b/src/postprocessor/core/reshapers/bud_metric.py
@@ -20,8 +20,9 @@ class BudMetricParameters(LineageProcessParameters):
 
 class BudMetric(LineageProcess):
     """
-    Requires mother-bud information to create a new dataframe where the indices are mother ids and
-    values are the daughters' values for a given signal.
+    Requires mother-bud information to create a new dataframe where the
+    indices are mother ids and values are the daughters' values for a
+    given signal.
     """
 
     def __init__(self, parameters: BudMetricParameters):
@@ -38,7 +39,6 @@ class BudMetric(LineageProcess):
             else:
                 assert "mother_label" in signal.index.names
                 lineage = signal.index.to_list()
-
         return self.get_bud_metric(signal, mb_array_to_dict(lineage))
 
     @staticmethod
@@ -48,7 +48,9 @@ class BudMetric(LineageProcess):
         """
 
         signal: Daughter-inclusive dataframe
-        md: Mother-daughters dictionary where key is mother's index and its values are a list of daughter indices
+        md: dictionary where key is mother's index,
+        defined as (trap, cell_label), and its values are a list of
+        daughter indices, as (trap, cell_label).
 
         Get fvi (First Valid Index) for all cells
         Create empty matrix
@@ -61,63 +63,81 @@ class BudMetric(LineageProcess):
         Convert matrix into dataframe using mother indices
 
         """
-        mothers_mat = np.zeros((len(md), signal.shape[1]))
-        cells_were_dropped = 0  # Flag determines if mothers (1), daughters (2) or both were missing (3)
-
         md_index = signal.index
-        if (
-            "mother_label" not in md_index.names
-        ):  # Generate mother label from md dict if unavailable
+        # md_index should only comprise (trap, cell_label)
+        if "mother_label" not in md_index.names:
+            # dict with daughter indices as keys
             d = {v: k for k, values in md.items() for v in values}
+            # generate mother_label in signal using the mother's cell_label
             signal["mother_label"] = list(
                 map(lambda x: d.get(x, [0])[-1], signal.index)
             )
             signal.set_index("mother_label", append=True, inplace=True)
-            related_items = set(
-                [*md.keys(), *[y for x in md.values() for y in x]]
-            )
-            md_index = md_index.intersection(related_items)
-        elif "mother_label" in md_index.names:
-            md_index = md_index.droplevel("mother_label")
+            # combine mothers and daughter indices
+            mothers_index = md.keys()
+            daughters_index = [y for x in md.values() for y in x]
+            relations = set([*mothers_index, *daughters_index])
+            # keep from md_index only mother and daughters
+            md_index = md_index.intersection(relations)
         else:
-            raise ("Unavailable relationship information")
-
+            md_index = md_index.droplevel("mother_label")
         if len(md_index) < len(signal):
-            print("Dropped cells before bud_metric")  # TODO log
-
+            print("Dropped cells before applying bud_metric")  # TODO log
+        # restrict signal to the cells in md_index, moving mother_label to do so
         signal = (
             signal.reset_index("mother_label")
             .loc(axis=0)[md_index]
             .set_index("mother_label", append=True)
         )
-
-        names = list(signal.index.names)
-        del names[-2]
-
-        output_df = (
-            signal.loc[signal.index.get_level_values("mother_label") > 0]
-            .groupby(names)
-            .apply(lambda x: _combine_daughter_tracks(x))
+        # restrict to daughters: cells with a mother
+        mother_labels = signal.index.get_level_values("mother_label")
+        daughter_df = signal.loc[mother_labels > 0]
+        # join data for daughters with the same mother
+        output_df = daughter_df.groupby(["trap", "mother_label"]).apply(
+            lambda x: _combine_daughter_tracks(x)
         )
+
         output_df.columns = signal.columns
         output_df["padding_level"] = 0
         output_df.set_index("padding_level", append=True, inplace=True)
-
         if len(output_df):
             output_df.index.names = signal.index.names
         return output_df
 
 
-def _combine_daughter_tracks(tracks: t.Collection[pd.Series]):
+def _combine_daughter_tracks(tracks: pd.DataFrame):
     """
-    Combine multiple time series of cells into one, overwriting values
-    prioritising the most recent entity.
+    Combine multiple time series of daughter cells into one time series.
+
+    At any one time, a mother cell should have only one daughter.
+
+    Two daughters are still sometimes present at the same time point, and we
+    then choose the daughter that appears first.
+
+    TODO We need to fix examples with more than one daughter at a time point.
+
+    Parameters
+    ----------
+    tracks: a Signal
+        Data for all daughters, which are distinguished by different cell_labels,
+        for a particular trap and mother_label.
     """
-    sorted_da_ids = tracks.sort_index(level="cell_label")
-    sorted_da_ids.index = range(len(sorted_da_ids))
-    tp_fvt = sorted_da_ids.apply(lambda x: x.first_valid_index(), axis=0)
-    tp_fvt = sorted_da_ids.columns.get_indexer(tp_fvt)
-    tp_fvt[tp_fvt < 0] = len(sorted_da_ids) - 1
-
-    _metric = np.choose(tp_fvt, sorted_da_ids.values)
-    return pd.Series(_metric, index=tracks.columns)
+    # sort by daughter IDs
+    bud_df = tracks.sort_index(level="cell_label")
+    # remove multi-index
+    bud_df.index = range(len(bud_df))
+    # find which row of sorted_df has the daughter for each time point
+    tp_fvt: pd.Series = bud_df.apply(lambda x: x.first_valid_index(), axis=0)
+    # combine data for all daughters
+    combined_tracks = np.nan * np.ones(tracks.columns.size)
+    for bud_row in np.unique(tp_fvt.dropna().values).astype(int):
+        ilocs = np.where(tp_fvt.values == bud_row)[0]
+        combined_tracks[ilocs] = bud_df.values[bud_row, ilocs]
+    # TODO delete old version
+    tp_fvt = bud_df.columns.get_indexer(tp_fvt)
+    tp_fvt[tp_fvt == -1] = len(bud_df) - 1
+    old = np.choose(tp_fvt, bud_df.values)
+    assert (
+        (combined_tracks == old) | (np.isnan(combined_tracks) & np.isnan(old))
+    ).all(), "yikes"
+    return pd.Series(combined_tracks, index=tracks.columns)
-- 
GitLab