Skip to content
Snippets Groups Projects
Commit 12c08472 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

add postmerge mother-daughter picking

Former-commit-id: 7a0bd7540a0db12bbe0db452681d3e4bdff799e6
parent 25281873
No related branches found
No related tags found
No related merge requests found
...@@ -15,24 +15,19 @@ from postprocessor.core.functions.tracks import max_ntps, max_nonstop_ntps ...@@ -15,24 +15,19 @@ from postprocessor.core.functions.tracks import max_ntps, max_nonstop_ntps
class pickerParameters(ParametersABC): class pickerParameters(ParametersABC):
def __init__( def __init__(
self, self,
condition: Tuple[str, Union[float, int]] = None,
lineage: str = None,
lineage_conditional: str = None,
sequence: List[str] = ["lineage", "condition"], sequence: List[str] = ["lineage", "condition"],
): ):
self.condition = condition
self.lineage = lineage
self.lineage_conditional = lineage_conditional
self.sequence = sequence self.sequence = sequence
@classmethod @classmethod
def default(cls): def default(cls):
return cls.from_dict( return cls.from_dict(
{ {
"condition": ["present", 0.8], "sequence": [
"lineage": "families", ("condition", "present", 0.8),
"lineage_conditional": "include", ("lineage", "families", "union"),
"sequence": ["condition", "lineage"], ("condition", "present", 10),
],
} }
) )
...@@ -91,10 +86,10 @@ class picker(ProcessABC): ...@@ -91,10 +86,10 @@ class picker(ProcessABC):
return nested_massign return nested_massign
def pick_by_lineage(self, signals): def pick_by_lineage(self, signals, how):
idx = signals.index idx = signals.index
if self.lineage: if how:
ma = self._cells["mother_assign_dynamic"] ma = self._cells["mother_assign_dynamic"]
trap = self._cells["trap"] trap = self._cells["trap"]
label = self._cells["cell_label"] label = self._cells["cell_label"]
...@@ -124,38 +119,45 @@ class picker(ProcessABC): ...@@ -124,38 +119,45 @@ class picker(ProcessABC):
mothers = set(mothers) mothers = set(mothers)
daughters = set(daughters) daughters = set(daughters)
# daughters, mothers = np.where(mother_bud_mat) # daughters, mothers = np.where(mother_bud_mat)
if self.lineage == "mothers": if how == "mothers":
idx = mothers idx = mothers
elif self.lineage == "daughters": elif how == "daughters":
idx = daughters idx = daughters
elif self.lineage == "families" or self.lineage == "orphans": elif how == "families" or how == "orphans":
families = mothers.union(daughters) families = mothers.union(daughters)
if self.lineage == "families": if how == "families":
idx = families idx = families
elif self.lineage == "orphans": # orphans elif how == "orphans":
idx = idx.diference(families) idx = idx.diference(families)
idx = idx.intersection(signals.index) idx = idx.intersection(signals.index)
return idx return idx
def pick_by_condition(self, signals): def pick_by_condition(self, signals, condition, thresh):
idx = self.switch_case(self.condition[0], signals, self.condition[1]) idx = self.switch_case(signals, condition, thresh)
return idx return idx
def run(self, signals): def run(self, signals):
indices = set(signals.index) indices = set(signals.index)
daughters, mothers = (None, None) daughters, mothers = (None, None)
for alg in self.sequence: for alg, param1, param2 in self.sequence:
indices = getattr(self, "pick_by_" + alg)(signals) if alg is "lineage":
new_indices = getattr(self, "pick_by_" + alg)(signals, param1)
if param2 is "union":
new_indices = new_indices.intersection(set(signals.index))
indices = indices.union(set(new_indices))
else:
new_indices = getattr(self, "pick_by_" + alg)(signals, param1, param2)
indices = indices.intersection(new_indices)
daughters, mothers = self.daughters, self.mothers daughters, mothers = self.daughters, self.mothers
return np.array(daughters), np.array(mothers), np.array(list(indices)) return np.array(daughters), np.array(mothers), np.array(list(indices))
@staticmethod @staticmethod
def switch_case( def switch_case(
condition: str,
signals: pd.DataFrame, signals: pd.DataFrame,
condition: str,
threshold: Union[float, int], threshold: Union[float, int],
): ):
threshold_asint = _as_int(threshold, signals.shape[1]) threshold_asint = _as_int(threshold, signals.shape[1])
...@@ -163,9 +165,9 @@ class picker(ProcessABC): ...@@ -163,9 +165,9 @@ class picker(ProcessABC):
"present": signals.notna().sum(axis=1) > threshold_asint, "present": signals.notna().sum(axis=1) > threshold_asint,
"nonstoply_present": signals.apply(max_nonstop_ntps, axis=1) "nonstoply_present": signals.apply(max_nonstop_ntps, axis=1)
> threshold_asint, > threshold_asint,
"quantile": [np.quantile(signals.values[signals.notna()], threshold)], # "quantile": [np.quantile(signals.values[signals.notna()], threshold)],
} }
return set(case_mgr[condition].index) return set(signals.index[case_mgr[condition]])
def _as_int(threshold: Union[float, int], ntps: int): def _as_int(threshold: Union[float, int], ntps: int):
......
...@@ -143,11 +143,26 @@ class PostProcessor: ...@@ -143,11 +143,26 @@ class PostProcessor:
merged_moda = set([tuple(x) for x in merge_events[:, 0, :]]).intersection( merged_moda = set([tuple(x) for x in merge_events[:, 0, :]]).intersection(
set([*moset, *daset, *picked_set]) set([*moset, *daset, *picked_set])
) )
for source, target in merge_events: search = lambda a, b: np.where(
if tuple(source) in merged_moda: np.in1d(
mothers[np.isin(mothers, source).all(axis=1)] = target np.ravel_multi_index(a.T, a.max(0) + 1),
daughters[np.isin(daughters, source).all(axis=1)] = target np.ravel_multi_index(b.T, a.max(0) + 1),
indices[np.isin(indices, source).all(axis=1)] = target )
)
for target, source in merge_events:
if (
tuple(source) in moset
): # update mother to lowest positive index among the two
mother_ids = search(mothers, source)
mothers[mother_ids] = (
target[0],
self.pick_mother(mothers[mother_ids][0][1], target[1]),
)
if tuple(source) in daset:
daughters[search(daughters, source)] = target
if tuple(source) in picked_set:
indices[search(indices, source)] = target
self._writer.write( self._writer.write(
"postprocessing/lineage_merged", "postprocessing/lineage_merged",
...@@ -167,6 +182,17 @@ class PostProcessor: ...@@ -167,6 +182,17 @@ class PostProcessor:
overwrite="overwrite", overwrite="overwrite",
) )
@staticmethod
def pick_mother(a, b):
"""Update the mother id following this priorities:
The mother has a lower id
"""
x = max(a, b)
if min([a, b]):
x = [a, b][np.argmin([a, b])]
return x
def run(self): def run(self):
self.run_prepost() self.run_prepost()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment