Skip to content
Snippets Groups Projects
Commit 7917862b authored by Alán Muñoz's avatar Alán Muñoz
Browse files

[WIP]: refactor merging for test

parent 802396af
No related branches found
No related tags found
No related merge requests found
...@@ -11,7 +11,7 @@ import pandas as pd ...@@ -11,7 +11,7 @@ import pandas as pd
from agora.io.bridge import BridgeH5 from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df from agora.io.decorators import _first_arg_str_to_df
from agora.utils.association import validate_association from agora.utils.indexing import validate_association
from agora.utils.kymograph import add_index_levels from agora.utils.kymograph import add_index_levels
from agora.utils.merge import apply_merges from agora.utils.merge import apply_merges
...@@ -171,7 +171,7 @@ class Signal(BridgeH5): ...@@ -171,7 +171,7 @@ class Signal(BridgeH5):
""" """
if isinstance(merges, bool): if isinstance(merges, bool):
merges: np.ndarray = self.get_merges() if merges else np.array([]) merges: np.ndarray = self.load_merges() if merges else np.array([])
if merges.any(): if merges.any():
merged = apply_merges(data, merges) merged = apply_merges(data, merges)
else: else:
...@@ -292,7 +292,7 @@ class Signal(BridgeH5): ...@@ -292,7 +292,7 @@ class Signal(BridgeH5):
self._log(f"Could not fetch dataset {dataset}: {e}", "error") self._log(f"Could not fetch dataset {dataset}: {e}", "error")
raise e raise e
def get_merges(self): def load_merges(self):
"""Get merge events going up to the first level.""" """Get merge events going up to the first level."""
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
merges = f.get("modifiers/merges", np.array([])) merges = f.get("modifiers/merges", np.array([]))
......
#!/usr/bin/env jupyter #!/usr/bin/env jupyter
"""
Utilities based on association are used to efficiently acquire indices of tracklets with some kind of relationship.
This can be:
- Cells that are to be merged
- Cells that have a linear relationship
"""
import numpy as np
import typing as t
def validate_association(
association: np.ndarray,
indices: np.ndarray,
match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
"""Select rows from the first array that are present in both.
We use casting for fast multiindexing, generalising for lineage dynamics
Parameters
----------
association : np.ndarray
2-D array where columns are (trap, mother, daughter) or 3-D array where
dimensions are (X,trap,2), containing tuples ((trap,mother), (trap,daughter))
across the 3rd dimension.
indices : np.ndarray
2-D array where each column is a different level. This should not include mother_label.
match_column: int
int indicating a specific column is required to match (i.e.
0-1 for target-source when trying to merge tracklets or mother-bud for lineage)
must be present in indices. If it is false one match suffices for the resultant indices
vector to be True.
Returns
-------
np.ndarray
1-D boolean array indicating valid merge events.
np.ndarray
1-D boolean array indicating indices with an association relationship.
Examples
--------
>>> import numpy as np
>>> from agora.utils.association import validate_association
>>> merges = np.array(range(12)).reshape(3,2,2)
>>> indices = np.array(range(6)).reshape(3,2)
>>> print(merges, indices)
>>> print(merges); print(indices)
[[[ 0 1]
[ 2 3]]
[[ 4 5]
[ 6 7]]
[[ 8 9]
[10 11]]]
[[0 1]
[2 3]
[4 5]]
>>> valid_associations, valid_indices = validate_association(merges, indices)
>>> print(valid_associations, valid_indices)
[ True False False] [ True True False]
"""
if association.ndim == 2:
# Reshape into 3-D array for broadcasting if neded
# association = np.stack(
# (association[:, [0, 1]], association[:, [0, 2]]), axis=1
# )
association = last_col_as_rows(association)
# Compare existing association with available indices
# Swap trap and label axes for the association array to correctly cast
valid_ndassociation = association[..., None] == indices.T[None, ...]
# Broadcasting is confusing (but efficient):
# First we check the dimension across trap and cell id, to ensure both match
valid_cell_ids = valid_ndassociation.all(axis=2)
if match_column is None:
# Then we check the merge tuples to check which cases have both target and source
valid_association = valid_cell_ids.any(axis=2).all(axis=1)
# Finally we check the dimension that crosses all indices, to ensure the pair
# is present in a valid merge event.
valid_indices = (
valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1))
)
else: # We fetch specific indices if we aim for the ones with one present
valid_indices = valid_cell_ids[:, match_column].any(axis=0)
# Valid association then becomes a boolean array, true means that there is a
# match (match_column) between that cell and the index
valid_association = (
valid_cell_ids[:, match_column] & valid_indices
).any(axis=1)
return valid_association, valid_indices
def last_col_as_rows(ndarray: np.ndarray):
"""
Convert the last column to a new row while repeating all previous indices.
This is useful when converting a signal multiindex before comparing association.
"""
columns = np.arange(ndarray.shape[1])
return np.stack(
(
ndarray[:, np.delete(columns, -1)],
ndarray[:, np.delete(columns, -2)],
),
axis=1,
)
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import pandas as pd import pandas as pd
from utils_find_1st import cmp_larger, find_1st from utils_find_1st import cmp_larger, find_1st
from agora.utils.association import validate_association from agora.utils.indexing import validate_association
def apply_merges(data: pd.DataFrame, merges: np.ndarray): def apply_merges(data: pd.DataFrame, merges: np.ndarray):
......
import typing as t
from itertools import takewhile from itertools import takewhile
from typing import Dict, List, Union from typing import Dict, List, Union
...@@ -10,6 +11,11 @@ from agora.abc import ParametersABC, ProcessABC ...@@ -10,6 +11,11 @@ from agora.abc import ParametersABC, ProcessABC
from agora.io.cells import Cells from agora.io.cells import Cells
from agora.io.signal import Signal from agora.io.signal import Signal
from agora.io.writer import Writer from agora.io.writer import Writer
from agora.utils.indexing import (
_assoc_indices_to_3d,
validate_association,
)
from agora.utils.kymograph import get_index_as_np
from postprocessor.core.abc import get_parameters, get_process from postprocessor.core.abc import get_parameters, get_process
from postprocessor.core.lineageprocess import LineageProcessParameters from postprocessor.core.lineageprocess import LineageProcessParameters
from postprocessor.core.reshapers.merger import Merger, MergerParameters from postprocessor.core.reshapers.merger import Merger, MergerParameters
...@@ -146,53 +152,14 @@ class PostProcessor(ProcessABC): ...@@ -146,53 +152,14 @@ class PostProcessor(ProcessABC):
def run_prepost(self): def run_prepost(self):
# TODO Split function # TODO Split function
"""Important processes run before normal post-processing ones""" """Important processes run before normal post-processing ones"""
record = self._signal.get_raw(self.targets["prepost"]["merger"])
merge_events = self.merger.run(record)
merge_events = self.merger.run(
self._signal[self.targets["prepost"]["merger"]]
)
prev_idchanges = self._signal.get_merges()
changes_history = list(prev_idchanges) + [
np.array(x) for x in merge_events
]
self._writer.write("modifiers/merges", data=changes_history)
# TODO Remove this once test is wriiten for consecutive postprocesses
with h5py.File(self._filename, "a") as f:
if "modifiers/picks" in f:
del f["modifiers/picks"]
indices = self.picker.run(
self._signal[self.targets["prepost"]["picker"][0]]
)
combined_idx = ([], [], [])
trap, mother, daughter = combined_idx
lineage = self.picker.cells.mothers_daughters
if lineage.any():
trap, mother, daughter = lineage.T
combined_idx = np.vstack((trap, mother, daughter))
trap_mother = np.vstack((trap, mother)).T
trap_daughter = np.vstack((trap, daughter)).T
multii = pd.MultiIndex.from_arrays(
combined_idx,
names=["trap", "mother_label", "daughter_label"],
)
self._writer.write( self._writer.write(
"postprocessing/lineage", "modifiers/merges", data=[np.array(x) for x in merge_events]
data=multii,
overwrite="overwrite",
) )
# apply merge to mother-trap_daughter lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters)
moset = set([tuple(x) for x in trap_mother])
daset = set([tuple(x) for x in trap_daughter])
picked_set = set([tuple(x) for x in indices])
with h5py.File(self._filename, "a") as f: with h5py.File(self._filename, "a") as f:
merge_events = f["modifiers/merges"][()] merge_events = f["modifiers/merges"][()]
...@@ -203,31 +170,41 @@ class PostProcessor(ProcessABC): ...@@ -203,31 +170,41 @@ class PostProcessor(ProcessABC):
) )
self.lineage_merged = multii self.lineage_merged = multii
if merge_events.any(): indices = get_index_as_np(record)
if merge_events.any(): # Update lineages after merge events
# We validate merges that associate existing mothers and daughters
valid_merges, valid_indices = validate_association(merges, indices)
def search(a, b): grouped_merges = group_merges(merges)
return np.where( # Sumarise the merges linking the first and final id
np.in1d( # Shape (X,2,2)
np.ravel_multi_index(a.T, a.max(0) + 1), summarised = np.array(
np.ravel_multi_index(b.T, a.max(0) + 1), [(x[0][0], x[-1][1]) for x in grouped_merges]
) )
) # List the indices that weill be deleted, as they are in-between
# Shape (Y,2)
to_delete = np.vstack(
[
x.reshape(-1, x.shape[-1])[1:-1]
for x in grouped_merges
if len(x) > 1
]
)
for target, source in merge_events: flat_indices = lineage.reshape(-1, 2)
if ( valid_merges, valid_indices = validate_association(
tuple(source) in moset summarised, flat_indices
): # update mother to lowest positive index among the two )
mother_ids = search(trap_mother, source) # Replace
trap_mother[mother_ids] = ( id_eq_matrix = compare_indices(flat_indices, to_delete)
target[0],
self.pick_mother( # Update labels of merged tracklets
trap_mother[mother_ids][0][1], target[1] flat_indices[valid_indices] = summarised[valid_merges, 1]
),
) # Remove labels that will be removed when merging
if tuple(source) in daset: flat_indices = flat_indices[id_eq_matrix.any(axis=1)]
trap_daughter[search(trap_daughter, source)] = target
if tuple(source) in picked_set: lineage_merged = flat_indices.reshape(-1, 2)
indices[search(indices, source)] = target
self.lineage_merged = pd.MultiIndex.from_arrays( self.lineage_merged = pd.MultiIndex.from_arrays(
np.unique( np.unique(
...@@ -240,21 +217,30 @@ class PostProcessor(ProcessABC): ...@@ -240,21 +217,30 @@ class PostProcessor(ProcessABC):
).T, ).T,
names=["trap", "mother_label", "daughter_label"], names=["trap", "mother_label", "daughter_label"],
) )
self._writer.write(
"postprocessing/lineage_merged",
data=self.lineage_merged,
overwrite="overwrite",
)
self._writer.write( # Remove after implementing outside
"modifiers/picks", # self._writer.write(
data=pd.MultiIndex.from_arrays( # "modifiers/picks",
# TODO Check if multiindices are still repeated # data=pd.MultiIndex.from_arrays(
np.unique(indices, axis=0).T if indices.any() else [[], []], # # TODO Check if multiindices are still repeated
names=["trap", "cell_label"], # np.unique(indices, axis=0).T if indices.any() else [[], []],
), # names=["trap", "cell_label"],
overwrite="overwrite", # ),
) # overwrite="overwrite",
# )
# combined_idx = ([], [], [])
# multii = pd.MultiIndex.from_arrays(
# combined_idx,
# names=["trap", "mother_label", "daughter_label"],
# )
# self._writer.write(
# "postprocessing/lineage",
# data=multii,
# # TODO check if overwrite is still needed
# overwrite="overwrite",
# )
@staticmethod @staticmethod
def pick_mother(a, b): def pick_mother(a, b):
...@@ -357,3 +343,38 @@ class PostProcessor(ProcessABC): ...@@ -357,3 +343,38 @@ class PostProcessor(ProcessABC):
metadata: Dict, metadata: Dict,
): ):
self._writer.write(path, result, meta=metadata, overwrite="overwrite") self._writer.write(path, result, meta=metadata, overwrite="overwrite")
def union_find(lsts):
sets = [set(lst) for lst in lsts if lst]
merged = True
while merged:
merged = False
results = []
while sets:
common, rest = sets[0], sets[1:]
sets = []
for x in rest:
if x.isdisjoint(common):
sets.append(x)
else:
merged = True
common |= x
results.append(common)
sets = results
return sets
def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
# Return a list where the cell is present as source and target
# (multimerges)
sources_targets = compare_indices(merges[:, 0, :], merges[:, 1, :])
is_multimerge = sources_targets.any(axis=0) | sources_targets.any(axis=1)
is_monomerge = ~is_multimerge
multimerge_subsets = union_find(list(zip(*np.where(sources_targets))))
return [
*[merges[np.array(tuple(x))] for x in multimerge_subsets],
*[[event] for event in merges[is_monomerge]],
]
...@@ -5,7 +5,7 @@ import pandas as pd ...@@ -5,7 +5,7 @@ import pandas as pd
from agora.abc import ParametersABC from agora.abc import ParametersABC
from agora.io.cells import Cells from agora.io.cells import Cells
from agora.utils.association import validate_association from agora.utils.indexing import validate_association
from agora.utils.cast import _str_to_int from agora.utils.cast import _str_to_int
from agora.utils.kymograph import drop_mother_label from agora.utils.kymograph import drop_mother_label
from postprocessor.core.lineageprocess import LineageProcess from postprocessor.core.lineageprocess import LineageProcess
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment