Skip to content
Snippets Groups Projects
Commit 0c6d7660 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(post): relocate lineage processes

parent 0c34246c
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3 import typing as t
from abc import abstractmethod
import numpy as np
import pandas as pd
from agora.abc import ParametersABC
from postprocessor.core.abc import PostProcessABC
# from agora.utils.lineage import group_matrix
class LineageProcessParameters(ParametersABC):
"""
Parameters
"""
_defaults = {}
class LineageProcess(PostProcessABC):
"""
Lineage process that must be passed a (N,3) lineage matrix (where the coliumns are trap, mother, daughter respectively)
"""
def __init__(self, parameters: LineageProcessParameters):
super().__init__(parameters)
def filter_signal_cells(
self, signal: pd.DataFrame, lineage: np.ndarray = None
):
"""
Use casting to filter cell ids in signal and lineage
"""
if lineage is None:
lineage = self.lineage
sig_ind = np.array(list(signal.index)).T[:, None, :]
mo_av = (
(lineage[:, :2].T[:, :, None] == sig_ind).all(axis=0).any(axis=1)
)
da_av = (
(lineage[:, [0, 2]].T[:, :, None] == sig_ind)
.all(axis=0)
.any(axis=1)
)
return lineage[mo_av & da_av]
@abstractmethod
def run(
self,
signal: pd.DataFrame,
lineage: np.ndarray,
*args,
):
pass
@classmethod
def as_function(
cls,
data: pd.DataFrame,
lineage: t.Union[t.Dict[t.Tuple[int], t.List[int]]],
*extra_data,
**kwargs,
):
"""
Overrides PostProcess.as_function classmethod.
Lineage functions require lineage information to be passed if run as function.
"""
# if isinstance(lineage, np.ndarray):
# lineage = group_matrix(lineage, n_keys=2)
parameters = cls.default_parameters(**kwargs)
return cls(parameters=parameters).run(
data, lineage=lineage, *extra_data
)
# super().as_function(data, *extra_data, lineage=lineage, **kwargs)
def load_lineage(self, lineage):
"""
Reshape the lineage information if needed
"""
self.lineage = lineage
import typing as t
from abc import abstractmethod
import numpy as np
import pandas as pd
from agora.abc import ParametersABC
from agora.utils.lineage import group_matrix
from postprocessor.core.abc import PostProcessABC
class LineageProcessParameters(ParametersABC):
"""
Parameters
"""
_defaults = {}
class LineageProcess(PostProcessABC):
"""
Lineage process that must be passed a (N,3) lineage matrix (where the coliumns are trap, mother, daughter respectively)
"""
def __init__(self, parameters: LineageProcessParameters):
super().__init__(parameters)
def filter_signal_cells(self, signal: pd.DataFrame):
"""
Use casting to filter cell ids in signal and lineage
"""
sig_ind = np.array(list(signal.index)).T[:, None, :]
mo_av = (
(self.lineage[:, :2].T[:, :, None] == sig_ind)
.all(axis=0)
.any(axis=1)
)
da_av = (
(self.lineage[:, [0, 2]].T[:, :, None] == sig_ind)
.all(axis=0)
.any(axis=1)
)
return self.lineage[mo_av & da_av]
@abstractmethod
def run(
self,
data: pd.DataFrame,
mother_bud_ids: t.Dict[t.Tuple[int], t.Collection[int]],
*args,
):
pass
@classmethod
def as_function(
cls,
data: pd.DataFrame,
lineage: t.Union[t.Dict[t.Tuple[int], t.List[int]]],
*extra_data,
**kwargs,
):
"""
Overrides PostProcess.as_function classmethod.
Lineage functions require lineage information to be passed if run as function.
"""
if isinstance(lineage, np.ndarray):
lineage = group_matrix(lineage, n_keys=2)
parameters = cls.default_parameters(**kwargs)
return cls(parameters=parameters).run(
data, mother_bud_ids=lineage, *extra_data
)
# super().as_function(data, *extra_data, lineage=lineage, **kwargs)
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
import pandas as pd import pandas as pd
from agora.utils.lineage import mb_array_to_dict from agora.utils.lineage import mb_array_to_dict
from postprocessor.core.processes.lineageprocess import ( from postprocessor.core.lineageprocess import (
LineageProcess, LineageProcess,
LineageProcessParameters, LineageProcessParameters,
) )
...@@ -74,10 +74,3 @@ class bud_metric(LineageProcess): ...@@ -74,10 +74,3 @@ class bud_metric(LineageProcess):
df = pd.DataFrame(mothers_mat, index=md.keys(), columns=signal.columns) df = pd.DataFrame(mothers_mat, index=md.keys(), columns=signal.columns)
df.index.names = signal.index.names df.index.names = signal.index.names
return df return df
def load_lineage(self, lineage):
"""
Reshape the lineage information if needed
"""
self.lineage = lineage
...@@ -6,7 +6,7 @@ from itertools import product ...@@ -6,7 +6,7 @@ from itertools import product
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from postprocessor.core.processes.lineageprocess import ( from postprocessor.core.lineageprocess import (
LineageProcess, LineageProcess,
LineageProcessParameters, LineageProcessParameters,
) )
......
...@@ -198,175 +198,9 @@ class picker(PostProcessABC): ...@@ -198,175 +198,9 @@ class picker(PostProcessABC):
"nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1) "nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1)
> thresh, > thresh,
"growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh, "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh,
"mb_guess": lambda s, p1, p2: self.mb_guess_wrap(s, p1, p2)
# "quantile": [np.quantile(signals.values[signals.notna()], threshold)],
} }
return set(signals.index[case_mgr[condition](signals, *threshold)]) return set(signals.index[case_mgr[condition](signals, *threshold)])
def mb_guess(self, df, ba, trap, min_budgrowth_t, min_mobud_ratio):
"""
Parameters
----------
signals : pd.DataFrame
ba : list of cell_labels that come from bud assignment
trap : Trap id (used to fetch raw bud)
min_budgrowth_t: Minimal number of timepoints we lock reassignment after assigning bud
min_initial_size: Minimal mother-bud ratio when it was first identified
add_ba: Bool that incorporates bud_assignment data after the normal assignment
Thinking this problem as the Movie Scheduling problem (Skiena's the algorithm design manual chapter 1.2),
we will try to pick the set of filtered cells that grow the fastest and don't overlap within 5 time points
TODO adjust overlap to minutes using metadata
"""
# if trap == 21: # Use this to check specific trap problems through a debugger
# print("stop")
ntps = df.notna().sum(axis=1)
mother_id = df.index[ntps.argmax()]
nomother = df.drop(mother_id)
if not len(nomother):
return []
nomother = nomother.loc[ # Clean short-lived cells outside our mother cell's timepoints
nomother.apply(
lambda x: x.first_valid_index()
>= df.loc[mother_id].first_valid_index()
and x.first_valid_index()
<= df.loc[mother_id].last_valid_index(),
axis=1,
)
]
score = -nomother.apply( # Get slope of candidate daughters
lambda x: self.get_slope(x.dropna()), axis=1
)
start = nomother.apply(pd.Series.first_valid_index, axis=1)
# clean duplicates
duplicates = start.duplicated(False)
if duplicates.any():
score = self.get_nodup_idx(start, score, duplicates, nomother)
nomother = nomother.loc[score.index]
nomother.index = nomother.index.astype("int")
start = start.loc[score.index]
start.index = start.index.astype(int)
d_to_mother = (
nomother[start] - df.loc[mother_id, start] * min_mobud_ratio
).sort_index(axis=1)
size_filter = d_to_mother[
d_to_mother.apply(lambda x: x.dropna().iloc[0], axis=1) < 0
]
cols_sorted = (
size_filter.sort_index(axis=1)
.apply(pd.Series.first_valid_index, axis=1)
.sort_values()
)
score = score.loc[cols_sorted.index]
if not len(cols_sorted):
bud_candidates = pd.DataFrame()
else:
# Find the set with the highest number of growing cells and highest avg growth rate for this #
mivs = self.max_ind_vertex_sets(
cols_sorted.values, min_budgrowth_t
)
best_set = list(
mivs[np.argmin([sum(score.iloc[list(s)]) for s in mivs])]
)
best_indices = cols_sorted.index[best_set]
start = start.loc[best_indices]
bud_candidates = cols_sorted.loc[best_indices]
# bud_candidates = cols_sorted.loc[
# [True, *(np.diff(cols_sorted.values) > min_budgrowth_t)]
# ]
# Add random-forest bud assignment information here
new_ba_cells = []
if (
ba
): # Use the mother-daughter rf information to prioritise tracks over others
# TODO add merge application to indices and see if that recovers more cells
ba = set(ba).intersection(nomother.index)
ba_df = nomother.loc[ba, :]
start_ba = ba_df.apply(pd.Series.first_valid_index, axis=1)
new_ba_cells = list(set(start_ba.index).difference(start.index))
distances = np.subtract.outer(
start.values, start_ba.loc[new_ba_cells].values
)
todrop, _ = np.where(abs(distances) < min_budgrowth_t)
bud_candidates = bud_candidates.drop(bud_candidates.index[todrop])
return [mother_id] + bud_candidates.index.tolist() + new_ba_cells
@staticmethod
def max_ind_vertex_sets(values, min_distance):
"""
Generates an adjacency matrix from multiple points, joining neighbours closer than min_distance
Then returns the maximal independent vertex sets
values: list of int values
min_distance: int minimal distance to cluster
"""
adj = np.zeros((len(values), len(values))).astype(bool)
dist = abs(np.subtract.outer(values, values))
adj[dist <= min_distance] = True
g = ig.Graph.Adjacency(adj, mode="undirected")
miv_sets = g.maximal_independent_vertex_sets()
return miv_sets
def get_nodup_idx(self, start, score, duplicates, nomother):
"""
Return the start DataFrame without duplicates
:start: pd.Series indicating the first valid time point
:score: pd.Series containing a score to minimise
:duplicates: Dataframe containing duplicated entries
:nomother: Dataframe with non-mother cells
"""
dup_tps = np.unique(start[duplicates])
idx, tps = zip(
*[
(score.loc[nomother.loc[start == tp, tp].index].idxmin(), tp)
for tp in dup_tps
]
)
score = score[~duplicates]
score = pd.concat(
(score, pd.Series(tps, index=idx, dtype="int", name="cell_label"))
)
return score
def mb_guess_wrap(self, signals, *args):
if not len(signals):
return pd.Series([])
ids = []
mothers, buds = self.get_mothers_daughters()
mothers = np.array(mothers)
buds = np.array(buds)
ba = []
# if buds.any():
# ba_bytrap = {
# i: np.where(buds[:, 0] == i) for i in range(buds[:, 0].max() + 1)
# }
for trap in signals.index.unique(level="trap"):
# ba = list(
# set(mothers[ba_bytrap[trap], 1][0].tolist()).union(
# buds[ba_bytrap[trap], 1][0].tolist()
# )
# )
df = signals.loc[trap]
selected_ids = self.mb_guess(df, ba, trap, *args)
ids += [(trap, i) for i in selected_ids]
idx_srs = pd.Series(False, signals.index).astype(bool)
idx_srs.loc[ids] = True
return idx_srs
@staticmethod
def get_slope(x):
return np.polyfit(range(len(x)), x, 1)[0]
def _as_int(threshold: Union[float, int], ntps: int): def _as_int(threshold: Union[float, int], ntps: int):
if type(threshold) is float: if type(threshold) is float:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment