From 00fbb8fcaf578686f4a42fa48b4e1c029a6395f0 Mon Sep 17 00:00:00 2001
From: Swainlab <peter.swain@ed.ac.uk>
Date: Wed, 12 Jul 2023 18:27:30 +0100
Subject: [PATCH] mostly docs; removed decorator in signal

---
 src/agora/io/decorators.py                    |   6 +-
 src/agora/io/signal.py                        |  40 ++---
 src/agora/utils/kymograph.py                  |   9 +-
 src/postprocessor/core/processor.py           |  23 ++-
 .../core/reshapers/bud_metric.py              | 138 ++++++++----------
 src/postprocessor/core/reshapers/buddings.py  |   9 +-
 src/postprocessor/core/reshapers/picker.py    |  28 ++--
 src/postprocessor/grouper.py                  |  18 +--
 8 files changed, 130 insertions(+), 141 deletions(-)

diff --git a/src/agora/io/decorators.py b/src/agora/io/decorators.py
index f4d8d023..15f08578 100644
--- a/src/agora/io/decorators.py
+++ b/src/agora/io/decorators.py
@@ -6,17 +6,19 @@ import typing as t
 from functools import wraps
 
 
-def _first_arg_str_to_df(
+def _first_arg_str_to_raw_df(
     fn: t.Callable,
 ):
     """Enable Signal-like classes to convert strings to data sets."""
+
     @wraps(fn)
     def format_input(*args, **kwargs):
         cls = args[0]
         data = args[1]
         if isinstance(data, str):
-            # get data from h5 file
+            # get data from h5 file using Signal's get_raw
             data = cls.get_raw(data)
         # replace path in the undecorated function with data
         return fn(cls, data, *args[2:], **kwargs)
+
     return format_input
diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index 7a7f871f..8a3957c3 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -10,7 +10,7 @@ import numpy as np
 import pandas as pd
 
 from agora.io.bridge import BridgeH5
-from agora.io.decorators import _first_arg_str_to_df
+from agora.io.decorators import _first_arg_str_to_raw_df
 from agora.utils.indexing import validate_association
 from agora.utils.kymograph import add_index_levels
 from agora.utils.merge import apply_merges
@@ -26,7 +26,8 @@ class Signal(BridgeH5):
     """
 
     def __init__(self, file: t.Union[str, Path]):
-        """Define index_names for dataframes, candidate fluorescence channels, and composite statistics."""
+        """Define index_names for dataframes, candidate fluorescence channels,
+        and composite statistics."""
         super().__init__(file, flag=None)
         self.index_names = (
             "experiment",
@@ -48,9 +49,9 @@ class Signal(BridgeH5):
 
     def __getitem__(self, dsets: t.Union[str, t.Collection]):
         """Get and potentially pre-process data from h5 file and return as a dataframe."""
-        if isinstance(dsets, str):  # no pre-processing
+        if isinstance(dsets, str):
             return self.get(dsets)
-        elif isinstance(dsets, list):  # pre-processing
+        elif isinstance(dsets, list):
             is_bgd = [dset.endswith("imBackground") for dset in dsets]
             # Check we are not comparing tile-indexed and cell-indexed data
             assert sum(is_bgd) == 0 or sum(is_bgd) == len(
@@ -60,17 +61,18 @@ class Signal(BridgeH5):
         else:
             raise Exception(f"Invalid type {type(dsets)} to get datasets")
 
-    def get(self, dsets: t.Union[str, t.Collection], **kwargs):
-        """Get and potentially pre-process data from h5 file and return as a dataframe."""
-        if isinstance(dsets, str):
-            # no pre-processing
-            dsets = self.get_raw(dsets, **kwargs)
+    def get(self, dset_name: t.Union[str, t.Collection], **kwargs):
+        """Return pre-processed data as a dataframe."""
+        if isinstance(dset_name, str):
+            dsets = self.get_raw(dset_name, **kwargs)
             prepost_applied = self.apply_prepost(dsets, **kwargs)
-            return self.add_name(prepost_applied, dsets)
+            return self.add_name(prepost_applied, dset_name)
+        else:
+            raise Exception("Error in Signal.get")
 
     @staticmethod
     def add_name(df, name):
-        """TODO"""
+        """Add name of the Signal as an attribute to its corresponding dataframe."""
         df.name = name
         return df
 
@@ -103,7 +105,8 @@ class Signal(BridgeH5):
 
     @staticmethod
     def get_retained(df, cutoff):
-        """Return a fraction of the df, one without later time points."""
+        """Return rows of df with at least cutoff fraction of the total number
+        of time points."""
         return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]
 
     @property
@@ -112,15 +115,15 @@ class Signal(BridgeH5):
         with h5py.File(self.filename, "r") as f:
             return list(f.attrs["channels"])
 
-    @_first_arg_str_to_df
     def retained(self, signal, cutoff=0.8):
         """
         Load data (via decorator) and reduce the resulting dataframe.
 
         Load data for a signal or a list of signals and reduce the resulting
-        dataframes to a fraction of their original size, losing late time
-        points.
+        dataframes to rows with sufficient numbers of time points.
         """
+        if isinstance(signal, str):
+            signal = self.get_raw(signal)
         if isinstance(signal, pd.DataFrame):
             return self.get_retained(signal, cutoff)
         elif isinstance(signal, list):
@@ -133,13 +136,12 @@ class Signal(BridgeH5):
         """
         Get lineage data from a given location in the h5 file.
 
-        Returns an array with three columns: the tile id, the mother label, and the daughter label.
+        Returns an array with three columns: the tile id, the mother label,
+        and the daughter label.
         """
         if lineage_location is None:
             lineage_location = "modifiers/lineage_merged"
         with h5py.File(self.filename, "r") as f:
-            # if lineage_location not in f:
-            #     lineage_location = lineage_location.split("_")[0]
             if lineage_location not in f:
                 lineage_location = "postprocessing/lineage"
             tile_mo_da = f[lineage_location]
@@ -155,7 +157,7 @@ class Signal(BridgeH5):
                 ).T
         return lineage
 
-    @_first_arg_str_to_df
+    @_first_arg_str_to_raw_df
     def apply_prepost(
         self,
         data: t.Union[str, pd.DataFrame],
diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py
index 62a5a962..46df84ec 100644
--- a/src/agora/utils/kymograph.py
+++ b/src/agora/utils/kymograph.py
@@ -86,16 +86,19 @@ def bidirectional_retainment_filter(
     daughters_thresh: int = 7,
 ) -> pd.DataFrame:
     """
-    Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points.
+    Retrieve families where mothers are present for more than a fraction
+    of the experiment and daughters for longer than some number of
+    time-points.
 
     Parameters
     ----------
     df: pd.DataFrame
         Data
     mothers_thresh: float
-        Minimum fraction of experiment's total duration for which mothers must be present.
+        Minimum fraction of experiment's total duration for which mothers
+        must be present.
     daughters_thresh: int
-        Minimum number of time points for which daughters must be observed
+        Minimum number of time points for which daughters must be observed.
     """
     # daughters
     all_daughters = df.loc[df.index.get_level_values("mother_label") > 0]
diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py
index ec9bcd04..8865cbdf 100644
--- a/src/postprocessor/core/processor.py
+++ b/src/postprocessor/core/processor.py
@@ -72,14 +72,14 @@ class PostProcessorParameters(ParametersABC):
             },
             "processes": [
                 ["buddings", ["/extraction/general/None/volume"]],
-                ["dsignal", ["/extraction/general/None/volume"]],
+                # ["dsignal", ["/extraction/general/None/volume"]],
                 ["bud_metric", ["/extraction/general/None/volume"]],
-                [
-                    "dsignal",
-                    [
-                        "/postprocessing/bud_metric/extraction_general_None_volume"
-                    ],
-                ],
+                # [
+                #     "dsignal",
+                #     [
+                #         "/postprocessing/bud_metric/extraction_general_None_volume"
+                #     ],
+                # ],
             ],
         }
         param_sets = {
@@ -209,9 +209,8 @@ class PostProcessor(ProcessABC):
         """
         # run merger, picker, and find lineages
         self.run_prepost()
-        # run processes
+        # run processes: process is a str; datasets is a list of str
         for process, datasets in tqdm(self.targets["processes"]):
-            # process is a str; datasets is a list of str
             if process in self.parameters["param_sets"].get("processes", {}):
                 # parameters already assigned
                 parameters = self.parameters_classfun[process](
@@ -230,9 +229,8 @@ class PostProcessor(ProcessABC):
 
     def run_process(self, dataset, process, loaded_process):
         """Run process to obtain a single dataset and write the result."""
-        # define signal
+        # get pre-processed data
         if isinstance(dataset, list):
-            # multisignal process
             signal = [self._signal[d] for d in dataset]
         elif isinstance(dataset, str):
             signal = self._signal[dataset]
@@ -249,8 +247,9 @@ class PostProcessor(ProcessABC):
                 [], columns=signal.columns, index=signal.index
             )
             result.columns.names = ["timepoint"]
-        # define outpath to write result
+        # use outpath to write result
         if process in self.parameters["outpaths"]:
+            # outpath already defined
             outpath = self.parameters["outpaths"][process]
         elif isinstance(dataset, list):
             # no outpath is defined
diff --git a/src/postprocessor/core/reshapers/bud_metric.py b/src/postprocessor/core/reshapers/bud_metric.py
index 9e9e1f70..b5c9ab2c 100644
--- a/src/postprocessor/core/reshapers/bud_metric.py
+++ b/src/postprocessor/core/reshapers/bud_metric.py
@@ -11,9 +11,7 @@ from postprocessor.core.lineageprocess import (
 
 
 class BudMetricParameters(LineageProcessParameters):
-    """
-    Parameters
-    """
+    """Give default location of lineage information."""
 
     _defaults = {"lineage_location": "postprocessing/lineage_merged"}
 
@@ -34,9 +32,11 @@ class BudMetric(LineageProcess):
         lineage: t.Dict[pd.Index, t.Tuple[pd.Index]] = None,
     ):
         if lineage is None:
+            # define lineage
             if hasattr(self, "lineage"):
                 lineage = self.lineage
             else:
+                # lineage information in the Signal dataframe
                 assert "mother_label" in signal.index.names
                 lineage = signal.index.to_list()
         return self.get_bud_metric(signal, mb_array_to_dict(lineage))
@@ -46,44 +46,40 @@ class BudMetric(LineageProcess):
         signal: pd.DataFrame, md: t.Dict[t.Tuple, t.Tuple[t.Tuple]] = None
     ):
         """
-
-        signal: Daughter-inclusive dataframe
-        md: dictionary where key is mother's index,
-        defined as (trap, cell_label), and its values are a list of
-        daughter indices, as (trap, cell_label).
-
-        Get fvi (First Valid Index) for all cells
-        Create empty matrix
-        for every mother:
-         - Get daughters' subdataframe
-         - sort  daughters by cell label
-         - get series of fvis
-         - concatenate the values of these ranges from the dataframe
-        Fill the empty matrix
-        Convert matrix into dataframe using mother indices
-
+        Generate a dataframe of a Signal for buds indexed by their mothers,
+        concatenating data from all the buds for each mother.
+
+        Parameters
+        ---------
+        signal: pd.Dataframe
+            A dataframe that includes data for both mothers and daughters.
+        md: dict
+            A dict of lineage information with each key a mother's index,
+            defined as (trap, cell_label), and the corresponding values are a
+            list of daughter indices, also defined as (trap, cell_label).
         """
         md_index = signal.index
         # md_index should only comprise (trap, cell_label)
         if "mother_label" not in md_index.names:
-            # dict with daughter indices as keys
-            d = {v: k for k, values in md.items() for v in values}
-            # generate mother_label in signal using the mother's cell_label
+            # dict with daughter indices as keys and mother indices as values
+            bud_dict = {v: k for k, values in md.items() for v in values}
+            # generate mother_label in Signal using the mother's cell_label
+            # cells with no mothers have a mother_label of 0
             signal["mother_label"] = list(
-                map(lambda x: d.get(x, [0])[-1], signal.index)
+                map(lambda x: bud_dict.get(x, [0])[-1], signal.index)
             )
             signal.set_index("mother_label", append=True, inplace=True)
             # combine mothers and daughter indices
             mothers_index = md.keys()
             daughters_index = [y for x in md.values() for y in x]
             relations = set([*mothers_index, *daughters_index])
-            # keep from md_index only mother and daughters
+            # keep from md_index only cells that are mother or daughters
             md_index = md_index.intersection(relations)
         else:
             md_index = md_index.droplevel("mother_label")
         if len(md_index) < len(signal):
             print("Dropped cells before applying bud_metric")  # TODO log
-        # restrict signal to the cells in md_index, moving mother_label to do so
+        # restrict signal to the cells in md_index moving mother_label to do so
         signal = (
             signal.reset_index("mother_label")
             .loc(axis=0)[md_index]
@@ -96,16 +92,52 @@ class BudMetric(LineageProcess):
         output_df = daughter_df.groupby(["trap", "mother_label"]).apply(
             lambda x: _combine_daughter_tracks(x)
         )
-
         output_df.columns = signal.columns
-        output_df["padding_level"] = 0
-        output_df.set_index("padding_level", append=True, inplace=True)
+        # daughter data is indexed by mothers, which themselves have no mothers
+        output_df["temp_mother_label"] = 0
+        output_df.set_index("temp_mother_label", append=True, inplace=True)
         if len(output_df):
             output_df.index.names = signal.index.names
         return output_df
 
 
-def _combine_daughter_tracks_old(tracks: pd.DataFrame):
+def _combine_daughter_tracks(tracks: pd.DataFrame):
+    """
+    Combine multiple time series of daughter cells into one time series.
+
+    Concatenate daughter values into one time series starting with the first
+    daughter and replacing later values with the values from the next daughter,
+    and so on.
+
+    Parameters
+    ----------
+    tracks: a Signal
+        Data for all daughters, which are distinguished by different cell_labels,
+        for a particular trap and mother_label.
+    """
+    # sort by daughter IDs
+    bud_df = tracks.sort_index(level="cell_label")
+    # remove multi-index
+    no_rows = len(bud_df)
+    bud_df.index = range(no_rows)
+    # find time point of first non-NaN data point of each row
+    init_tps = [
+        bud_df.iloc[irow].first_valid_index() for irow in range(no_rows)
+    ]
+    # sort so that earliest daughter is first
+    sorted_rows = np.argsort(init_tps)
+    init_tps = np.sort(init_tps)
+    # combine data for all daughters
+    combined_tracks = np.nan * np.ones(tracks.columns.size)
+    for j, jrow in enumerate(sorted_rows):
+        # over-write with next earliest daughter
+        combined_tracks[bud_df.columns.get_loc(init_tps[j]) :] = (
+            bud_df.iloc[jrow].loc[init_tps[j] :].values
+        )
+    return pd.Series(combined_tracks, index=tracks.columns)
+
+
+def _combine_daughter_tracks_original(tracks: pd.DataFrame):
     """
     Combine multiple time series of daughter cells into one time series.
 
@@ -141,51 +173,3 @@ def _combine_daughter_tracks_old(tracks: pd.DataFrame):
         (combined_tracks == old) | (np.isnan(combined_tracks) & np.isnan(old))
     ).all(), "yikes"
     return pd.Series(combined_tracks, index=tracks.columns)
-
-
-def _combine_daughter_tracks(tracks: pd.DataFrame):
-    """
-    Combine multiple time series of daughter cells into one time series.
-
-    Concatenate daughter values into one time series starting with the first
-    daughter and replacing later values with the values from the next daughter,
-    and so on.
-
-    Parameters
-    ----------
-    tracks: a Signal
-        Data for all daughters, which are distinguished by different cell_labels,
-        for a particular trap and mother_label.
-    """
-    # sort by daughter IDs
-    bud_df = tracks.sort_index(level="cell_label")
-    # remove multi-index
-    no_rows = len(bud_df)
-    bud_df.index = range(no_rows)
-    # find time point of first non-NaN data point of each row
-    init_tps = [
-        bud_df.iloc[irow].first_valid_index() for irow in range(no_rows)
-    ]
-    # sort so that earliest daughter is first
-    sorted_rows = np.argsort(init_tps)
-    init_tps = np.sort(init_tps)
-    # combine data for all daughters
-    combined_tracks = np.nan * np.ones(tracks.columns.size)
-    for j, jrow in enumerate(sorted_rows):
-        # over-write with next earliest daughter
-        combined_tracks[bud_df.columns.get_loc(init_tps[j]) :] = (
-            bud_df.iloc[jrow].loc[init_tps[j] :].values
-        )
-    # ## OLD
-    # # find which row of sorted_df has the daughter for each time point
-    # tp_fvt: pd.Series = bud_df.apply(lambda x: x.first_valid_index(), axis=0)
-    # # combine data for all daughters
-    # old = np.nan * np.ones(tracks.columns.size)
-    # for bud_row in np.unique(tp_fvt.dropna().values).astype(int):
-    #     ilocs = np.where(tp_fvt.values == bud_row)[0]
-    #     old[ilocs] = bud_df.values[bud_row, ilocs]
-    # assert (
-    #     (combined_tracks == old) | (np.isnan(combined_tracks) & np.isnan(old))
-    # ).all(), "yikes"
-    # ###
-    return pd.Series(combined_tracks, index=tracks.columns)
diff --git a/src/postprocessor/core/reshapers/buddings.py b/src/postprocessor/core/reshapers/buddings.py
index 90785bce..9c17428f 100644
--- a/src/postprocessor/core/reshapers/buddings.py
+++ b/src/postprocessor/core/reshapers/buddings.py
@@ -51,17 +51,16 @@ class buddings(LineageProcess):
             for trap_mo in lineage[:, :2]
             if tuple(trap_mo) in signal.index
         }
-        # find daughters for these traps and mothers
+        # add daughters, potentially multiple, for these traps and mothers
         for trap, mother, daughter in lineage:
             if (trap, mother) in traps_mothers.keys():
                 traps_mothers[(trap, mother)].append(daughter)
-        # sub dataframe of signal for the selected mothers
+        # a new dataframe with dimensions (n_mother_cells * n_tps)
         mothers = signal.loc[
             set(signal.index).intersection(traps_mothers.keys())
         ]
-        # a new dataframe with dimensions (n_mother_cells * n_tps)
         buddings = pd.DataFrame(
-            np.zeros((mothers.shape[0], signal.shape[1])).astype(bool),
+            np.zeros(mothers.shape).astype(bool),
             index=mothers.index,
             columns=signal.columns,
         )
@@ -76,7 +75,7 @@ class buddings(LineageProcess):
             times_of_bud_appearance = fvi.loc[
                 fvi.index.intersection(trap_daughter_ids)
             ].values
-            # ignore zeros - ignore buds in first image
+            # ignore zeros - buds in first image are not budding events
             daughters_idx = set(times_of_bud_appearance).difference({0})
             buddings.loc[trap_mother_id, daughters_idx] = True
         return buddings
diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py
index f3c35852..74f3a4b0 100644
--- a/src/postprocessor/core/reshapers/picker.py
+++ b/src/postprocessor/core/reshapers/picker.py
@@ -15,9 +15,13 @@ class PickerParameters(ParametersABC):
     """
     A dictionary specifying the sequence of picks in order.
 
-    "lineage" is further specified by "mothers", "daughters", "families" (mother-bud pairs), and "orphans", where orphans picks cells that are not in families.
+    "lineage" is further specified by "mothers", "daughters",
+    "families" (mother-bud pairs), and "orphans", where orphans
+    picks cells that are not in families.
 
-    "condition" is further specified by "present", "continuously_present", "any_present", or "growing" and a threshold, either a number of time points or a fraction of the total duration of the experiment.
+    "condition" is further specified by "present", "continuously_present",
+    "any_present", or "growing" and a threshold, either a number of time
+    points or a fraction of the total duration of the experiment.
     """
 
     _defaults = {
@@ -30,7 +34,7 @@ class PickerParameters(ParametersABC):
 
 class Picker(LineageProcess):
     """
-    Picker selects cells from a signal using lineage information and
+    Picker selects cells using lineage information and
     by how and for how long they are retained in the data set.
     """
 
@@ -49,7 +53,10 @@ class Picker(LineageProcess):
         how: str,
         mothers_daughters: t.Optional[np.ndarray] = None,
     ) -> pd.MultiIndex:
-        """Return rows of a signal corresponding to either mothers, daughters, or mother-daughter pairs using lineage information."""
+        """
+        Return rows of a signal corresponding to either mothers, daughters,
+        or mother-daughter pairs using lineage information.
+        """
         cells_present = drop_mother_label(signal.index)
         mothers_daughters = self.get_lineage_information(signal)
         #: might be better if match_column defined as a string to make everything one line
@@ -72,7 +79,8 @@ class Picker(LineageProcess):
 
     def run(self, signal):
         """
-        Pick indices from the index of a signal's dataframe and return as an array.
+        Pick indices from the index of a signal's dataframe and return
+        as an array.
 
         Typically, we first pick by lineage, then by condition.
         """
@@ -85,15 +93,13 @@ class Picker(LineageProcess):
             for alg, *params in self.sequence:
                 if indices:
                     if alg == "lineage":
-                        # pick mothers, buds, or mother-bud pairs
                         param1 = params[0]
-                        new_indices = getattr(self, "pick_by_" + alg)(
+                        new_indices = self.pick_by_lineage(
                             signal.loc[list(indices)], param1
                         )
                     else:
-                        # pick by condition
                         param1, *param2 = params
-                        new_indices = getattr(self, "pick_by_" + alg)(
+                        new_indices = self.pick_by_condition(
                             signal.loc[list(indices)], param1, param2
                         )
                 else:
@@ -107,10 +113,6 @@ class Picker(LineageProcess):
         indices_arr = np.array([tuple(map(_str_to_int, x)) for x in indices])
         return indices_arr
 
-    # def pick_by_condition(self, signal, condition, thresh):
-    #     idx = self.switch_case(signal, condition, thresh)
-    #     return idx
-
     def pick_by_condition(
         self,
         signal: pd.DataFrame,
diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py
index 08440093..b665ac15 100644
--- a/src/postprocessor/grouper.py
+++ b/src/postprocessor/grouper.py
@@ -86,16 +86,15 @@ class Grouper(ABC):
         **kwargs,
     ):
         """
-        Concatenate data for one signal from different h5 files into a dataframe.
-
-        Each h5 file corresponds to one position
+        Concatenate data for one signal from different h5 files, one for
+        each position, into a dataframe.
 
         Parameters
         ----------
         path : str
-           Signal location within h5py file
+           Signal location within h5 file.
         pool : int
-           Number of threads used; if 0 or None only one core is used
+           Number of threads used; if 0 or None only one core is used.
         mode: str
         standard: boolean
         **kwargs : key, value pairings
@@ -111,7 +110,7 @@ class Grouper(ABC):
         if standard:
             fn_pos = concat_standard
         else:
-            fn_pos = concat_signal_ind
+            fn_pos = concat_one_signal
             kwargs["mode"] = mode
         records = self.pool_function(
             path=path,
@@ -164,9 +163,8 @@ class Grouper(ABC):
         chainers: t.Dict[str, Chainer] = None,
         **kwargs,
     ):
-        """Enable different threads for independent chains, particularly useful when aggregating multiple elements."""
-        if pool is None:
-            pass
+        """Enable different threads for independent chains, particularly
+        useful when aggregating multiple elements."""
         chainers = chainers or self.chainers
         if pool:
             with Pool(pool) as p:
@@ -371,7 +369,7 @@ def concat_standard(
     return combined
 
 
-def concat_signal_ind(
+def concat_one_signal(
     path: str,
     chainer: Chainer,
     group: str,
-- 
GitLab