Skip to content
Snippets Groups Projects
Commit 5b885b36 authored by Alan Muñoz's avatar Alan Muñoz
Browse files

Merge branch 'merge_pp_dev' into 'dev'

Merge pp dev

See merge request aliby/aliby!17
parents 4e096148 a1d7372b
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ from postprocessor.core.processes.savgol import non_uniform_savgol
def load_test_dset():
# Load development dataset to test functions
"""Load development dataset to test functions."""
return pd.DataFrame(
{
("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5,
......@@ -31,7 +31,7 @@ def load_test_dset():
def max_ntps(track: pd.Series) -> int:
# Get number of timepoints
"""Get number of time points."""
indices = np.where(track.notna())
return np.max(indices) - np.min(indices)
......@@ -84,9 +84,7 @@ def clean_tracks(
"""
ntps = get_tracks_ntps(tracks)
grs = get_avg_grs(tracks)
growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
return growing_long_tracks
......@@ -111,7 +109,6 @@ def merge_tracks(
joinable_pairs = get_joinable(tracks, **kwargs)
if joinable_pairs:
tracks = join_tracks(tracks, joinable_pairs, drop=drop)
return (tracks, joinable_pairs)
......@@ -135,28 +132,23 @@ def get_joint_ids(merging_seqs) -> dict:
2 ab cd
3 abcd
We shold get:
We should get:
output {a:a, b:a, c:a, d:a}
"""
if not merging_seqs:
return {}
targets, origins = list(zip(*merging_seqs))
static_tracks = set(targets).difference(origins)
joint = {track_id: track_id for track_id in static_tracks}
for target, origin in merging_seqs:
joint[origin] = target
moved_target = [
k for k, v in joint.items() if joint[v] != v and v in joint.values()
]
for orig in moved_target:
joint[orig] = rec_bottom(joint, orig)
return {
k: v for k, v in joint.items() if k != v
} # remove ids that point to themselves
......@@ -184,14 +176,11 @@ def join_tracks(tracks, joinable_pairs, drop=True) -> pd.DataFrame:
:param drop: bool indicating whether or not to drop moved rows
"""
tmp = copy(tracks)
for target, source in joinable_pairs:
tmp.loc[target] = join_track_pair(tmp.loc[target], tmp.loc[source])
if drop:
tmp = tmp.drop(source)
return tmp
......@@ -206,7 +195,7 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, it chooses the one with a lowest error.
ones, choose the one with lowest error.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
......@@ -324,7 +313,6 @@ def get_means(x, i):
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
......@@ -335,12 +323,12 @@ def get_last_i(x, i):
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
def localid_to_idx(local_ids, contig_trap):
"""Fetch then original ids from a nested list with joinable local_ids
"""
Fetch the original ids from a nested list with joinable local_ids.
input
:param local_ids: list of list of pairs with cell ids to be joint
......@@ -392,7 +380,6 @@ def get_dMetric(
dMetric = np.abs(np.subtract.outer(post, pre))
else:
dMetric = np.abs(np.subtract.outer(pre, post))
dMetric[np.isnan(dMetric)] = (
tol + 1 + np.nanmax(dMetric)
) # nans will be filtered
......@@ -404,28 +391,26 @@ def solve_matrices(
):
"""
Solve the distance matrices obtained in get_dMetric and/or merged from
independent dMetric matrices
independent dMetric matrices.
"""
ids = solve_matrix(dMetric)
if not len(ids[0]):
return []
pre, post = prepost
norm = (
np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
) # relative or absolute tol
result = dMetric[ids] / norm
ids = ids if len(pre) < len(post) else ids[::-1]
return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
def get_closest_pairs(
pre: List[float], post: List[float], tol: Union[float, int] = 1
):
"""Calculate a cost matrix the Hungarian algorithm to pick the best set of
options
"""
Calculate a cost matrix for the Hungarian algorithm to pick the best set of
options.
input
:param pre: list of floats with edges on left
......@@ -437,7 +422,6 @@ def get_closest_pairs(
"""
dMetric = get_dMetric(pre, post, tol)
return solve_matrices(dMetric, pre, post, tol)
......@@ -465,9 +449,7 @@ def solve_matrix(dMetric):
tmp[:, j] += np.nan
glob_is.append(i)
glob_js.append(j)
std = sorted(tmp[~np.isnan(tmp)])
return (np.array(glob_is), np.array(glob_js))
......@@ -475,7 +457,6 @@ def plot_joinable(tracks, joinable_pairs):
"""
Convenience plotting function for debugging and data vis
"""
nx = 8
ny = 8
_, axes = plt.subplots(nx, ny)
......@@ -493,7 +474,6 @@ def plot_joinable(tracks, joinable_pairs):
# except:
# pass
ax.plot(post_srs.index, post_srs.values, "g")
plt.show()
......@@ -506,21 +486,17 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
:param min_dgr: float minimum difference in growth rate from
the interpolation
"""
mins, maxes = [
tracks.notna().apply(np.where, axis=1).apply(fn)
for fn in (np.min, np.max)
]
mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
mins_d.index = mins_d.index - 1 # make indices equal
# TODO add support for skipping time points
maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist())
common = sorted(
set(mins_d.index).intersection(maxes_d.index), reverse=True
)
return [(maxes_d[t], mins_d[t]) for t in common]
......
import logging
import typing as t
from itertools import takewhile
from typing import Dict, List, Union
import numpy as np
import pandas as pd
......@@ -16,7 +14,6 @@ from agora.utils.indexing import (
_assoc_indices_to_3d,
)
from agora.utils.merge import merge_association
from agora.utils.kymograph import get_index_as_np
from postprocessor.core.abc import get_parameters, get_process
from postprocessor.core.lineageprocess import (
LineageProcess,
......@@ -35,24 +32,37 @@ class PostProcessorParameters(ParametersABC):
while objectives are relative or absolute paths to datasets. If relative paths the
post-processed addresses are used. The order of processes matters.
Supply parameters for picker, merger, and processes.
The order of processes matters.
'processes' are defined in ./processes/ while objectives are relative or absolute paths to datasets. If relative paths the post-processed addresses are used.
"""
def __init__(
self,
targets={},
param_sets={},
outpaths={},
targets: t.Dict = {},
param_sets: t.Dict = {},
outpaths: t.Dict = {},
):
self.targets: Dict = targets
self.param_sets: Dict = param_sets
self.outpaths: Dict = outpaths
self.targets = targets
self.param_sets = param_sets
self.outpaths = outpaths
def __getitem__(self, item):
return getattr(self, item)
@classmethod
def default(cls, kind=[]):
"""Sequential postprocesses to be operated"""
"""
Include buddings and bud_metric and estimates of their time derivatives.
Parameters
----------
kind: list of str
If "ph_batman" included, add targets for experiments using pHlourin.
"""
# each subitem specifies the function to be called and the location
# on the h5 file to be written
targets = {
"prepost": {
"merger": "/extraction/general/None/area",
......@@ -61,9 +71,7 @@ class PostProcessorParameters(ParametersABC):
"processes": [
[
"buddings",
[
"/extraction/general/None/volume",
],
["/extraction/general/None/volume"],
],
[
"dsignal",
......@@ -93,7 +101,7 @@ class PostProcessorParameters(ParametersABC):
}
outpaths = {}
outpaths["aggregate"] = "/postprocessing/experiment_wide/aggregated/"
# pHlourin experiments are special
if "ph_batman" in kind:
targets["processes"]["dsignal"].append(
[
......@@ -115,46 +123,56 @@ class PostProcessorParameters(ParametersABC):
]
],
)
return cls(targets=targets, param_sets=param_sets, outpaths=outpaths)
class PostProcessor(ProcessABC):
def __init__(self, filename, parameters):
"""
Initialise PostProcessor
Parameters
----------
filename: str or PosixPath
Name of h5 file.
parameters: PostProcessorParameters object
An instance of PostProcessorParameters.
"""
super().__init__(parameters)
self._filename = filename
self._signal = Signal(filename)
self._writer = Writer(filename)
# parameters for merger and picker
dicted_params = {
i: parameters["param_sets"]["prepost"][i]
for i in ["merger", "picker"]
}
for k in dicted_params.keys():
if not isinstance(dicted_params[k], dict):
dicted_params[k] = dicted_params[k].to_dict()
# merger and picker
self.merger = Merger(
MergerParameters.from_dict(dicted_params["merger"])
)
self.picker = Picker(
PickerParameters.from_dict(dicted_params["picker"]),
cells=Cells.from_source(filename),
)
# processes, such as buddings
self.classfun = {
process: get_process(process)
for process, _ in parameters["targets"]["processes"]
}
# parameters for the process in classfun
self.parameters_classfun = {
process: get_parameters(process)
for process, _ in parameters["targets"]["processes"]
}
# locations to be written in the h5 file
self.targets = parameters["targets"]
def run_prepost(self):
# TODO Split function
"""Using picker, get and write lineages, returning mothers and daughters."""
"""Important processes run before normal post-processing ones"""
record = self._signal.get_raw(self.targets["prepost"]["merger"])
merges = np.array(self.merger.run(record), dtype=int)
......@@ -175,7 +193,6 @@ class PostProcessor(ProcessABC):
self.lineage = _3d_index_to_2d(
lineage_merged if len(lineage_merged) else lineage
)
self._writer.write(
"modifiers/lineage_merged", _3d_index_to_2d(lineage_merged)
)
......@@ -204,91 +221,98 @@ class PostProcessor(ProcessABC):
return x
def run(self):
# TODO Documentation :) + Split
"""
Write the results to the h5 file.
Processes include identifying buddings and finding bud metrics.
"""
# run merger, picker, and find lineages
self.run_prepost()
# run processes
for process, datasets in tqdm(self.targets["processes"]):
if process in self.parameters["param_sets"].get(
"processes", {}
): # If we assigned parameters
if process in self.parameters["param_sets"].get("processes", {}):
# parameters already assigned
parameters = self.parameters_classfun[process](
self.parameters[process]
)
else:
# assign parameters
parameters = self.parameters_classfun[process].default()
# load process
loaded_process = self.classfun[process](parameters)
if isinstance(parameters, LineageProcessParameters):
loaded_process.lineage = self.lineage
# apply process to each dataset
for dataset in datasets:
if isinstance(dataset, list): # multisignal process
signal = [self._signal[d] for d in dataset]
elif isinstance(dataset, str):
signal = self._signal[dataset]
else:
raise Exception("Unavailable record")
if len(signal) and (
not isinstance(loaded_process, LineageProcess)
or len(loaded_process.lineage)
):
result = loaded_process.run(signal)
else:
result = pd.DataFrame(
[], columns=signal.columns, index=signal.index
)
result.columns.names = ["timepoint"]
if process in self.parameters["outpaths"]:
outpath = self.parameters["outpaths"][process]
elif isinstance(dataset, list):
# If no outpath defined, place the result in the minimum common
# branch of all signals used
prefix = "".join(
c[0]
for c in takewhile(
lambda x: all(x[0] == y for y in x), zip(*dataset)
)
)
outpath = (
prefix
+ "_".join( # TODO check that it always finishes in '/'
[
d[len(prefix) :].replace("/", "_")
for d in dataset
]
)
)
elif isinstance(dataset, str):
outpath = dataset[1:].replace("/", "_")
else:
raise ("Outpath not defined", type(dataset))
if process not in self.parameters["outpaths"]:
outpath = "/postprocessing/" + process + "/" + outpath
if isinstance(result, dict): # Multiple Signals as output
for k, v in result.items():
self.write_result(
outpath + f"/{k}",
v,
metadata={},
)
else:
self.write_result(
outpath,
result,
metadata={},
)
self.run_process(dataset, process, loaded_process)
def run_process(self, dataset, process, loaded_process):
"""Run process on a single dataset and write the result."""
# define signal
if isinstance(dataset, list):
# multisignal process
signal = [self._signal[d] for d in dataset]
elif isinstance(dataset, str):
signal = self._signal[dataset]
else:
raise ("Incorrect dataset")
# run process on signal
if len(signal) and (
not isinstance(loaded_process, LineageProcess)
or len(loaded_process.lineage)
):
result = loaded_process.run(signal)
else:
result = pd.DataFrame(
[], columns=signal.columns, index=signal.index
)
result.columns.names = ["timepoint"]
# define outpath, where result will be written
if process in self.parameters["outpaths"]:
outpath = self.parameters["outpaths"][process]
elif isinstance(dataset, list):
# no outpath is defined
# place the result in the minimum common branch of all signals
prefix = "".join(
c[0]
for c in takewhile(
lambda x: all(x[0] == y for y in x), zip(*dataset)
)
)
outpath = (
prefix
+ "_".join( # TODO check that it always finishes in '/'
[d[len(prefix) :].replace("/", "_") for d in dataset]
)
)
elif isinstance(dataset, str):
outpath = dataset[1:].replace("/", "_")
else:
raise ("Outpath not defined", type(dataset))
# add postprocessing to outpath when required
if process not in self.parameters["outpaths"]:
outpath = "/postprocessing/" + process + "/" + outpath
# write result
if isinstance(result, dict):
# multiple Signals as output
for k, v in result.items():
self.write_result(
outpath + f"/{k}",
v,
metadata={},
)
else:
# a single Signal as output
self.write_result(
outpath,
result,
metadata={},
)
def write_result(
self,
path: str,
result: Union[List, pd.DataFrame, np.ndarray],
metadata: Dict,
result: t.Union[t.List, pd.DataFrame, np.ndarray],
metadata: t.Dict,
):
if not result.any().any():
logging.getLogger("aliby").warning(f"Record {path} is empty")
self._writer.write(path, result, meta=metadata, overwrite="overwrite")
......@@ -6,11 +6,21 @@ from postprocessor.core.functions.tracks import get_joinable
class MergerParameters(ParametersABC):
"""
:param tol: float or int threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
:param window: int value of window used for savgol_filter
:param degree: int value of polynomial degree passed to savgol_filter
Define the parameters for merger from a dict.
There are five parameters expected in the dict:
smooth, boolean
Whether or not to smooth with a savgol_filter.
tol: float or int
The threshold of average prediction error/std necessary to
consider two tracks the same.
If float, the threshold is the fraction of the first track;
if int, the threshold is in absolute units.
window: int
The size of the window of the savgol_filter.
degree: int v
The order of the polynomial used by the savgol_filter
"""
_defaults = {
......@@ -23,9 +33,7 @@ class MergerParameters(ParametersABC):
class Merger(PostProcessABC):
"""
Combines rows of tracklet that are likely to be the same.
"""
"""Combine rows of tracklet that are likely to be the same."""
def __init__(self, parameters):
super().__init__(parameters)
......@@ -33,5 +41,11 @@ class Merger(PostProcessABC):
def run(self, signal):
joinable = []
if signal.shape[1] > 4:
joinable = get_joinable(signal, tol=self.parameters.tolerance)
joinable = get_joinable(
signal,
smooth=self.parameters.smooth,
tol=self.parameters.tolerance,
window=self.parameters.window,
degree=self.parameters.degree,
)
return joinable
......@@ -25,7 +25,7 @@ class Picker(LineageProcess):
:cells: Cell object passed to the constructor
:condition: Tuple with condition and associated parameter(s), conditions can be
"present", "nonstoply_present" or "quantile".
Determines the thersholds or fractions of signals to use.
Determine the thresholds or fractions of signals to use.
:lineage: str {"mothers", "daughters", "families" (mothers AND daughters), "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families.
"""
......@@ -35,7 +35,6 @@ class Picker(LineageProcess):
cells: Cells or None = None,
):
super().__init__(parameters=parameters)
self.cells = cells
def pick_by_lineage(
......@@ -44,13 +43,9 @@ class Picker(LineageProcess):
how: str,
mothers_daughters: t.Optional[np.ndarray] = None,
) -> pd.MultiIndex:
cells_present = drop_mother_label(signal.index)
mothers_daughters = self.get_lineage_information(signal)
valid_indices = slice(None)
if how == "mothers":
_, valid_indices = validate_association(
mothers_daughters, cells_present, match_column=0
......@@ -63,7 +58,6 @@ class Picker(LineageProcess):
_, valid_indices = validate_association(
mothers_daughters, cells_present
)
return signal.index[valid_indices]
def pick_by_condition(self, signal, condition, thresh):
......@@ -73,13 +67,10 @@ class Picker(LineageProcess):
def run(self, signal):
self.orig_signal = signal
indices = set(signal.index)
lineage = self.get_lineage_information(signal)
if len(lineage):
self.mothers = lineage[:, :2]
self.daughters = lineage[:, [0, 2]]
for alg, *params in self.sequence:
new_indices = tuple()
if indices:
......@@ -94,13 +85,11 @@ class Picker(LineageProcess):
signal.loc[list(indices)], param1, param2
)
new_indices = [tuple(x) for x in new_indices]
indices = indices.intersection(new_indices)
else:
self._log(f"No lineage assignment")
indices = np.array([])
return np.array([tuple(map(_str_to_int, x)) for x in indices]).T
return np.array([tuple(map(_str_to_int, x)) for x in indices])
def switch_case(
self,
......@@ -128,7 +117,7 @@ def _as_int(threshold: t.Union[float, int], ntps: int):
def any_present(signal, threshold):
"""
Returns a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints.
Return a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints.
"""
any_present = pd.Series(
np.sum(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment