diff --git a/src/postprocessor/core/functions/tracks.py b/src/postprocessor/core/functions/tracks.py
index d49a4cfbbfea06f1554aaae290f54bf358478484..8deecb39e73f63d02c517bdc95b93804a1a183bc 100644
--- a/src/postprocessor/core/functions/tracks.py
+++ b/src/postprocessor/core/functions/tracks.py
@@ -1,8 +1,13 @@
 """
-Functions to process, filter and merge tracks.
-"""
+Functions to process, filter, and merge tracks.
+
+We call two tracks contiguous if they are adjacent in time: the
+maximal time point of one is one time point less than the
+minimal time point of the other.
 
-# from collections import Counter
+A right track can have multiple potential left tracks. We must
+pick the best.
+"""
 
 import typing as t
 from copy import copy
@@ -17,6 +22,76 @@ from utils_find_1st import cmp_larger, find_1st
 from postprocessor.core.processes.savgol import non_uniform_savgol
 
 
+def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
+    """
+    Get the pair of track (without repeats) that have a smaller error than the
+    tolerance. If there is a track that can be assigned to two or more other
+    ones, choose the one with lowest error.
+
+    Parameters
+    ----------
+    tracks: (m x n) Signal
+        A Signal, usually area, dataframe where rows are cell tracks and
+        columns are time points.
+    tol: float or int
+        threshold of average (prediction error/std) necessary
+        to consider two tracks the same. If float is fraction of first track,
+        if int it is absolute units.
+    window: int
+        value of window used for savgol_filter
+    degree: int
+        value of polynomial degree passed to savgol_filter
+    """
+    # only consider time series with more than two non-NaN data points
+    tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
+    # get contiguous tracks
+    if smooth:
+        # specialise to tracks with growing cells and of long duration
+        clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9)
+        contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs)
+    else:
+        contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
+    # remove traps with no contiguous tracks
+    contigs = contigs.loc[contigs.apply(len) > 0]
+    # flatten to (trap, cell_id) pairs
+    flat = set([k for v in contigs.values for i in v for j in i for k in j])
+    # make a data frame of contiguous tracks with the tracks as arrays
+    if smooth:
+        smoothed_tracks = clean.loc[flat].apply(
+            lambda x: non_uniform_savgol(x.index, x.values, window, degree),
+            axis=1,
+        )
+    else:
+        smoothed_tracks = tracks.loc[flat].apply(
+            lambda x: np.array(x.values), axis=1
+        )
+    # get the Signal values for neighbouring end points of contiguous tracks
+    actual_edges = contigs.apply(lambda x: get_edge_values(x, smoothed_tracks))
+    # get the predicted values
+    predicted_edges = contigs.apply(
+        lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
+    )
+    # Prediction of pre and mean of post
+    prediction_costs = predicted_edges.apply(get_dMetric_wrap, tol=tol)
+    solutions = [
+        solve_matrices_wrap(cost, edges, tol=tol)
+        for (trap_id, cost), edges in zip(
+            prediction_costs.items(), actual_edges
+        )
+    ]
+    breakpoint()
+    closest_pairs = pd.Series(
+        solutions,
+        index=edges_dMetric_pred.index,
+    )
+    # match local with global ids
+    joinable_ids = [
+        localid_to_idx(closest_pairs.loc[i], contigs.loc[i])
+        for i in closest_pairs.index
+    ]
+    return [pair for pairset in joinable_ids for pair in pairset]
+
+
 def load_test_dset():
     """Load development dataset to test functions."""
     return pd.DataFrame(
@@ -45,46 +120,21 @@ def max_nonstop_ntps(track: pd.Series) -> int:
     return max(consecutive_nonas_grouped)
 
 
-def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series:
-    return tracks.apply(max_ntps, axis=1)
-
-
-def get_avg_gr(track: pd.Series) -> int:
-    """
-    Get average growth rate for a track.
-
-    :param tracks: Series with volume and timepoints as indices
-    """
+def get_avg_gr(track: pd.Series) -> float:
+    """Get average growth rate for a track."""
     ntps = max_ntps(track)
     vals = track.dropna().values
     gr = (vals[-1] - vals[0]) / ntps
     return gr
 
 
-def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame:
-    """
-    Get average growth rate for a group of tracks
-
-    :param tracks: (m x n) dataframe where rows are cell tracks and
-        columns are timepoints
-    """
-    return tracks.apply(get_avg_gr, axis=1)
-
-
 def clean_tracks(
-    tracks, min_len: int = 15, min_gr: float = 1.0
+    tracks, min_duration: int = 15, min_gr: float = 1.0
 ) -> pd.DataFrame:
-    """
-    Clean small non-growing tracks and return the reduced dataframe
-
-    :param tracks: (m x n) dataframe where rows are cell tracks and
-        columns are timepoints
-    :param min_len: int number of timepoints cells must have not to be removed
-    :param min_gr: float Minimum mean growth rate to assume an outline is growing
-    """
-    ntps = get_tracks_ntps(tracks)
-    grs = get_avg_grs(tracks)
-    growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
+    """Remove small non-growing tracks and return the reduced data frame."""
+    ntps = tracks.apply(max_ntps, axis=1)
+    grs = tracks.apply(get_avg_gr, axis=1)
+    growing_long_tracks = tracks.loc[(ntps >= min_duration) & (grs > min_gr)]
     return growing_long_tracks
 
 
@@ -191,125 +241,78 @@ def join_track_pair(target, source):
     return tgt_copy
 
 
-def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
-    """
-    Get the pair of track (without repeats) that have a smaller error than the
-    tolerance. If there is a track that can be assigned to two or more other
-    ones, choose the one with lowest error.
+def get_edge_values(contigs_ids, smoothed_tracks):
+    """Get Signal values for adjacent end points for each contiguous track."""
+    values = [
+        (
+            [get_value(smoothed_tracks.loc[pre_id], -1) for pre_id in pre_ids],
+            [
+                get_value(smoothed_tracks.loc[post_id], 0)
+                for post_id in post_ids
+            ],
+        )
+        for pre_ids, post_ids in contigs_ids
+    ]
+    return values
 
-    Parameters
-    ----------
-    tracks: (m x n) Signal
-        A Signal, usually area, dataframe where rows are cell tracks and
-        columns are time points.
-    tol: float or int
-        threshold of average (prediction error/std) necessary
-        to consider two tracks the same. If float is fraction of first track,
-        if int it is absolute units.
-    window: int
-        value of window used for savgol_filter
-    degree: int
-        value of polynomial degree passed to savgol_filter
-    """
-    # only consider time series with more than two non-NaN data points
-    tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
 
-    # smooth all relevant tracks
-    if smooth:
-        # Apply savgol filter TODO fix nans affecting edge placing
-        clean = clean_tracks(
-            tracks, min_len=window + 1, min_gr=0.9
-        )  # get useful tracks
-
-        def savgol_on_srs(x):
-            return non_uniform_savgol(x.index, x.values, window, degree)
-
-        contig = clean.groupby(["trap"]).apply(get_contiguous_pairs)
-        contig = contig.loc[contig.apply(len) > 0]
-        flat = set([k for v in contig.values for i in v for j in i for k in j])
-        smoothed_tracks = clean.loc[flat].apply(savgol_on_srs, 1)
-    else:
-        contig = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
-        contig = contig.loc[contig.apply(len) > 0]
-        flat = set([k for v in contig.values for i in v for j in i for k in j])
-        smoothed_tracks = tracks.loc[flat].apply(
-            lambda x: np.array(x.values), axis=1
-        )
+def get_predicted_edge_values(contigs_ids, smoothed_tracks, window):
+    """
+    Find neighbouring values of two contiguous tracks.
 
-    # fetch edges from ids TODO (IF necessary, here we can compare growth rates)
-    def idx_to_edge(preposts):
-        return [
-            (
-                [get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
-                [get_val(smoothed_tracks.loc[post], 0) for post in posts],
+    Predict the next value for the leftmost track using window values
+    and find the mean of the initial window values of the rightmost
+    track.
+    """
+    result = []
+    for pre_ids, post_ids in contigs_ids:
+        pre_res = []
+        # left contiguous tracks
+        for pre_id in pre_ids:
+            # get last window values of a track
+            y = get_values_i(smoothed_tracks.loc[pre_id], -window)
+            # predict next value
+            pre_res.append(
+                np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
             )
-            for pres, posts in preposts
+        # right contiguous tracks
+        pos_res = [
+            # mean value of initial window values of a track
+            get_mean_value_i(smoothed_tracks.loc[post_id], window)
+            for post_id in post_ids
         ]
+        result.append([pre_res, pos_res])
+    return result
 
-    def idx_to_pred(preposts):
-        result = []
-        for pres, posts in preposts:
-            pre_res = []
-            for pre in pres:
-                y = get_last_i(smoothed_tracks.loc[pre], -window)
-                pre_res.append(
-                    np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
-                )
-            pos_res = [
-                get_means(smoothed_tracks.loc[post], window) for post in posts
-            ]
-            result.append([pre_res, pos_res])
-
-        return result
-
-    edges = contig.apply(idx_to_edge)  # Raw edges
-    # edges_mean = contig.apply(idx_to_means)  # Mean of both
-    pre_pred = contig.apply(idx_to_pred)  # Prediction of pre and mean of post
-
-    # edges_dMetric = edges.apply(get_dMetric_wrap, tol=tol)
-    # edges_dMetric_mean = edges_mean.apply(get_dMetric_wrap, tol=tol)
-    edges_dMetric_pred = pre_pred.apply(get_dMetric_wrap, tol=tol)
-    solutions = []
-    # for (i, dMetrics), edgeset in zip(combined_dMetric.items(), edges):
-    for (i, dMetrics), edgeset in zip(edges_dMetric_pred.items(), edges):
-        solutions.append(solve_matrices_wrap(dMetrics, edgeset, tol=tol))
-
-    closest_pairs = pd.Series(
-        solutions,
-        index=edges_dMetric_pred.index,
-    )
-
-    # match local with global ids
-    joinable_ids = [
-        localid_to_idx(closest_pairs.loc[i], contig.loc[i])
-        for i in closest_pairs.index
-    ]
-
-    return [pair for pairset in joinable_ids for pair in pairset]
 
+def get_value(x, n):
+    """Get value from an array ignoring NaN."""
+    val = x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
+    return val
 
-def get_val(x, n):
-    return x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
 
-
-def get_means(x, i):
+def get_mean_value_i(x, i):
+    """Get track's mean Signal value from values either from or up to an index."""
     if not len(x[~np.isnan(x)]):
         return np.nan
-    if i > 0:
-        v = x[~np.isnan(x)][:i]
     else:
-        v = x[~np.isnan(x)][i:]
-    return np.nanmean(v)
+        if i > 0:
+            v = x[~np.isnan(x)][:i]
+        else:
+            v = x[~np.isnan(x)][i:]
+        return np.nanmean(v)
 
 
-def get_last_i(x, i):
+def get_values_i(x, i):
+    """Get track's Signal values either from or up to an index."""
     if not len(x[~np.isnan(x)]):
         return np.nan
-    if i > 0:
-        v = x[~np.isnan(x)][:i]
     else:
-        v = x[~np.isnan(x)][i:]
-    return v
+        if i > 0:
+            v = x[~np.isnan(x)][:i]
+        else:
+            v = x[~np.isnan(x)][i:]
+        return v
 
 
 def localid_to_idx(local_ids, contig_trap):
@@ -338,57 +341,55 @@ def get_vec_closest_pairs(lst: List, **kwargs):
 
 
 def get_dMetric_wrap(lst: List, **kwargs):
+    """Calculate dMetric on a list."""
     return [get_dMetric(*sublist, **kwargs) for sublist in lst]
 
 
 def solve_matrices_wrap(dMetric: List, edges: List, **kwargs):
+    """Calculate solve_matrices on a list."""
     return [
         solve_matrices(mat, edgeset, **kwargs)
         for mat, edgeset in zip(dMetric, edges)
     ]
 
 
-def get_dMetric(
-    pre: List[float], post: List[float], tol: Union[float, int] = 1
-):
-    """Calculate a cost matrix
-
-    input
-    :param pre: list of floats with edges on left
-    :param post: list of floats with edges on right
-    :param tol: int or float if int metrics of tolerance, if float fraction
-
-    returns
-    :: list of indices corresponding to the best solutions for matrices
+def get_dMetric(pre: List[float], post: List[float], tol):
+    """
+    Calculate a cost matrix based on the difference between two Signal
+    values.
 
+    Parameters
+    ----------
+    pre: list of floats
+        Values of the Signal for left contiguous tracks.
+    post: list of floats
+        Values of the Signal for right contiguous tracks.
     """
     if len(pre) > len(post):
         dMetric = np.abs(np.subtract.outer(post, pre))
     else:
         dMetric = np.abs(np.subtract.outer(pre, post))
-    dMetric[np.isnan(dMetric)] = (
-        tol + 1 + np.nanmax(dMetric)
-    )  # nans will be filtered
+    # replace NaNs with maximal cost values
+    dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric)
     return dMetric
 
 
-def solve_matrices(
-    dMetric: np.ndarray, prepost: List, tol: Union[float, int] = 1
-):
+def solve_matrices(cost: np.ndarray, edges: List, tol: Union[float, int] = 1):
     """
     Solve the distance matrices obtained in get_dMetric and/or merged from
     independent dMetric matrices.
     """
-    ids = solve_matrix(dMetric)
-    if not len(ids[0]):
+    ids = solve_matrix(cost)
+    if len(ids[0]):
+        pre, post = edges
+        norm = (
+            np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
+        )  # relative or absolute tol
+        result = dMetric[ids] / norm
+        ids = ids if len(pre) < len(post) else ids[::-1]
+        return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
+    else:
         return []
-    pre, post = prepost
-    norm = (
-        np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
-    )  # relative or absolute tol
-    result = dMetric[ids] / norm
-    ids = ids if len(pre) < len(post) else ids[::-1]
-    return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
 
 
 def get_closest_pairs(
@@ -412,37 +413,31 @@ def get_closest_pairs(
 
 
 def solve_matrix(dMetric):
-    """
-    Solve cost matrix focusing on getting the smallest cost at each iteration.
-
-    input
-    :param dMetric: np.array cost matrix
-
-    returns
-    tuple of np.arrays indicating picks with lowest individual value
-    """
+    """Arrange indices to the cost matrix in order of increasing cost."""
     glob_is = []
     glob_js = []
     if (~np.isnan(dMetric)).any():
-        tmp = copy(dMetric)
-        std = sorted(tmp[~np.isnan(tmp)])
-        while (~np.isnan(std)).any():
-            v = std[0]
-            i_s, j_s = np.where(tmp == v)
+        lMetric = copy(dMetric)
+        sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
+        while (~np.isnan(sortedMetric)).any():
+            # indices of point with minimal cost
+            i_s, j_s = np.where(lMetric == sortedMetric[0])
             i = i_s[0]
             j = j_s[0]
-            tmp[i, :] += np.nan
-            tmp[:, j] += np.nan
+            # store this point
             glob_is.append(i)
             glob_js.append(j)
-            std = sorted(tmp[~np.isnan(tmp)])
-    return (np.array(glob_is), np.array(glob_js))
+            # remove from lMetric
+            lMetric[i, :] += np.nan
+            lMetric[:, j] += np.nan
+            sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
+    indices = (np.array(glob_is), np.array(glob_js))
+    breakpoint()
+    return indices
 
 
 def plot_joinable(tracks, joinable_pairs):
-    """
-    Convenience plotting function for debugging and data vis
-    """
+    """Convenience plotting function for debugging."""
     nx = 8
     ny = 8
     _, axes = plt.subplots(nx, ny)
@@ -467,53 +462,31 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
     """
     Get all pair of contiguous track ids from a tracks data frame.
 
+    For two tracks to be contiguous, they must be exactly adjacent.
+
     Parameters
     ----------
     tracks:  pd.Dataframe
-        A dataframe for one trap where rows are cell tracks and columns
-        are time points.
+        A dataframe where rows are cell tracks and columns are time
+        points.
     """
-    # time points bounding a tracklet of non-NaN values
+    # TODO add support for skipping time points
+    # find time points bounding tracks of non-NaN values
     mins, maxs = [
         tracks.notna().apply(np.where, axis=1).apply(fn)
         for fn in (np.min, np.max)
     ]
+    # mins.name = "min_tpt"
+    # maxs.name = "max_tpt"
+    # df = pd.merge(mins, maxs, right_index=True, left_index=True)
+    # df["duration"] = df.max_tpt - df.min_tpt
+    #
+    # flip so that time points become the index
     mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
-    mins_d.index = mins_d.index - 1  # make indices equal
-    # TODO add support for skipping time points
     maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist())
+    # reduce minimal time point to make a right track overlap with a left track
+    mins_d.index = mins_d.index - 1
+    # find common end points
     common = sorted(set(mins_d.index).intersection(maxs_d.index), reverse=True)
-    return [(maxs_d[t], mins_d[t]) for t in common]
-
-
-# def fit_track(track: pd.Series, obj=None):
-#     if obj is None:
-#         obj = objective
-
-#     x = track.dropna().index
-#     y = track.dropna().values
-#     popt, _ = curve_fit(obj, x, y)
-
-#     return popt
-
-# def interpolate(track, xs) -> list:
-#     '''
-#     Interpolate next timepoint from a track
-
-#     :param track: pd.Series of volume growth over a time period
-#     :param t: int timepoint to interpolate
-#     '''
-#     popt = fit_track(track)
-#     # perr = np.sqrt(np.diag(pcov))
-#     return objective(np.array(xs), *popt)
-
-
-# def objective(x,a,b,c,d) -> float:
-#     # return (a)/(1+b*np.exp(c*x))+d
-#     return (((x+d)*a)/((x+d)+b))+c
-
-# def cand_pairs_to_dict(candidates):
-#     d={x:[] for x,_ in candidates}
-#     for x,y in candidates:
-#         d[x].append(y)
-#     return d
+    contigs = [(maxs_d[t], mins_d[t]) for t in common]
+    return contigs