Skip to content
Snippets Groups Projects
Commit d29beecc authored by Alán Muñoz's avatar Alán Muñoz
Browse files

fix(extraction): merge validation works

parent f1aba90c
No related branches found
No related tags found
No related merge requests found
......@@ -129,24 +129,23 @@ class Signal(BridgeH5):
merges = self.get_merges()
df = self.get_raw(dataset)
merged = copy(df)
if merges.any():
# Split in two dfs, one with rows relevant for merging and one
# without them
valid_merges = merges[
(
merges[:, :, :, None]
== np.array(list(df.index)).T[:, None, :]
)
.all(axis=(1, 2))
.any(axis=1)
] # Casting allows fast multiindexing
valid_merges = validate_merges(merges, np.array(list(df.index)))
# TODO use the same info from validate_merges to select both
valid_indices = [
tuple(x)
for x in (np.unique(valid_merges.reshape(-1, 2), axis=0))
]
merged = self.apply_merge(
df.loc[map(tuple, valid_merges.reshape(-1, 2))],
df.loc[valid_indices],
valid_merges,
)
nonmergeable_ids = df.index.difference(valid_merges.reshape(-1, 2))
nonmergeable_ids = df.index.difference(valid_indices)
merged = pd.concat(
(merged, df.loc[nonmergeable_ids]), names=df.index.names
......@@ -339,7 +338,8 @@ class Signal(BridgeH5):
@staticmethod
def join_tracks_pair(target: pd.Series, source: pd.Series):
"""
Join two tracks
Join two tracks and return the new value of the target.
TODO replace this with arrays only.
"""
tgt_copy = copy(target)
end = find_1st(target.values[::-1], 0, cmp_larger)
......@@ -388,3 +388,40 @@ class Signal(BridgeH5):
if end <= self.max_span
]
return tuple((stage, ntps) for stage, ntps in zip(self.stages, spans))
def validate_merges(merges: np.ndarray, indices: np.ndarray) -> np.ndarray:
"""Select rows from the first array that are present in both.
We use casting for fast multiindexing
Parameters
----------
merges : np.ndarray
2-D array where columns are (trap, mother, daughter) or 3-D array where
dimensions are (X, (trap,mother), (trap,daughter))
indices : np.ndarray
2-D array where each column is a different level.
Returns
-------
np.ndarray
3-D array with elements in both arrays.
Examples
--------
FIXME: Add docs.
"""
if merges.ndim < 3:
# Reshape into 3-D array for casting if neded
merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1)
# Compare existing merges with available indices
# Swap trap and label axes for the merges array to correctly cast
# valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :]
valid_ndmerges = merges[..., None] == indices.T[None, ...]
valid_merges = merges[valid_ndmerges.all(axis=2).any(axis=2).any(axis=1)]
# valid_merges = merges[allnan.any(axis=1)]
return valid_merges
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment