From a1d7372b00c85e5be55536f684a01fb48a3a3a61 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Alan=20Mu=C3=B1oz?= <alan_munoz@protonmail.com>
Date: Thu, 23 Mar 2023 07:34:55 +0000
Subject: [PATCH] Merge pp dev

---
 src/postprocessor/core/functions/tracks.py |  44 +----
 src/postprocessor/core/processor.py        | 206 ++++++++++++---------
 src/postprocessor/core/reshapers/merger.py |  32 +++-
 src/postprocessor/core/reshapers/picker.py |  17 +-
 4 files changed, 151 insertions(+), 148 deletions(-)

diff --git a/src/postprocessor/core/functions/tracks.py b/src/postprocessor/core/functions/tracks.py
index db8ed272..5fa04a15 100644
--- a/src/postprocessor/core/functions/tracks.py
+++ b/src/postprocessor/core/functions/tracks.py
@@ -18,7 +18,7 @@ from postprocessor.core.processes.savgol import non_uniform_savgol
 
 
 def load_test_dset():
-    # Load development dataset to test functions
+    """Load development dataset to test functions."""
     return pd.DataFrame(
         {
             ("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5,
@@ -31,7 +31,7 @@ def load_test_dset():
 
 
 def max_ntps(track: pd.Series) -> int:
-    # Get number of timepoints
+    """Get number of time points."""
     indices = np.where(track.notna())
     return np.max(indices) - np.min(indices)
 
@@ -84,9 +84,7 @@ def clean_tracks(
     """
     ntps = get_tracks_ntps(tracks)
     grs = get_avg_grs(tracks)
-
     growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
-
     return growing_long_tracks
 
 
@@ -111,7 +109,6 @@ def merge_tracks(
     joinable_pairs = get_joinable(tracks, **kwargs)
     if joinable_pairs:
         tracks = join_tracks(tracks, joinable_pairs, drop=drop)
-
     return (tracks, joinable_pairs)
 
 
@@ -135,28 +132,23 @@ def get_joint_ids(merging_seqs) -> dict:
     2 ab cd
     3 abcd
 
-    We shold get:
+    We should get:
 
     output {a:a, b:a, c:a, d:a}
 
     """
     if not merging_seqs:
         return {}
-
     targets, origins = list(zip(*merging_seqs))
     static_tracks = set(targets).difference(origins)
-
     joint = {track_id: track_id for track_id in static_tracks}
     for target, origin in merging_seqs:
         joint[origin] = target
-
     moved_target = [
         k for k, v in joint.items() if joint[v] != v and v in joint.values()
     ]
-
     for orig in moved_target:
         joint[orig] = rec_bottom(joint, orig)
-
     return {
         k: v for k, v in joint.items() if k != v
     }  # remove ids that point to themselves
@@ -184,14 +176,11 @@ def join_tracks(tracks, joinable_pairs, drop=True) -> pd.DataFrame:
     :param drop: bool indicating whether or not to drop moved rows
 
     """
-
     tmp = copy(tracks)
     for target, source in joinable_pairs:
         tmp.loc[target] = join_track_pair(tmp.loc[target], tmp.loc[source])
-
         if drop:
             tmp = tmp.drop(source)
-
     return tmp
 
 
@@ -206,7 +195,7 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
     """
     Get the pair of track (without repeats) that have a smaller error than the
     tolerance. If there is a track that can be assigned to two or more other
-    ones, it chooses the one with a lowest error.
+    ones, choose the one with lowest error.
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
         columns are timepoints
@@ -324,7 +313,6 @@ def get_means(x, i):
         v = x[~np.isnan(x)][:i]
     else:
         v = x[~np.isnan(x)][i:]
-
     return np.nanmean(v)
 
 
@@ -335,12 +323,12 @@ def get_last_i(x, i):
         v = x[~np.isnan(x)][:i]
     else:
         v = x[~np.isnan(x)][i:]
-
     return v
 
 
 def localid_to_idx(local_ids, contig_trap):
-    """Fetch then original ids from a nested list with joinable local_ids
+    """
+    Fetch the original ids from a nested list with joinable local_ids.
 
     input
     :param local_ids: list of list of pairs with cell ids to be joint
@@ -392,7 +380,6 @@ def get_dMetric(
         dMetric = np.abs(np.subtract.outer(post, pre))
     else:
         dMetric = np.abs(np.subtract.outer(pre, post))
-
     dMetric[np.isnan(dMetric)] = (
         tol + 1 + np.nanmax(dMetric)
     )  # nans will be filtered
@@ -404,28 +391,26 @@ def solve_matrices(
 ):
     """
     Solve the distance matrices obtained in get_dMetric and/or merged from
-    independent dMetric matrices
+    independent dMetric matrices.
     """
-
     ids = solve_matrix(dMetric)
     if not len(ids[0]):
         return []
     pre, post = prepost
-
     norm = (
         np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
     )  # relative or absolute tol
     result = dMetric[ids] / norm
     ids = ids if len(pre) < len(post) else ids[::-1]
-
     return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
 
 
 def get_closest_pairs(
     pre: List[float], post: List[float], tol: Union[float, int] = 1
 ):
-    """Calculate a cost matrix the Hungarian algorithm to pick the best set of
-    options
+    """
+    Calculate a cost matrix for the Hungarian algorithm to pick the best set of
+    options.
 
     input
     :param pre: list of floats with edges on left
@@ -437,7 +422,6 @@ def get_closest_pairs(
 
     """
     dMetric = get_dMetric(pre, post, tol)
-
     return solve_matrices(dMetric, pre, post, tol)
 
 
@@ -465,9 +449,7 @@ def solve_matrix(dMetric):
             tmp[:, j] += np.nan
             glob_is.append(i)
             glob_js.append(j)
-
             std = sorted(tmp[~np.isnan(tmp)])
-
     return (np.array(glob_is), np.array(glob_js))
 
 
@@ -475,7 +457,6 @@ def plot_joinable(tracks, joinable_pairs):
     """
     Convenience plotting function for debugging and data vis
     """
-
     nx = 8
     ny = 8
     _, axes = plt.subplots(nx, ny)
@@ -493,7 +474,6 @@ def plot_joinable(tracks, joinable_pairs):
                 # except:
                 #     pass
                 ax.plot(post_srs.index, post_srs.values, "g")
-
     plt.show()
 
 
@@ -506,21 +486,17 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
     :param min_dgr: float minimum difference in growth rate from
         the interpolation
     """
-
     mins, maxes = [
         tracks.notna().apply(np.where, axis=1).apply(fn)
         for fn in (np.min, np.max)
     ]
-
     mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
     mins_d.index = mins_d.index - 1  # make indices equal
     # TODO add support for skipping time points
     maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist())
-
     common = sorted(
         set(mins_d.index).intersection(maxes_d.index), reverse=True
     )
-
     return [(maxes_d[t], mins_d[t]) for t in common]
 
 
diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py
index d4cd0367..e0adac14 100644
--- a/src/postprocessor/core/processor.py
+++ b/src/postprocessor/core/processor.py
@@ -1,7 +1,5 @@
-import logging
 import typing as t
 from itertools import takewhile
-from typing import Dict, List, Union
 
 import numpy as np
 import pandas as pd
@@ -16,7 +14,6 @@ from agora.utils.indexing import (
     _assoc_indices_to_3d,
 )
 from agora.utils.merge import merge_association
-from agora.utils.kymograph import get_index_as_np
 from postprocessor.core.abc import get_parameters, get_process
 from postprocessor.core.lineageprocess import (
     LineageProcess,
@@ -35,24 +32,37 @@ class PostProcessorParameters(ParametersABC):
         while objectives are relative or absolute paths to datasets. If relative paths the
         post-processed addresses are used. The order of processes matters.
 
+    Supply parameters for picker, merger, and processes.
+    The order of processes matters.
+
+    'processes' are defined in ./processes/ while objectives are relative or absolute paths to datasets. If relative paths the post-processed addresses are used.
     """
 
     def __init__(
         self,
-        targets={},
-        param_sets={},
-        outpaths={},
+        targets: t.Dict = {},
+        param_sets: t.Dict = {},
+        outpaths: t.Dict = {},
     ):
-        self.targets: Dict = targets
-        self.param_sets: Dict = param_sets
-        self.outpaths: Dict = outpaths
+        self.targets = targets
+        self.param_sets = param_sets
+        self.outpaths = outpaths
 
     def __getitem__(self, item):
         return getattr(self, item)
 
     @classmethod
     def default(cls, kind=[]):
-        """Sequential postprocesses to be operated"""
+        """
+        Include buddings and bud_metric and estimates of their time derivatives.
+
+        Parameters
+        ----------
+        kind: list of str
+            If "ph_batman" included, add targets for experiments using pHlourin.
+        """
+        # each subitem specifies the function to be called and the location
+        # on the h5 file to be written
         targets = {
             "prepost": {
                 "merger": "/extraction/general/None/area",
@@ -61,9 +71,7 @@ class PostProcessorParameters(ParametersABC):
             "processes": [
                 [
                     "buddings",
-                    [
-                        "/extraction/general/None/volume",
-                    ],
+                    ["/extraction/general/None/volume"],
                 ],
                 [
                     "dsignal",
@@ -93,7 +101,7 @@ class PostProcessorParameters(ParametersABC):
         }
         outpaths = {}
         outpaths["aggregate"] = "/postprocessing/experiment_wide/aggregated/"
-
+        # pHlourin experiments are special
         if "ph_batman" in kind:
             targets["processes"]["dsignal"].append(
                 [
@@ -115,46 +123,56 @@ class PostProcessorParameters(ParametersABC):
                     ]
                 ],
             )
-
         return cls(targets=targets, param_sets=param_sets, outpaths=outpaths)
 
 
 class PostProcessor(ProcessABC):
     def __init__(self, filename, parameters):
+        """
+        Initialise PostProcessor
+
+        Parameters
+        ----------
+        filename: str or PosixPath
+            Name of h5 file.
+        parameters: PostProcessorParameters object
+            An instance of PostProcessorParameters.
+        """
         super().__init__(parameters)
         self._filename = filename
         self._signal = Signal(filename)
         self._writer = Writer(filename)
-
+        # parameters for merger and picker
         dicted_params = {
             i: parameters["param_sets"]["prepost"][i]
             for i in ["merger", "picker"]
         }
-
         for k in dicted_params.keys():
             if not isinstance(dicted_params[k], dict):
                 dicted_params[k] = dicted_params[k].to_dict()
-
+        # merger and picker
         self.merger = Merger(
             MergerParameters.from_dict(dicted_params["merger"])
         )
-
         self.picker = Picker(
             PickerParameters.from_dict(dicted_params["picker"]),
             cells=Cells.from_source(filename),
         )
+        # processes, such as buddings
         self.classfun = {
             process: get_process(process)
             for process, _ in parameters["targets"]["processes"]
         }
+        # parameters for the process in classfun
         self.parameters_classfun = {
             process: get_parameters(process)
             for process, _ in parameters["targets"]["processes"]
         }
+        # locations to be written in the h5 file
         self.targets = parameters["targets"]
 
     def run_prepost(self):
-        # TODO Split function
+        """Using picker, get and write lineages, returning mothers and daughters."""
         """Important processes run before normal post-processing ones"""
         record = self._signal.get_raw(self.targets["prepost"]["merger"])
         merges = np.array(self.merger.run(record), dtype=int)
@@ -175,7 +193,6 @@ class PostProcessor(ProcessABC):
         self.lineage = _3d_index_to_2d(
             lineage_merged if len(lineage_merged) else lineage
         )
-
         self._writer.write(
             "modifiers/lineage_merged", _3d_index_to_2d(lineage_merged)
         )
@@ -204,91 +221,98 @@ class PostProcessor(ProcessABC):
         return x
 
     def run(self):
-        # TODO Documentation :) + Split
+        """
+        Write the results to the h5 file.
+        Processes include identifying buddings and finding bud metrics.
+        """
+        # run merger, picker, and find lineages
         self.run_prepost()
-
+        # run processes
         for process, datasets in tqdm(self.targets["processes"]):
-            if process in self.parameters["param_sets"].get(
-                "processes", {}
-            ):  # If we assigned parameters
+            if process in self.parameters["param_sets"].get("processes", {}):
+                # parameters already assigned
                 parameters = self.parameters_classfun[process](
                     self.parameters[process]
                 )
-
             else:
+                # assign parameters
                 parameters = self.parameters_classfun[process].default()
-
+            # load process
             loaded_process = self.classfun[process](parameters)
             if isinstance(parameters, LineageProcessParameters):
                 loaded_process.lineage = self.lineage
 
+            # apply process to each dataset
             for dataset in datasets:
-                if isinstance(dataset, list):  # multisignal process
-                    signal = [self._signal[d] for d in dataset]
-                elif isinstance(dataset, str):
-                    signal = self._signal[dataset]
-                else:
-                    raise Exception("Unavailable record")
-
-                if len(signal) and (
-                    not isinstance(loaded_process, LineageProcess)
-                    or len(loaded_process.lineage)
-                ):
-                    result = loaded_process.run(signal)
-                else:
-                    result = pd.DataFrame(
-                        [], columns=signal.columns, index=signal.index
-                    )
-                    result.columns.names = ["timepoint"]
-
-                if process in self.parameters["outpaths"]:
-                    outpath = self.parameters["outpaths"][process]
-                elif isinstance(dataset, list):
-                    # If no outpath defined, place the result in the minimum common
-                    # branch of all signals used
-                    prefix = "".join(
-                        c[0]
-                        for c in takewhile(
-                            lambda x: all(x[0] == y for y in x), zip(*dataset)
-                        )
-                    )
-                    outpath = (
-                        prefix
-                        + "_".join(  # TODO check that it always finishes in '/'
-                            [
-                                d[len(prefix) :].replace("/", "_")
-                                for d in dataset
-                            ]
-                        )
-                    )
-                elif isinstance(dataset, str):
-                    outpath = dataset[1:].replace("/", "_")
-                else:
-                    raise ("Outpath not defined", type(dataset))
-
-                if process not in self.parameters["outpaths"]:
-                    outpath = "/postprocessing/" + process + "/" + outpath
-
-                if isinstance(result, dict):  # Multiple Signals as output
-                    for k, v in result.items():
-                        self.write_result(
-                            outpath + f"/{k}",
-                            v,
-                            metadata={},
-                        )
-                else:
-                    self.write_result(
-                        outpath,
-                        result,
-                        metadata={},
-                    )
+                self.run_process(dataset, process, loaded_process)
+
+    def run_process(self, dataset, process, loaded_process):
+        """Run process on a single dataset and write the result."""
+        # define signal
+        if isinstance(dataset, list):
+            # multisignal process
+            signal = [self._signal[d] for d in dataset]
+        elif isinstance(dataset, str):
+            signal = self._signal[dataset]
+        else:
+            raise ("Incorrect dataset")
+        # run process on signal
+        if len(signal) and (
+            not isinstance(loaded_process, LineageProcess)
+            or len(loaded_process.lineage)
+        ):
+            result = loaded_process.run(signal)
+        else:
+            result = pd.DataFrame(
+                [], columns=signal.columns, index=signal.index
+            )
+            result.columns.names = ["timepoint"]
+        # define outpath, where result will be written
+        if process in self.parameters["outpaths"]:
+            outpath = self.parameters["outpaths"][process]
+        elif isinstance(dataset, list):
+            # no outpath is defined
+            # place the result in the minimum common branch of all signals
+            prefix = "".join(
+                c[0]
+                for c in takewhile(
+                    lambda x: all(x[0] == y for y in x), zip(*dataset)
+                )
+            )
+            outpath = (
+                prefix
+                + "_".join(  # TODO check that it always finishes in '/'
+                    [d[len(prefix) :].replace("/", "_") for d in dataset]
+                )
+            )
+        elif isinstance(dataset, str):
+            outpath = dataset[1:].replace("/", "_")
+        else:
+            raise ("Outpath not defined", type(dataset))
+        # add postprocessing to outpath when required
+        if process not in self.parameters["outpaths"]:
+            outpath = "/postprocessing/" + process + "/" + outpath
+        # write result
+        if isinstance(result, dict):
+            # multiple Signals as output
+            for k, v in result.items():
+                self.write_result(
+                    outpath + f"/{k}",
+                    v,
+                    metadata={},
+                )
+        else:
+            # a single Signal as output
+            self.write_result(
+                outpath,
+                result,
+                metadata={},
+            )
 
     def write_result(
         self,
         path: str,
-        result: Union[List, pd.DataFrame, np.ndarray],
-        metadata: Dict,
+        result: t.Union[t.List, pd.DataFrame, np.ndarray],
+        metadata: t.Dict,
     ):
-        if not result.any().any():
-            logging.getLogger("aliby").warning(f"Record {path} is empty")
         self._writer.write(path, result, meta=metadata, overwrite="overwrite")
diff --git a/src/postprocessor/core/reshapers/merger.py b/src/postprocessor/core/reshapers/merger.py
index 67c14b27..e18fbf7b 100644
--- a/src/postprocessor/core/reshapers/merger.py
+++ b/src/postprocessor/core/reshapers/merger.py
@@ -6,11 +6,21 @@ from postprocessor.core.functions.tracks import get_joinable
 
 class MergerParameters(ParametersABC):
     """
-    :param tol: float or int threshold of average (prediction error/std) necessary
-        to consider two tracks the same. If float is fraction of first track,
-        if int it is absolute units.
-    :param window: int value of window used for savgol_filter
-    :param degree: int value of polynomial degree passed to savgol_filter
+    Define the parameters for merger from a dict.
+
+    There are five parameters expected in the dict:
+
+    smooth, boolean
+        Whether or not to smooth with a savgol_filter.
+    tol: float or  int
+        The threshold of average prediction error/std necessary to
+        consider two tracks the same.
+        If float, the threshold is the fraction of the first track;
+        if int, the threshold is in absolute units.
+    window: int
+        The size of the window of the savgol_filter.
+    degree: int v
+        The order of the polynomial used by the savgol_filter
     """
 
     _defaults = {
@@ -23,9 +33,7 @@ class MergerParameters(ParametersABC):
 
 
 class Merger(PostProcessABC):
-    """
-    Combines rows of tracklet that are likely to be the same.
-    """
+    """Combine rows of tracklet that are likely to be the same."""
 
     def __init__(self, parameters):
         super().__init__(parameters)
@@ -33,5 +41,11 @@ class Merger(PostProcessABC):
     def run(self, signal):
         joinable = []
         if signal.shape[1] > 4:
-            joinable = get_joinable(signal, tol=self.parameters.tolerance)
+            joinable = get_joinable(
+                signal,
+                smooth=self.parameters.smooth,
+                tol=self.parameters.tolerance,
+                window=self.parameters.window,
+                degree=self.parameters.degree,
+            )
         return joinable
diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py
index 2d3f8cc4..1796fe65 100644
--- a/src/postprocessor/core/reshapers/picker.py
+++ b/src/postprocessor/core/reshapers/picker.py
@@ -25,7 +25,7 @@ class Picker(LineageProcess):
     :cells: Cell object passed to the constructor
     :condition: Tuple with condition and associated parameter(s), conditions can be
     "present", "nonstoply_present" or "quantile".
-    Determines the thersholds or fractions of signals to use.
+    Determine the thresholds or fractions of signals to use.
     :lineage: str {"mothers", "daughters", "families" (mothers AND daughters), "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families.
     """
 
@@ -35,7 +35,6 @@ class Picker(LineageProcess):
         cells: Cells or None = None,
     ):
         super().__init__(parameters=parameters)
-
         self.cells = cells
 
     def pick_by_lineage(
@@ -44,13 +43,9 @@ class Picker(LineageProcess):
         how: str,
         mothers_daughters: t.Optional[np.ndarray] = None,
     ) -> pd.MultiIndex:
-
         cells_present = drop_mother_label(signal.index)
-
         mothers_daughters = self.get_lineage_information(signal)
-
         valid_indices = slice(None)
-
         if how == "mothers":
             _, valid_indices = validate_association(
                 mothers_daughters, cells_present, match_column=0
@@ -63,7 +58,6 @@ class Picker(LineageProcess):
             _, valid_indices = validate_association(
                 mothers_daughters, cells_present
             )
-
         return signal.index[valid_indices]
 
     def pick_by_condition(self, signal, condition, thresh):
@@ -73,13 +67,10 @@ class Picker(LineageProcess):
     def run(self, signal):
         self.orig_signal = signal
         indices = set(signal.index)
-
         lineage = self.get_lineage_information(signal)
-
         if len(lineage):
             self.mothers = lineage[:, :2]
             self.daughters = lineage[:, [0, 2]]
-
             for alg, *params in self.sequence:
                 new_indices = tuple()
                 if indices:
@@ -94,13 +85,11 @@ class Picker(LineageProcess):
                             signal.loc[list(indices)], param1, param2
                         )
                         new_indices = [tuple(x) for x in new_indices]
-
                 indices = indices.intersection(new_indices)
         else:
             self._log(f"No lineage assignment")
             indices = np.array([])
-
-        return np.array([tuple(map(_str_to_int, x)) for x in indices]).T
+        return np.array([tuple(map(_str_to_int, x)) for x in indices])
 
     def switch_case(
         self,
@@ -128,7 +117,7 @@ def _as_int(threshold: t.Union[float, int], ntps: int):
 
 def any_present(signal, threshold):
     """
-    Returns a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints.
+    Return a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints.
     """
     any_present = pd.Series(
         np.sum(
-- 
GitLab