From 8e18a84404f633db867dba88aa61274704ed7b88 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Thu, 14 Jul 2022 11:56:07 +0100
Subject: [PATCH] style(all): limit docstring length

---
 core/functions/tracks.py         |  23 ++--
 core/processes/births.py         |   6 +-
 core/processes/knngraph.py       |   5 +-
 core/processes/leiden.py         |  16 ++-
 core/processes/peaks.py          |  41 ------
 core/processes/picker.py         |  70 ++++++----
 core/processes/savgol.py         |  65 ++++++---
 core/processes/standardscaler.py |  50 -------
 examples/basic_processes.py      |  15 ---
 examples/group.py                |  24 ----
 examples/testing.py              | 220 -------------------------------
 grouper.py                       |  39 ++----
 routines/heatmap.py              |   4 +-
 routines/mean_plot.py            |   2 +-
 routines/median_plot.py          |   7 +-
 15 files changed, 135 insertions(+), 452 deletions(-)
 delete mode 100644 core/processes/peaks.py
 delete mode 100644 core/processes/standardscaler.py
 delete mode 100644 examples/basic_processes.py
 delete mode 100644 examples/group.py
 delete mode 100644 examples/testing.py

diff --git a/core/functions/tracks.py b/core/functions/tracks.py
index b5481329..fd91bf5c 100644
--- a/core/functions/tracks.py
+++ b/core/functions/tracks.py
@@ -245,13 +245,15 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
         )
 
     # fetch edges from ids TODO (IF necessary, here we can compare growth rates)
-    idx_to_edge = lambda preposts: [
-        (
-            [get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
-            [get_val(smoothed_tracks.loc[post], 0) for post in posts],
-        )
-        for pres, posts in preposts
-    ]
+    def idx_to_edge(preposts):
+        return [
+            (
+                [get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
+                [get_val(smoothed_tracks.loc[post], 0) for post in posts],
+            )
+            for pres, posts in preposts
+        ]
+
     # idx_to_means = lambda preposts: [
     #     (
     #         [get_means(smoothed_tracks.loc[pre], -window) for pre in pres],
@@ -311,7 +313,8 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
     return [pair for pairset in joinable_ids for pair in pairset]
 
 
-get_val = lambda x, n: x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
+def get_val(x, n):
+    return x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
 
 
 def get_means(x, i):
@@ -357,11 +360,11 @@ def localid_to_idx(local_ids, contig_trap):
 
 
 def get_vec_closest_pairs(lst: List, **kwargs):
-    return [get_closest_pairs(*l, **kwargs) for l in lst]
+    return [get_closest_pairs(*sublist, **kwargs) for sublist in lst]
 
 
 def get_dMetric_wrap(lst: List, **kwargs):
-    return [get_dMetric(*l, **kwargs) for l in lst]
+    return [get_dMetric(*sublist, **kwargs) for sublist in lst]
 
 
 def solve_matrices_wrap(dMetric: List, edges: List, **kwargs):
diff --git a/core/processes/births.py b/core/processes/births.py
index 84a9e50c..9059b0bf 100644
--- a/core/processes/births.py
+++ b/core/processes/births.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
 
+import typing as t
 from itertools import product
 
 import numpy as np
@@ -47,10 +48,9 @@ class births(LineageProcess):
         if lineage is None:
             lineage = self.lineage
 
-        def fvi(signal):
-            return signal.apply(lambda x: x.first_valid_index(), axis=1)
+        fvi = signal.apply(lambda x: x.first_valid_index(), axis=1)
 
-        traps_mothers = {
+        traps_mothers: t.Dict[tuple, list] = {
             tuple(mo): [] for mo in lineage[:, :2] if tuple(mo) in signal.index
         }
         for trap, mother, daughter in lineage:
diff --git a/core/processes/knngraph.py b/core/processes/knngraph.py
index d9532e0e..c3a165ab 100644
--- a/core/processes/knngraph.py
+++ b/core/processes/knngraph.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
 
+import numpy as np
 import pandas as pd
 from sklearn.metrics.pairwise import euclidean_distances
 import igraph as ig
@@ -41,7 +42,9 @@ class knngraph(PostProcessABC):
             Feature matrix.
         """
         distance_matrix = euclidean_distances(signal)
-        distance_matrix_pruned = graph_prune(distance_matrix, self.n_neighbours)
+        distance_matrix_pruned = graph_prune(
+            distance_matrix, self.n_neighbours
+        )
         graph = ig.Graph.Weighted_Adjacency(
             distance_matrix_pruned.tolist(), mode="undirected"
         )
diff --git a/core/processes/leiden.py b/core/processes/leiden.py
index 131985d0..9cc31050 100644
--- a/core/processes/leiden.py
+++ b/core/processes/leiden.py
@@ -2,7 +2,6 @@ from itertools import product
 
 import numpy as np
 import pandas as pd
-from itertools import product
 import igraph as ig
 import leidenalg
 
@@ -28,15 +27,22 @@ class leiden(PostProcessABC):
 
     def run(self, features: pd.DataFrame):
         # Generate euclidean distance matrix
-        distances = np.linalg.norm(features.values - features.values[:, None], axis=2)
+        distances = np.linalg.norm(
+            features.values - features.values[:, None], axis=2
+        )
         ind = [
-            "_".join([str(y) for y in x[1:]]) for x in features.index.to_flat_index()
+            "_".join([str(y) for y in x[1:]])
+            for x in features.index.to_flat_index()
         ]
         source, target = zip(*product(ind, ind))
         df = pd.DataFrame(
-            {"source": source, "target": target, "distance": distances.flatten()}
+            {
+                "source": source,
+                "target": target,
+                "distance": distances.flatten(),
+            }
         )
         df = df.loc[df["source"] != df["target"]]
         g = ig.Graph.DataFrame(df, directed=False)
 
-        part = leidenalg.find_partition(g, leidenalg.ModularityVertexPartition)
+        return leidenalg.find_partition(g, leidenalg.ModularityVertexPartition)
diff --git a/core/processes/peaks.py b/core/processes/peaks.py
deleted file mode 100644
index 99dff4eb..00000000
--- a/core/processes/peaks.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from scipy.signal import argrelmax, argrelmin
-
-from agora.abc import ParametersABC
-from postprocessor.core.abc import PostProcessABC
-
-
-class PeaksParameters(ParametersABC):
-    """
-    Parameters
-        type : str {minima,  maxima, "all"}. Determines which type of peaks to identify
-        order : int Parameter to pass to scipy.signal.argrelextrema indicating
-            how many points to use for comparison.
-    """
-
-    _defaults = {"type": "minima", "order": 3}
-
-
-class Peaks(PostProcessABC):
-    """
-    Identifies a signal sharply dropping.
-    """
-
-    def __init__(self, parameters: PeaksParameters):
-        super().__init__(parameters)
-
-    def run(self, signal: pd.DataFrame):
-        """
-        Returns a boolean dataframe with the same shape as the
-        original signal but with peaks as true values.
-        """
-        peaks_mat = np.zeros_like(signal, dtype=bool)
-
-        comparator = np.less if self.parameters.type is "minima" else np.greater
-        peaks_ids = argrelextrema(new_df, comparator=comparator, order=order)
-        peaks_mat[peak_ids] = True
-
-        return pd.DataFrame(
-            peaks_mat,
-            index=signal.index,
-            columns=signal.columns,
-        )
diff --git a/core/processes/picker.py b/core/processes/picker.py
index 14b3abe8..6811231c 100644
--- a/core/processes/picker.py
+++ b/core/processes/picker.py
@@ -71,7 +71,11 @@ class picker(PostProcessABC):
         cell_gid = np.unique(idlist, axis=0)
 
         last_lin_preds = [
-            find_1st(((cell_label[::-1] == lbl) & (trap[::-1] == tr)), True, cmp_equal)
+            find_1st(
+                ((cell_label[::-1] == lbl) & (trap[::-1] == tr)),
+                True,
+                cmp_equal,
+            )
             for tr, lbl in cell_gid
         ]
         mother_assign_sorted = ma[::-1][last_lin_preds]
@@ -93,12 +97,18 @@ class picker(PostProcessABC):
             daughters = set(self.daughters)
             # daughters, mothers = np.where(mother_bud_mat)
 
-            search = lambda a, b: np.where(
-                np.in1d(
-                    np.ravel_multi_index(np.array(a).T, np.array(a).max(0) + 1),
-                    np.ravel_multi_index(np.array(b).T, np.array(a).max(0) + 1),
+            def search(a, b):
+                return np.where(
+                    np.in1d(
+                        np.ravel_multi_index(
+                            np.array(a).T, np.array(a).max(0) + 1
+                        ),
+                        np.ravel_multi_index(
+                            np.array(b).T, np.array(a).max(0) + 1
+                        ),
+                    )
                 )
-            )
+
             if how == "mothers":
                 idx = mothers
             elif how == "daughters":
@@ -109,7 +119,9 @@ class picker(PostProcessABC):
                     [
                         tuple(x)
                         for m in present_mothers
-                        for x in np.array(self.daughters)[search(self.mothers, m)]
+                        for x in np.array(self.daughters)[
+                            search(self.mothers, m)
+                        ]
                     ]
                 )
 
@@ -120,7 +132,9 @@ class picker(PostProcessABC):
                     [
                         tuple(x)
                         for d in present_daughters
-                        for x in np.array(self.mothers)[search(self.daughters, d)]
+                        for x in np.array(self.mothers)[
+                            search(self.daughters, d)
+                        ]
                     ]
                 )
             elif how == "full_families":
@@ -171,13 +185,6 @@ class picker(PostProcessABC):
 
         if sum([x for y in nested_massign for x in y]):
 
-            idx = set(
-                [
-                    (tid, i + 1)
-                    for tid, x in enumerate(nested_massign)
-                    for i in range(len(x))
-                ]
-            )
             mothers, daughters = zip(
                 *[
                     ((tid, m), (tid, d))
@@ -197,6 +204,7 @@ class picker(PostProcessABC):
         indices = set(signals.index)
         self.mothers, self.daughters = self.get_mothers_daughters()
         for alg, op, *params in self.sequence:
+            new_indices = tuple()
             if indices:
                 if alg == "lineage":
                     param1 = params[0]
@@ -209,7 +217,7 @@ class picker(PostProcessABC):
                         signals.loc[list(indices)], param1, param2
                     )
 
-            if op is "union":
+            if op == "union":
                 # new_indices = new_indices.intersection(set(signals.index))
                 new_indices = indices.union(new_indices)
 
@@ -228,7 +236,8 @@ class picker(PostProcessABC):
         case_mgr = {
             "any_present": lambda s, thresh: any_present(s, thresh),
             "present": lambda s, thresh: s.notna().sum(axis=1) > thresh,
-            "nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1) > thresh,
+            "nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1)
+            > thresh,
             "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh,
             "mb_guess": lambda s, p1, p2: self.mb_guess_wrap(s, p1, p2)
             # "quantile": [np.quantile(signals.values[signals.notna()], threshold)],
@@ -258,16 +267,15 @@ class picker(PostProcessABC):
         nomother = df.drop(mother_id)
         if not len(nomother):
             return []
-        nomother = (  # Clean short-lived cells outside our mother cell's timepoints
-            nomother.loc[
-                nomother.apply(
-                    lambda x: x.first_valid_index()
-                    >= df.loc[mother_id].first_valid_index()
-                    and x.first_valid_index() <= df.loc[mother_id].last_valid_index(),
-                    axis=1,
-                )
-            ]
-        )
+        nomother = nomother.loc[  # Clean short-lived cells outside our mother cell's timepoints
+            nomother.apply(
+                lambda x: x.first_valid_index()
+                >= df.loc[mother_id].first_valid_index()
+                and x.first_valid_index()
+                <= df.loc[mother_id].last_valid_index(),
+                axis=1,
+            )
+        ]
 
         score = -nomother.apply(  # Get slope of candidate daughters
             lambda x: self.get_slope(x.dropna()), axis=1
@@ -299,8 +307,12 @@ class picker(PostProcessABC):
             bud_candidates = pd.DataFrame()
         else:
             # Find the set with the highest number of growing cells and highest avg growth rate for this #
-            mivs = self.max_ind_vertex_sets(cols_sorted.values, min_budgrowth_t)
-            best_set = list(mivs[np.argmin([sum(score.iloc[list(s)]) for s in mivs])])
+            mivs = self.max_ind_vertex_sets(
+                cols_sorted.values, min_budgrowth_t
+            )
+            best_set = list(
+                mivs[np.argmin([sum(score.iloc[list(s)]) for s in mivs])]
+            )
             best_indices = cols_sorted.index[best_set]
 
             start = start.loc[best_indices]
diff --git a/core/processes/savgol.py b/core/processes/savgol.py
index 2584c414..40cbf3d6 100644
--- a/core/processes/savgol.py
+++ b/core/processes/savgol.py
@@ -31,16 +31,23 @@ class savgol(PostProcessABC):
     def run(self, signal: pd.DataFrame):
         try:
             post_savgol = pd.DataFrame(
-                savgol_filter(signal, self.parameters.window, self.parameters.polynom),
+                savgol_filter(
+                    signal, self.parameters.window, self.parameters.polynom
+                ),
                 index=signal.index,
                 columns=signal.columns,
             )
         except Exception as e:
             print(e)
 
-            savgol_on_srs = lambda x: self.non_uniform_savgol(
-                x.index, x.values, self.parameters.window, self.parameters.polynom
-            )
+            def savgol_on_srs(x):
+                return self.non_uniform_savgol(
+                    x.index,
+                    x.values,
+                    self.parameters.window,
+                    self.parameters.polynom,
+                )
+
             post_savgol = signal.apply(savgol_on_srs, 1).apply(pd.Series)
         return post_savgol
 
@@ -72,23 +79,34 @@ class savgol(PostProcessABC):
         np.array of float
             The smoothed y values
         """
-        if len(x) != len(y):
-            raise ValueError('"x" and "y" must be of the same size')
-
-        if len(x) < window:
-            raise ValueError("The data size must be larger than the window size")
-
-        if type(window) is not int:
-            raise TypeError('"window" must be an integer')
-
-        if window % 2 == 0:
-            raise ValueError('The "window" must be an odd integer')
-
-        if type(polynom) is not int:
-            raise TypeError('"polynom" must be an integer')
-
-        if polynom >= window:
-            raise ValueError('"polynom" must be less than "window"')
+        _raiseif(
+            len(x) != len(y),
+            '"x" and "y" must be of the same size',
+            ValueError,
+        )
+        _raiseif(
+            len(x) < window,
+            "The data size must be larger than the window size",
+            ValueError,
+        )
+        _raiseif(
+            not isinstance(window, int),
+            '"window" must be an integer',
+            TypeError,
+        )
+        _raiseif(window % 2, 'The "window" must be an odd integer', ValueError)
+
+        _raiseif(
+            not isinstance(polynom, int),
+            '"polynom" must be an integer',
+            TypeError,
+        )
+
+        _raiseif(
+            polynom >= window,
+            '"polynom" must be less than "window"',
+            ValueError,
+        )
 
         half_window = window // 2
         polynom += 1
@@ -156,3 +174,8 @@ class savgol(PostProcessABC):
                 x_i *= x[i] - x[-half_window - 1]
 
         return y_smoothed
+
+
+def _raiseif(cond, msg="", exc=AssertionError):
+    if cond:
+        raise exc(msg)
diff --git a/core/processes/standardscaler.py b/core/processes/standardscaler.py
deleted file mode 100644
index be2b3e73..00000000
--- a/core/processes/standardscaler.py
+++ /dev/null
@@ -1,50 +0,0 @@
-#!/usr/bin/env python3
-
-import pandas as pd
-from sklearn.preprocessing import StandardScaler
-
-from agora.abc import ParametersABC
-from postprocessor.core.abc import PostProcessABC
-
-
-class standardscalerParameters(ParametersABC):
-    """
-    Parameters for the 'scale' process.
-    """
-
-    _defaults = {}
-
-
-class standardscaler(PostProcessABC):
-    """
-    Process to scale a DataFrame of a signal using the standard scaler.
-
-    Methods
-    -------
-    run(signal: pd.DataFrame)
-        Scale values in a dataframe of time series.
-    """
-
-    def __init__(self, parameters: standardscalerParameters):
-        super().__init__(parameters)
-
-    def run(self, signal: pd.DataFrame):
-        """Scale values in a dataframe of time series.
-
-        Scale values in a dataframe of time series.  This function is effectively a
-        wrapper for sklearn.preprocessing.StandardScaler.
-
-        Parameters
-        ----------
-        signal : pd.DataFrame
-            Time series, with rows indicating individual time series (e.g. from
-            each cell), and columns indicating time points.
-        """
-        signal_array = signal.to_numpy()
-        scaler = StandardScaler().fit(signal_array.transpose())
-        signal_scaled_array = scaler.transform(signal_array.transpose())
-        signal_scaled_array = signal_scaled_array.transpose()
-        signal_scaled = pd.DataFrame(
-            signal_scaled_array, columns=signal.columns, index=signal.index
-        )
-        return signal_scaled
diff --git a/examples/basic_processes.py b/examples/basic_processes.py
deleted file mode 100644
index c322fd4b..00000000
--- a/examples/basic_processes.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from postprocessor.core.processor import PostProcessor, PostProcessorParameters
-
-params = PostProcessorParameters.default()
-pp = PostProcessor(
-    "/shared_libs/pipeline-core/scripts/pH_calibration_dual_phl__ura8__by4741__01/ph_5_29_025store.h5",
-    params,
-)
-tmp = pp.run()
-
-import h5py
-
-# f = h5py.File(
-#     "/shared_libs/pipeline-core/scripts/pH_calibration_dual_phl__ura8__by4741__01/ph_5_29_025store.h5",
-#     "a",
-# )
diff --git a/examples/group.py b/examples/group.py
deleted file mode 100644
index 990f5616..00000000
--- a/examples/group.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from pathlib import Path
-
-from postprocessor.core.group import GroupParameters, Group
-
-poses = [
-    x.name.split("store")[0]
-    for x in Path(
-        "/shared_libs/pipeline-core/scripts/data/ph_calibration_dual_phl_ura8_5_04_5_83_7_69_7_13_6_59__01"
-    ).rglob("*")
-    if x.name != "images.h5"
-]
-
-gr = Group(
-    GroupParameters(
-        signals=[
-            "/extraction/general/None/area",
-            "/extraction/mCherry/np_max/median",
-        ]
-    )
-)
-gr.run(
-    central_store="/shared_libs/pipeline-core/scripts/data/ph_calibration_dual_phl_ura8_5_04_5_83_7_69_7_13_6_59__01",
-    poses=poses,
-)
diff --git a/examples/testing.py b/examples/testing.py
deleted file mode 100644
index d5420d83..00000000
--- a/examples/testing.py
+++ /dev/null
@@ -1,220 +0,0 @@
-from typing import Dict, List, Union
-import re
-
-import numpy as np
-import pandas as pd
-from pandas import Series, DataFrame
-from sklearn.cluster import KMeans
-from matplotlib import pyplot as plt
-import seaborn as sns
-
-import compress_pickle
-
-from postprocessor.core.postprocessor import PostProcessor
-from postprocessor.core.tracks import get_avg_grs, non_uniform_savgol
-from postprocessor.core.tracks import clean_tracks, merge_tracks, join_tracks
-from postprocessor.core.ph import *
-
-sns.set_context("talk", font_scale=1.8)
-sns.set_theme(style="whitegrid")
-
-# pp_c = PostProcessor(source=19920)  # 19916
-# pp_c.load_tiler_cells()
-# # f = "/home/alan/Documents/libs/extraction/extraction/examples/gluStarv_2_0_x2_dual_phl_ura8_00/extraction"
-# # f = "/home/alan/Documents/tmp/pH_calibration_dual_phl__ura8__by4741__01/extraction/"
-# f = "/home/alan/Documents/tmp/pH_calibration_dual_phl__ura8__by4741_Alan4_00/extraction/"
-# pp_c.load_extraction(f)
-
-# calib = process_phs(pp_c)
-# # c = calib.loc[(5.4 < calib["ph"]) & (calib["ph"] < 8)].groupby("ph").mean()
-# sns.violinplot(x="ph", y="em_ratio_mean", data=calib)
-# plt.show()
-
-
-# bring timelapse data and convert it to pH
-
-pp = PostProcessor(source=19831)  # 19831
-pp.load_tiler_cells()
-f = "/home/alan/Documents/tmp/gluStarv_2_0_x2_dual_phl_ura8_00/extraction/"
-# f = "/home/alan/Documents/tmp/downUpshift_2_0_2_glu_dual_phluorin__glt1_psa1_ura7__thrice_00/extraction/"
-pp.load_extraction(f)
-
-import compress_pickle
-
-compress_pickle.dump(pp.extraction, "/home/alan/extraction_example.pkl")
-
-if True:  # Load extracted data or pkld version
-    new_dfs = compress_pickle.load("/home/alan/Documents/tmp/new_dfs.gz")
-# Combine dataframes
-else:
-    new_dfs = combine_dfs(get_dfs(pp))
-    # t = [x.index for x in new_dfs.values()]
-    # i = set(t[0])
-    # for j in t:
-    #     i = i.intersection(j)
-    new_dfs = {
-        k: v
-        for k, v in new_dfs.items()
-        if k[2] != "imBackground"
-        and k[2] != "median"
-        and ~(((k[0] == "GFPFast") | (k[0] == "pHluorin405")) and k[2] == "max2p5pc")
-    }
-
-del pp
-compress_pickle.dump(new_dfs, "/home/alan/Documents/tmp/new_dfs.gz")
-
-
-def get_clean_dfs(dfs=None):
-    if dfs is None:
-        clean_dfs = compress_pickle.load("/home/alan/Documents/tmp/clean_dfs.gz")
-    else:
-
-        # Clean timelapse
-        clean = clean_tracks(new_dfs[("general", None, "area")])
-        tra, joint = merge_tracks(clean)
-        clean_dfs = new_dfs
-        i_ids = set(clean.index).intersection(
-            clean_dfs[("general", None, "area")].index
-        )
-        clean_dfs = {k: v.loc[i_ids] for k, v in clean_dfs.items()}
-        clean_dfs = {k: join_tracks(v, joint, drop=True) for k, v in clean_dfs.items()}
-
-        del new_dfs
-        compress_pickle.dump(clean_dfs, "/home/alan/Documents/tmp/clean_dfs.gz")
-
-
-def plot_ph_hmap(clean_dfs):
-    GFPFast = clean_dfs[("GFPFast", np.maximum, "mean")]
-    phluorin = clean_dfs[("pHluorin405", np.maximum, "mean")]
-    ph = GFPFast / phluorin
-    ph = ph.loc[ph.notna().sum(axis=1) > 0.7 * ph.shape[1]]
-    ph = 1 / ph
-
-    fig, ax = plt.subplots()
-    hmap(ph, cbar_kws={"label": r"emission ratio $\propto$ pH"})
-    plt.xlabel("Time (hours)")
-    plt.ylabel("Cells")
-    xticks = plt.xticks(fontsize=15)[0]
-    ax.set(yticklabels=[], xticklabels=[str(round(i * 5 / 60, 1)) for i in xticks])
-    # plt.setp(ax.get_xticklabels(), Rotation=90)
-    plt.show()
-
-
-def fit_calibs(c, h):
-    h = process_phs(pp_c)
-    h["ratio"] = h["GFPFast_bgsub_median"] / h["pHluorin405_bgsub_median"]
-    sns.lineplot(x="ph", y="ratio", data=h, err_style="bars")
-    plt.show()
-
-    calibs = fit_calibration(c)
-    for k, params in calibs.items():
-        i, j = ("_".join(k.split("_")[:-1]), k.split("_")[-1])
-        if j == "mean" and "k2" not in k:
-            clean_dfs[k] = objective(clean_dfs[i, np.maximum, j], *params)
-
-
-# max 2.5% / med
-def plot_ratio_vs_max2p5(h):
-    fig, ax = plt.subplots()
-    sns.regplot(
-        x="em_ratio_median",
-        y="mCherry_max2p5pc",
-        data=h,
-        scatter=False,
-        ax=ax,
-        color="teal",
-    )
-    sns.scatterplot(x="em_ratio_median", y="max5_d_med", data=h, hue="ph", ax=ax)
-    plt.xlabel(r"Fluorescence ratio $R \propto (1/pH)$")
-    plt.ylabel("Max 2.5% px / median")
-    plt.show()
-
-
-em = clean_dfs[("em_ratio", np.maximum, "mean")]
-area = clean_dfs[("general", None, "area")]
-
-
-def get_grs(clean_dfs):
-    area = clean_dfs[("general", None, "area")]
-    area = area.loc[area.notna().sum(axis=1) > 10]
-    return area.apply(growth_rate, axis=1)
-
-
-def get_agg(dfs, rng):
-    # df dict of DataFrames containing an area/vol one TODO generalise this beyond area
-    # rng tuple of section to use
-    grs = get_grs(dfs)
-    smooth = grs.loc(axis=1)[list(range(rng[0], rng[1]))].dropna(how="all")
-
-    aggregate_mean = lambda dfs, rng: pd.concat(
-        {
-            k[0] + "_" + k[2]: dfs[k].loc[smooth.index, rng[0] : rng[1]].mean(axis=1)
-            for k in clean_dfs.keys()
-        },
-        axis=1,
-    )
-    # f_comp_df = comp_df.loc[(comp_df["gr"] > 0) & (area.notna().sum(axis=1) > 50)]
-
-    agg = aggregate_mean(dfs, rng)
-    agg["max2_med"] = agg["mCherry_max2p5pc"] / agg["mCherry_mean"]
-
-    for c in agg.columns:
-        agg[c + "_log"] = np.log(agg[c])
-
-    agg["gr_mean"] = smooth.loc[set(agg.index).intersection(smooth.index)].mean(axis=1)
-    agg["gr_max"] = smooth.loc[set(agg.index).intersection(smooth.index)].max(axis=1)
-
-    return agg
-
-
-def plot_scatter_fit(x, y, data, hue=None, xlabel=None, ylabel=None, ylim=None):
-    fig, ax = plt.subplots()
-    sns.regplot(x=x, y=y, data=data, scatter=False, ax=ax)
-    sns.scatterplot(x=x, y=y, data=data, ax=ax, alpha=0.1, hue=hue)
-    # plt.show()
-    if xlabel is not None:
-        plt.xlabel(xlabel)
-    if ylabel is not None:
-        plt.ylabel(ylabel)
-    if ylim is not None:
-        plt.ylim(ylim)
-
-    fig.savefig(
-        "/home/alan/Documents/sync_docs/drafts/third_year_pres/figs/"
-        + str(len(data))
-        + "_"
-        + x
-        + "_vs_"
-        + y
-        + ".png",
-        dpi=200,
-    )
-
-
-from extraction.core.argo import Argo, annot_from_dset
-
-
-def additional_feats(aggs):
-    aggs["gr_mean_norm"] = aggs["gr_mean"] * 12
-    aggs["log_ratio_r"] = np.log(1 / aggs["em_ratio_mean"])
-    return aggs
-
-
-def compare_methods_ph_calculation(dfs):
-    GFPFast = dfs[("GFPFast", np.maximum, "mean")]
-    phluorin = dfs[("pHluorin405", np.maximum, "mean")]
-    ph = GFPFast / phluorin
-
-    sns.scatterplot(
-        dfs["em_ratio", np.maximum, "mean"].values.flatten(),
-        ph.values.flatten(),
-        alpha=0.1,
-    )
-    plt.xlabel("ratio_median")
-    plt.ylabel("median_ratio")
-    plt.title("Comparison of ph calculation")
-    plt.show()
-
-
-# get the number of changes in a bool list
-nchanges = lambda x: sum([i for i, j in zip(x[:-2], x[1:]) if operator.xor(i, j)])
diff --git a/grouper.py b/grouper.py
index a3b947b3..982522b7 100644
--- a/grouper.py
+++ b/grouper.py
@@ -19,9 +19,7 @@ from agora.io.signal import Signal
 
 
 class Grouper(ABC):
-    """
-    Base grouper class
-    """
+    """Base grouper class."""
 
     files = []
 
@@ -103,8 +101,6 @@ class Grouper(ABC):
         Examples
         --------
         FIXME: Add docs.
-
-
         """
         if not path.startswith("/"):
             path = "/" + path
@@ -168,15 +164,13 @@ class Grouper(ABC):
 
 
 class MetaGrouper(Grouper):
-    """Group positions using metadata's 'group' number"""
+    """Group positions using metadata's 'group' number."""
 
     pass
 
 
 class NameGrouper(Grouper):
-    """
-    Group a set of positions using a subsection of the name
-    """
+    """Group a set of positions using a subsection of the name."""
 
     def __init__(self, dir, by=None):
         super().__init__(dir=dir)
@@ -216,10 +210,8 @@ class NameGrouper(Grouper):
 
 
 class phGrouper(NameGrouper):
-    """
-    Grouper for pH calibration experiments where all surveyed media pH values
-    are within a single experiment.
-    """
+    """Grouper for pH calibration experiments where all surveyed media pH
+    values are within a single experiment."""
 
     def __init__(self, dir, by=(3, 7)):
         super().__init__(dir=dir, by=by)
@@ -235,9 +227,7 @@ class phGrouper(NameGrouper):
         return float(group_name.replace("_", "."))
 
     def aggregate_multisignals(self, paths: list) -> pd.DataFrame:
-        """
-        Accumulate multiple signals
-        """
+        """Accumulate multiple signals."""
 
         aggregated = pd.concat(
             [
@@ -269,9 +259,8 @@ def concat_signal_ind(
     mode: str = "retained",
     **kwargs,
 ) -> pd.DataFrame:
-    """
-    Core function that handles retrieval of an individual signal, applies filtering if requested and adjusts indices.
-    """
+    """Core function that handles retrieval of an individual signal, applies
+    filtering if requested and adjusts indices."""
     if mode == "retained":
         combined = signal.retained(path, **kwargs)
     if mode == "mothers":
@@ -291,9 +280,8 @@ def concat_signal_ind(
 
 
 class MultiGrouper:
-    """
-    Wrap results from multiple experiments stored as folders inside a folder.
-    """
+    """Wrap results from multiple experiments stored as folders inside a
+    folder."""
 
     def __init__(self, source: Union[str, list]):
         if isinstance(source, str):
@@ -312,9 +300,8 @@ class MultiGrouper:
 
     @property
     def sigtable(self) -> pd.DataFrame:
-        """
-        Generate a matrix containing the number of datasets for each signal and experiment
-        """
+        """Generate a matrix containing the number of datasets for each signal
+        and experiment."""
 
         def regex_cleanup(x):
             x = re.sub(r"\/extraction\/", "", x)
@@ -360,7 +347,6 @@ class MultiGrouper:
         Examples
         --------
         FIXME: Add docs.
-
         """
         ax = sns.heatmap(self.sigtable, cmap="viridis")
         ax.set_xticklabels(
@@ -393,7 +379,6 @@ class MultiGrouper:
         Examples
         --------
         FIXME: Add docs.
-
         """
         if isinstance(signals, str):
             signals = [signals]
diff --git a/routines/heatmap.py b/routines/heatmap.py
index d2a8927c..e3616a47 100644
--- a/routines/heatmap.py
+++ b/routines/heatmap.py
@@ -91,12 +91,12 @@ class _HeatmapPlotter(BasePlotter):
                 births_array == 0, births_array
             )
             # Overlay
-            births_heatmap = ax.imshow(
+            ax.imshow(
                 births_heatmap_mask,
                 interpolation="none",
             )
         # Draw colour bar
-        colorbar = ax.figure.colorbar(
+        ax.figure.colorbar(
             mappable=trace_heatmap, cax=cax, ax=ax, label=self.colorbarlabel
         )
 
diff --git a/routines/mean_plot.py b/routines/mean_plot.py
index e2594483..1d4b13f9 100644
--- a/routines/mean_plot.py
+++ b/routines/mean_plot.py
@@ -70,7 +70,7 @@ def mean_plot(
     error_color="lightblue",
     mean_linestyle="-",
     xlabel="Time (min)",
-    ylabel=f"Normalised flavin fluorescence (AU)",
+    ylabel="Normalised flavin fluorescence (AU)",
     plot_title="",
     ax=None,
 ):
diff --git a/routines/median_plot.py b/routines/median_plot.py
index 573263b3..6e5a9f66 100644
--- a/routines/median_plot.py
+++ b/routines/median_plot.py
@@ -3,6 +3,7 @@
 import numpy as np
 import matplotlib.pyplot as plt
 
+
 from postprocessor.routines.plottingabc import BasePlotter
 
 
@@ -71,11 +72,12 @@ def median_plot(
     error_color="lightblue",
     median_linestyle="-",
     xlabel="Time (min)",
-    ylabel=f"Normalised flavin fluorescence (AU)",
+    ylabel="Normalised flavin fluorescence (AU)",
     plot_title="",
     ax=None,
 ):
-    """Plot median time series of a DataFrame, with interquartile range shading.
+    """Plot median time series of a DataFrame, with interquartile range
+    shading.
 
     Parameters
     ----------
@@ -105,7 +107,6 @@ def median_plot(
     Examples
     --------
     FIXME: Add docs.
-
     """
     plotter = _MedianPlotter(
         trace_df,
-- 
GitLab