Skip to content
Snippets Groups Projects
Commit 7917862b authored by Alán Muñoz's avatar Alán Muñoz
Browse files

[WIP]: refactor merging for test

parent 802396af
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,7 @@ import pandas as pd
from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df
from agora.utils.association import validate_association
from agora.utils.indexing import validate_association
from agora.utils.kymograph import add_index_levels
from agora.utils.merge import apply_merges
......@@ -171,7 +171,7 @@ class Signal(BridgeH5):
"""
if isinstance(merges, bool):
merges: np.ndarray = self.get_merges() if merges else np.array([])
merges: np.ndarray = self.load_merges() if merges else np.array([])
if merges.any():
merged = apply_merges(data, merges)
else:
......@@ -292,7 +292,7 @@ class Signal(BridgeH5):
self._log(f"Could not fetch dataset {dataset}: {e}", "error")
raise e
def get_merges(self):
def load_merges(self):
"""Get merge events going up to the first level."""
with h5py.File(self.filename, "r") as f:
merges = f.get("modifiers/merges", np.array([]))
......
#!/usr/bin/env jupyter
"""
Utilities based on association are used to efficiently acquire indices of tracklets with some kind of relationship.
This can be:
- Cells that are to be merged
- Cells that have a linear relationship
"""
import numpy as np
import typing as t
def validate_association(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""Select rows from the first array that are present in both.
We use casting for fast multiindexing, generalising for lineage dynamics
Parameters
----------
association : np.ndarray
2-D array where columns are (trap, mother, daughter) or 3-D array where
dimensions are (X,trap,2), containing tuples ((trap,mother), (trap,daughter))
across the 3rd dimension.
indices : np.ndarray
2-D array where each column is a different level. This should not include mother_label.
match_column: int
int indicating a specific column is required to match (i.e.
0-1 for target-source when trying to merge tracklets or mother-bud for lineage)
must be present in indices. If it is false one match suffices for the resultant indices
vector to be True.
Returns
-------
np.ndarray
1-D boolean array indicating valid merge events.
np.ndarray
1-D boolean array indicating indices with an association relationship.
Examples
--------
>>> import numpy as np
>>> from agora.utils.association import validate_association
>>> merges = np.array(range(12)).reshape(3,2,2)
>>> indices = np.array(range(6)).reshape(3,2)
>>> print(merges, indices)
>>> print(merges); print(indices)
[[[ 0 1]
[ 2 3]]
[[ 4 5]
[ 6 7]]
[[ 8 9]
[10 11]]]
[[0 1]
[2 3]
[4 5]]
>>> valid_associations, valid_indices = validate_association(merges, indices)
>>> print(valid_associations, valid_indices)
[ True False False] [ True True False]
"""
if association.ndim == 2:
# Reshape into 3-D array for broadcasting if neded
# association = np.stack(
# (association[:, [0, 1]], association[:, [0, 2]]), axis=1
# )
association = last_col_as_rows(association)
# Compare existing association with available indices
# Swap trap and label axes for the association array to correctly cast
valid_ndassociation = association[..., None] == indices.T[None, ...]
# Broadcasting is confusing (but efficient):
# First we check the dimension across trap and cell id, to ensure both match
valid_cell_ids = valid_ndassociation.all(axis=2)
if match_column is None:
# Then we check the merge tuples to check which cases have both target and source
valid_association = valid_cell_ids.any(axis=2).all(axis=1)
# Finally we check the dimension that crosses all indices, to ensure the pair
# is present in a valid merge event.
valid_indices = (
valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1))
)
else: # We fetch specific indices if we aim for the ones with one present
valid_indices = valid_cell_ids[:, match_column].any(axis=0)
# Valid association then becomes a boolean array, true means that there is a
# match (match_column) between that cell and the index
valid_association = (
valid_cell_ids[:, match_column] & valid_indices
).any(axis=1)
return valid_association, valid_indices
def last_col_as_rows(ndarray: np.ndarray):
"""
Convert the last column to a new row while repeating all previous indices.
This is useful when converting a signal multiindex before comparing association.
"""
columns = np.arange(ndarray.shape[1])
return np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
......@@ -9,7 +9,7 @@ import numpy as np
import pandas as pd
from utils_find_1st import cmp_larger, find_1st
from agora.utils.association import validate_association
from agora.utils.indexing import validate_association
def apply_merges(data: pd.DataFrame, merges: np.ndarray):
......
import typing as t
from itertools import takewhile
from typing import Dict, List, Union
......@@ -10,6 +11,11 @@ from agora.abc import ParametersABC, ProcessABC
from agora.io.cells import Cells
from agora.io.signal import Signal
from agora.io.writer import Writer
from agora.utils.indexing import (
_assoc_indices_to_3d,
validate_association,
)
from agora.utils.kymograph import get_index_as_np
from postprocessor.core.abc import get_parameters, get_process
from postprocessor.core.lineageprocess import LineageProcessParameters
from postprocessor.core.reshapers.merger import Merger, MergerParameters
......@@ -146,53 +152,14 @@ class PostProcessor(ProcessABC):
def run_prepost(self):
# TODO Split function
"""Important processes run before normal post-processing ones"""
record = self._signal.get_raw(self.targets["prepost"]["merger"])
merge_events = self.merger.run(record)
merge_events = self.merger.run(
self._signal[self.targets["prepost"]["merger"]]
)
prev_idchanges = self._signal.get_merges()
changes_history = list(prev_idchanges) + [
np.array(x) for x in merge_events
]
self._writer.write("modifiers/merges", data=changes_history)
# TODO Remove this once test is wriiten for consecutive postprocesses
with h5py.File(self._filename, "a") as f:
if "modifiers/picks" in f:
del f["modifiers/picks"]
indices = self.picker.run(
self._signal[self.targets["prepost"]["picker"][0]]
)
combined_idx = ([], [], [])
trap, mother, daughter = combined_idx
lineage = self.picker.cells.mothers_daughters
if lineage.any():
trap, mother, daughter = lineage.T
combined_idx = np.vstack((trap, mother, daughter))
trap_mother = np.vstack((trap, mother)).T
trap_daughter = np.vstack((trap, daughter)).T
multii = pd.MultiIndex.from_arrays(
combined_idx,
names=["trap", "mother_label", "daughter_label"],
)
self._writer.write(
"postprocessing/lineage",
data=multii,
overwrite="overwrite",
"modifiers/merges", data=[np.array(x) for x in merge_events]
)
# apply merge to mother-trap_daughter
moset = set([tuple(x) for x in trap_mother])
daset = set([tuple(x) for x in trap_daughter])
picked_set = set([tuple(x) for x in indices])
lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters)
with h5py.File(self._filename, "a") as f:
merge_events = f["modifiers/merges"][()]
......@@ -203,31 +170,41 @@ class PostProcessor(ProcessABC):
)
self.lineage_merged = multii
if merge_events.any():
indices = get_index_as_np(record)
if merge_events.any(): # Update lineages after merge events
# We validate merges that associate existing mothers and daughters
valid_merges, valid_indices = validate_association(merges, indices)
def search(a, b):
return np.where(
np.in1d(
np.ravel_multi_index(a.T, a.max(0) + 1),
np.ravel_multi_index(b.T, a.max(0) + 1),
)
)
grouped_merges = group_merges(merges)
# Sumarise the merges linking the first and final id
# Shape (X,2,2)
summarised = np.array(
[(x[0][0], x[-1][1]) for x in grouped_merges]
)
# List the indices that weill be deleted, as they are in-between
# Shape (Y,2)
to_delete = np.vstack(
[
x.reshape(-1, x.shape[-1])[1:-1]
for x in grouped_merges
if len(x) > 1
]
)
for target, source in merge_events:
if (
tuple(source) in moset
): # update mother to lowest positive index among the two
mother_ids = search(trap_mother, source)
trap_mother[mother_ids] = (
target[0],
self.pick_mother(
trap_mother[mother_ids][0][1], target[1]
),
)
if tuple(source) in daset:
trap_daughter[search(trap_daughter, source)] = target
if tuple(source) in picked_set:
indices[search(indices, source)] = target
flat_indices = lineage.reshape(-1, 2)
valid_merges, valid_indices = validate_association(
summarised, flat_indices
)
# Replace
id_eq_matrix = compare_indices(flat_indices, to_delete)
# Update labels of merged tracklets
flat_indices[valid_indices] = summarised[valid_merges, 1]
# Remove labels that will be removed when merging
flat_indices = flat_indices[id_eq_matrix.any(axis=1)]
lineage_merged = flat_indices.reshape(-1, 2)
self.lineage_merged = pd.MultiIndex.from_arrays(
np.unique(
......@@ -240,21 +217,30 @@ class PostProcessor(ProcessABC):
).T,
names=["trap", "mother_label", "daughter_label"],
)
self._writer.write(
"postprocessing/lineage_merged",
data=self.lineage_merged,
overwrite="overwrite",
)
self._writer.write(
"modifiers/picks",
data=pd.MultiIndex.from_arrays(
# TODO Check if multiindices are still repeated
np.unique(indices, axis=0).T if indices.any() else [[], []],
names=["trap", "cell_label"],
),
overwrite="overwrite",
)
# Remove after implementing outside
# self._writer.write(
# "modifiers/picks",
# data=pd.MultiIndex.from_arrays(
# # TODO Check if multiindices are still repeated
# np.unique(indices, axis=0).T if indices.any() else [[], []],
# names=["trap", "cell_label"],
# ),
# overwrite="overwrite",
# )
# combined_idx = ([], [], [])
# multii = pd.MultiIndex.from_arrays(
# combined_idx,
# names=["trap", "mother_label", "daughter_label"],
# )
# self._writer.write(
# "postprocessing/lineage",
# data=multii,
# # TODO check if overwrite is still needed
# overwrite="overwrite",
# )
@staticmethod
def pick_mother(a, b):
......@@ -357,3 +343,38 @@ class PostProcessor(ProcessABC):
metadata: Dict,
):
self._writer.write(path, result, meta=metadata, overwrite="overwrite")
def union_find(lsts):
sets = [set(lst) for lst in lsts if lst]
merged = True
while merged:
merged = False
results = []
while sets:
common, rest = sets[0], sets[1:]
sets = []
for x in rest:
if x.isdisjoint(common):
sets.append(x)
else:
merged = True
common |= x
results.append(common)
sets = results
return sets
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(list(zip(*np.where(sources_targets))))
return [
*[merges[np.array(tuple(x))] for x in multimerge_subsets],
*[[event] for event in merges[is_monomerge]],
]
......@@ -5,7 +5,7 @@ import pandas as pd
from agora.abc import ParametersABC
from agora.io.cells import Cells
from agora.utils.association import validate_association
from agora.utils.indexing import validate_association
from agora.utils.cast import _str_to_int
from agora.utils.kymograph import drop_mother_label
from postprocessor.core.lineageprocess import LineageProcess
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment