From 9f9d2ba0ecf8a1e9dd61f8b354ed9af7b521d3db Mon Sep 17 00:00:00 2001 From: Swainlab <peter.swain@ed.ac.uk> Date: Sun, 2 Jul 2023 18:08:43 +0100 Subject: [PATCH] docs for budding; re-write of budmetric --- src/agora/io/signal.py | 2 +- src/agora/utils/lineage.py | 5 - .../core/reshapers/bud_metric.py | 100 +++++++++++------- 3 files changed, 61 insertions(+), 46 deletions(-) diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 20b6d8ba..7a7f871f 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -70,7 +70,7 @@ class Signal(BridgeH5): @staticmethod def add_name(df, name): - """Add column of identical strings to a dataframe.""" + """TODO""" df.name = name return df diff --git a/src/agora/utils/lineage.py b/src/agora/utils/lineage.py index 52fb552b..05c161a5 100644 --- a/src/agora/utils/lineage.py +++ b/src/agora/utils/lineage.py @@ -1,12 +1,8 @@ #!/usr/bin/env python3 -import re -import typing as t import numpy as np -import pandas as pd from agora.io.bridge import groupsort -from itertools import groupby def mb_array_to_dict(mb_array: np.ndarray): @@ -19,4 +15,3 @@ def mb_array_to_dict(mb_array: np.ndarray): for trap, mo_da in groupsort(mb_array).items() for mo, daughters in groupsort(mo_da).items() } - diff --git a/src/postprocessor/core/reshapers/bud_metric.py b/src/postprocessor/core/reshapers/bud_metric.py index b8952288..ee239221 100644 --- a/src/postprocessor/core/reshapers/bud_metric.py +++ b/src/postprocessor/core/reshapers/bud_metric.py @@ -20,8 +20,9 @@ class BudMetricParameters(LineageProcessParameters): class BudMetric(LineageProcess): """ - Requires mother-bud information to create a new dataframe where the indices are mother ids and - values are the daughters' values for a given signal. + Requires mother-bud information to create a new dataframe where the + indices are mother ids and values are the daughters' values for a + given signal. """ def __init__(self, parameters: BudMetricParameters): @@ -38,7 +39,6 @@ class BudMetric(LineageProcess): else: assert "mother_label" in signal.index.names lineage = signal.index.to_list() - return self.get_bud_metric(signal, mb_array_to_dict(lineage)) @staticmethod @@ -48,7 +48,9 @@ class BudMetric(LineageProcess): """ signal: Daughter-inclusive dataframe - md: Mother-daughters dictionary where key is mother's index and its values are a list of daughter indices + md: dictionary where key is mother's index, + defined as (trap, cell_label), and its values are a list of + daughter indices, as (trap, cell_label). Get fvi (First Valid Index) for all cells Create empty matrix @@ -61,63 +63,81 @@ class BudMetric(LineageProcess): Convert matrix into dataframe using mother indices """ - mothers_mat = np.zeros((len(md), signal.shape[1])) - cells_were_dropped = 0 # Flag determines if mothers (1), daughters (2) or both were missing (3) - md_index = signal.index - if ( - "mother_label" not in md_index.names - ): # Generate mother label from md dict if unavailable + # md_index should only comprise (trap, cell_label) + if "mother_label" not in md_index.names: + # dict with daughter indices as keys d = {v: k for k, values in md.items() for v in values} + # generate mother_label in signal using the mother's cell_label signal["mother_label"] = list( map(lambda x: d.get(x, [0])[-1], signal.index) ) signal.set_index("mother_label", append=True, inplace=True) - related_items = set( - [*md.keys(), *[y for x in md.values() for y in x]] - ) - md_index = md_index.intersection(related_items) - elif "mother_label" in md_index.names: - md_index = md_index.droplevel("mother_label") + # combine mothers and daughter indices + mothers_index = md.keys() + daughters_index = [y for x in md.values() for y in x] + relations = set([*mothers_index, *daughters_index]) + # keep from md_index only mother and daughters + md_index = md_index.intersection(relations) else: - raise ("Unavailable relationship information") - + md_index = md_index.droplevel("mother_label") if len(md_index) < len(signal): - print("Dropped cells before bud_metric") # TODO log - + print("Dropped cells before applying bud_metric") # TODO log + # restrict signal to the cells in md_index, moving mother_label to do so signal = ( signal.reset_index("mother_label") .loc(axis=0)[md_index] .set_index("mother_label", append=True) ) - - names = list(signal.index.names) - del names[-2] - - output_df = ( - signal.loc[signal.index.get_level_values("mother_label") > 0] - .groupby(names) - .apply(lambda x: _combine_daughter_tracks(x)) + # restrict to daughters: cells with a mother + mother_labels = signal.index.get_level_values("mother_label") + daughter_df = signal.loc[mother_labels > 0] + # join data for daughters with the same mother + output_df = daughter_df.groupby(["trap", "mother_label"]).apply( + lambda x: _combine_daughter_tracks(x) ) + output_df.columns = signal.columns output_df["padding_level"] = 0 output_df.set_index("padding_level", append=True, inplace=True) - if len(output_df): output_df.index.names = signal.index.names return output_df -def _combine_daughter_tracks(tracks: t.Collection[pd.Series]): +def _combine_daughter_tracks(tracks: pd.DataFrame): """ - Combine multiple time series of cells into one, overwriting values - prioritising the most recent entity. + Combine multiple time series of daughter cells into one time series. + + At any one time, a mother cell should have only one daughter. + + Two daughters are still sometimes present at the same time point, and we + then choose the daughter that appears first. + + TODO We need to fix examples with more than one daughter at a time point. + + Parameters + ---------- + tracks: a Signal + Data for all daughters, which are distinguished by different cell_labels, + for a particular trap and mother_label. """ - sorted_da_ids = tracks.sort_index(level="cell_label") - sorted_da_ids.index = range(len(sorted_da_ids)) - tp_fvt = sorted_da_ids.apply(lambda x: x.first_valid_index(), axis=0) - tp_fvt = sorted_da_ids.columns.get_indexer(tp_fvt) - tp_fvt[tp_fvt < 0] = len(sorted_da_ids) - 1 - - _metric = np.choose(tp_fvt, sorted_da_ids.values) - return pd.Series(_metric, index=tracks.columns) + # sort by daughter IDs + bud_df = tracks.sort_index(level="cell_label") + # remove multi-index + bud_df.index = range(len(bud_df)) + # find which row of sorted_df has the daughter for each time point + tp_fvt: pd.Series = bud_df.apply(lambda x: x.first_valid_index(), axis=0) + # combine data for all daughters + combined_tracks = np.nan * np.ones(tracks.columns.size) + for bud_row in np.unique(tp_fvt.dropna().values).astype(int): + ilocs = np.where(tp_fvt.values == bud_row)[0] + combined_tracks[ilocs] = bud_df.values[bud_row, ilocs] + # TODO delete old version + tp_fvt = bud_df.columns.get_indexer(tp_fvt) + tp_fvt[tp_fvt == -1] = len(bud_df) - 1 + old = np.choose(tp_fvt, bud_df.values) + assert ( + (combined_tracks == old) | (np.isnan(combined_tracks) & np.isnan(old)) + ).all(), "yikes" + return pd.Series(combined_tracks, index=tracks.columns) -- GitLab