From d99e55a3c76ffd43913a627dcd91d3addc08eb7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Fri, 30 Sep 2022 17:59:39 +0100 Subject: [PATCH] tweak(agora): isolate merge as functions --- src/agora/io/signal.py | 169 ++++++++++------------------ src/agora/utils/merge.py | 113 +++++++++++++++++++ src/postprocessor/core/processor.py | 6 +- 3 files changed, 175 insertions(+), 113 deletions(-) create mode 100644 src/agora/utils/merge.py diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 7f8af048..604a7ef0 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -1,15 +1,16 @@ import typing as t from copy import copy +from functools import cached_property, lru_cache from pathlib import PosixPath import bottleneck as bn import h5py import numpy as np import pandas as pd -from utils_find_1st import cmp_larger, find_1st from agora.io.bridge import BridgeH5 from agora.io.decorators import _first_arg_str_to_df +from agora.utils.merge import apply_merges class Signal(BridgeH5): @@ -34,10 +35,6 @@ class Signal(BridgeH5): def __getitem__(self, dsets: t.Union[str, t.Collection]): - assert isinstance( - dsets, (str, t.Collection) - ), "Incorrect type for dset" - if isinstance(dsets, str) and dsets.endswith("imBackground"): df = self.get_raw(dsets) @@ -52,6 +49,8 @@ class Signal(BridgeH5): return [ self.add_name(self.apply_prepost(dset), dset) for dset in dsets ] + else: + raise Exception(f"Invalid type {type(dsets)} to get datasets") # return self.cols_in_mins(self.add_name(df, dsets)) return self.add_name(df, dsets) @@ -74,12 +73,12 @@ class Signal(BridgeH5): ) return df - @property + @cached_property def ntimepoints(self): with h5py.File(self.filename, "r") as f: return f["extraction/general/None/area/timepoint"][-1] + 1 - @property + @cached_property def tinterval(self) -> int: tinterval_location = "time_settings/timeinterval" with h5py.File(self.filename, "r") as f: @@ -89,6 +88,7 @@ class Signal(BridgeH5): def get_retained(df, cutoff): return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff] + @lru_cache(30) def retained(self, signal, cutoff=0.8): df = self[signal] @@ -98,6 +98,7 @@ class Signal(BridgeH5): elif isinstance(df, list): return [self.get_retained(d, cutoff=cutoff) for d in df] + @lru_cache(2) def lineage( self, lineage_location: t.Optional[str] = None, merged: bool = False ) -> np.ndarray: @@ -127,40 +128,48 @@ class Signal(BridgeH5): def apply_prepost( self, data: t.Union[str, pd.DataFrame], - merges: np.ndarray = None, - picks: t.Optional[bool] = None, + merges: t.Union[np.ndarray, bool] = True, + picks: t.Union[t.Collection, bool] = True, ): + """Apply modifier operations (picker, merger) to a given dataframe. + + + Parameters + ---------- + data : t.Union[str, pd.DataFrame] + DataFrame or url to one. + merges : t.Union[np.ndarray, bool] + (optional) 2-D array with three columns and variable length. The + first column is the trap id, second is mother label and third one is + daughter id. + If it is True it fetches merges from file, if false it skips merging step. + picks : t.Union[np.ndarray, bool] + (optional) 2-D ndarray where first column is traps and second column + is cell labels. + If it is True it fetches picks from file, if false it skips picking step. + + Examples + -------- + FIXME: Add docs. + """ - Apply modifier operations (picker, merger) to a given dataframe. - """ - if merges is None: - merges = self.get_merges() + if isinstance(merges, bool): + merges: np.ndarray = self.get_merges() if merges else np.array([]) + merged = copy(data) if merges.any(): - # Split in two dfs, one with rows relevant for merging and one - # without them - valid_merges = validate_merges(merges, np.array(list(data.index))) - - # TODO use the same info from validate_merges to select both - valid_indices = [ - tuple(x) - for x in (np.unique(valid_merges.reshape(-1, 2), axis=0)) - ] - merged = self.apply_merge( - data.loc[valid_indices], - valid_merges, - ) - - nonmergeable_ids = data.index.difference(valid_indices) + merged = apply_merges(data, merges) - merged = pd.concat( - (merged, data.loc[nonmergeable_ids]), names=data.index.names + if isinstance(picks, bool): + picks = ( + self.get_picks(names=merged.index.names) + if picks + else set(merged.index) ) with h5py.File(self.filename, "r") as f: - if "modifiers/picks" in f and not picks: - picks = self.get_picks(names=merged.index.names) + if "modifiers/picks" in f and picks: # missing_cells = [i for i in picks if tuple(i) not in # set(merged.index)] @@ -184,7 +193,7 @@ class Signal(BridgeH5): merged = pd.DataFrame([], index=index) return merged - @property + @cached_property def datasets(self): if not hasattr(self, "_siglist"): self._siglist = [] @@ -195,12 +204,12 @@ class Signal(BridgeH5): for sig in self.siglist: print(sig) - @property + @cached_property def p_siglist(self): """Print signal list""" self.datasets - @property + @cached_property def siglist(self): """Return list of signals""" try: @@ -215,34 +224,24 @@ class Signal(BridgeH5): return self._siglist def get_merged(self, dataset): - return self.apply_prepost(dataset, skip_pick=True) + return self.apply_prepost(dataset, skip_picks=True) - @property + @cached_property def merges(self): with h5py.File(self.filename, "r") as f: dsets = f.visititems(self._if_merges) return dsets - @property + @cached_property def n_merges(self): return len(self.merges) - @property + @cached_property def picks(self): with h5py.File(self.filename, "r") as f: dsets = f.visititems(self._if_picks) return dsets - def apply_merge(self, df, changes): - if len(changes): - for target, source in changes: - df.loc[tuple(target)] = self.join_tracks_pair( - df.loc[tuple(target)], df.loc[tuple(source)] - ) - df.drop(tuple(source), inplace=True) - - return df - def get_raw(self, dataset: str, in_minutes: bool = True): try: if isinstance(dataset, str): @@ -266,14 +265,20 @@ class Signal(BridgeH5): return merges - # def get_picks(self, levels): - def get_picks(self, names, path="modifiers/picks/"): + def get_picks( + self, + names: t.Tuple[str, ...] = ("trap", "cell_label"), + path: str = "modifiers/picks/", + ) -> t.Set[t.Tuple[int, str]]: + """ + Return the relevant picks based on names + """ with h5py.File(self.filename, "r") as f: + picks = set() if path in f: - return list(zip(*[f[path + name] for name in names])) - # return f["modifiers/picks"] - else: - return None + picks = set(zip(*[f[path + name] for name in names])) + + return picks def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: """ @@ -322,7 +327,7 @@ class Signal(BridgeH5): # columns=f[path + "/timepoint"][()], # ) - def get_siglist(self, name: str, node): + def get_siglist(self, node): fullname = node.name if isinstance(node, h5py.Group) and np.all( [isinstance(x, h5py.Dataset) for x in node.values()] @@ -348,17 +353,6 @@ class Signal(BridgeH5): if isinstance(obj, h5py.Group) and name.endswith("picks"): return obj[()] - @staticmethod - def join_tracks_pair(target: pd.Series, source: pd.Series): - """ - Join two tracks and return the new value of the target. - TODO replace this with arrays only. - """ - tgt_copy = copy(target) - end = find_1st(target.values[::-1], 0, cmp_larger) - tgt_copy.iloc[-end:] = source.iloc[-end:].values - return tgt_copy - # TODO FUTURE add stages support to fluigent system @property def ntps(self) -> int: @@ -401,44 +395,3 @@ class Signal(BridgeH5): if end <= self.max_span ] return tuple((stage, ntps) for stage, ntps in zip(self.stages, spans)) - - -def validate_merges(merges: np.ndarray, indices: np.ndarray) -> np.ndarray: - """Select rows from the first array that are present in both. - We use casting for fast multiindexing - - - Parameters - ---------- - merges : np.ndarray - 2-D array where columns are (trap, mother, daughter) or 3-D array where - dimensions are (X, (trap,mother), (trap,daughter)) - indices : np.ndarray - 2-D array where each column is a different level. - - Returns - ------- - np.ndarray - 3-D array with elements in both arrays. - - Examples - -------- - FIXME: Add docs. - - """ - if merges.ndim < 3: - # Reshape into 3-D array for casting if neded - merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1) - - # Compare existing merges with available indices - # Swap trap and label axes for the merges array to correctly cast - # valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :] - valid_ndmerges = merges[..., None] == indices.T[None, ...] - - # Casting is confusing (but efficient): - # - First we check the dimension across trap and cell id, to ensure both match - # - Then we check the dimension that crosses all indices, to ensure the pair is present there - # - Finally we check the merge tuples to check which cases have both target and source - valid_merges = merges[valid_ndmerges.all(axis=2).any(axis=2).all(axis=1)] - # valid_merges = merges[allnan.any(axis=1)] - return valid_merges diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py new file mode 100644 index 00000000..9a28fd59 --- /dev/null +++ b/src/agora/utils/merge.py @@ -0,0 +1,113 @@ +#!/usr/bin/env jupyter +""" +Functions to efficiently merge rows in DataFrames. +""" +import typing as t +from copy import copy + +import numpy as np +import pandas as pd +from utils_find_1st import cmp_larger, find_1st + + +def apply_merges(data: pd.DataFrame, merges: np.ndarray): + """Split data in two, one subset for rows relevant for merging and one + without them. It uses an array of source tracklets and target tracklets + to efficiently merge them. + + Parameters + ---------- + data : pd.DataFrame + Input DataFrame. + merges : np.ndarray + 3-D ndarray where dimensions are (X,2,2): nmerges, source-target + pair and single-cell identifiers, respectively. + + Examples + -------- + FIXME: Add docs. + + """ + + valid_merges, indices = validate_merges(merges, np.array(list(data.index))) + + # Assign non-merged + merged = data.loc[~indices] + + # Implement the merges and drop source rows. + if valid_merges.any(): + to_merge = data.loc[indices] + for target, source in merges[valid_merges]: + target, source = tuple(target), tuple(source) + to_merge.loc[target] = join_tracks_pair( + to_merge.loc[target].values, + to_merge.loc[source].values, + ) + to_merge.drop(source, inplace=True) + + merged = pd.concat((merged, to_merge), names=data.index.names) + return merged + + +def validate_merges( + merges: np.ndarray, indices: np.ndarray +) -> t.Tuple[np.ndarray, np.ndarray]: + + """Select rows from the first array that are present in both. + We use casting for fast multiindexing. + + + + + Parameters + ---------- + merges : np.ndarray + 2-D array where columns are (trap, mother, daughter) or 3-D array where + dimensions are (X, (trap,mother), (trap,daughter)) + indices : np.ndarray + 2-D array where each column is a different level. + + Returns + ------- + np.ndarray + 1-D boolean array indicating valid merge events. + np.ndarray + 1-D boolean array indicating indices involved in merging. + + Examples + -------- + FIXME: Add docs. + + """ + if merges.ndim < 3: + # Reshape into 3-D array for broadcasting if neded + merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1) + + # Compare existing merges with available indices + # Swap trap and label axes for the merges array to correctly cast + # valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :] + valid_ndmerges = merges[..., None] == indices.T[None, ...] + + # Broadcasting is confusing (but efficient): + # First we check the dimension across trap and cell id, to ensure both match + valid_cell_ids = valid_ndmerges.all(axis=2) + + # Then we check the merge tuples to check which cases have both target and source + valid_merges = valid_cell_ids.any(axis=2).all(axis=1) + + # Finalle we check the dimension that crosses all indices, to ensure the pair + # is present in a valid merge event. + valid_indices = valid_ndmerges[valid_merges].all(axis=2).any(axis=(0, 1)) + + return valid_merges, valid_indices + + +def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: + """ + Join two tracks and return the new value of the target. + TODO replace this with arrays only. + """ + target_copy = copy(target) + end = find_1st(target_copy[::-1], 0, cmp_larger) + target_copy[-end:] = source[-end:] + return target_copy diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 7cc465c9..6efbddaf 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -298,11 +298,7 @@ class PostProcessor(ProcessABC): self.run_prepost() - for i, (process, datasets) in tqdm( - enumerate(self.targets["processes"]) - ): - if i == 3: - print("stop") + for process, datasets in tqdm(enumerate(self.targets["processes"])): if process in self.parameters["param_sets"].get( "processes", {} ): # If we assigned parameters -- GitLab