From a1d7372b00c85e5be55536f684a01fb48a3a3a61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alan=20Mu=C3=B1oz?= <alan_munoz@protonmail.com> Date: Thu, 23 Mar 2023 07:34:55 +0000 Subject: [PATCH] Merge pp dev --- src/postprocessor/core/functions/tracks.py | 44 +---- src/postprocessor/core/processor.py | 206 ++++++++++++--------- src/postprocessor/core/reshapers/merger.py | 32 +++- src/postprocessor/core/reshapers/picker.py | 17 +- 4 files changed, 151 insertions(+), 148 deletions(-) diff --git a/src/postprocessor/core/functions/tracks.py b/src/postprocessor/core/functions/tracks.py index db8ed272..5fa04a15 100644 --- a/src/postprocessor/core/functions/tracks.py +++ b/src/postprocessor/core/functions/tracks.py @@ -18,7 +18,7 @@ from postprocessor.core.processes.savgol import non_uniform_savgol def load_test_dset(): - # Load development dataset to test functions + """Load development dataset to test functions.""" return pd.DataFrame( { ("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5, @@ -31,7 +31,7 @@ def load_test_dset(): def max_ntps(track: pd.Series) -> int: - # Get number of timepoints + """Get number of time points.""" indices = np.where(track.notna()) return np.max(indices) - np.min(indices) @@ -84,9 +84,7 @@ def clean_tracks( """ ntps = get_tracks_ntps(tracks) grs = get_avg_grs(tracks) - growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)] - return growing_long_tracks @@ -111,7 +109,6 @@ def merge_tracks( joinable_pairs = get_joinable(tracks, **kwargs) if joinable_pairs: tracks = join_tracks(tracks, joinable_pairs, drop=drop) - return (tracks, joinable_pairs) @@ -135,28 +132,23 @@ def get_joint_ids(merging_seqs) -> dict: 2 ab cd 3 abcd - We shold get: + We should get: output {a:a, b:a, c:a, d:a} """ if not merging_seqs: return {} - targets, origins = list(zip(*merging_seqs)) static_tracks = set(targets).difference(origins) - joint = {track_id: track_id for track_id in static_tracks} for target, origin in merging_seqs: joint[origin] = target - moved_target = [ k for k, v in joint.items() if joint[v] != v and v in joint.values() ] - for orig in moved_target: joint[orig] = rec_bottom(joint, orig) - return { k: v for k, v in joint.items() if k != v } # remove ids that point to themselves @@ -184,14 +176,11 @@ def join_tracks(tracks, joinable_pairs, drop=True) -> pd.DataFrame: :param drop: bool indicating whether or not to drop moved rows """ - tmp = copy(tracks) for target, source in joinable_pairs: tmp.loc[target] = join_track_pair(tmp.loc[target], tmp.loc[source]) - if drop: tmp = tmp.drop(source) - return tmp @@ -206,7 +195,7 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: """ Get the pair of track (without repeats) that have a smaller error than the tolerance. If there is a track that can be assigned to two or more other - ones, it chooses the one with a lowest error. + ones, choose the one with lowest error. :param tracks: (m x n) dataframe where rows are cell tracks and columns are timepoints @@ -324,7 +313,6 @@ def get_means(x, i): v = x[~np.isnan(x)][:i] else: v = x[~np.isnan(x)][i:] - return np.nanmean(v) @@ -335,12 +323,12 @@ def get_last_i(x, i): v = x[~np.isnan(x)][:i] else: v = x[~np.isnan(x)][i:] - return v def localid_to_idx(local_ids, contig_trap): - """Fetch then original ids from a nested list with joinable local_ids + """ + Fetch the original ids from a nested list with joinable local_ids. input :param local_ids: list of list of pairs with cell ids to be joint @@ -392,7 +380,6 @@ def get_dMetric( dMetric = np.abs(np.subtract.outer(post, pre)) else: dMetric = np.abs(np.subtract.outer(pre, post)) - dMetric[np.isnan(dMetric)] = ( tol + 1 + np.nanmax(dMetric) ) # nans will be filtered @@ -404,28 +391,26 @@ def solve_matrices( ): """ Solve the distance matrices obtained in get_dMetric and/or merged from - independent dMetric matrices + independent dMetric matrices. """ - ids = solve_matrix(dMetric) if not len(ids[0]): return [] pre, post = prepost - norm = ( np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1 ) # relative or absolute tol result = dMetric[ids] / norm ids = ids if len(pre) < len(post) else ids[::-1] - return [idx for idx, res in zip(zip(*ids), result) if res <= tol] def get_closest_pairs( pre: List[float], post: List[float], tol: Union[float, int] = 1 ): - """Calculate a cost matrix the Hungarian algorithm to pick the best set of - options + """ + Calculate a cost matrix for the Hungarian algorithm to pick the best set of + options. input :param pre: list of floats with edges on left @@ -437,7 +422,6 @@ def get_closest_pairs( """ dMetric = get_dMetric(pre, post, tol) - return solve_matrices(dMetric, pre, post, tol) @@ -465,9 +449,7 @@ def solve_matrix(dMetric): tmp[:, j] += np.nan glob_is.append(i) glob_js.append(j) - std = sorted(tmp[~np.isnan(tmp)]) - return (np.array(glob_is), np.array(glob_js)) @@ -475,7 +457,6 @@ def plot_joinable(tracks, joinable_pairs): """ Convenience plotting function for debugging and data vis """ - nx = 8 ny = 8 _, axes = plt.subplots(nx, ny) @@ -493,7 +474,6 @@ def plot_joinable(tracks, joinable_pairs): # except: # pass ax.plot(post_srs.index, post_srs.values, "g") - plt.show() @@ -506,21 +486,17 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list: :param min_dgr: float minimum difference in growth rate from the interpolation """ - mins, maxes = [ tracks.notna().apply(np.where, axis=1).apply(fn) for fn in (np.min, np.max) ] - mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist()) mins_d.index = mins_d.index - 1 # make indices equal # TODO add support for skipping time points maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist()) - common = sorted( set(mins_d.index).intersection(maxes_d.index), reverse=True ) - return [(maxes_d[t], mins_d[t]) for t in common] diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index d4cd0367..e0adac14 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -1,7 +1,5 @@ -import logging import typing as t from itertools import takewhile -from typing import Dict, List, Union import numpy as np import pandas as pd @@ -16,7 +14,6 @@ from agora.utils.indexing import ( _assoc_indices_to_3d, ) from agora.utils.merge import merge_association -from agora.utils.kymograph import get_index_as_np from postprocessor.core.abc import get_parameters, get_process from postprocessor.core.lineageprocess import ( LineageProcess, @@ -35,24 +32,37 @@ class PostProcessorParameters(ParametersABC): while objectives are relative or absolute paths to datasets. If relative paths the post-processed addresses are used. The order of processes matters. + Supply parameters for picker, merger, and processes. + The order of processes matters. + + 'processes' are defined in ./processes/ while objectives are relative or absolute paths to datasets. If relative paths the post-processed addresses are used. """ def __init__( self, - targets={}, - param_sets={}, - outpaths={}, + targets: t.Dict = {}, + param_sets: t.Dict = {}, + outpaths: t.Dict = {}, ): - self.targets: Dict = targets - self.param_sets: Dict = param_sets - self.outpaths: Dict = outpaths + self.targets = targets + self.param_sets = param_sets + self.outpaths = outpaths def __getitem__(self, item): return getattr(self, item) @classmethod def default(cls, kind=[]): - """Sequential postprocesses to be operated""" + """ + Include buddings and bud_metric and estimates of their time derivatives. + + Parameters + ---------- + kind: list of str + If "ph_batman" included, add targets for experiments using pHlourin. + """ + # each subitem specifies the function to be called and the location + # on the h5 file to be written targets = { "prepost": { "merger": "/extraction/general/None/area", @@ -61,9 +71,7 @@ class PostProcessorParameters(ParametersABC): "processes": [ [ "buddings", - [ - "/extraction/general/None/volume", - ], + ["/extraction/general/None/volume"], ], [ "dsignal", @@ -93,7 +101,7 @@ class PostProcessorParameters(ParametersABC): } outpaths = {} outpaths["aggregate"] = "/postprocessing/experiment_wide/aggregated/" - + # pHlourin experiments are special if "ph_batman" in kind: targets["processes"]["dsignal"].append( [ @@ -115,46 +123,56 @@ class PostProcessorParameters(ParametersABC): ] ], ) - return cls(targets=targets, param_sets=param_sets, outpaths=outpaths) class PostProcessor(ProcessABC): def __init__(self, filename, parameters): + """ + Initialise PostProcessor + + Parameters + ---------- + filename: str or PosixPath + Name of h5 file. + parameters: PostProcessorParameters object + An instance of PostProcessorParameters. + """ super().__init__(parameters) self._filename = filename self._signal = Signal(filename) self._writer = Writer(filename) - + # parameters for merger and picker dicted_params = { i: parameters["param_sets"]["prepost"][i] for i in ["merger", "picker"] } - for k in dicted_params.keys(): if not isinstance(dicted_params[k], dict): dicted_params[k] = dicted_params[k].to_dict() - + # merger and picker self.merger = Merger( MergerParameters.from_dict(dicted_params["merger"]) ) - self.picker = Picker( PickerParameters.from_dict(dicted_params["picker"]), cells=Cells.from_source(filename), ) + # processes, such as buddings self.classfun = { process: get_process(process) for process, _ in parameters["targets"]["processes"] } + # parameters for the process in classfun self.parameters_classfun = { process: get_parameters(process) for process, _ in parameters["targets"]["processes"] } + # locations to be written in the h5 file self.targets = parameters["targets"] def run_prepost(self): - # TODO Split function + """Using picker, get and write lineages, returning mothers and daughters.""" """Important processes run before normal post-processing ones""" record = self._signal.get_raw(self.targets["prepost"]["merger"]) merges = np.array(self.merger.run(record), dtype=int) @@ -175,7 +193,6 @@ class PostProcessor(ProcessABC): self.lineage = _3d_index_to_2d( lineage_merged if len(lineage_merged) else lineage ) - self._writer.write( "modifiers/lineage_merged", _3d_index_to_2d(lineage_merged) ) @@ -204,91 +221,98 @@ class PostProcessor(ProcessABC): return x def run(self): - # TODO Documentation :) + Split + """ + Write the results to the h5 file. + Processes include identifying buddings and finding bud metrics. + """ + # run merger, picker, and find lineages self.run_prepost() - + # run processes for process, datasets in tqdm(self.targets["processes"]): - if process in self.parameters["param_sets"].get( - "processes", {} - ): # If we assigned parameters + if process in self.parameters["param_sets"].get("processes", {}): + # parameters already assigned parameters = self.parameters_classfun[process]( self.parameters[process] ) - else: + # assign parameters parameters = self.parameters_classfun[process].default() - + # load process loaded_process = self.classfun[process](parameters) if isinstance(parameters, LineageProcessParameters): loaded_process.lineage = self.lineage + # apply process to each dataset for dataset in datasets: - if isinstance(dataset, list): # multisignal process - signal = [self._signal[d] for d in dataset] - elif isinstance(dataset, str): - signal = self._signal[dataset] - else: - raise Exception("Unavailable record") - - if len(signal) and ( - not isinstance(loaded_process, LineageProcess) - or len(loaded_process.lineage) - ): - result = loaded_process.run(signal) - else: - result = pd.DataFrame( - [], columns=signal.columns, index=signal.index - ) - result.columns.names = ["timepoint"] - - if process in self.parameters["outpaths"]: - outpath = self.parameters["outpaths"][process] - elif isinstance(dataset, list): - # If no outpath defined, place the result in the minimum common - # branch of all signals used - prefix = "".join( - c[0] - for c in takewhile( - lambda x: all(x[0] == y for y in x), zip(*dataset) - ) - ) - outpath = ( - prefix - + "_".join( # TODO check that it always finishes in '/' - [ - d[len(prefix) :].replace("/", "_") - for d in dataset - ] - ) - ) - elif isinstance(dataset, str): - outpath = dataset[1:].replace("/", "_") - else: - raise ("Outpath not defined", type(dataset)) - - if process not in self.parameters["outpaths"]: - outpath = "/postprocessing/" + process + "/" + outpath - - if isinstance(result, dict): # Multiple Signals as output - for k, v in result.items(): - self.write_result( - outpath + f"/{k}", - v, - metadata={}, - ) - else: - self.write_result( - outpath, - result, - metadata={}, - ) + self.run_process(dataset, process, loaded_process) + + def run_process(self, dataset, process, loaded_process): + """Run process on a single dataset and write the result.""" + # define signal + if isinstance(dataset, list): + # multisignal process + signal = [self._signal[d] for d in dataset] + elif isinstance(dataset, str): + signal = self._signal[dataset] + else: + raise ("Incorrect dataset") + # run process on signal + if len(signal) and ( + not isinstance(loaded_process, LineageProcess) + or len(loaded_process.lineage) + ): + result = loaded_process.run(signal) + else: + result = pd.DataFrame( + [], columns=signal.columns, index=signal.index + ) + result.columns.names = ["timepoint"] + # define outpath, where result will be written + if process in self.parameters["outpaths"]: + outpath = self.parameters["outpaths"][process] + elif isinstance(dataset, list): + # no outpath is defined + # place the result in the minimum common branch of all signals + prefix = "".join( + c[0] + for c in takewhile( + lambda x: all(x[0] == y for y in x), zip(*dataset) + ) + ) + outpath = ( + prefix + + "_".join( # TODO check that it always finishes in '/' + [d[len(prefix) :].replace("/", "_") for d in dataset] + ) + ) + elif isinstance(dataset, str): + outpath = dataset[1:].replace("/", "_") + else: + raise ("Outpath not defined", type(dataset)) + # add postprocessing to outpath when required + if process not in self.parameters["outpaths"]: + outpath = "/postprocessing/" + process + "/" + outpath + # write result + if isinstance(result, dict): + # multiple Signals as output + for k, v in result.items(): + self.write_result( + outpath + f"/{k}", + v, + metadata={}, + ) + else: + # a single Signal as output + self.write_result( + outpath, + result, + metadata={}, + ) def write_result( self, path: str, - result: Union[List, pd.DataFrame, np.ndarray], - metadata: Dict, + result: t.Union[t.List, pd.DataFrame, np.ndarray], + metadata: t.Dict, ): - if not result.any().any(): - logging.getLogger("aliby").warning(f"Record {path} is empty") self._writer.write(path, result, meta=metadata, overwrite="overwrite") diff --git a/src/postprocessor/core/reshapers/merger.py b/src/postprocessor/core/reshapers/merger.py index 67c14b27..e18fbf7b 100644 --- a/src/postprocessor/core/reshapers/merger.py +++ b/src/postprocessor/core/reshapers/merger.py @@ -6,11 +6,21 @@ from postprocessor.core.functions.tracks import get_joinable class MergerParameters(ParametersABC): """ - :param tol: float or int threshold of average (prediction error/std) necessary - to consider two tracks the same. If float is fraction of first track, - if int it is absolute units. - :param window: int value of window used for savgol_filter - :param degree: int value of polynomial degree passed to savgol_filter + Define the parameters for merger from a dict. + + There are five parameters expected in the dict: + + smooth, boolean + Whether or not to smooth with a savgol_filter. + tol: float or int + The threshold of average prediction error/std necessary to + consider two tracks the same. + If float, the threshold is the fraction of the first track; + if int, the threshold is in absolute units. + window: int + The size of the window of the savgol_filter. + degree: int v + The order of the polynomial used by the savgol_filter """ _defaults = { @@ -23,9 +33,7 @@ class MergerParameters(ParametersABC): class Merger(PostProcessABC): - """ - Combines rows of tracklet that are likely to be the same. - """ + """Combine rows of tracklet that are likely to be the same.""" def __init__(self, parameters): super().__init__(parameters) @@ -33,5 +41,11 @@ class Merger(PostProcessABC): def run(self, signal): joinable = [] if signal.shape[1] > 4: - joinable = get_joinable(signal, tol=self.parameters.tolerance) + joinable = get_joinable( + signal, + smooth=self.parameters.smooth, + tol=self.parameters.tolerance, + window=self.parameters.window, + degree=self.parameters.degree, + ) return joinable diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index 2d3f8cc4..1796fe65 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -25,7 +25,7 @@ class Picker(LineageProcess): :cells: Cell object passed to the constructor :condition: Tuple with condition and associated parameter(s), conditions can be "present", "nonstoply_present" or "quantile". - Determines the thersholds or fractions of signals to use. + Determine the thresholds or fractions of signals to use. :lineage: str {"mothers", "daughters", "families" (mothers AND daughters), "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families. """ @@ -35,7 +35,6 @@ class Picker(LineageProcess): cells: Cells or None = None, ): super().__init__(parameters=parameters) - self.cells = cells def pick_by_lineage( @@ -44,13 +43,9 @@ class Picker(LineageProcess): how: str, mothers_daughters: t.Optional[np.ndarray] = None, ) -> pd.MultiIndex: - cells_present = drop_mother_label(signal.index) - mothers_daughters = self.get_lineage_information(signal) - valid_indices = slice(None) - if how == "mothers": _, valid_indices = validate_association( mothers_daughters, cells_present, match_column=0 @@ -63,7 +58,6 @@ class Picker(LineageProcess): _, valid_indices = validate_association( mothers_daughters, cells_present ) - return signal.index[valid_indices] def pick_by_condition(self, signal, condition, thresh): @@ -73,13 +67,10 @@ class Picker(LineageProcess): def run(self, signal): self.orig_signal = signal indices = set(signal.index) - lineage = self.get_lineage_information(signal) - if len(lineage): self.mothers = lineage[:, :2] self.daughters = lineage[:, [0, 2]] - for alg, *params in self.sequence: new_indices = tuple() if indices: @@ -94,13 +85,11 @@ class Picker(LineageProcess): signal.loc[list(indices)], param1, param2 ) new_indices = [tuple(x) for x in new_indices] - indices = indices.intersection(new_indices) else: self._log(f"No lineage assignment") indices = np.array([]) - - return np.array([tuple(map(_str_to_int, x)) for x in indices]).T + return np.array([tuple(map(_str_to_int, x)) for x in indices]) def switch_case( self, @@ -128,7 +117,7 @@ def _as_int(threshold: t.Union[float, int], ntps: int): def any_present(signal, threshold): """ - Returns a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints. + Return a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints. """ any_present = pd.Series( np.sum( -- GitLab