diff --git a/src/postprocessor/core/lineageprocess.py b/src/postprocessor/core/lineageprocess.py index e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe..296effce76612577ff97c8783a2d615389b7fc31 100644 --- a/src/postprocessor/core/lineageprocess.py +++ b/src/postprocessor/core/lineageprocess.py @@ -1 +1,85 @@ -#!/usr/bin/env python3 +import typing as t +from abc import abstractmethod + +import numpy as np +import pandas as pd + +from agora.abc import ParametersABC +from postprocessor.core.abc import PostProcessABC + +# from agora.utils.lineage import group_matrix + + +class LineageProcessParameters(ParametersABC): + """ + Parameters + """ + + _defaults = {} + + +class LineageProcess(PostProcessABC): + """ + Lineage process that must be passed a (N,3) lineage matrix (where the coliumns are trap, mother, daughter respectively) + """ + + def __init__(self, parameters: LineageProcessParameters): + super().__init__(parameters) + + def filter_signal_cells( + self, signal: pd.DataFrame, lineage: np.ndarray = None + ): + """ + Use casting to filter cell ids in signal and lineage + """ + if lineage is None: + lineage = self.lineage + + sig_ind = np.array(list(signal.index)).T[:, None, :] + mo_av = ( + (lineage[:, :2].T[:, :, None] == sig_ind).all(axis=0).any(axis=1) + ) + da_av = ( + (lineage[:, [0, 2]].T[:, :, None] == sig_ind) + .all(axis=0) + .any(axis=1) + ) + + return lineage[mo_av & da_av] + + @abstractmethod + def run( + self, + signal: pd.DataFrame, + lineage: np.ndarray, + *args, + ): + pass + + @classmethod + def as_function( + cls, + data: pd.DataFrame, + lineage: t.Union[t.Dict[t.Tuple[int], t.List[int]]], + *extra_data, + **kwargs, + ): + """ + Overrides PostProcess.as_function classmethod. + Lineage functions require lineage information to be passed if run as function. + """ + # if isinstance(lineage, np.ndarray): + # lineage = group_matrix(lineage, n_keys=2) + + parameters = cls.default_parameters(**kwargs) + return cls(parameters=parameters).run( + data, lineage=lineage, *extra_data + ) + # super().as_function(data, *extra_data, lineage=lineage, **kwargs) + + def load_lineage(self, lineage): + """ + Reshape the lineage information if needed + """ + + self.lineage = lineage diff --git a/src/postprocessor/core/processes/lineageprocess.py b/src/postprocessor/core/processes/lineageprocess.py deleted file mode 100644 index 38565e18240b8a2b1d8f6f937e7be7a6ea0beb14..0000000000000000000000000000000000000000 --- a/src/postprocessor/core/processes/lineageprocess.py +++ /dev/null @@ -1,75 +0,0 @@ -import typing as t -from abc import abstractmethod - -import numpy as np -import pandas as pd - -from agora.abc import ParametersABC -from agora.utils.lineage import group_matrix -from postprocessor.core.abc import PostProcessABC - - -class LineageProcessParameters(ParametersABC): - """ - Parameters - """ - - _defaults = {} - - -class LineageProcess(PostProcessABC): - """ - Lineage process that must be passed a (N,3) lineage matrix (where the coliumns are trap, mother, daughter respectively) - """ - - def __init__(self, parameters: LineageProcessParameters): - super().__init__(parameters) - - def filter_signal_cells(self, signal: pd.DataFrame): - """ - Use casting to filter cell ids in signal and lineage - """ - - sig_ind = np.array(list(signal.index)).T[:, None, :] - mo_av = ( - (self.lineage[:, :2].T[:, :, None] == sig_ind) - .all(axis=0) - .any(axis=1) - ) - da_av = ( - (self.lineage[:, [0, 2]].T[:, :, None] == sig_ind) - .all(axis=0) - .any(axis=1) - ) - - return self.lineage[mo_av & da_av] - - @abstractmethod - def run( - self, - data: pd.DataFrame, - mother_bud_ids: t.Dict[t.Tuple[int], t.Collection[int]], - *args, - ): - pass - - @classmethod - def as_function( - cls, - data: pd.DataFrame, - lineage: t.Union[t.Dict[t.Tuple[int], t.List[int]]], - *extra_data, - **kwargs, - ): - """ - Overrides PostProcess.as_function classmethod. - Lineage functions require lineage information to be passed if run as function. - """ - if isinstance(lineage, np.ndarray): - lineage = group_matrix(lineage, n_keys=2) - - parameters = cls.default_parameters(**kwargs) - return cls(parameters=parameters).run( - data, mother_bud_ids=lineage, *extra_data - ) - # super().as_function(data, *extra_data, lineage=lineage, **kwargs) diff --git a/src/postprocessor/core/multisignal/aggregate.py b/src/postprocessor/core/reshapers/aggregate.py similarity index 100% rename from src/postprocessor/core/multisignal/aggregate.py rename to src/postprocessor/core/reshapers/aggregate.py diff --git a/src/postprocessor/core/processes/bud_metric.py b/src/postprocessor/core/reshapers/bud_metric.py similarity index 92% rename from src/postprocessor/core/processes/bud_metric.py rename to src/postprocessor/core/reshapers/bud_metric.py index f4d42ab275e8075b82db31a55a502fb9ee7ceefb..be4a978c1c70aafa84c5157de49d447a6c587571 100644 --- a/src/postprocessor/core/processes/bud_metric.py +++ b/src/postprocessor/core/reshapers/bud_metric.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd from agora.utils.lineage import mb_array_to_dict -from postprocessor.core.processes.lineageprocess import ( +from postprocessor.core.lineageprocess import ( LineageProcess, LineageProcessParameters, ) @@ -74,10 +74,3 @@ class bud_metric(LineageProcess): df = pd.DataFrame(mothers_mat, index=md.keys(), columns=signal.columns) df.index.names = signal.index.names return df - - def load_lineage(self, lineage): - """ - Reshape the lineage information if needed - """ - - self.lineage = lineage diff --git a/src/postprocessor/core/processes/buddings.py b/src/postprocessor/core/reshapers/buddings.py similarity index 97% rename from src/postprocessor/core/processes/buddings.py rename to src/postprocessor/core/reshapers/buddings.py index 839f0fdc0c2e04863ae33ac8b522120ac9256023..0b01dad70c5eab067f26b430e5451fdf70bcb4e2 100644 --- a/src/postprocessor/core/processes/buddings.py +++ b/src/postprocessor/core/reshapers/buddings.py @@ -6,7 +6,7 @@ from itertools import product import numpy as np import pandas as pd -from postprocessor.core.processes.lineageprocess import ( +from postprocessor.core.lineageprocess import ( LineageProcess, LineageProcessParameters, ) diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index cf06dd3643c415e1a16e5cb59fee36a6d2f25f91..2fc1ece75e5d078dbe05a309bf3a8e30aed06f07 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -198,175 +198,9 @@ class picker(PostProcessABC): "nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1) > thresh, "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh, - "mb_guess": lambda s, p1, p2: self.mb_guess_wrap(s, p1, p2) - # "quantile": [np.quantile(signals.values[signals.notna()], threshold)], } return set(signals.index[case_mgr[condition](signals, *threshold)]) - def mb_guess(self, df, ba, trap, min_budgrowth_t, min_mobud_ratio): - """ - Parameters - ---------- - signals : pd.DataFrame - ba : list of cell_labels that come from bud assignment - trap : Trap id (used to fetch raw bud) - min_budgrowth_t: Minimal number of timepoints we lock reassignment after assigning bud - min_initial_size: Minimal mother-bud ratio when it was first identified - add_ba: Bool that incorporates bud_assignment data after the normal assignment - - Thinking this problem as the Movie Scheduling problem (Skiena's the algorithm design manual chapter 1.2), - we will try to pick the set of filtered cells that grow the fastest and don't overlap within 5 time points - TODO adjust overlap to minutes using metadata - """ - - # if trap == 21: # Use this to check specific trap problems through a debugger - # print("stop") - ntps = df.notna().sum(axis=1) - mother_id = df.index[ntps.argmax()] - nomother = df.drop(mother_id) - if not len(nomother): - return [] - nomother = nomother.loc[ # Clean short-lived cells outside our mother cell's timepoints - nomother.apply( - lambda x: x.first_valid_index() - >= df.loc[mother_id].first_valid_index() - and x.first_valid_index() - <= df.loc[mother_id].last_valid_index(), - axis=1, - ) - ] - - score = -nomother.apply( # Get slope of candidate daughters - lambda x: self.get_slope(x.dropna()), axis=1 - ) - start = nomother.apply(pd.Series.first_valid_index, axis=1) - - # clean duplicates - duplicates = start.duplicated(False) - if duplicates.any(): - score = self.get_nodup_idx(start, score, duplicates, nomother) - nomother = nomother.loc[score.index] - nomother.index = nomother.index.astype("int") - start = start.loc[score.index] - start.index = start.index.astype(int) - - d_to_mother = ( - nomother[start] - df.loc[mother_id, start] * min_mobud_ratio - ).sort_index(axis=1) - size_filter = d_to_mother[ - d_to_mother.apply(lambda x: x.dropna().iloc[0], axis=1) < 0 - ] - cols_sorted = ( - size_filter.sort_index(axis=1) - .apply(pd.Series.first_valid_index, axis=1) - .sort_values() - ) - score = score.loc[cols_sorted.index] - if not len(cols_sorted): - bud_candidates = pd.DataFrame() - else: - # Find the set with the highest number of growing cells and highest avg growth rate for this # - mivs = self.max_ind_vertex_sets( - cols_sorted.values, min_budgrowth_t - ) - best_set = list( - mivs[np.argmin([sum(score.iloc[list(s)]) for s in mivs])] - ) - best_indices = cols_sorted.index[best_set] - - start = start.loc[best_indices] - bud_candidates = cols_sorted.loc[best_indices] - # bud_candidates = cols_sorted.loc[ - # [True, *(np.diff(cols_sorted.values) > min_budgrowth_t)] - # ] - - # Add random-forest bud assignment information here - new_ba_cells = [] - if ( - ba - ): # Use the mother-daughter rf information to prioritise tracks over others - # TODO add merge application to indices and see if that recovers more cells - ba = set(ba).intersection(nomother.index) - ba_df = nomother.loc[ba, :] - start_ba = ba_df.apply(pd.Series.first_valid_index, axis=1) - new_ba_cells = list(set(start_ba.index).difference(start.index)) - - distances = np.subtract.outer( - start.values, start_ba.loc[new_ba_cells].values - ) - todrop, _ = np.where(abs(distances) < min_budgrowth_t) - bud_candidates = bud_candidates.drop(bud_candidates.index[todrop]) - - return [mother_id] + bud_candidates.index.tolist() + new_ba_cells - - @staticmethod - def max_ind_vertex_sets(values, min_distance): - """ - Generates an adjacency matrix from multiple points, joining neighbours closer than min_distance - Then returns the maximal independent vertex sets - values: list of int values - min_distance: int minimal distance to cluster - """ - adj = np.zeros((len(values), len(values))).astype(bool) - dist = abs(np.subtract.outer(values, values)) - adj[dist <= min_distance] = True - - g = ig.Graph.Adjacency(adj, mode="undirected") - miv_sets = g.maximal_independent_vertex_sets() - return miv_sets - - def get_nodup_idx(self, start, score, duplicates, nomother): - """ - Return the start DataFrame without duplicates - - :start: pd.Series indicating the first valid time point - :score: pd.Series containing a score to minimise - :duplicates: Dataframe containing duplicated entries - :nomother: Dataframe with non-mother cells - """ - dup_tps = np.unique(start[duplicates]) - idx, tps = zip( - *[ - (score.loc[nomother.loc[start == tp, tp].index].idxmin(), tp) - for tp in dup_tps - ] - ) - score = score[~duplicates] - score = pd.concat( - (score, pd.Series(tps, index=idx, dtype="int", name="cell_label")) - ) - return score - - def mb_guess_wrap(self, signals, *args): - if not len(signals): - return pd.Series([]) - ids = [] - mothers, buds = self.get_mothers_daughters() - mothers = np.array(mothers) - buds = np.array(buds) - ba = [] - # if buds.any(): - # ba_bytrap = { - # i: np.where(buds[:, 0] == i) for i in range(buds[:, 0].max() + 1) - # } - for trap in signals.index.unique(level="trap"): - # ba = list( - # set(mothers[ba_bytrap[trap], 1][0].tolist()).union( - # buds[ba_bytrap[trap], 1][0].tolist() - # ) - # ) - df = signals.loc[trap] - selected_ids = self.mb_guess(df, ba, trap, *args) - ids += [(trap, i) for i in selected_ids] - - idx_srs = pd.Series(False, signals.index).astype(bool) - idx_srs.loc[ids] = True - return idx_srs - - @staticmethod - def get_slope(x): - return np.polyfit(range(len(x)), x, 1)[0] - def _as_int(threshold: Union[float, int], ntps: int): if type(threshold) is float: