From 7917862b9903f28dca6f1df4fb3da1803e69560e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Thu, 16 Mar 2023 17:02:14 +0000 Subject: [PATCH] [WIP]: refactor merging for test --- src/agora/io/signal.py | 6 +- src/agora/utils/association.py | 120 -------------- src/agora/utils/merge.py | 2 +- src/postprocessor/core/processor.py | 181 ++++++++++++--------- src/postprocessor/core/reshapers/picker.py | 2 +- 5 files changed, 106 insertions(+), 205 deletions(-) diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 6f7ea3e4..08e4c944 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -11,7 +11,7 @@ import pandas as pd from agora.io.bridge import BridgeH5 from agora.io.decorators import _first_arg_str_to_df -from agora.utils.association import validate_association +from agora.utils.indexing import validate_association from agora.utils.kymograph import add_index_levels from agora.utils.merge import apply_merges @@ -171,7 +171,7 @@ class Signal(BridgeH5): """ if isinstance(merges, bool): - merges: np.ndarray = self.get_merges() if merges else np.array([]) + merges: np.ndarray = self.load_merges() if merges else np.array([]) if merges.any(): merged = apply_merges(data, merges) else: @@ -292,7 +292,7 @@ class Signal(BridgeH5): self._log(f"Could not fetch dataset {dataset}: {e}", "error") raise e - def get_merges(self): + def load_merges(self): """Get merge events going up to the first level.""" with h5py.File(self.filename, "r") as f: merges = f.get("modifiers/merges", np.array([])) diff --git a/src/agora/utils/association.py b/src/agora/utils/association.py index d523e427..a051bb08 100644 --- a/src/agora/utils/association.py +++ b/src/agora/utils/association.py @@ -1,121 +1 @@ #!/usr/bin/env jupyter -""" -Utilities based on association are used to efficiently acquire indices of tracklets with some kind of relationship. -This can be: - - Cells that are to be merged - - Cells that have a linear relationship -""" - -import numpy as np -import typing as t - - -def validate_association( - association: np.ndarray, - indices: np.ndarray, - match_column: t.Optional[int] = None, -) -> t.Tuple[np.ndarray, np.ndarray]: - - """Select rows from the first array that are present in both. - We use casting for fast multiindexing, generalising for lineage dynamics - - - Parameters - ---------- - association : np.ndarray - 2-D array where columns are (trap, mother, daughter) or 3-D array where - dimensions are (X,trap,2), containing tuples ((trap,mother), (trap,daughter)) - across the 3rd dimension. - indices : np.ndarray - 2-D array where each column is a different level. This should not include mother_label. - match_column: int - int indicating a specific column is required to match (i.e. - 0-1 for target-source when trying to merge tracklets or mother-bud for lineage) - must be present in indices. If it is false one match suffices for the resultant indices - vector to be True. - - Returns - ------- - np.ndarray - 1-D boolean array indicating valid merge events. - np.ndarray - 1-D boolean array indicating indices with an association relationship. - - Examples - -------- - - >>> import numpy as np - >>> from agora.utils.association import validate_association - >>> merges = np.array(range(12)).reshape(3,2,2) - >>> indices = np.array(range(6)).reshape(3,2) - - >>> print(merges, indices) - >>> print(merges); print(indices) - [[[ 0 1] - [ 2 3]] - - [[ 4 5] - [ 6 7]] - - [[ 8 9] - [10 11]]] - - [[0 1] - [2 3] - [4 5]] - - >>> valid_associations, valid_indices = validate_association(merges, indices) - >>> print(valid_associations, valid_indices) - [ True False False] [ True True False] - - """ - if association.ndim == 2: - # Reshape into 3-D array for broadcasting if neded - # association = np.stack( - # (association[:, [0, 1]], association[:, [0, 2]]), axis=1 - # ) - association = last_col_as_rows(association) - - # Compare existing association with available indices - # Swap trap and label axes for the association array to correctly cast - valid_ndassociation = association[..., None] == indices.T[None, ...] - - # Broadcasting is confusing (but efficient): - # First we check the dimension across trap and cell id, to ensure both match - valid_cell_ids = valid_ndassociation.all(axis=2) - - if match_column is None: - # Then we check the merge tuples to check which cases have both target and source - valid_association = valid_cell_ids.any(axis=2).all(axis=1) - - # Finally we check the dimension that crosses all indices, to ensure the pair - # is present in a valid merge event. - valid_indices = ( - valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1)) - ) - else: # We fetch specific indices if we aim for the ones with one present - valid_indices = valid_cell_ids[:, match_column].any(axis=0) - # Valid association then becomes a boolean array, true means that there is a - # match (match_column) between that cell and the index - valid_association = ( - valid_cell_ids[:, match_column] & valid_indices - ).any(axis=1) - - return valid_association, valid_indices - - -def last_col_as_rows(ndarray: np.ndarray): - """ - Convert the last column to a new row while repeating all previous indices. - - This is useful when converting a signal multiindex before comparing association. - """ - columns = np.arange(ndarray.shape[1]) - - return np.stack( - ( - ndarray[:, np.delete(columns, -1)], - ndarray[:, np.delete(columns, -2)], - ), - axis=1, - ) diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py index aec59a60..8f3aee4e 100644 --- a/src/agora/utils/merge.py +++ b/src/agora/utils/merge.py @@ -9,7 +9,7 @@ import numpy as np import pandas as pd from utils_find_1st import cmp_larger, find_1st -from agora.utils.association import validate_association +from agora.utils.indexing import validate_association def apply_merges(data: pd.DataFrame, merges: np.ndarray): diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 381729f8..e6ff00db 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -1,3 +1,4 @@ +import typing as t from itertools import takewhile from typing import Dict, List, Union @@ -10,6 +11,11 @@ from agora.abc import ParametersABC, ProcessABC from agora.io.cells import Cells from agora.io.signal import Signal from agora.io.writer import Writer +from agora.utils.indexing import ( + _assoc_indices_to_3d, + validate_association, +) +from agora.utils.kymograph import get_index_as_np from postprocessor.core.abc import get_parameters, get_process from postprocessor.core.lineageprocess import LineageProcessParameters from postprocessor.core.reshapers.merger import Merger, MergerParameters @@ -146,53 +152,14 @@ class PostProcessor(ProcessABC): def run_prepost(self): # TODO Split function """Important processes run before normal post-processing ones""" + record = self._signal.get_raw(self.targets["prepost"]["merger"]) + merge_events = self.merger.run(record) - merge_events = self.merger.run( - self._signal[self.targets["prepost"]["merger"]] - ) - - prev_idchanges = self._signal.get_merges() - - changes_history = list(prev_idchanges) + [ - np.array(x) for x in merge_events - ] - self._writer.write("modifiers/merges", data=changes_history) - - # TODO Remove this once test is wriiten for consecutive postprocesses - with h5py.File(self._filename, "a") as f: - if "modifiers/picks" in f: - del f["modifiers/picks"] - - indices = self.picker.run( - self._signal[self.targets["prepost"]["picker"][0]] - ) - - combined_idx = ([], [], []) - trap, mother, daughter = combined_idx - - lineage = self.picker.cells.mothers_daughters - - if lineage.any(): - trap, mother, daughter = lineage.T - combined_idx = np.vstack((trap, mother, daughter)) - - trap_mother = np.vstack((trap, mother)).T - trap_daughter = np.vstack((trap, daughter)).T - - multii = pd.MultiIndex.from_arrays( - combined_idx, - names=["trap", "mother_label", "daughter_label"], - ) self._writer.write( - "postprocessing/lineage", - data=multii, - overwrite="overwrite", + "modifiers/merges", data=[np.array(x) for x in merge_events] ) - # apply merge to mother-trap_daughter - moset = set([tuple(x) for x in trap_mother]) - daset = set([tuple(x) for x in trap_daughter]) - picked_set = set([tuple(x) for x in indices]) + lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters) with h5py.File(self._filename, "a") as f: merge_events = f["modifiers/merges"][()] @@ -203,31 +170,41 @@ class PostProcessor(ProcessABC): ) self.lineage_merged = multii - if merge_events.any(): + indices = get_index_as_np(record) + if merge_events.any(): # Update lineages after merge events + # We validate merges that associate existing mothers and daughters + valid_merges, valid_indices = validate_association(merges, indices) - def search(a, b): - return np.where( - np.in1d( - np.ravel_multi_index(a.T, a.max(0) + 1), - np.ravel_multi_index(b.T, a.max(0) + 1), - ) - ) + grouped_merges = group_merges(merges) + # Sumarise the merges linking the first and final id + # Shape (X,2,2) + summarised = np.array( + [(x[0][0], x[-1][1]) for x in grouped_merges] + ) + # List the indices that weill be deleted, as they are in-between + # Shape (Y,2) + to_delete = np.vstack( + [ + x.reshape(-1, x.shape[-1])[1:-1] + for x in grouped_merges + if len(x) > 1 + ] + ) - for target, source in merge_events: - if ( - tuple(source) in moset - ): # update mother to lowest positive index among the two - mother_ids = search(trap_mother, source) - trap_mother[mother_ids] = ( - target[0], - self.pick_mother( - trap_mother[mother_ids][0][1], target[1] - ), - ) - if tuple(source) in daset: - trap_daughter[search(trap_daughter, source)] = target - if tuple(source) in picked_set: - indices[search(indices, source)] = target + flat_indices = lineage.reshape(-1, 2) + valid_merges, valid_indices = validate_association( + summarised, flat_indices + ) + # Replace + id_eq_matrix = compare_indices(flat_indices, to_delete) + + # Update labels of merged tracklets + flat_indices[valid_indices] = summarised[valid_merges, 1] + + # Remove labels that will be removed when merging + flat_indices = flat_indices[id_eq_matrix.any(axis=1)] + + lineage_merged = flat_indices.reshape(-1, 2) self.lineage_merged = pd.MultiIndex.from_arrays( np.unique( @@ -240,21 +217,30 @@ class PostProcessor(ProcessABC): ).T, names=["trap", "mother_label", "daughter_label"], ) - self._writer.write( - "postprocessing/lineage_merged", - data=self.lineage_merged, - overwrite="overwrite", - ) - self._writer.write( - "modifiers/picks", - data=pd.MultiIndex.from_arrays( - # TODO Check if multiindices are still repeated - np.unique(indices, axis=0).T if indices.any() else [[], []], - names=["trap", "cell_label"], - ), - overwrite="overwrite", - ) + # Remove after implementing outside + # self._writer.write( + # "modifiers/picks", + # data=pd.MultiIndex.from_arrays( + # # TODO Check if multiindices are still repeated + # np.unique(indices, axis=0).T if indices.any() else [[], []], + # names=["trap", "cell_label"], + # ), + # overwrite="overwrite", + # ) + + # combined_idx = ([], [], []) + + # multii = pd.MultiIndex.from_arrays( + # combined_idx, + # names=["trap", "mother_label", "daughter_label"], + # ) + # self._writer.write( + # "postprocessing/lineage", + # data=multii, + # # TODO check if overwrite is still needed + # overwrite="overwrite", + # ) @staticmethod def pick_mother(a, b): @@ -357,3 +343,38 @@ class PostProcessor(ProcessABC): metadata: Dict, ): self._writer.write(path, result, meta=metadata, overwrite="overwrite") + + +def union_find(lsts): + sets = [set(lst) for lst in lsts if lst] + merged = True + while merged: + merged = False + results = [] + while sets: + common, rest = sets[0], sets[1:] + sets = [] + for x in rest: + if x.isdisjoint(common): + sets.append(x) + else: + merged = True + common |= x + results.append(common) + sets = results + return sets + + +def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: + # Return a list where the cell is present as source and target + # (multimerges) + + sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :]) + is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1) + is_monomerge = ~is_multimerge + + multimerge_subsets = union_find(list(zip(*np.where(sources_targets)))) + return [ + *[merges[np.array(tuple(x))] for x in multimerge_subsets], + *[[event] for event in merges[is_monomerge]], + ] diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index 84e1c21f..ca666bc4 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -5,7 +5,7 @@ import pandas as pd from agora.abc import ParametersABC from agora.io.cells import Cells -from agora.utils.association import validate_association +from agora.utils.indexing import validate_association from agora.utils.cast import _str_to_int from agora.utils.kymograph import drop_mother_label from postprocessor.core.lineageprocess import LineageProcess -- GitLab