Skip to content
Snippets Groups Projects
Commit 9ce33256 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

increase prepost robustness

Former-commit-id: b093f9bd7365a9a494147f92b9c8d68cf9be09b6
parent e7f52582
No related branches found
No related tags found
No related merge requests found
......@@ -13,17 +13,17 @@ class mergerParameters(ParametersABC):
def __init__(
self,
tolerance: float,
smooth: bool = False,
tolerance: float = 0.1,
window: int = 5,
degree: int = 3,
min_avg_delta: float = 0.9,
):
self.smooth = smooth
self.tolerance = tolerance
self.smooth = smooth
self.window = window
self.degree = degree
......@@ -35,7 +35,7 @@ class mergerParameters(ParametersABC):
return cls.from_dict(
{
"smooth": False,
"tolerance": 0.1,
"tolerance": 0.2,
"window": 5,
"degree": 3,
"min_avg_delta": 0.9,
......@@ -52,7 +52,7 @@ class merger(ProcessABC):
super().__init__(parameters)
def run(self, signal):
joinable = get_joinable(signal)
joinable = get_joinable(signal, tol=self.parameters.tolerance)
# merged, _ = merge_tracks(signal) # , min_len=self.window + 1)
# indices = (*zip(*merged.index.tolist()),)
# names = merged.index.names
......
......@@ -2,6 +2,7 @@ from typing import Tuple, Union, List
from abc import ABC, abstractmethod
from itertools import groupby
from utils_find_1st import find_1st, cmp_equal
import numpy as np
import pandas as pd
......@@ -24,9 +25,11 @@ class pickerParameters(ParametersABC):
return cls.from_dict(
{
"sequence": [
("condition", "present", 0.8),
("lineage", "families", "union"),
("condition", "present", 10),
# ("lineage", "intersection", "families"),
("condition", "intersection", "any_present", 0.8),
("condition", "intersection", "growing", 50),
("condition", "intersection", "present", 10),
# ("lineage", "full_families", "intersection"),
],
}
)
......@@ -114,6 +117,8 @@ class picker(ProcessABC):
for x in np.array(self.daughters)[search(self.mothers, m)]
]
)
print("associated daughters: ", idx)
elif how == "mothers_w_daughters":
present_daughters = idx.intersection(daughters)
idx = set(
......@@ -189,18 +194,21 @@ class picker(ProcessABC):
def run(self, signals):
indices = set(signals.index)
self.mothers, self.daughters = self.get_mothers_daughters()
for alg, param1, param2 in self.sequence:
for alg, op, *params in self.sequence:
if alg is "lineage":
param1 = params[0]
new_indices = getattr(self, "pick_by_" + alg)(signals, param1)
if param2 is "union":
new_indices = new_indices.intersection(set(signals.index))
indices = indices.union(set(new_indices))
else:
param1, param2 = params
new_indices = getattr(self, "pick_by_" + alg)(signals, param1, param2)
if op is "union":
new_indices = new_indices.intersection(set(signals.index))
new_indices = indices.union(set(new_indices))
indices = indices.intersection(new_indices)
mothers, daughters = self.mothers, self.daughters
return np.array(mothers), np.array(daughters), np.array(list(indices))
return np.array(list(indices))
@staticmethod
def switch_case(
......@@ -212,13 +220,39 @@ class picker(ProcessABC):
if isinstance(threshold, list):
thresh_presence = threshold[0]
case_mgr = {
"present": signals.notna().sum(axis=1) > threshold_asint,
"nonstoply_present": signals.apply(max_nonstop_ntps, axis=1)
"any_present": lambda s, thresh: any_present(s, threshold_asint),
"present": lambda s, thresh: signals.notna().sum(axis=1) > threshold_asint,
"nonstoply_present": lambda s, thresh: signals.apply(
max_nonstop_ntps, axis=1
)
> threshold_asint,
"growing": signals.diff(axis=1).sum(axis=1) > threshold,
"growing": lambda s, thresh: signals.diff(axis=1).sum(axis=1) > threshold,
# "quantile": [np.quantile(signals.values[signals.notna()], threshold)],
}
return set(signals.index[case_mgr[condition]])
return set(signals.index[case_mgr[condition](signals, threshold)])
from copy import copy
def any_present(signals, threshold):
"""
Returns a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints.
"""
any_present = pd.Series(
np.sum(
[
np.isin([x[0] for x in signals.index], i) & v
for i, v in (signals.notna().sum(axis=1) > threshold)
.groupby("trap")
.any()
.items()
],
axis=0,
).astype(bool),
index=signals.index,
)
return any_present
def _as_int(threshold: Union[float, int], ntps: int):
......
......@@ -122,9 +122,13 @@ class PostProcessor:
if "modifiers/picks" in f:
del f["modifiers/picks"]
mothers, daughters, indices = self.picker.run(
self._signal[self.targets["prepost"]["picker"][0]]
indices = self.picker.run(self._signal[self.targets["prepost"]["picker"][0]])
from collections import Counter
mothers, daughters = np.array(self.picker.mothers), np.array(
self.picker.daughters
)
self.tmp = [y for y in Counter([tuple(x) for x in mothers]).items() if y[1] > 2]
self._writer.write(
"postprocessing/lineage",
data=pd.MultiIndex.from_arrays(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment