Skip to content
Snippets Groups Projects
Commit df58c878 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(postpro): major changes to internals

parent cc1fd0c0
No related branches found
No related tags found
No related merge requests found
...@@ -14,6 +14,7 @@ from agora.io.writer import Writer ...@@ -14,6 +14,7 @@ from agora.io.writer import Writer
from agora.utils.indexing import ( from agora.utils.indexing import (
_assoc_indices_to_3d, _assoc_indices_to_3d,
validate_association, validate_association,
compare_indices,
) )
from agora.utils.kymograph import get_index_as_np from agora.utils.kymograph import get_index_as_np
from postprocessor.core.abc import get_parameters, get_process from postprocessor.core.abc import get_parameters, get_process
...@@ -153,94 +154,52 @@ class PostProcessor(ProcessABC): ...@@ -153,94 +154,52 @@ class PostProcessor(ProcessABC):
# TODO Split function # TODO Split function
"""Important processes run before normal post-processing ones""" """Important processes run before normal post-processing ones"""
record = self._signal.get_raw(self.targets["prepost"]["merger"]) record = self._signal.get_raw(self.targets["prepost"]["merger"])
merge_events = self.merger.run(record) merges = np.array(self.merger.run(record), dtype=int)
self._writer.write( self._writer.write(
"modifiers/merges", data=[np.array(x) for x in merge_events] "modifiers/merges", data=[np.array(x) for x in merges]
) )
lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters) lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters)
with h5py.File(self._filename, "a") as f:
merge_events = f["modifiers/merges"][()]
multii = pd.MultiIndex(
[[], [], []],
[[], [], []],
names=["trap", "mother_label", "daughter_label"],
)
self.lineage_merged = multii
indices = get_index_as_np(record) indices = get_index_as_np(record)
if merge_events.any(): # Update lineages after merge events if merges.any(): # Update lineages after merge events
# We validate merges that associate existing mothers and daughters
valid_merges, valid_indices = validate_association(merges, indices)
grouped_merges = group_merges(merges) grouped_merges = group_merges(merges)
# Sumarise the merges linking the first and final id
# Shape (X,2,2)
summarised = np.array(
[(x[0][0], x[-1][1]) for x in grouped_merges]
)
# List the indices that weill be deleted, as they are in-between
# Shape (Y,2)
to_delete = np.vstack(
[
x.reshape(-1, x.shape[-1])[1:-1]
for x in grouped_merges
if len(x) > 1
]
)
flat_indices = lineage.reshape(-1, 2) flat_indices = lineage.reshape(-1, 2)
valid_merges, valid_indices = validate_association( comparison_mat = compare_indices(merges[:, 0], flat_indices)
summarised, flat_indices
) valid_indices = comparison_mat.any(axis=0)
# Replace
id_eq_matrix = compare_indices(flat_indices, to_delete) replacement_d = {}
for dataset in grouped_merges:
# Update labels of merged tracklets for k in dataset:
flat_indices[valid_indices] = summarised[valid_merges, 1] replacement_d[tuple(k[0])] = dataset[-1][1]
# Remove labels that will be removed when merging flat_indices[valid_indices] = [
flat_indices = flat_indices[id_eq_matrix.any(axis=1)] replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
lineage_merged = flat_indices.reshape(-1, 2)
self.lineage_merged = pd.MultiIndex.from_arrays(
np.unique(
np.append(
trap_mother,
trap_daughter[:, 1].reshape(-1, 1),
axis=1,
),
axis=0,
).T,
names=["trap", "mother_label", "daughter_label"],
)
# Remove after implementing outside # Remove repeated labels post-merging
# self._writer.write( lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0)
# "modifiers/picks",
# data=pd.MultiIndex.from_arrays( self._writer.write("modifiers/lineage_merged", lineage_merged)
# # TODO Check if multiindices are still repeated
# np.unique(indices, axis=0).T if indices.any() else [[], []], picked_indices = self.picker.run(
# names=["trap", "cell_label"], self._signal[self.targets["prepost"]["picker"][0]]
# ), )
# overwrite="overwrite", self._writer.write(
# ) "modifiers/picks",
data=pd.MultiIndex.from_arrays(
# combined_idx = ([], [], []) # TODO Check if multiindices are still repeated
np.unique(picked_indices, axis=0).T
# multii = pd.MultiIndex.from_arrays( if indices.any()
# combined_idx, else [[], []],
# names=["trap", "mother_label", "daughter_label"], names=["trap", "cell_label"],
# ) ),
# self._writer.write( overwrite="overwrite",
# "postprocessing/lineage", )
# data=multii,
# # TODO check if overwrite is still needed
# overwrite="overwrite",
# )
@staticmethod @staticmethod
def pick_mother(a, b): def pick_mother(a, b):
...@@ -373,8 +332,26 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: ...@@ -373,8 +332,26 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1) is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge is_monomerge = ~is_multimerge
multimerge_subsets = union_find(list(zip(*np.where(sources_targets)))) multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges
return [ return [
*[merges[np.array(tuple(x))] for x in multimerge_subsets], *sorted_merges,
*[[event] for event in merges[is_monomerge]], *[[event] for event in merges[is_monomerge]],
] ]
def sort_association(array: np.ndarray):
# Sort the internal associations
order = np.where(
(array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1)
)
res = []
[res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)]
return sorted_array
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment