Skip to content
Snippets Groups Projects
Commit e4e31e1d authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(merge): isolate validate_association

parent c6fdaaa7
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,8 @@ import numpy as np ...@@ -9,6 +9,8 @@ import numpy as np
import pandas as pd import pandas as pd
from utils_find_1st import cmp_larger, find_1st from utils_find_1st import cmp_larger, find_1st
from agora.utils.association import validate_association
def apply_merges(data: pd.DataFrame, merges: np.ndarray): def apply_merges(data: pd.DataFrame, merges: np.ndarray):
"""Split data in two, one subset for rows relevant for merging and one """Split data in two, one subset for rows relevant for merging and one
...@@ -29,7 +31,9 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): ...@@ -29,7 +31,9 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
""" """
valid_merges, indices = validate_merges(merges, np.array(list(data.index))) valid_merges, indices = validate_association(
merges, np.array(list(data.index))
)
# Assign non-merged # Assign non-merged
merged = data.loc[~indices] merged = data.loc[~indices]
...@@ -49,59 +53,6 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): ...@@ -49,59 +53,6 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
return merged return merged
def validate_merges(
merges: np.ndarray, indices: np.ndarray
) -> t.Tuple[np.ndarray, np.ndarray]:
"""Select rows from the first array that are present in both.
We use casting for fast multiindexing.
Parameters
----------
merges : np.ndarray
2-D array where columns are (trap, mother, daughter) or 3-D array where
dimensions are (X, (trap,mother), (trap,daughter))
indices : np.ndarray
2-D array where each column is a different level.
Returns
-------
np.ndarray
1-D boolean array indicating valid merge events.
np.ndarray
1-D boolean array indicating indices involved in merging.
Examples
--------
FIXME: Add docs.
"""
if merges.ndim < 3:
# Reshape into 3-D array for broadcasting if neded
merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1)
# Compare existing merges with available indices
# Swap trap and label axes for the merges array to correctly cast
# valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :]
valid_ndmerges = merges[..., None] == indices.T[None, ...]
# Broadcasting is confusing (but efficient):
# First we check the dimension across trap and cell id, to ensure both match
valid_cell_ids = valid_ndmerges.all(axis=2)
# Then we check the merge tuples to check which cases have both target and source
valid_merges = valid_cell_ids.any(axis=2).all(axis=1)
# Finalle we check the dimension that crosses all indices, to ensure the pair
# is present in a valid merge event.
valid_indices = valid_ndmerges[valid_merges].all(axis=2).any(axis=(0, 1))
return valid_merges, valid_indices
def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
""" """
Join two tracks and return the new value of the target. Join two tracks and return the new value of the target.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment