Skip to content
Snippets Groups Projects
Commit e38b5a3d authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(processor): isolate merge_association

parent d345aa21
No related branches found
No related tags found
No related merge requests found
......@@ -14,8 +14,8 @@ from agora.io.writer import Writer
from agora.utils.indexing import (
_3d_index_to_2d,
_assoc_indices_to_3d,
compare_indices,
)
from agora.utils.merge import merge_association
from agora.utils.kymograph import get_index_as_np
from postprocessor.core.abc import get_parameters, get_process
from postprocessor.core.lineageprocess import (
......@@ -168,24 +168,9 @@ class PostProcessor(ProcessABC):
if merges.any(): # Update lineages after merge events
grouped_merges = group_merges(merges)
flat_indices = lineage.reshape(-1, 2)
comparison_mat = compare_indices(merges[:, 0], flat_indices)
valid_indices = comparison_mat.any(axis=0)
replacement_d = {}
for dataset in grouped_merges:
for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
merged_indices = merge_association(lineage, merges)
# Remove repeated labels post-merging
lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0)
lineage_merged = np.unique(merged_indices, axis=0)
self.lineage = _3d_index_to_2d(
lineage_merged if len(lineage_merged) else lineage
......@@ -307,56 +292,3 @@ class PostProcessor(ProcessABC):
if not result.any().any():
logging.getLogger("aliby").warning(f"Record {path} is empty")
self._writer.write(path, result, meta=metadata, overwrite="overwrite")
def union_find(lsts):
sets = [set(lst) for lst in lsts if lst]
merged = True
while merged:
merged = False
results = []
while sets:
common, rest = sets[0], sets[1:]
sets = []
for x in rest:
if x.isdisjoint(common):
sets.append(x)
else:
merged = True
common |= x
results.append(common)
sets = results
return sets
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges
return [
*sorted_merges,
*[[event] for event in merges[is_monomerge]],
]
def sort_association(array: np.ndarray):
# Sort the internal associations
order = np.where(
(array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1)
)
res = []
[res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)]
return sorted_array
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment