From 7917862b9903f28dca6f1df4fb3da1803e69560e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Thu, 16 Mar 2023 17:02:14 +0000
Subject: [PATCH] [WIP]: refactor merging for test

---
 src/agora/io/signal.py                     |   6 +-
 src/agora/utils/association.py             | 120 --------------
 src/agora/utils/merge.py                   |   2 +-
 src/postprocessor/core/processor.py        | 181 ++++++++++++---------
 src/postprocessor/core/reshapers/picker.py |   2 +-
 5 files changed, 106 insertions(+), 205 deletions(-)

diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index 6f7ea3e4..08e4c944 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -11,7 +11,7 @@ import pandas as pd
 
 from agora.io.bridge import BridgeH5
 from agora.io.decorators import _first_arg_str_to_df
-from agora.utils.association import validate_association
+from agora.utils.indexing import validate_association
 from agora.utils.kymograph import add_index_levels
 from agora.utils.merge import apply_merges
 
@@ -171,7 +171,7 @@ class Signal(BridgeH5):
 
         """
         if isinstance(merges, bool):
-            merges: np.ndarray = self.get_merges() if merges else np.array([])
+            merges: np.ndarray = self.load_merges() if merges else np.array([])
         if merges.any():
             merged = apply_merges(data, merges)
         else:
@@ -292,7 +292,7 @@ class Signal(BridgeH5):
             self._log(f"Could not fetch dataset {dataset}: {e}", "error")
             raise e
 
-    def get_merges(self):
+    def load_merges(self):
         """Get merge events going up to the first level."""
         with h5py.File(self.filename, "r") as f:
             merges = f.get("modifiers/merges", np.array([]))
diff --git a/src/agora/utils/association.py b/src/agora/utils/association.py
index d523e427..a051bb08 100644
--- a/src/agora/utils/association.py
+++ b/src/agora/utils/association.py
@@ -1,121 +1 @@
 #!/usr/bin/env jupyter
-"""
-Utilities based on association are used to efficiently acquire indices of tracklets with some kind of relationship.
-This can be:
-    - Cells that are to be merged
-    - Cells that have a linear relationship
-"""
-
-import numpy as np
-import typing as t
-
-
-def validate_association(
-    association: np.ndarray,
-    indices: np.ndarray,
-    match_column: t.Optional[int] = None,
-) -> t.Tuple[np.ndarray, np.ndarray]:
-
-    """Select rows from the first array that are present in both.
-        We use casting for fast multiindexing, generalising for lineage dynamics
-
-
-        Parameters
-        ----------
-        association : np.ndarray
-            2-D array where columns are (trap, mother, daughter) or 3-D array where
-            dimensions are (X,trap,2), containing tuples ((trap,mother), (trap,daughter))
-            across the 3rd dimension.
-        indices : np.ndarray
-            2-D array where each column is a different level. This should not include mother_label.
-        match_column: int
-            int indicating a specific column is required to match (i.e.
-            0-1 for target-source when trying to merge tracklets or mother-bud for lineage)
-            must be present in indices. If it is false one match suffices for the resultant indices
-            vector to be True.
-
-        Returns
-        -------
-        np.ndarray
-            1-D boolean array indicating valid merge events.
-        np.ndarray
-            1-D boolean array indicating indices with an association relationship.
-
-        Examples
-        --------
-
-        >>> import numpy as np
-        >>> from agora.utils.association import validate_association
-        >>> merges = np.array(range(12)).reshape(3,2,2)
-        >>> indices = np.array(range(6)).reshape(3,2)
-
-        >>> print(merges, indices)
-        >>> print(merges); print(indices)
-        [[[ 0  1]
-          [ 2  3]]
-
-         [[ 4  5]
-          [ 6  7]]
-
-         [[ 8  9]
-          [10 11]]]
-
-        [[0 1]
-         [2 3]
-         [4 5]]
-
-        >>> valid_associations, valid_indices  = validate_association(merges, indices)
-        >>> print(valid_associations, valid_indices)
-    [ True False False] [ True  True False]
-
-    """
-    if association.ndim == 2:
-        # Reshape into 3-D array for broadcasting if neded
-        # association = np.stack(
-        #     (association[:, [0, 1]], association[:, [0, 2]]), axis=1
-        # )
-        association = last_col_as_rows(association)
-
-    # Compare existing association with available indices
-    # Swap trap and label axes for the association array to correctly cast
-    valid_ndassociation = association[..., None] == indices.T[None, ...]
-
-    # Broadcasting is confusing (but efficient):
-    # First we check the dimension across trap and cell id, to ensure both match
-    valid_cell_ids = valid_ndassociation.all(axis=2)
-
-    if match_column is None:
-        # Then we check the merge tuples to check which cases have both target and source
-        valid_association = valid_cell_ids.any(axis=2).all(axis=1)
-
-        # Finally we check the dimension that crosses all indices, to ensure the pair
-        # is present in a valid merge event.
-        valid_indices = (
-            valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1))
-        )
-    else:  # We fetch specific indices if we aim for the ones with one present
-        valid_indices = valid_cell_ids[:, match_column].any(axis=0)
-        # Valid association then becomes a boolean array, true means that there is a
-        # match (match_column) between that cell and the index
-        valid_association = (
-            valid_cell_ids[:, match_column] & valid_indices
-        ).any(axis=1)
-
-    return valid_association, valid_indices
-
-
-def last_col_as_rows(ndarray: np.ndarray):
-    """
-    Convert the last column to a new row while repeating all previous indices.
-
-    This is useful when converting a signal multiindex before comparing association.
-    """
-    columns = np.arange(ndarray.shape[1])
-
-    return np.stack(
-        (
-            ndarray[:, np.delete(columns, -1)],
-            ndarray[:, np.delete(columns, -2)],
-        ),
-        axis=1,
-    )
diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py
index aec59a60..8f3aee4e 100644
--- a/src/agora/utils/merge.py
+++ b/src/agora/utils/merge.py
@@ -9,7 +9,7 @@ import numpy as np
 import pandas as pd
 from utils_find_1st import cmp_larger, find_1st
 
-from agora.utils.association import validate_association
+from agora.utils.indexing import validate_association
 
 
 def apply_merges(data: pd.DataFrame, merges: np.ndarray):
diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py
index 381729f8..e6ff00db 100644
--- a/src/postprocessor/core/processor.py
+++ b/src/postprocessor/core/processor.py
@@ -1,3 +1,4 @@
+import typing as t
 from itertools import takewhile
 from typing import Dict, List, Union
 
@@ -10,6 +11,11 @@ from agora.abc import ParametersABC, ProcessABC
 from agora.io.cells import Cells
 from agora.io.signal import Signal
 from agora.io.writer import Writer
+from agora.utils.indexing import (
+    _assoc_indices_to_3d,
+    validate_association,
+)
+from agora.utils.kymograph import get_index_as_np
 from postprocessor.core.abc import get_parameters, get_process
 from postprocessor.core.lineageprocess import LineageProcessParameters
 from postprocessor.core.reshapers.merger import Merger, MergerParameters
@@ -146,53 +152,14 @@ class PostProcessor(ProcessABC):
     def run_prepost(self):
         # TODO Split function
         """Important processes run before normal post-processing ones"""
+        record = self._signal.get_raw(self.targets["prepost"]["merger"])
+        merge_events = self.merger.run(record)
 
-        merge_events = self.merger.run(
-            self._signal[self.targets["prepost"]["merger"]]
-        )
-
-        prev_idchanges = self._signal.get_merges()
-
-        changes_history = list(prev_idchanges) + [
-            np.array(x) for x in merge_events
-        ]
-        self._writer.write("modifiers/merges", data=changes_history)
-
-        # TODO Remove this once test is wriiten for consecutive postprocesses
-        with h5py.File(self._filename, "a") as f:
-            if "modifiers/picks" in f:
-                del f["modifiers/picks"]
-
-        indices = self.picker.run(
-            self._signal[self.targets["prepost"]["picker"][0]]
-        )
-
-        combined_idx = ([], [], [])
-        trap, mother, daughter = combined_idx
-
-        lineage = self.picker.cells.mothers_daughters
-
-        if lineage.any():
-            trap, mother, daughter = lineage.T
-            combined_idx = np.vstack((trap, mother, daughter))
-
-        trap_mother = np.vstack((trap, mother)).T
-        trap_daughter = np.vstack((trap, daughter)).T
-
-        multii = pd.MultiIndex.from_arrays(
-            combined_idx,
-            names=["trap", "mother_label", "daughter_label"],
-        )
         self._writer.write(
-            "postprocessing/lineage",
-            data=multii,
-            overwrite="overwrite",
+            "modifiers/merges", data=[np.array(x) for x in merge_events]
         )
 
-        # apply merge to mother-trap_daughter
-        moset = set([tuple(x) for x in trap_mother])
-        daset = set([tuple(x) for x in trap_daughter])
-        picked_set = set([tuple(x) for x in indices])
+        lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters)
 
         with h5py.File(self._filename, "a") as f:
             merge_events = f["modifiers/merges"][()]
@@ -203,31 +170,41 @@ class PostProcessor(ProcessABC):
         )
         self.lineage_merged = multii
 
-        if merge_events.any():
+        indices = get_index_as_np(record)
+        if merge_events.any():  # Update lineages after merge events
+            # We validate merges that associate existing mothers and daughters
+            valid_merges, valid_indices = validate_association(merges, indices)
 
-            def search(a, b):
-                return np.where(
-                    np.in1d(
-                        np.ravel_multi_index(a.T, a.max(0) + 1),
-                        np.ravel_multi_index(b.T, a.max(0) + 1),
-                    )
-                )
+            grouped_merges = group_merges(merges)
+            # Sumarise the merges linking the first and final id
+            # Shape (X,2,2)
+            summarised = np.array(
+                [(x[0][0], x[-1][1]) for x in grouped_merges]
+            )
+            # List the indices that weill be deleted, as they are in-between
+            # Shape (Y,2)
+            to_delete = np.vstack(
+                [
+                    x.reshape(-1, x.shape[-1])[1:-1]
+                    for x in grouped_merges
+                    if len(x) > 1
+                ]
+            )
 
-            for target, source in merge_events:
-                if (
-                    tuple(source) in moset
-                ):  # update mother to lowest positive index among the two
-                    mother_ids = search(trap_mother, source)
-                    trap_mother[mother_ids] = (
-                        target[0],
-                        self.pick_mother(
-                            trap_mother[mother_ids][0][1], target[1]
-                        ),
-                    )
-                if tuple(source) in daset:
-                    trap_daughter[search(trap_daughter, source)] = target
-                if tuple(source) in picked_set:
-                    indices[search(indices, source)] = target
+            flat_indices = lineage.reshape(-1, 2)
+            valid_merges, valid_indices = validate_association(
+                summarised, flat_indices
+            )
+            # Replace
+            id_eq_matrix = compare_indices(flat_indices, to_delete)
+
+            # Update labels of merged tracklets
+            flat_indices[valid_indices] = summarised[valid_merges, 1]
+
+            # Remove labels that will be removed when merging
+            flat_indices = flat_indices[id_eq_matrix.any(axis=1)]
+
+            lineage_merged = flat_indices.reshape(-1, 2)
 
             self.lineage_merged = pd.MultiIndex.from_arrays(
                 np.unique(
@@ -240,21 +217,30 @@ class PostProcessor(ProcessABC):
                 ).T,
                 names=["trap", "mother_label", "daughter_label"],
             )
-        self._writer.write(
-            "postprocessing/lineage_merged",
-            data=self.lineage_merged,
-            overwrite="overwrite",
-        )
 
-        self._writer.write(
-            "modifiers/picks",
-            data=pd.MultiIndex.from_arrays(
-                # TODO Check if multiindices are still repeated
-                np.unique(indices, axis=0).T if indices.any() else [[], []],
-                names=["trap", "cell_label"],
-            ),
-            overwrite="overwrite",
-        )
+        # Remove after implementing outside
+        # self._writer.write(
+        #     "modifiers/picks",
+        #     data=pd.MultiIndex.from_arrays(
+        #         # TODO Check if multiindices are still repeated
+        #         np.unique(indices, axis=0).T if indices.any() else [[], []],
+        #         names=["trap", "cell_label"],
+        #     ),
+        #     overwrite="overwrite",
+        # )
+
+        # combined_idx = ([], [], [])
+
+        # multii = pd.MultiIndex.from_arrays(
+        #     combined_idx,
+        #     names=["trap", "mother_label", "daughter_label"],
+        # )
+        # self._writer.write(
+        #     "postprocessing/lineage",
+        #     data=multii,
+        #     # TODO check if overwrite is still needed
+        #     overwrite="overwrite",
+        # )
 
     @staticmethod
     def pick_mother(a, b):
@@ -357,3 +343,38 @@ class PostProcessor(ProcessABC):
         metadata: Dict,
     ):
         self._writer.write(path, result, meta=metadata, overwrite="overwrite")
+
+
+def union_find(lsts):
+    sets = [set(lst) for lst in lsts if lst]
+    merged = True
+    while merged:
+        merged = False
+        results = []
+        while sets:
+            common, rest = sets[0], sets[1:]
+            sets = []
+            for x in rest:
+                if x.isdisjoint(common):
+                    sets.append(x)
+                else:
+                    merged = True
+                    common |= x
+            results.append(common)
+        sets = results
+    return sets
+
+
+def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
+    # Return a list where the cell is present as source and target
+    # (multimerges)
+
+    sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
+    is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
+    is_monomerge = ~is_multimerge
+
+    multimerge_subsets = union_find(list(zip(*np.where(sources_targets))))
+    return [
+        *[merges[np.array(tuple(x))] for x in multimerge_subsets],
+        *[[event] for event in merges[is_monomerge]],
+    ]
diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py
index 84e1c21f..ca666bc4 100644
--- a/src/postprocessor/core/reshapers/picker.py
+++ b/src/postprocessor/core/reshapers/picker.py
@@ -5,7 +5,7 @@ import pandas as pd
 
 from agora.abc import ParametersABC
 from agora.io.cells import Cells
-from agora.utils.association import validate_association
+from agora.utils.indexing import validate_association
 from agora.utils.cast import _str_to_int
 from agora.utils.kymograph import drop_mother_label
 from postprocessor.core.lineageprocess import LineageProcess
-- 
GitLab