Skip to content
Snippets Groups Projects
Commit df58c878 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(postpro): major changes to internals

parent cc1fd0c0
No related branches found
No related tags found
No related merge requests found
......@@ -14,6 +14,7 @@ from agora.io.writer import Writer
from agora.utils.indexing import (
_assoc_indices_to_3d,
validate_association,
compare_indices,
)
from agora.utils.kymograph import get_index_as_np
from postprocessor.core.abc import get_parameters, get_process
......@@ -153,94 +154,52 @@ class PostProcessor(ProcessABC):
# TODO Split function
"""Important processes run before normal post-processing ones"""
record = self._signal.get_raw(self.targets["prepost"]["merger"])
merge_events = self.merger.run(record)
merges = np.array(self.merger.run(record), dtype=int)
self._writer.write(
"modifiers/merges", data=[np.array(x) for x in merge_events]
"modifiers/merges", data=[np.array(x) for x in merges]
)
lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters)
with h5py.File(self._filename, "a") as f:
merge_events = f["modifiers/merges"][()]
multii = pd.MultiIndex(
[[], [], []],
[[], [], []],
names=["trap", "mother_label", "daughter_label"],
)
self.lineage_merged = multii
indices = get_index_as_np(record)
if merge_events.any(): # Update lineages after merge events
# We validate merges that associate existing mothers and daughters
valid_merges, valid_indices = validate_association(merges, indices)
if merges.any(): # Update lineages after merge events
grouped_merges = group_merges(merges)
# Sumarise the merges linking the first and final id
# Shape (X,2,2)
summarised = np.array(
[(x[0][0], x[-1][1]) for x in grouped_merges]
)
# List the indices that weill be deleted, as they are in-between
# Shape (Y,2)
to_delete = np.vstack(
[
x.reshape(-1, x.shape[-1])[1:-1]
for x in grouped_merges
if len(x) > 1
]
)
flat_indices = lineage.reshape(-1, 2)
valid_merges, valid_indices = validate_association(
summarised, flat_indices
)
# Replace
id_eq_matrix = compare_indices(flat_indices, to_delete)
# Update labels of merged tracklets
flat_indices[valid_indices] = summarised[valid_merges, 1]
# Remove labels that will be removed when merging
flat_indices = flat_indices[id_eq_matrix.any(axis=1)]
lineage_merged = flat_indices.reshape(-1, 2)
self.lineage_merged = pd.MultiIndex.from_arrays(
np.unique(
np.append(
trap_mother,
trap_daughter[:, 1].reshape(-1, 1),
axis=1,
),
axis=0,
).T,
names=["trap", "mother_label", "daughter_label"],
)
comparison_mat = compare_indices(merges[:, 0], flat_indices)
valid_indices = comparison_mat.any(axis=0)
replacement_d = {}
for dataset in grouped_merges:
for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices]
]
# Remove after implementing outside
# self._writer.write(
# "modifiers/picks",
# data=pd.MultiIndex.from_arrays(
# # TODO Check if multiindices are still repeated
# np.unique(indices, axis=0).T if indices.any() else [[], []],
# names=["trap", "cell_label"],
# ),
# overwrite="overwrite",
# )
# combined_idx = ([], [], [])
# multii = pd.MultiIndex.from_arrays(
# combined_idx,
# names=["trap", "mother_label", "daughter_label"],
# )
# self._writer.write(
# "postprocessing/lineage",
# data=multii,
# # TODO check if overwrite is still needed
# overwrite="overwrite",
# )
# Remove repeated labels post-merging
lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0)
self._writer.write("modifiers/lineage_merged", lineage_merged)
picked_indices = self.picker.run(
self._signal[self.targets["prepost"]["picker"][0]]
)
self._writer.write(
"modifiers/picks",
data=pd.MultiIndex.from_arrays(
# TODO Check if multiindices are still repeated
np.unique(picked_indices, axis=0).T
if indices.any()
else [[], []],
names=["trap", "cell_label"],
),
overwrite="overwrite",
)
@staticmethod
def pick_mother(a, b):
......@@ -373,8 +332,26 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(list(zip(*np.where(sources_targets))))
multimerge_subsets = union_find(zip(*np.where(sources_targets)))
merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets]
sorted_merges = list(map(sort_association, merge_groups))
# Ensure that source and target are at the edges
return [
*[merges[np.array(tuple(x))] for x in multimerge_subsets],
*sorted_merges,
*[[event] for event in merges[is_monomerge]],
]
def sort_association(array: np.ndarray):
# Sort the internal associations
order = np.where(
(array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1)
)
res = []
[res.append(x) for x in np.flip(order).flatten() if x not in res]
sorted_array = array[np.array(res)]
return sorted_array
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment