Skip to content
Snippets Groups Projects
Commit d99e55a3 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

tweak(agora): isolate merge as functions

parent e318491f
No related branches found
No related tags found
No related merge requests found
import typing as t import typing as t
from copy import copy from copy import copy
from functools import cached_property, lru_cache
from pathlib import PosixPath from pathlib import PosixPath
import bottleneck as bn import bottleneck as bn
import h5py import h5py
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from utils_find_1st import cmp_larger, find_1st
from agora.io.bridge import BridgeH5 from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_df from agora.io.decorators import _first_arg_str_to_df
from agora.utils.merge import apply_merges
class Signal(BridgeH5): class Signal(BridgeH5):
...@@ -34,10 +35,6 @@ class Signal(BridgeH5): ...@@ -34,10 +35,6 @@ class Signal(BridgeH5):
def __getitem__(self, dsets: t.Union[str, t.Collection]): def __getitem__(self, dsets: t.Union[str, t.Collection]):
assert isinstance(
dsets, (str, t.Collection)
), "Incorrect type for dset"
if isinstance(dsets, str) and dsets.endswith("imBackground"): if isinstance(dsets, str) and dsets.endswith("imBackground"):
df = self.get_raw(dsets) df = self.get_raw(dsets)
...@@ -52,6 +49,8 @@ class Signal(BridgeH5): ...@@ -52,6 +49,8 @@ class Signal(BridgeH5):
return [ return [
self.add_name(self.apply_prepost(dset), dset) for dset in dsets self.add_name(self.apply_prepost(dset), dset) for dset in dsets
] ]
else:
raise Exception(f"Invalid type {type(dsets)} to get datasets")
# return self.cols_in_mins(self.add_name(df, dsets)) # return self.cols_in_mins(self.add_name(df, dsets))
return self.add_name(df, dsets) return self.add_name(df, dsets)
...@@ -74,12 +73,12 @@ class Signal(BridgeH5): ...@@ -74,12 +73,12 @@ class Signal(BridgeH5):
) )
return df return df
@property @cached_property
def ntimepoints(self): def ntimepoints(self):
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
return f["extraction/general/None/area/timepoint"][-1] + 1 return f["extraction/general/None/area/timepoint"][-1] + 1
@property @cached_property
def tinterval(self) -> int: def tinterval(self) -> int:
tinterval_location = "time_settings/timeinterval" tinterval_location = "time_settings/timeinterval"
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
...@@ -89,6 +88,7 @@ class Signal(BridgeH5): ...@@ -89,6 +88,7 @@ class Signal(BridgeH5):
def get_retained(df, cutoff): def get_retained(df, cutoff):
return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff] return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]
@lru_cache(30)
def retained(self, signal, cutoff=0.8): def retained(self, signal, cutoff=0.8):
df = self[signal] df = self[signal]
...@@ -98,6 +98,7 @@ class Signal(BridgeH5): ...@@ -98,6 +98,7 @@ class Signal(BridgeH5):
elif isinstance(df, list): elif isinstance(df, list):
return [self.get_retained(d, cutoff=cutoff) for d in df] return [self.get_retained(d, cutoff=cutoff) for d in df]
@lru_cache(2)
def lineage( def lineage(
self, lineage_location: t.Optional[str] = None, merged: bool = False self, lineage_location: t.Optional[str] = None, merged: bool = False
) -> np.ndarray: ) -> np.ndarray:
...@@ -127,40 +128,48 @@ class Signal(BridgeH5): ...@@ -127,40 +128,48 @@ class Signal(BridgeH5):
def apply_prepost( def apply_prepost(
self, self,
data: t.Union[str, pd.DataFrame], data: t.Union[str, pd.DataFrame],
merges: np.ndarray = None, merges: t.Union[np.ndarray, bool] = True,
picks: t.Optional[bool] = None, picks: t.Union[t.Collection, bool] = True,
): ):
"""Apply modifier operations (picker, merger) to a given dataframe.
Parameters
----------
data : t.Union[str, pd.DataFrame]
DataFrame or url to one.
merges : t.Union[np.ndarray, bool]
(optional) 2-D array with three columns and variable length. The
first column is the trap id, second is mother label and third one is
daughter id.
If it is True it fetches merges from file, if false it skips merging step.
picks : t.Union[np.ndarray, bool]
(optional) 2-D ndarray where first column is traps and second column
is cell labels.
If it is True it fetches picks from file, if false it skips picking step.
Examples
--------
FIXME: Add docs.
""" """
Apply modifier operations (picker, merger) to a given dataframe. if isinstance(merges, bool):
""" merges: np.ndarray = self.get_merges() if merges else np.array([])
if merges is None:
merges = self.get_merges()
merged = copy(data) merged = copy(data)
if merges.any(): if merges.any():
# Split in two dfs, one with rows relevant for merging and one merged = apply_merges(data, merges)
# without them
valid_merges = validate_merges(merges, np.array(list(data.index)))
# TODO use the same info from validate_merges to select both
valid_indices = [
tuple(x)
for x in (np.unique(valid_merges.reshape(-1, 2), axis=0))
]
merged = self.apply_merge(
data.loc[valid_indices],
valid_merges,
)
nonmergeable_ids = data.index.difference(valid_indices)
merged = pd.concat( if isinstance(picks, bool):
(merged, data.loc[nonmergeable_ids]), names=data.index.names picks = (
self.get_picks(names=merged.index.names)
if picks
else set(merged.index)
) )
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
if "modifiers/picks" in f and not picks: if "modifiers/picks" in f and picks:
picks = self.get_picks(names=merged.index.names)
# missing_cells = [i for i in picks if tuple(i) not in # missing_cells = [i for i in picks if tuple(i) not in
# set(merged.index)] # set(merged.index)]
...@@ -184,7 +193,7 @@ class Signal(BridgeH5): ...@@ -184,7 +193,7 @@ class Signal(BridgeH5):
merged = pd.DataFrame([], index=index) merged = pd.DataFrame([], index=index)
return merged return merged
@property @cached_property
def datasets(self): def datasets(self):
if not hasattr(self, "_siglist"): if not hasattr(self, "_siglist"):
self._siglist = [] self._siglist = []
...@@ -195,12 +204,12 @@ class Signal(BridgeH5): ...@@ -195,12 +204,12 @@ class Signal(BridgeH5):
for sig in self.siglist: for sig in self.siglist:
print(sig) print(sig)
@property @cached_property
def p_siglist(self): def p_siglist(self):
"""Print signal list""" """Print signal list"""
self.datasets self.datasets
@property @cached_property
def siglist(self): def siglist(self):
"""Return list of signals""" """Return list of signals"""
try: try:
...@@ -215,34 +224,24 @@ class Signal(BridgeH5): ...@@ -215,34 +224,24 @@ class Signal(BridgeH5):
return self._siglist return self._siglist
def get_merged(self, dataset): def get_merged(self, dataset):
return self.apply_prepost(dataset, skip_pick=True) return self.apply_prepost(dataset, skip_picks=True)
@property @cached_property
def merges(self): def merges(self):
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
dsets = f.visititems(self._if_merges) dsets = f.visititems(self._if_merges)
return dsets return dsets
@property @cached_property
def n_merges(self): def n_merges(self):
return len(self.merges) return len(self.merges)
@property @cached_property
def picks(self): def picks(self):
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
dsets = f.visititems(self._if_picks) dsets = f.visititems(self._if_picks)
return dsets return dsets
def apply_merge(self, df, changes):
if len(changes):
for target, source in changes:
df.loc[tuple(target)] = self.join_tracks_pair(
df.loc[tuple(target)], df.loc[tuple(source)]
)
df.drop(tuple(source), inplace=True)
return df
def get_raw(self, dataset: str, in_minutes: bool = True): def get_raw(self, dataset: str, in_minutes: bool = True):
try: try:
if isinstance(dataset, str): if isinstance(dataset, str):
...@@ -266,14 +265,20 @@ class Signal(BridgeH5): ...@@ -266,14 +265,20 @@ class Signal(BridgeH5):
return merges return merges
# def get_picks(self, levels): def get_picks(
def get_picks(self, names, path="modifiers/picks/"): self,
names: t.Tuple[str, ...] = ("trap", "cell_label"),
path: str = "modifiers/picks/",
) -> t.Set[t.Tuple[int, str]]:
"""
Return the relevant picks based on names
"""
with h5py.File(self.filename, "r") as f: with h5py.File(self.filename, "r") as f:
picks = set()
if path in f: if path in f:
return list(zip(*[f[path + name] for name in names])) picks = set(zip(*[f[path + name] for name in names]))
# return f["modifiers/picks"]
else: return picks
return None
def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
""" """
...@@ -322,7 +327,7 @@ class Signal(BridgeH5): ...@@ -322,7 +327,7 @@ class Signal(BridgeH5):
# columns=f[path + "/timepoint"][()], # columns=f[path + "/timepoint"][()],
# ) # )
def get_siglist(self, name: str, node): def get_siglist(self, node):
fullname = node.name fullname = node.name
if isinstance(node, h5py.Group) and np.all( if isinstance(node, h5py.Group) and np.all(
[isinstance(x, h5py.Dataset) for x in node.values()] [isinstance(x, h5py.Dataset) for x in node.values()]
...@@ -348,17 +353,6 @@ class Signal(BridgeH5): ...@@ -348,17 +353,6 @@ class Signal(BridgeH5):
if isinstance(obj, h5py.Group) and name.endswith("picks"): if isinstance(obj, h5py.Group) and name.endswith("picks"):
return obj[()] return obj[()]
@staticmethod
def join_tracks_pair(target: pd.Series, source: pd.Series):
"""
Join two tracks and return the new value of the target.
TODO replace this with arrays only.
"""
tgt_copy = copy(target)
end = find_1st(target.values[::-1], 0, cmp_larger)
tgt_copy.iloc[-end:] = source.iloc[-end:].values
return tgt_copy
# TODO FUTURE add stages support to fluigent system # TODO FUTURE add stages support to fluigent system
@property @property
def ntps(self) -> int: def ntps(self) -> int:
...@@ -401,44 +395,3 @@ class Signal(BridgeH5): ...@@ -401,44 +395,3 @@ class Signal(BridgeH5):
if end <= self.max_span if end <= self.max_span
] ]
return tuple((stage, ntps) for stage, ntps in zip(self.stages, spans)) return tuple((stage, ntps) for stage, ntps in zip(self.stages, spans))
def validate_merges(merges: np.ndarray, indices: np.ndarray) -> np.ndarray:
"""Select rows from the first array that are present in both.
We use casting for fast multiindexing
Parameters
----------
merges : np.ndarray
2-D array where columns are (trap, mother, daughter) or 3-D array where
dimensions are (X, (trap,mother), (trap,daughter))
indices : np.ndarray
2-D array where each column is a different level.
Returns
-------
np.ndarray
3-D array with elements in both arrays.
Examples
--------
FIXME: Add docs.
"""
if merges.ndim < 3:
# Reshape into 3-D array for casting if neded
merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1)
# Compare existing merges with available indices
# Swap trap and label axes for the merges array to correctly cast
# valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :]
valid_ndmerges = merges[..., None] == indices.T[None, ...]
# Casting is confusing (but efficient):
# - First we check the dimension across trap and cell id, to ensure both match
# - Then we check the dimension that crosses all indices, to ensure the pair is present there
# - Finally we check the merge tuples to check which cases have both target and source
valid_merges = merges[valid_ndmerges.all(axis=2).any(axis=2).all(axis=1)]
# valid_merges = merges[allnan.any(axis=1)]
return valid_merges
#!/usr/bin/env jupyter
"""
Functions to efficiently merge rows in DataFrames.
"""
import typing as t
from copy import copy
import numpy as np
import pandas as pd
from utils_find_1st import cmp_larger, find_1st
def apply_merges(data: pd.DataFrame, merges: np.ndarray):
"""Split data in two, one subset for rows relevant for merging and one
without them. It uses an array of source tracklets and target tracklets
to efficiently merge them.
Parameters
----------
data : pd.DataFrame
Input DataFrame.
merges : np.ndarray
3-D ndarray where dimensions are (X,2,2): nmerges, source-target
pair and single-cell identifiers, respectively.
Examples
--------
FIXME: Add docs.
"""
valid_merges, indices = validate_merges(merges, np.array(list(data.index)))
# Assign non-merged
merged = data.loc[~indices]
# Implement the merges and drop source rows.
if valid_merges.any():
to_merge = data.loc[indices]
for target, source in merges[valid_merges]:
target, source = tuple(target), tuple(source)
to_merge.loc[target] = join_tracks_pair(
to_merge.loc[target].values,
to_merge.loc[source].values,
)
to_merge.drop(source, inplace=True)
merged = pd.concat((merged, to_merge), names=data.index.names)
return merged
def validate_merges(
merges: np.ndarray, indices: np.ndarray
) -> t.Tuple[np.ndarray, np.ndarray]:
"""Select rows from the first array that are present in both.
We use casting for fast multiindexing.
Parameters
----------
merges : np.ndarray
2-D array where columns are (trap, mother, daughter) or 3-D array where
dimensions are (X, (trap,mother), (trap,daughter))
indices : np.ndarray
2-D array where each column is a different level.
Returns
-------
np.ndarray
1-D boolean array indicating valid merge events.
np.ndarray
1-D boolean array indicating indices involved in merging.
Examples
--------
FIXME: Add docs.
"""
if merges.ndim < 3:
# Reshape into 3-D array for broadcasting if neded
merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1)
# Compare existing merges with available indices
# Swap trap and label axes for the merges array to correctly cast
# valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :]
valid_ndmerges = merges[..., None] == indices.T[None, ...]
# Broadcasting is confusing (but efficient):
# First we check the dimension across trap and cell id, to ensure both match
valid_cell_ids = valid_ndmerges.all(axis=2)
# Then we check the merge tuples to check which cases have both target and source
valid_merges = valid_cell_ids.any(axis=2).all(axis=1)
# Finalle we check the dimension that crosses all indices, to ensure the pair
# is present in a valid merge event.
valid_indices = valid_ndmerges[valid_merges].all(axis=2).any(axis=(0, 1))
return valid_merges, valid_indices
def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
"""
Join two tracks and return the new value of the target.
TODO replace this with arrays only.
"""
target_copy = copy(target)
end = find_1st(target_copy[::-1], 0, cmp_larger)
target_copy[-end:] = source[-end:]
return target_copy
...@@ -298,11 +298,7 @@ class PostProcessor(ProcessABC): ...@@ -298,11 +298,7 @@ class PostProcessor(ProcessABC):
self.run_prepost() self.run_prepost()
for i, (process, datasets) in tqdm( for process, datasets in tqdm(enumerate(self.targets["processes"])):
enumerate(self.targets["processes"])
):
if i == 3:
print("stop")
if process in self.parameters["param_sets"].get( if process in self.parameters["param_sets"].get(
"processes", {} "processes", {}
): # If we assigned parameters ): # If we assigned parameters
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment