From 04546d20c4eeebbf63ec0a42d2b5a80b2a953ac1 Mon Sep 17 00:00:00 2001
From: Alan Munoz <afer.mg@gmail.com>
Date: Mon, 7 Jun 2021 13:14:08 +0100
Subject: [PATCH] restructure postprocessor

Former-commit-id: 955b36c88d92d1ee4c0f68947b81e8660fc0dec5
---
 core/{ => functions}/tracks.py  | 207 ++++++++++--------
 core/io/writer.py               |  16 ++
 core/merger.py                  |   5 +
 core/{ => old}/cell.py          |   0
 core/old/ph.py                  | 364 ++++++++++++++++++++++++++++++++
 core/{ => old}/postprocessor.py |  40 +++-
 core/picker.py                  |  98 +++++++++
 core/processor.py               |   0
 examples/testing.py             | 352 +++++++++++++++---------------
 9 files changed, 819 insertions(+), 263 deletions(-)
 rename core/{ => functions}/tracks.py (76%)
 create mode 100644 core/io/writer.py
 create mode 100644 core/merger.py
 rename core/{ => old}/cell.py (100%)
 create mode 100644 core/old/ph.py
 rename core/{ => old}/postprocessor.py (65%)
 create mode 100644 core/picker.py
 create mode 100644 core/processor.py

diff --git a/core/tracks.py b/core/functions/tracks.py
similarity index 76%
rename from core/tracks.py
rename to core/functions/tracks.py
index 71002b72..bff48680 100644
--- a/core/tracks.py
+++ b/core/functions/tracks.py
@@ -1,6 +1,6 @@
-'''
+"""
 Functions to process, filter and merge tracks.
-'''
+"""
 
 # from collections import Counter
 
@@ -10,69 +10,87 @@ from typing import Union, List
 import numpy as np
 import pandas as pd
 
-from scipy.signal   import savgol_filter
+import more_itertools as mit
+from scipy.signal import savgol_filter
+
 # from scipy.optimize import linear_sum_assignment
 # from scipy.optimize import curve_fit
 
 from matplotlib import pyplot as plt
 
+
 def load_test_dset():
     # Load development dataset to test functions
-    return pd.DataFrame({('a',1,1):[2, 5, np.nan, 6,8] + [np.nan] * 5,
-                         ('a',1,2):list(range(2,12)),
-                         ('a',1,3):[np.nan] * 8 + [6,7],
-                         ('a',1,4):[np.nan] * 5 + [9,12,10,14,18]},
-                        index=range(1,11)).T
-
-def get_ntps(track:pd.Series) -> int:
+    return pd.DataFrame(
+        {
+            ("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5,
+            ("a", 1, 2): list(range(2, 12)),
+            ("a", 1, 3): [np.nan] * 8 + [6, 7],
+            ("a", 1, 4): [np.nan] * 5 + [9, 12, 10, 14, 18],
+        },
+        index=range(1, 11),
+    ).T
+
+
+def max_ntps(track: pd.Series) -> int:
     # Get number of timepoints
     indices = np.where(track.notna())
     return np.max(indices) - np.min(indices)
 
 
-def get_tracks_ntps(tracks:pd.DataFrame) -> pd.Series:
+def max_nonstop_ntps(track: pd.Series) -> int:
+    nona_tracks = track.notna()
+    consecutive_nonas_grouped = [
+        len(list(x)) for x in mit.consecutive_groups(np.flatnonzero(nona_tracks))
+    ]
+    return max(consecutive_nonas_grouped)
+
+
+def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series:
     return tracks.apply(get_ntps, axis=1)
 
-def get_avg_gr(track:pd.Series) -> int:
-    '''
+
+def get_avg_gr(track: pd.Series) -> int:
+    """
     Get average growth rate for a track.
 
     :param tracks: Series with volume and timepoints as indices
-    '''
+    """
     ntps = get_ntps(track)
     vals = track.dropna().values
-    gr = (vals[-1] - vals[0] )/ ntps
+    gr = (vals[-1] - vals[0]) / ntps
     return gr
 
 
-def get_avg_grs(tracks:pd.DataFrame) -> pd.DataFrame:
-    '''
+def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame:
+    """
     Get average growth rate for a group of tracks
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
         columns are timepoints
-    '''
+    """
     return tracks.apply(get_avg_gr, axis=1)
 
 
-def clean_tracks(tracks, min_len:int=6, min_gr:float=0.5) -> pd.DataFrame:
-    '''
+def clean_tracks(tracks, min_len: int = 15, min_gr: float = 1.0) -> pd.DataFrame:
+    """
     Clean small non-growing tracks and return the reduced dataframe
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
         columns are timepoints
     :param min_len: int number of timepoints cells must have not to be removed
     :param min_gr: float Minimum mean growth rate to assume an outline is growing
-    '''
+    """
     ntps = get_tracks_ntps(tracks)
     grs = get_avg_grs(tracks)
 
     growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
 
-    return (growing_long_tracks)
+    return growing_long_tracks
+
 
 def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame:
-    '''
+    """
     Join tracks that are contiguous and within a volume threshold of each other
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
@@ -84,7 +102,7 @@ def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame:
     :joint_tracks: (m x n) Dataframe where rows are cell tracks and
         columns are timepoints. Merged tracks are still present but filled
         with np.nans.
-    '''
+    """
 
     # calculate tracks that can be merged until no more traps can be merged
     joinable_pairs = get_joinable(tracks, **kwargs)
@@ -92,11 +110,11 @@ def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame:
         tracks = join_tracks(tracks, joinable_pairs, drop=drop)
     joint_ids = get_joint_ids(joinable_pairs)
 
-    return (tracks, joint_ids)
+    return (tracks, joinable_pairs)
 
 
 def get_joint_ids(merging_seqs) -> dict:
-    '''
+    """
     Convert a series of merges into a dictionary where
     the key is the cell_id of destination and the value a list
     of the other track ids that were merged into the key
@@ -119,7 +137,7 @@ def get_joint_ids(merging_seqs) -> dict:
 
     output {a:a, b:a, c:a, d:a}
 
-    '''
+    """
     targets, origins = list(zip(*merging_seqs))
     static_tracks = set(targets).difference(origins)
 
@@ -127,13 +145,17 @@ def get_joint_ids(merging_seqs) -> dict:
     for target, origin in merging_seqs:
         joint[origin] = target
 
-    moved_target = [k for k,v in joint.items() \
-                      if joint[v]!=v and v in joint.values()]
+    moved_target = [
+        k for k, v in joint.items() if joint[v] != v and v in joint.values()
+    ]
 
     for orig in moved_target:
         joint[orig] = rec_bottom(joint, orig)
 
-    return {k:v for k,v in joint.items() if k!=v} # remove ids that point to themselves
+    return {
+        k: v for k, v in joint.items() if k != v
+    }  # remove ids that point to themselves
+
 
 def rec_bottom(d, k):
     if d[k] == k:
@@ -141,8 +163,9 @@ def rec_bottom(d, k):
     else:
         return rec_bottom(d, d[k])
 
+
 def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame:
-    '''
+    """
     Join pairs of tracks from later tps towards the start.
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
@@ -155,8 +178,7 @@ def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame:
         with np.nans.
     :param drop: bool indicating whether or not to drop moved rows
 
-    '''
-
+    """
 
     tmp = copy(tracks)
     for target, source in joinable_pairs:
@@ -165,7 +187,8 @@ def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame:
         if drop:
             tmp = tmp.drop(source)
 
-    return (tmp)
+    return tmp
+
 
 def join_track_pairs(track1, track2):
     tmp = copy(track1)
@@ -173,8 +196,9 @@ def join_track_pairs(track1, track2):
 
     return tmp
 
+
 def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
-    '''
+    """
     Get the pair of track (without repeats) that have a smaller error than the
     tolerance. If there is a track that can be assigned to two or more other
     ones, it chooses the one with a lowest error.
@@ -187,39 +211,48 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
     :param window: int value of window used for savgol_filter
     :param degree: int value of polynomial degree passed to savgol_filter
 
-    '''
+    """
 
-    tracks.index.names = ['pos', 'trap', 'cell'] #TODO remove this once it is integrated in the tracker
+    tracks.index.names = [
+        "pos",
+        "trap",
+        "cell",
+    ]  # TODO remove this once it is integrated in the tracker
     # contig=tracks.groupby(['pos','trap']).apply(tracks2contig)
-    clean = clean_tracks(tracks, min_len=window+1, min_gr = 0.9) # get useful tracks
-    contig=clean.groupby(['pos','trap']).apply(get_contiguous_pairs)
+    clean = clean_tracks(tracks, min_len=window + 1, min_gr=0.9)  # get useful tracks
+    contig = clean.groupby(["pos", "trap"]).apply(get_contiguous_pairs)
     contig = contig.loc[contig.apply(len) > 0]
     # candict = {k:v for d in contig.values for k,v in d.items()}
 
     # smooth all relevant tracks
 
-    linear=set([k for v in contig.values for i in v for j in i for k in j])
-    if smooth: # Apply savgol filter TODO fix nans affecting edge placing
-        savgol_on_srs = lambda x: non_uniform_savgol(x.index, x.values,
-                                            window, degree)
-        smoothed_tracks = clean.loc[linear].apply(savgol_on_srs,1)
+    linear = set([k for v in contig.values for i in v for j in i for k in j])
+    if smooth:  # Apply savgol filter TODO fix nans affecting edge placing
+        savgol_on_srs = lambda x: non_uniform_savgol(x.index, x.values, window, degree)
+        smoothed_tracks = clean.loc[linear].apply(savgol_on_srs, 1)
     else:
         smoothed_tracks = clean.loc[linear].apply(lambda x: np.array(x.values), axis=1)
 
     # fetch edges from ids TODO (IF necessary, here we can compare growth rates)
-    idx_to_edge = lambda preposts: [([get_val(smoothed_tracks.loc[pre],-1) for pre in pres],
-                            [get_val(smoothed_tracks.loc[post],0) for post in posts])
-                            for pres, posts in preposts]
+    idx_to_edge = lambda preposts: [
+        (
+            [get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
+            [get_val(smoothed_tracks.loc[post], 0) for post in posts],
+        )
+        for pres, posts in preposts
+    ]
     edges = contig.apply(idx_to_edge)
 
     closest_pairs = edges.apply(get_vec_closest_pairs, tol=tol)
 
-    #match local with global ids
-    joinable_ids = [localid_to_idx(closest_pairs.loc[i], contig.loc[i])\
-                    for i in closest_pairs.index]
+    # match local with global ids
+    joinable_ids = [
+        localid_to_idx(closest_pairs.loc[i], contig.loc[i]) for i in closest_pairs.index
+    ]
 
     return [pair for pairset in joinable_ids for pair in pairset]
 
+
 get_val = lambda x, n: x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
 
 
@@ -234,17 +267,18 @@ def localid_to_idx(local_ids, contig_trap):
     list of pairs with (experiment-level) ids to be joint
     """
     lin_pairs = []
-    for i,pairs in enumerate(local_ids):
+    for i, pairs in enumerate(local_ids):
         if len(pairs):
-            for left,right in pairs:
-                lin_pairs.append((contig_trap[i][0][left],
-                                  contig_trap[i][1][right]))
+            for left, right in pairs:
+                lin_pairs.append((contig_trap[i][0][left], contig_trap[i][1][right]))
     return lin_pairs
 
-def get_vec_closest_pairs(lst:List, **kwags):
+
+def get_vec_closest_pairs(lst: List, **kwags):
     return [get_closest_pairs(*l, **kwags) for l in lst]
 
-def get_closest_pairs(pre:List[float], post:List[float], tol:Union[float,int]=1):
+
+def get_closest_pairs(pre: List[float], post: List[float], tol: Union[float, int] = 1):
     """Calculate a cost matrix the Hungarian algorithm to pick the best set of
     options
 
@@ -258,22 +292,25 @@ def get_closest_pairs(pre:List[float], post:List[float], tol:Union[float,int]=1)
 
     """
     if len(pre) > len(post):
-        dMetric = np.abs(np.subtract.outer(post,pre))
+        dMetric = np.abs(np.subtract.outer(post, pre))
     else:
-        dMetric = np.abs(np.subtract.outer(pre,post))
+        dMetric = np.abs(np.subtract.outer(pre, post))
     # dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered
     # ids = linear_sum_assignment(dMetric)
-    dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered
+    dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric)  # nans will be filtered
 
     ids = solve_matrix(dMetric)
     if not len(ids[0]):
         return []
 
-    norm = np.array(pre)[ids[len(pre)>len(post)]] if tol<1 else 1 # relative or absolute tol
-    result = dMetric[ids]/norm
-    ids = ids if len(pre)<len(post) else ids[::-1]
+    norm = (
+        np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
+    )  # relative or absolute tol
+    result = dMetric[ids] / norm
+    ids = ids if len(pre) < len(post) else ids[::-1]
+
+    return [idx for idx, res in zip(zip(*ids), result) if res < tol]
 
-    return [idx for idx,res in zip(zip(*ids), result) if res < tol]
 
 def solve_matrix(dMetric):
     """
@@ -288,70 +325,72 @@ def solve_matrix(dMetric):
     glob_is = []
     glob_js = []
     if (~np.isnan(dMetric)).any():
-        tmp = copy(dMetric )
+        tmp = copy(dMetric)
         std = sorted(tmp[~np.isnan(tmp)])
         while (~np.isnan(std)).any():
             v = std[0]
-            i_s,j_s = np.where(tmp==v)
+            i_s, j_s = np.where(tmp == v)
             i = i_s[0]
             j = j_s[0]
-            tmp[i,:]+= np.nan
-            tmp[:,j]+= np.nan
+            tmp[i, :] += np.nan
+            tmp[:, j] += np.nan
             glob_is.append(i)
             glob_js.append(j)
 
             std = sorted(tmp[~np.isnan(tmp)])
 
+    return (np.array(glob_is), np.array(glob_js))
 
-    return (np.array( glob_is ), np.array( glob_js ))
 
 def plot_joinable(tracks, joinable_pairs, max=64):
     """
     Convenience plotting function for debugging and data vis
     """
 
-    nx=8
-    ny=8
-    _, axes = plt.subplots(nx,ny)
+    nx = 8
+    ny = 8
+    _, axes = plt.subplots(nx, ny)
     for i in range(nx):
         for j in range(ny):
-            if i*ny+j < len(joinable_pairs):
+            if i * ny + j < len(joinable_pairs):
                 ax = axes[i, j]
                 pre, post = joinable_pairs[i * ny + j]
                 pre_srs = tracks.loc[pre].dropna()
                 post_srs = tracks.loc[post].dropna()
-                ax.plot(pre_srs.index, pre_srs.values , 'b')
+                ax.plot(pre_srs.index, pre_srs.values, "b")
                 # try:
                 #     totrange = np.arange(pre_srs.index[0],post_srs.index[-1])
                 #     ax.plot(totrange, interpolate(pre_srs, totrange), 'r-')
                 # except:
                 #     pass
-                ax.plot(post_srs.index, post_srs.values, 'g')
+                ax.plot(post_srs.index, post_srs.values, "g")
 
     plt.show()
 
+
 def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
-    '''
+    """
     Get all pair of contiguous track ids from a tracks dataframe.
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
         columns are timepoints
     :param min_dgr: float minimum difference in growth rate from the interpolation
-    '''
+    """
     # indices = np.where(tracks.notna())
 
-
-    mins, maxes = [tracks.notna().apply(np.where, axis=1).apply(fn)
-                   for fn in (np.min, np.max)]
+    mins, maxes = [
+        tracks.notna().apply(np.where, axis=1).apply(fn) for fn in (np.min, np.max)
+    ]
 
     mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
-    mins_d.index = mins_d.index - 1 # make indices equal
+    mins_d.index = mins_d.index - 1  # make indices equal
     maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist())
 
     common = sorted(set(mins_d.index).intersection(maxes_d.index), reverse=True)
 
     return [(maxes_d[t], mins_d[t]) for t in common]
 
+
 # def fit_track(track: pd.Series, obj=None):
 #     if obj is None:
 #         obj = objective
@@ -416,7 +455,7 @@ def non_uniform_savgol(x, y, window, polynom):
         raise ValueError('"x" and "y" must be of the same size')
 
     if len(x) < window:
-        raise ValueError('The data size must be larger than the window size')
+        raise ValueError("The data size must be larger than the window size")
 
     if type(window) is not int:
         raise TypeError('"window" must be an integer')
@@ -434,9 +473,9 @@ def non_uniform_savgol(x, y, window, polynom):
     polynom += 1
 
     # Initialize variables
-    A = np.empty((window, polynom))     # Matrix
-    tA = np.empty((polynom, window))    # Transposed matrix
-    t = np.empty(window)                # Local x variables
+    A = np.empty((window, polynom))  # Matrix
+    tA = np.empty((polynom, window))  # Transposed matrix
+    t = np.empty(window)  # Local x variables
     y_smoothed = np.full(len(y), np.nan)
 
     # Start smoothing
diff --git a/core/io/writer.py b/core/io/writer.py
new file mode 100644
index 00000000..de1660e0
--- /dev/null
+++ b/core/io/writer.py
@@ -0,0 +1,16 @@
+import numpy as np
+import pandas as pd
+
+def Writer(filename):
+    def __init__(self, filename):
+        self._file = h5py.File(filename)
+
+    def write(self, address, data):
+        self._file.add_group(address)
+        if type(data) is pd.DataFrame:
+            self.write_df(address, data)
+        elif type(data) is np.array:
+            self.write_np(address, data)
+
+    def write_df(self, adress, df):
+        self._file.get(address)[()] = data
diff --git a/core/merger.py b/core/merger.py
new file mode 100644
index 00000000..6000a2b5
--- /dev/null
+++ b/core/merger.py
@@ -0,0 +1,5 @@
+# Classes in charge of merging tracks
+
+class Merger(BasePicker):
+    def __init__(self, parameters):
+        pass
diff --git a/core/cell.py b/core/old/cell.py
similarity index 100%
rename from core/cell.py
rename to core/old/cell.py
diff --git a/core/old/ph.py b/core/old/ph.py
new file mode 100644
index 00000000..a28a8055
--- /dev/null
+++ b/core/old/ph.py
@@ -0,0 +1,364 @@
+from typing import Dict, List, Union
+import re
+
+import numpy as np
+import pandas as pd
+from pandas import Series, DataFrame
+from sklearn.cluster import KMeans
+from matplotlib import pyplot as plt
+import seaborn as sns
+
+import compress_pickle
+
+from postprocessor.core.postprocessor import PostProcessor
+from postprocessor.core.tracks import get_avg_grs, non_uniform_savgol
+from postprocessor.core.ph import *
+
+
+def filter_by_gfp(dfs):
+    gfps = pd.concat([t[("GFPFast_bgsub", np.maximum, "median")] for t in dfs])
+    avgs_gfp = gfps.mean(axis=1)
+    high_gfp = get_high_k2(avgs_gfp)
+    # high_gfp = avgs_gfp[avgs_gfp > 200]
+
+    return high_gfp
+
+
+def filter_by_area(dfs, min=50):
+    areas = pd.concat([t[("general", None, "area")] for t in dfs])
+    avgs_areas = areas[(areas.notna().sum(axis=1) > areas.shape[1] // (1.25))].mean(
+        axis=1
+    )
+    avgs_areas = avgs_areas[(avgs_areas > min)]
+
+    return avgs_areas
+
+
+def get_high_k2(df):
+    kmeans = KMeans(n_clusters=2)
+    vals = df.values.reshape(-1, 1)
+    kmeans.fit(vals)
+    high_clust_id = kmeans.cluster_centers_.argmax()
+
+    return df.loc[kmeans.predict(vals) == high_clust_id]
+
+
+def get_concats(dfs, keys):
+    return pd.concat([t.get((keys), pd.DataFrame()) for t in dfs])
+
+
+def get_dfs(pp):
+    dfs = [pp.extraction[pp.expt.positions[i]] for i in range(len(pp.expt.positions))]
+    return dfs
+
+
+combine_dfs = lambda dfs: {k: get_concats(dfs, k) for k in dfs[0].keys()}
+
+
+def merge_channels(pp, min_area=50):
+    dfs = get_dfs(pp)
+    # rats = get_concats(dfs, ("em_ratio_bgsub", np.maximum, "median"))
+    # gfps = filter_by_gfp(dfs)
+
+    avgs_area = filter_by_area(dfs, min=50)
+    # ids = [x for x in set(gfps.index).intersection(avgs_area.index)]
+    ids = avgs_area.index
+
+    new_dfs = combine_dfs(dfs)
+
+    h = pd.DataFrame(
+        {
+            k[0] + "_" + k[2]: v.loc[ids].mean(axis=1)
+            for k, v in new_dfs.items()
+            if k[-1] != "imBackground"
+        }
+    )
+    return h
+
+
+def process_phs(pp, min_area=200):
+    h = merge_channels(pp, min_area)
+    h.index.names = ["pos", "trap", "cell"]
+    ids = h.index
+    h = h.reset_index()
+
+    h["ph"] = h["pos"].apply(lambda x: float(x[3:7].replace("_", ".")))
+    h["max5_d_med"] = h["mCherry_max2p5pc"] / h["mCherry_median"]
+
+    h = h.set_index(ids)
+    h = h.drop(["pos", "trap", "cell"], axis=1)
+    return h
+
+
+def growth_rate(
+    data: Series, alg=None, filt={"kind": "savgol", "window": 5, "degree": 3}
+):
+    window = filt["window"]
+    degree = filt["degree"]
+    if alg is None:
+        alg = "standard"
+
+    if filt:  # TODO add support for multiple algorithms
+        data = Series(
+            non_uniform_savgol(
+                data.dropna().index, data.dropna().values, window, degree
+            ),
+            index=data.dropna().index,
+        )
+
+    return Series(np.convolve(data, diff_kernel, "same"), index=data.dropna().index)
+
+
+# import numpy as np
+
+# diff_kernel = np.array([1, -1])
+# gr = clean.apply(growth_rate, axis=1)
+# from postprocessor.core.tracks import non_uniform_savgol, clean_tracks
+
+
+def sort_df(df, by="first", rev=True):
+    nona = df.notna()
+    if by == "len":
+        idx = nona.sum(axis=1)
+    elif by == "first":
+        idx = nona.idxmax(axis=1)
+    idx = idx.sort_values().index
+
+    if rev:
+        idx = idx[::-1]
+
+    return df.loc[idx]
+
+
+# test = tmp[("GFPFast", np.maximum, "median")]
+# test2 = tmp[("pHluorin405", np.maximum, "median")]
+# ph = test / test2
+# ph = ph.stack().reset_index(1)
+# ph.columns = ["tp", "fl"]
+
+
+def m2p5_med(ext, ch, red=np.maximum):
+    m2p5pc = ext[(ch, red, "max2p5pc")]
+    med = ext[(ch, red, "median")]
+
+    result = m2p5pc / med
+
+    return result
+
+
+def plot_avg(df):
+    df = df.stack().reset_index(1)
+    df.columns = ["tp", "val"]
+
+    sns.relplot(x=df["tp"], y=df["val"], kind="line")
+    plt.show()
+
+
+def split_data(df: DataFrame, splits: List[int]):
+    dfs = [df.iloc[:, i:j] for i, j in zip((0,) + splits, splits + (df.shape[1],))]
+    return dfs
+
+
+def growth_rate(
+    data: Series, alg=None, filt={"kind": "savgol", "window": 7, "degree": 3}
+):
+    if alg is None:
+        alg = "standard"
+
+    if filt:  # TODO add support for multiple algorithms
+        window = filt["window"]
+        degree = filt["degree"]
+        data = Series(
+            non_uniform_savgol(
+                data.dropna().index, data.dropna().values, window, degree
+            ),
+            index=data.dropna().index,
+        )
+
+    diff_kernel = np.array([1, -1])
+
+    return Series(np.convolve(data, diff_kernel, "same"), index=data.dropna().index)
+
+
+# pp = PostProcessor(source=19831)
+# pp.load_tiler_cells()
+# f = "/home/alan/Documents/sync_docs/libs/postproc/gluStarv_2_0_x2_dual_phl_ura8_00/extraction"
+# pp.load_extraction(
+#     "/home/alan/Documents/sync_docs/libs/postproc/postprocessor/"
+#     + pp.expt.name
+#     + "/extraction/"
+# )
+# tmp = pp.extraction["phl_ura8_002"]
+
+
+def _check_bg(data):
+    for k in list(pp.extraction.values())[0].keys():
+        for p in pp.expt.positions:
+            if k not in pp.extraction[p]:
+                print(p, k)
+
+
+# data = {
+#     k: pd.concat([pp.extraction[pos][k] for pos in pp.expt.positions[:-3]])
+#     for k in list(pp.extraction.values())[0].keys()
+# }
+
+
+def hmap(df, **kwargs):
+    g = sns.heatmap(sort_df(df), robust=True, cmap="mako_r", **kwargs)
+    plt.xlabel("")
+    return g
+
+
+# from random import randint
+# x = randint(0, len(smooth))
+# plt.plot(clean.iloc[x], 'b')
+# plt.plot(smooth.iloc[x], 'r')
+# plt.show()
+
+
+# data = tmp
+# df = data[("general", None, "area")]
+# clean = clean_tracks(df, min_len=160)
+# clean = clean.loc[clean.notna().sum(axis=1) > 9]
+# gr = clean.apply(growth_rate, axis=1)
+# splits = (72, 108, 180)
+# gr_sp = split_data(gr, splits)
+
+# idx = gr.index
+
+# bg = get_bg(data)
+# test = data[("GFPFast", np.maximum, "median")]
+# test2 = data[("pHluorin405", np.maximum, "median")]
+# ph = (test / test2).loc[idx]
+# c = pd.concat((ph.mean(1), gr.max(1)), axis=1)
+# c.columns = ["ph", "gr_max"]
+# # ph = ph.stack().reset_index(1)
+# # ph.columns = ['tp', 'fl']
+
+# ph_sp = split_data(gr, splits)
+
+
+def get_bg(data):
+    bg = {}
+    fl_subkeys = [
+        x
+        for x in data.keys()
+        if x[0] in ["GFP", "GFPFast", "mCherry", "pHluorin405"]
+        and x[-1] != "imBackground"
+    ]
+    for k in fl_subkeys:
+        nk = list(k)
+        bk = tuple(nk[:-1] + ["imBackground"])
+        nk = tuple(nk[:-1] + [nk[-1] + "_BgSub"])
+        tmp = []
+        for i, v in data[bk].iterrows():
+            if i in data[k].index:
+                newdf = data[k].loc[i] / v
+                newdf.index = pd.MultiIndex.from_tuples([(*i, c) for c in newdf.index])
+            tmp.append(newdf)
+        bg[nk] = pd.concat(tmp)
+
+    return bg
+
+
+def calc_ph(bg):
+    fl_subkeys = [x for x in bg.keys() if x[0] in ["GFP", "GFPFast", "pHluorin405"]]
+    chs = list(set([x[0] for x in fl_subkeys]))
+    assert len(chs) == 2, "Too many channels"
+    ch1 = [x[1:] for x in fl_subkeys if x[0] == chs[0]]
+    ch2 = [x[1:] for x in fl_subkeys if x[0] == chs[1]]
+    inter = list(set(ch1).intersection(ch2))
+    ph = {}
+    for red_fld in inter:
+        ph[tuple(("ph",) + red_fld)] = (
+            bg[tuple((chs[0],) + red_fld)] / bg[tuple((chs[1],) + red_fld)]
+        )
+
+
+def get_traps(pp):
+    t0 = {}
+    for pos in pp.tiler.positions:
+        pp.tiler.current_position = pos
+        t0[pos] = pp.tiler.get_traps_timepoint(
+            0, channels=[0, pp.tiler.channels.index("mCherry")], z=[0, 1, 2, 3, 4]
+        )
+
+    return t0
+
+
+def get_pos_ph(pp):
+    pat = re.compile(r"ph_([0-9]_[0-9][0-9])")
+    return {
+        pos: float(pat.findall(pos)[0].replace("_", ".")) for pos in pp.tiler.positions
+    }
+
+
+def plot_sample_bf_mch(pp):
+    bf_mch = get_traps(pp)
+    ts = [{i: v[:, j, ...] for i, v in bf_mch.items()} for j in [0, 1]]
+    tsbf = {i: v[:, 0, ...] for i, v in bf_mch.items()}
+
+    posdict = {k: v for k, v in get_pos_ph(pp).items()}
+    posdict = {v: k for k, v in posdict.items()}
+    posdict = {v: k for k, v in posdict.items()}
+    ph = np.unique(list(posdict.values())).tolist()
+    counters = {ph: 0 for ph in ph}
+    n = [np.random.randint(ts[0][k].shape[0]) for k in posdict.keys()]
+
+    fig, axes = plt.subplots(2, 5)
+    for k, (t, name) in enumerate(zip(ts, ["Bright field", "mCherry"])):
+        for i, (pos, ph) in enumerate(posdict.items()):
+            # i = ph.index(posdict[pos])
+            axes[k, i].grid(False)
+            axes[k, i].set(
+                xticklabels=[],
+                yticklabels=[],
+            )
+            axes[k, i].set_xlabel(posdict[pos] if k else None, fontsize=28)
+            axes[k, i].imshow(
+                np.maximum.reduce(t[pos][n[i], 0], axis=2),
+                cmap="gist_gray" if not k else None,
+            )
+            # counters[posdict[pos]] += 1
+        plt.tick_params(
+            axis="x",  # changes apply to the x-axis
+            which="both",  # both major and minor ticks are affected
+            bottom=False,  # ticks along the bottom edge are off
+            top=False,  # ticks along the top edge are off
+            labelbottom=False,
+        )  # labels along the bottom edge are off
+        axes[k, 0].set_ylabel(name, fontsize=28)
+    plt.tight_layout()
+    plt.show()
+
+
+# Plotting calibration curve
+from scipy.optimize import curve_fit
+
+
+def fit_calibration(h):
+    ycols = [x for x in h.columns if "em_ratio" in x]
+    xcol = "ph"
+
+    def objective(x, a, b):
+        return a * x + b
+
+    # fig, axes = plt.subplots(1, len(ycols))
+    # for i, ycol in enumerate(ycols):
+    #     d = h[[xcol, ycol]]
+    #     params, _ = curve_fit(objective, *[d[col].values for col in d.columns])
+    #     sns.lineplot(x=xcol, y=ycol, data=h, alpha=0.5, err_style="bars", ax=axes[i])
+    #     # sns.lineplot(d[xcol], objective(d[xcol].values, *params), ax=axes[i])
+    # plt.show()
+
+    ycol = "em_ratio_mean"
+    d = h[[xcol, *ycols]]
+    tmp = d.groupby("ph").mean()
+    calibs = {ycol: curve_fit(objective, tmp.index, tmp[ycol])[0] for ycol in ycols}
+    # sns.lineplot(x=xcol, y=ycol, data=d, alpha=0.5, err_style="bars")
+    # plt.xlabel("pH")
+    # plt.ylabel("pHluorin emission ratio")
+    # sns.lineplot(d[xcol], objective(d[xcol], *params))
+
+    return calibs
diff --git a/core/postprocessor.py b/core/old/postprocessor.py
similarity index 65%
rename from core/postprocessor.py
rename to core/old/postprocessor.py
index 9a04b66f..7e024d26 100644
--- a/core/postprocessor.py
+++ b/core/old/postprocessor.py
@@ -20,7 +20,7 @@ class PostProcessor:
     :param source: Origin of experiment, if int it is assumed from Omero, if str
         or Path
     '''
-    def __init__(self, parameters=None, source: Union[int,str,Path]=None):
+    def __init__(self, parameters=None, source: Union[int, str, Path] = None):
         # self.params = parameters
         if source is not None:
             if type(source) is int:
@@ -46,14 +46,15 @@ class PostProcessor:
 
         return self.expt.current_position
 
-    def load_expt(self, source:Union[int,str], omero:bool=False) -> None:
+    def load_expt(self, source: Union[int, str], omero: bool = False) -> None:
         if omero:
-            self.expt = Experiment.from_source(self.expt_id, #Experiment ID on OMERO
-                                'upload', #OMERO Username
-                                '***REMOVED***', #OMERO Password
-                                'islay.bio.ed.ac.uk', #OMERO host
-                                port=4064 #This is default
-                                )
+            self.expt = Experiment.from_source(
+                self.expt_id,  #Experiment ID on OMERO
+                'upload',  #OMERO Username
+                '***REMOVED***',  #OMERO Password
+                'islay.bio.ed.ac.uk',  #OMERO host
+                port=4064  #This is default
+            )
         else:
             self.expt = Experiment.from_source(source)
             self.expt_id = self.expt.exptID
@@ -64,6 +65,24 @@ class PostProcessor:
         self.cells = self.cells.from_source(
             self.expt.current_position.annotation)
 
+    def get_pos_mo_bud(self):
+        annot = self.expt._get_position_annotation(
+            self.expt.current_position.name)
+        matob = matObject(annot)
+        m = matob["timelapseTrapsOmero"].get("cellMothers", None)
+        if m is not None:
+            ids = np.nonzero(m.todense())
+            d = {(self.expt.current_position.name, i, int(m[i, j])): []
+                 for i, j in zip(*ids)}
+            for i, j, k in zip(*ids, d.keys()):
+                d[k].append((self.expt.current_position.name, i, j + 1))
+        else:
+            print("Pos {} has no mother matrix".format(
+                self.expt.current_position.name))
+            d = {}
+
+        return d
+
     def get_exp_mo_bud(self):
         d = {}
         for pos in self.expt.positions:
@@ -73,10 +92,11 @@ class PostProcessor:
         self.expt.current_position = self.expt.positions[0]
 
         return d
-    def load_extraction(self, folder=None)-> None:
+
+    def load_extraction(self, folder=None) -> None:
         if folder is None:
             folder = Path(self.expt.name + '/extraction')
-        
+
         self.extraction = {}
         for pos in self.expt.positions:
             try:
diff --git a/core/picker.py b/core/picker.py
new file mode 100644
index 00000000..666eaa47
--- /dev/null
+++ b/core/picker.py
@@ -0,0 +1,98 @@
+# from abc import ABC, abstractmethod
+from typing import Tuple, Union, Array
+
+import numpy as np
+import pandas as pd
+
+from core.cells import CellsHDF
+from postprocessor.core.functions.tracks import max_ntps, max_nonstop_ntps
+
+
+# def BasePicker(ABC):
+#     """
+#     Base class to add mother-bud filtering support
+#     """
+#     def __init__(self, branch=None, lineage=None):
+#         self.lineage = lineage
+
+
+class Picker:
+    """
+    :tracks: pd.DataFrame
+    :cells: Cell object passed to the constructor
+    :condition: Tuple with condition and associated parameter(s), conditions can be
+    "present", "nonstoply_present" or "quantile".
+    Determines the thersholds or fractions of tracks/signals to use.
+    :lineage: str {"mothers", "daughters", "families", "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families.
+    """
+
+    def __init__(
+        self,
+        tracks: pd.DataFrame,
+        cells: CellsHDF,
+        condition: Tuple[str, Union[float, int]] = None,
+        lineage: str = None,
+    ):
+        self._tracks = tracks
+        self._cells = cells
+        self.condition = condition
+        self.lineage = lineage
+
+    @staticmethod
+    def mother_assign_to_mb_matrix(ma: Array[np.array]):
+        # Convert from list of lists to mother_bud sparse matrix
+        ncells = sum([len(t) for t in ma])
+        mb_matrix = np.zeros((ncells, ncells), dtype=bool)
+        c = 0
+        for cells in ma:
+            for d, m in enumerate(cells):
+                if m:
+                    mb_matrix[c + d, c + m] = True
+
+            c += len(cells)
+
+        return mb_matrix
+
+    def pick_by_lineage(self):
+        idx = self._tracks.index
+
+        if self.lineage:
+            ma = self._cells["mother_assign"]
+            mother_bud_mat = self.mother_assign_to_mb_matrix(ma)
+            daughters, mothers = np.where(mother_bud_mat)
+            if self.lineage == "mothers":
+                idx = idx[mothers]
+            elif self.lineage == "daughters":
+                idx = idx[daughters]
+            elif self.lineage == "families" or self.lineage == "orphans":
+                families = list(set(np.append(daughters, mothers)))
+                if self.lineage == "families":
+                    idx = idx[families]
+                else:  # orphans
+                    idx = idx[list(set(range(len(idx))).difference(families))]
+
+        return self._tracks.loc[idx]
+
+    def pick_by_condition(self):
+        idx = switch_case(self.condition[0], self._tracks, self.condition[1])
+        return self._tracks.loc[idx]
+
+
+def as_int(threshold: Union[float, int], ntps: int):
+    if type(threshold) is float:
+        threshold = threshold / ntps
+    return threshold
+
+
+def switch_case(
+    condition: str,
+    tracks: pd.DataFrame,
+    threshold: Union[float, int],
+):
+    threshold_asint = as_int(threshold, tracks.shape[1])
+    case_mgr = {
+        "present": tracks.apply(max_ntps, axis=1) > threshold_asint,
+        "nonstoply_present": tracks.apply(max_nonstop_ntps, axis=1) > threshold_asint,
+        "quantile": [np.quantile(tracks.values[tracks.notna()], threshold)],
+    }
+    return case_mgr[condition]
diff --git a/core/processor.py b/core/processor.py
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/testing.py b/examples/testing.py
index d6b96f61..5a0b6516 100644
--- a/examples/testing.py
+++ b/examples/testing.py
@@ -1,207 +1,221 @@
 from typing import Dict, List, Union
+import re
 
 import numpy as np
 import pandas as pd
 from pandas import Series, DataFrame
+from sklearn.cluster import KMeans
 from matplotlib import pyplot as plt
 import seaborn as sns
 
+import compress_pickle
+
 from postprocessor.core.postprocessor import PostProcessor
-<<<<<<< HEAD
-from postprocessor.core.tracks import non_uniform_savgol
+from postprocessor.core.tracks import get_avg_grs, non_uniform_savgol
+from postprocessor.core.ph import *
+
+sns.set_context("talk", font_scale=1.8)
+sns.set_theme(style="whitegrid")
+
+# pp_c = PostProcessor(source=19920)  # 19916
+# pp_c.load_tiler_cells()
+# # f = "/home/alan/Documents/libs/extraction/extraction/examples/gluStarv_2_0_x2_dual_phl_ura8_00/extraction"
+# # f = "/home/alan/Documents/tmp/pH_calibration_dual_phl__ura8__by4741__01/extraction/"
+# f = "/home/alan/Documents/tmp/pH_calibration_dual_phl__ura8__by4741_Alan4_00/extraction/"
+# pp_c.load_extraction(f)
+
+# calib = process_phs(pp_c)
+# # c = calib.loc[(5.4 < calib["ph"]) & (calib["ph"] < 8)].groupby("ph").mean()
+# sns.violinplot(x="ph", y="em_ratio_mean", data=calib)
+# plt.show()
+
 
-pp = PostProcessor(source=19916)  # 19916
+# bring timelapse data and convert it to pH
+
+pp = PostProcessor(source=19831)  # 19831
 pp.load_tiler_cells()
-# f = '/home/alan/Documents/libs/extraction/extraction/examples/gluStarv_2_0_x2_dual_phl_ura8_00/extraction'
-f = "/home/alan/Documents/libs/extraction/extraction/examples/pH_calibration_dual_phl__ura8__by4741__01"
-pp.load_extraction(
-    "/home/alan/Documents/libs/extraction/extraction/examples/"
-    + pp.expt.name
-    + "/extraction/"
-)
-
-tmp = pp.extraction[pp.expt.positions[0]]
-
-# prepare data
-test = tmp[("GFPFast", np.maximum, "mean")]
-clean = test.loc[test.notna().sum(axis=1) > 30]
-
-window = 9
-degree = 3
-savgol_on_srs = lambda x: Series(
-    non_uniform_savgol(x.dropna().index, x.dropna().values, window, degree),
-    index=x.dropna().index,
-)
-
-smooth = clean.apply(savgol_on_srs, axis=1)
-
-from random import randint
-
-x = randint(0, len(smooth))
-plt.plot(clean.iloc[x], "b")
-plt.plot(smooth.iloc[x], "r")
-plt.show()
-
-
-def growth_rate(
-    data: Series, alg=None, filt={"kind": "savgol", "window": 9, "degree": 3}
-):
-    if alg is None:
-        alg = "standard"
-
-    if filt:  # TODO add support for multiple algorithms
-        data = Series(
-            non_uniform_savgol(
-                data.dropna().index, data.dropna().values, window, degree
-            ),
-            index=data.dropna().index,
+f = "/home/alan/Documents/tmp/gluStarv_2_0_x2_dual_phl_ura8_00/extraction/"
+# f = "/home/alan/Documents/tmp/downUpshift_2_0_2_glu_dual_phluorin__glt1_psa1_ura7__thrice_00/extraction/"
+pp.load_extraction(f)
+
+import compress_pickle
+
+compress_pickle.dump(pp.extraction, "/home/alan/extraction_example.pkl")
+
+if True:  # Load extracted data or pkld version
+    new_dfs = compress_pickle.load("/home/alan/Documents/tmp/new_dfs.gz")
+# Combine dataframes
+else:
+    new_dfs = combine_dfs(get_dfs(pp))
+    # t = [x.index for x in new_dfs.values()]
+    # i = set(t[0])
+    # for j in t:
+    #     i = i.intersection(j)
+    new_dfs = {
+        k: v
+        for k, v in new_dfs.items()
+        if k[2] != "imBackground"
+        and k[2] != "median"
+        and ~(((k[0] == "GFPFast") | (k[0] == "pHluorin405")) and k[2] == "max2p5pc")
+    }
+
+del pp
+compress_pickle.dump(new_dfs, "/home/alan/Documents/tmp/new_dfs.gz")
+
+
+def get_clean_dfs(dfs=None):
+    if dfs is None:
+        clean_dfs = compress_pickle.load("/home/alan/Documents/tmp/clean_dfs.gz")
+    else:
+
+        from postprocessor.core.tracks import clean_tracks, merge_tracks, join_tracks
+
+        # Clean timelapse
+        clean = clean_tracks(new_dfs[("general", None, "area")])
+        tra, joint = merge_tracks(clean)
+        clean_dfs = new_dfs
+        i_ids = set(clean.index).intersection(
+            clean_dfs[("general", None, "area")].index
         )
+        clean_dfs = {k: v.loc[i_ids] for k, v in clean_dfs.items()}
+        clean_dfs = {k: join_tracks(v, joint, drop=True) for k, v in clean_dfs.items()}
+
+        del new_dfs
+        compress_pickle.dump(clean_dfs, "/home/alan/Documents/tmp/clean_dfs.gz")
+
+
+def plot_ph_hmap(clean_dfs):
+    GFPFast = clean_dfs[("GFPFast", np.maximum, "mean")]
+    phluorin = clean_dfs[("pHluorin405", np.maximum, "mean")]
+    ph = GFPFast / phluorin
+    ph = ph.loc[ph.notna().sum(axis=1) > 0.7 * ph.shape[1]]
+    ph = 1 / ph
+
+    fig, ax = plt.subplots()
+    hmap(ph, cbar_kws={"label": r"emission ratio $\propto$ pH"})
+    plt.xlabel("Time (hours)")
+    plt.ylabel("Cells")
+    xticks = plt.xticks(fontsize=15)[0]
+    ax.set(yticklabels=[], xticklabels=[str(round(i * 5 / 60, 1)) for i in xticks])
+    # plt.setp(ax.get_xticklabels(), Rotation=90)
+    plt.show()
 
-    return Series(np.convolve(data, diff_kernel, "same"), index=data.dropna().index)
 
+def fit_calibs(c, h):
+    h = process_phs(pp_c)
+    h["ratio"] = h["GFPFast_bgsub_median"] / h["pHluorin405_bgsub_median"]
+    sns.lineplot(x="ph", y="ratio", data=h, err_style="bars")
+    plt.show()
 
-import numpy as np
+    calibs = fit_calibration(c)
+    for k, params in calibs.items():
+        i, j = ("_".join(k.split("_")[:-1]), k.split("_")[-1])
+        if j == "mean" and "k2" not in k:
+            clean_dfs[k] = objective(clean_dfs[i, np.maximum, j], *params)
+
+
+# max 2.5% / med
+def plot_ratio_vs_max2p5(h):
+    fig, ax = plt.subplots()
+    sns.regplot(
+        x="em_ratio_median",
+        y="mCherry_max2p5pc",
+        data=h,
+        scatter=False,
+        ax=ax,
+        color="teal",
+    )
+    sns.scatterplot(x="em_ratio_median", y="max5_d_med", data=h, hue="ph", ax=ax)
+    plt.xlabel(r"Fluorescence ratio $R \propto (1/pH)$")
+    plt.ylabel("Max 2.5% px / median")
+    plt.show()
 
-diff_kernel = np.array([1, -1])
-gr = clean.apply(growth_rate, axis=1)
-=======
-from postprocessor.core.tracks import non_uniform_savgol, clean_tracks
->>>>>>> 96f513af38080e6ebb6d301159ca973b5d90ce81
 
+em = clean_dfs[("em_ratio", np.maximum, "mean")]
+area = clean_dfs[("general", None, "area")]
 
-def sort_df(df, by="first", rev=True):
-    nona = df.notna()
-    if by == "len":
-        idx = nona.sum(axis=1)
-    elif by == "first":
-        idx = nona.idxmax(axis=1)
-    idx = idx.sort_values().index
 
-    if rev:
-        idx = idx[::-1]
+def get_grs(clean_dfs):
+    area = clean_dfs[("general", None, "area")]
+    area = area.loc[area.notna().sum(axis=1) > 10]
+    return area.apply(growth_rate, axis=1)
 
-    return df.loc[idx]
-<<<<<<< HEAD
 
+def get_agg(dfs, rng):
+    # df dict of DataFrames containing an area/vol one TODO generalise this beyond area
+    # rng tuple of section to use
+    grs = get_grs(dfs)
+    smooth = grs.loc(axis=1)[list(range(rng[0], rng[1]))].dropna(how="all")
 
-test = tmp[("GFPFast", np.maximum, "median")]
-test2 = tmp[("pHluorin405", np.maximum, "median")]
-ph = test / test2
-ph = ph.stack().reset_index(1)
-ph.columns = ["tp", "fl"]
+    aggregate_mean = lambda dfs, rng: pd.concat(
+        {
+            k[0] + "_" + k[2]: dfs[k].loc[smooth.index, rng[0] : rng[1]].mean(axis=1)
+            for k in clean_dfs.keys()
+        },
+        axis=1,
+    )
+    # f_comp_df = comp_df.loc[(comp_df["gr"] > 0) & (area.notna().sum(axis=1) > 50)]
 
+    agg = aggregate_mean(dfs, rng)
+    agg["max2_med"] = agg["mCherry_max2p5pc"] / agg["mCherry_mean"]
 
-=======
->>>>>>> 96f513af38080e6ebb6d301159ca973b5d90ce81
-def m2p5_med(ext, ch, red=np.maximum):
-    m2p5pc = ext[(ch, red, "max2p5pc")]
-    med = ext[(ch, red, "median")]
+    for c in agg.columns:
+        agg[c + "_log"] = np.log(agg[c])
 
-    result = m2p5pc / med
+    agg["gr_mean"] = smooth.loc[set(agg.index).intersection(smooth.index)].mean(axis=1)
+    agg["gr_max"] = smooth.loc[set(agg.index).intersection(smooth.index)].max(axis=1)
 
-    return result
+    return agg
 
 
-def plot_avg(df):
-    df = df.stack().reset_index(1)
-    df.columns = ["tp", "val"]
+def plot_scatter_fit(x, y, data, hue=None, xlabel=None, ylabel=None, ylim=None):
+    fig, ax = plt.subplots()
+    sns.regplot(x=x, y=y, data=data, scatter=False, ax=ax)
+    sns.scatterplot(x=x, y=y, data=data, ax=ax, alpha=0.1, hue=hue)
+    # plt.show()
+    if xlabel is not None:
+        plt.xlabel(xlabel)
+    if ylabel is not None:
+        plt.ylabel(ylabel)
+    if ylim is not None:
+        plt.ylim(ylim)
 
-    sns.relplot(x=df["tp"], y=df["val"], kind="line")
-    plt.show()
+    fig.savefig(
+        "/home/alan/Documents/sync_docs/drafts/third_year_pres/figs/"
+        + str(len(data))
+        + "_"
+        + x
+        + "_vs_"
+        + y
+        + ".png",
+        dpi=200,
+    )
 
-def split_data(df:DataFrame, splits:List[int]):
-    dfs = [df.iloc[:,i:j] for i,j in zip( (0,) + splits,
-                                                splits + (df.shape[1],))]
-    return dfs
 
-def growth_rate(data:Series, alg=None, filt = {'kind':'savgol','window':7, 'degree':3}):
-    if alg is None:
-        alg='standard'
+from extraction.core.argo import Argo, annot_from_dset
 
-    if filt: #TODO add support for multiple algorithms
-        window = filt['window']
-        degree = filt['degree']
-        data = Series(non_uniform_savgol(data.dropna().index, data.dropna().values,
-                                         window, degree), index = data.dropna().index)
 
-    diff_kernel = np.array([1,-1])
+def additional_feats(aggs):
+    aggs["gr_mean_norm"] = aggs["gr_mean"] * 12
+    aggs["log_ratio_r"] = np.log(1 / aggs["em_ratio_mean"])
+    return aggs
 
 
-    return Series(np.convolve(data,diff_kernel ,'same'), index=data.dropna().index)
+def compare_methods_ph_calculation(dfs):
+    GFPFast = dfs[("GFPFast", np.maximum, "mean")]
+    phluorin = dfs[("pHluorin405", np.maximum, "mean")]
+    ph = GFPFast / phluorin
 
-pp = PostProcessor(source=19831)
-pp.load_tiler_cells()
-f = '/home/alan/Documents/sync_docs/libs/postproc/gluStarv_2_0_x2_dual_phl_ura8_00/extraction'
-pp.load_extraction('/home/alan/Documents/sync_docs/libs/postproc/postprocessor/' + pp.expt.name + '/extraction/')
-tmp=pp.extraction['phl_ura8_002']
-
-def _check_bg(data):
-    for k in list(pp.extraction.values())[0].keys():
-        for p in pp.expt.positions:
-            if k not in pp.extraction[p]:
-                print(p, k)
-data = {k:pd.concat([pp.extraction[pos][k] for pos in \
-                     pp.expt.positions[:-3]]) for k in list(pp.extraction.values())[0].keys()}
-
-hmap = lambda df: sns.heatmap(sort_df(df), robust=True);
-# from random import randint
-# x = randint(0, len(smooth))
-# plt.plot(clean.iloc[x], 'b')
-# plt.plot(smooth.iloc[x], 'r')
-# plt.show()
+    sns.scatterplot(
+        dfs["em_ratio", np.maximum, "mean"].values.flatten(),
+        ph.values.flatten(),
+        alpha=0.1,
+    )
+    plt.xlabel("ratio_median")
+    plt.ylabel("median_ratio")
+    plt.title("Comparison of ph calculation")
+    plt.show()
 
 
-# data = tmp
-df= data[('general',None,'area')]
-clean = clean_tracks(df, min_len=160)
-clean = clean.loc[clean.notna().sum(axis=1) > 9]
-gr = clean.apply(growth_rate, axis=1)
-splits = (72,108,180)
-gr_sp = split_data(gr, splits)
-
-idx = gr.index
-
-bg = get_bg(data)
-test = data[('GFPFast', np.maximum, 'median')]
-test2 = data[('pHluorin405', np.maximum, 'median')]
-ph = (test/test2).loc[idx]
-c=pd.concat((ph.mean(1), gr.max(1)), axis=1); c.columns = ['ph', 'gr_max']
-# ph = ph.stack().reset_index(1)
-# ph.columns = ['tp', 'fl']
-
-ph_sp=split_data(gr, splits)
-
-def get_bg(data):
-    bg = {}
-    fl_subkeys = [x for x in data.keys() if x[0] in \
-                  ['GFP', 'GFPFast', 'mCherry', 'pHluorin405'] and x[-1]!='imBackground']
-    for k in fl_subkeys:
-            nk = list(k)
-            bk = tuple(nk[:-1] + ['imBackground'])
-            nk = tuple(nk[:-1] +  [nk[-1] + '_BgSub'])
-            tmp = []
-            for i,v in data[bk].iterrows():
-                if i in data[k].index:
-                    newdf = data[k].loc[i] / v
-                    newdf.index = pd.MultiIndex.from_tuples([(*i, c) for c in \
-                                                          newdf.index])
-                tmp.append(newdf)
-            bg[nk] = pd.concat(tmp)
-
-    return bg
-
-def calc_ph(bg):
-    fl_subkeys = [x for x in bg.keys() if x[0] in \
-                  ['GFP', 'GFPFast', 'pHluorin405']]
-    chs = list(set([x[0] for x in fl_subkeys]))
-    assert len(chs)==2, 'Too many channels'
-    ch1 = [x[1:] for x in fl_subkeys if x[0]==chs[0]]
-    ch2 = [x[1:] for x in fl_subkeys if x[0]==chs[1]]
-    inter = list(set(ch1).intersection(ch2))
-    ph = {}
-    for red_fld in inter:
-        ph[tuple(('ph',) + red_fld)] = bg[tuple((chs[0],) + red_fld)] / bg[tuple((chs[1],) + red_fld)]
-
-# sns.heatmap(sort_df(data[('mCherry', np.maximum, 'max2p5pc_BgSub')] / data[('mCherry', np.maximum, 'median_BgSub')]), robust=True)
-
-# from postprocessor.core.tracks import clean_tracks
+# get the number of changes in a bool list
+nchanges = lambda x: sum([i for i, j in zip(x[:-2], x[1:]) if operator.xor(i, j)])
-- 
GitLab