diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 6f92799ec38b3928daa27e6f2ebfde394fa7f2b2..d4cd0367a7b1d0a9d29c65a1c3ea90263b1af0a1 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -14,8 +14,8 @@ from agora.io.writer import Writer from agora.utils.indexing import ( _3d_index_to_2d, _assoc_indices_to_3d, - compare_indices, ) +from agora.utils.merge import merge_association from agora.utils.kymograph import get_index_as_np from postprocessor.core.abc import get_parameters, get_process from postprocessor.core.lineageprocess import ( @@ -168,24 +168,9 @@ class PostProcessor(ProcessABC): if merges.any(): # Update lineages after merge events - grouped_merges = group_merges(merges) - - flat_indices = lineage.reshape(-1, 2) - comparison_mat = compare_indices(merges[:, 0], flat_indices) - - valid_indices = comparison_mat.any(axis=0) - - replacement_d = {} - for dataset in grouped_merges: - for k in dataset: - replacement_d[tuple(k[0])] = dataset[-1][1] - - flat_indices[valid_indices] = [ - replacement_d[tuple(i)] for i in flat_indices[valid_indices] - ] - + merged_indices = merge_association(lineage, merges) # Remove repeated labels post-merging - lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0) + lineage_merged = np.unique(merged_indices, axis=0) self.lineage = _3d_index_to_2d( lineage_merged if len(lineage_merged) else lineage @@ -307,56 +292,3 @@ class PostProcessor(ProcessABC): if not result.any().any(): logging.getLogger("aliby").warning(f"Record {path} is empty") self._writer.write(path, result, meta=metadata, overwrite="overwrite") - - -def union_find(lsts): - sets = [set(lst) for lst in lsts if lst] - merged = True - while merged: - merged = False - results = [] - while sets: - common, rest = sets[0], sets[1:] - sets = [] - for x in rest: - if x.isdisjoint(common): - sets.append(x) - else: - merged = True - common |= x - results.append(common) - sets = results - return sets - - -def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: - # Return a list where the cell is present as source and target - # (multimerges) - - sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :]) - is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1) - is_monomerge = ~is_multimerge - - multimerge_subsets = union_find(zip(*np.where(sources_targets))) - merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets] - - sorted_merges = list(map(sort_association, merge_groups)) - - # Ensure that source and target are at the edges - return [ - *sorted_merges, - *[[event] for event in merges[is_monomerge]], - ] - - -def sort_association(array: np.ndarray): - # Sort the internal associations - - order = np.where( - (array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1) - ) - - res = [] - [res.append(x) for x in np.flip(order).flatten() if x not in res] - sorted_array = array[np.array(res)] - return sorted_array