diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index e6ff00db9f8428c08c3c47320d20e92bc947f9fb..9613832114a473dc129bbe336a5f7f9b1f9ce031 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -14,6 +14,7 @@ from agora.io.writer import Writer from agora.utils.indexing import ( _assoc_indices_to_3d, validate_association, + compare_indices, ) from agora.utils.kymograph import get_index_as_np from postprocessor.core.abc import get_parameters, get_process @@ -153,94 +154,52 @@ class PostProcessor(ProcessABC): # TODO Split function """Important processes run before normal post-processing ones""" record = self._signal.get_raw(self.targets["prepost"]["merger"]) - merge_events = self.merger.run(record) + merges = np.array(self.merger.run(record), dtype=int) self._writer.write( - "modifiers/merges", data=[np.array(x) for x in merge_events] + "modifiers/merges", data=[np.array(x) for x in merges] ) lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters) - with h5py.File(self._filename, "a") as f: - merge_events = f["modifiers/merges"][()] - multii = pd.MultiIndex( - [[], [], []], - [[], [], []], - names=["trap", "mother_label", "daughter_label"], - ) - self.lineage_merged = multii - indices = get_index_as_np(record) - if merge_events.any(): # Update lineages after merge events - # We validate merges that associate existing mothers and daughters - valid_merges, valid_indices = validate_association(merges, indices) + if merges.any(): # Update lineages after merge events grouped_merges = group_merges(merges) - # Sumarise the merges linking the first and final id - # Shape (X,2,2) - summarised = np.array( - [(x[0][0], x[-1][1]) for x in grouped_merges] - ) - # List the indices that weill be deleted, as they are in-between - # Shape (Y,2) - to_delete = np.vstack( - [ - x.reshape(-1, x.shape[-1])[1:-1] - for x in grouped_merges - if len(x) > 1 - ] - ) flat_indices = lineage.reshape(-1, 2) - valid_merges, valid_indices = validate_association( - summarised, flat_indices - ) - # Replace - id_eq_matrix = compare_indices(flat_indices, to_delete) - - # Update labels of merged tracklets - flat_indices[valid_indices] = summarised[valid_merges, 1] - - # Remove labels that will be removed when merging - flat_indices = flat_indices[id_eq_matrix.any(axis=1)] - - lineage_merged = flat_indices.reshape(-1, 2) - - self.lineage_merged = pd.MultiIndex.from_arrays( - np.unique( - np.append( - trap_mother, - trap_daughter[:, 1].reshape(-1, 1), - axis=1, - ), - axis=0, - ).T, - names=["trap", "mother_label", "daughter_label"], - ) + comparison_mat = compare_indices(merges[:, 0], flat_indices) + + valid_indices = comparison_mat.any(axis=0) + + replacement_d = {} + for dataset in grouped_merges: + for k in dataset: + replacement_d[tuple(k[0])] = dataset[-1][1] + + flat_indices[valid_indices] = [ + replacement_d[tuple(i)] for i in flat_indices[valid_indices] + ] - # Remove after implementing outside - # self._writer.write( - # "modifiers/picks", - # data=pd.MultiIndex.from_arrays( - # # TODO Check if multiindices are still repeated - # np.unique(indices, axis=0).T if indices.any() else [[], []], - # names=["trap", "cell_label"], - # ), - # overwrite="overwrite", - # ) - - # combined_idx = ([], [], []) - - # multii = pd.MultiIndex.from_arrays( - # combined_idx, - # names=["trap", "mother_label", "daughter_label"], - # ) - # self._writer.write( - # "postprocessing/lineage", - # data=multii, - # # TODO check if overwrite is still needed - # overwrite="overwrite", - # ) + # Remove repeated labels post-merging + lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0) + + self._writer.write("modifiers/lineage_merged", lineage_merged) + + picked_indices = self.picker.run( + self._signal[self.targets["prepost"]["picker"][0]] + ) + self._writer.write( + "modifiers/picks", + data=pd.MultiIndex.from_arrays( + # TODO Check if multiindices are still repeated + np.unique(picked_indices, axis=0).T + if indices.any() + else [[], []], + names=["trap", "cell_label"], + ), + overwrite="overwrite", + ) @staticmethod def pick_mother(a, b): @@ -373,8 +332,26 @@ def group_merges(merges: np.ndarray) -> t.List[t.Tuple]: is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1) is_monomerge = ~is_multimerge - multimerge_subsets = union_find(list(zip(*np.where(sources_targets)))) + multimerge_subsets = union_find(zip(*np.where(sources_targets))) + merge_groups = [merges[np.array(tuple(x))] for x in multimerge_subsets] + + sorted_merges = list(map(sort_association, merge_groups)) + + # Ensure that source and target are at the edges return [ - *[merges[np.array(tuple(x))] for x in multimerge_subsets], + *sorted_merges, *[[event] for event in merges[is_monomerge]], ] + + +def sort_association(array: np.ndarray): + # Sort the internal associations + + order = np.where( + (array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1) + ) + + res = [] + [res.append(x) for x in np.flip(order).flatten() if x not in res] + sorted_array = array[np.array(res)] + return sorted_array