diff --git a/extraction/core/tracks.py b/extraction/core/tracks.py
index 71002b72bf096332b29507e40dd39f97b951149e..62e7bff4df9b3d00013347870631ac9e813f16d7 100644
--- a/extraction/core/tracks.py
+++ b/extraction/core/tracks.py
@@ -1,6 +1,6 @@
-'''
+"""
 Functions to process, filter and merge tracks.
-'''
+"""
 
 # from collections import Counter
 
@@ -10,69 +10,78 @@ from typing import Union, List
 import numpy as np
 import pandas as pd
 
-from scipy.signal   import savgol_filter
+from scipy.signal import savgol_filter
+
 # from scipy.optimize import linear_sum_assignment
 # from scipy.optimize import curve_fit
 
 from matplotlib import pyplot as plt
 
+
 def load_test_dset():
     # Load development dataset to test functions
-    return pd.DataFrame({('a',1,1):[2, 5, np.nan, 6,8] + [np.nan] * 5,
-                         ('a',1,2):list(range(2,12)),
-                         ('a',1,3):[np.nan] * 8 + [6,7],
-                         ('a',1,4):[np.nan] * 5 + [9,12,10,14,18]},
-                        index=range(1,11)).T
-
-def get_ntps(track:pd.Series) -> int:
+    return pd.DataFrame(
+        {
+            ("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5,
+            ("a", 1, 2): list(range(2, 12)),
+            ("a", 1, 3): [np.nan] * 8 + [6, 7],
+            ("a", 1, 4): [np.nan] * 5 + [9, 12, 10, 14, 18],
+        },
+        index=range(1, 11),
+    ).T
+
+
+def get_ntps(track: pd.Series) -> int:
     # Get number of timepoints
     indices = np.where(track.notna())
     return np.max(indices) - np.min(indices)
 
 
-def get_tracks_ntps(tracks:pd.DataFrame) -> pd.Series:
+def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series:
     return tracks.apply(get_ntps, axis=1)
 
-def get_avg_gr(track:pd.Series) -> int:
-    '''
+
+def get_avg_gr(track: pd.Series) -> int:
+    """
     Get average growth rate for a track.
 
     :param tracks: Series with volume and timepoints as indices
-    '''
+    """
     ntps = get_ntps(track)
     vals = track.dropna().values
-    gr = (vals[-1] - vals[0] )/ ntps
+    gr = (vals[-1] - vals[0]) / ntps
     return gr
 
 
-def get_avg_grs(tracks:pd.DataFrame) -> pd.DataFrame:
-    '''
+def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame:
+    """
     Get average growth rate for a group of tracks
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
         columns are timepoints
-    '''
+    """
     return tracks.apply(get_avg_gr, axis=1)
 
 
-def clean_tracks(tracks, min_len:int=6, min_gr:float=0.5) -> pd.DataFrame:
-    '''
+def clean_tracks(tracks, min_len: int = 6, min_gr: float = 0.5) -> pd.DataFrame:
+    """
     Clean small non-growing tracks and return the reduced dataframe
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
         columns are timepoints
     :param min_len: int number of timepoints cells must have not to be removed
     :param min_gr: float Minimum mean growth rate to assume an outline is growing
-    '''
+    """
     ntps = get_tracks_ntps(tracks)
     grs = get_avg_grs(tracks)
 
     growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
 
-    return (growing_long_tracks)
+    return growing_long_tracks
+
 
 def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame:
-    '''
+    """
     Join tracks that are contiguous and within a volume threshold of each other
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
@@ -84,7 +93,7 @@ def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame:
     :joint_tracks: (m x n) Dataframe where rows are cell tracks and
         columns are timepoints. Merged tracks are still present but filled
         with np.nans.
-    '''
+    """
 
     # calculate tracks that can be merged until no more traps can be merged
     joinable_pairs = get_joinable(tracks, **kwargs)
@@ -96,7 +105,7 @@ def merge_tracks(tracks, drop=False, **kwargs) -> pd.DataFrame:
 
 
 def get_joint_ids(merging_seqs) -> dict:
-    '''
+    """
     Convert a series of merges into a dictionary where
     the key is the cell_id of destination and the value a list
     of the other track ids that were merged into the key
@@ -119,7 +128,7 @@ def get_joint_ids(merging_seqs) -> dict:
 
     output {a:a, b:a, c:a, d:a}
 
-    '''
+    """
     targets, origins = list(zip(*merging_seqs))
     static_tracks = set(targets).difference(origins)
 
@@ -127,13 +136,17 @@ def get_joint_ids(merging_seqs) -> dict:
     for target, origin in merging_seqs:
         joint[origin] = target
 
-    moved_target = [k for k,v in joint.items() \
-                      if joint[v]!=v and v in joint.values()]
+    moved_target = [
+        k for k, v in joint.items() if joint[v] != v and v in joint.values()
+    ]
 
     for orig in moved_target:
         joint[orig] = rec_bottom(joint, orig)
 
-    return {k:v for k,v in joint.items() if k!=v} # remove ids that point to themselves
+    return {
+        k: v for k, v in joint.items() if k != v
+    }  # remove ids that point to themselves
+
 
 def rec_bottom(d, k):
     if d[k] == k:
@@ -141,8 +154,9 @@ def rec_bottom(d, k):
     else:
         return rec_bottom(d, d[k])
 
+
 def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame:
-    '''
+    """
     Join pairs of tracks from later tps towards the start.
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
@@ -155,8 +169,7 @@ def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame:
         with np.nans.
     :param drop: bool indicating whether or not to drop moved rows
 
-    '''
-
+    """
 
     tmp = copy(tracks)
     for target, source in joinable_pairs:
@@ -165,7 +178,8 @@ def join_tracks(tracks, joinable_pairs, drop=False) -> pd.DataFrame:
         if drop:
             tmp = tmp.drop(source)
 
-    return (tmp)
+    return tmp
+
 
 def join_track_pairs(track1, track2):
     tmp = copy(track1)
@@ -173,8 +187,9 @@ def join_track_pairs(track1, track2):
 
     return tmp
 
+
 def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
-    '''
+    """
     Get the pair of track (without repeats) that have a smaller error than the
     tolerance. If there is a track that can be assigned to two or more other
     ones, it chooses the one with a lowest error.
@@ -187,39 +202,48 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
     :param window: int value of window used for savgol_filter
     :param degree: int value of polynomial degree passed to savgol_filter
 
-    '''
+    """
 
-    tracks.index.names = ['pos', 'trap', 'cell'] #TODO remove this once it is integrated in the tracker
+    tracks.index.names = [
+        "pos",
+        "trap",
+        "cell",
+    ]  # TODO remove this once it is integrated in the tracker
     # contig=tracks.groupby(['pos','trap']).apply(tracks2contig)
-    clean = clean_tracks(tracks, min_len=window+1, min_gr = 0.9) # get useful tracks
-    contig=clean.groupby(['pos','trap']).apply(get_contiguous_pairs)
+    clean = clean_tracks(tracks, min_len=window + 1, min_gr=0.9)  # get useful tracks
+    contig = clean.groupby(["pos", "trap"]).apply(get_contiguous_pairs)
     contig = contig.loc[contig.apply(len) > 0]
     # candict = {k:v for d in contig.values for k,v in d.items()}
 
     # smooth all relevant tracks
 
-    linear=set([k for v in contig.values for i in v for j in i for k in j])
-    if smooth: # Apply savgol filter TODO fix nans affecting edge placing
-        savgol_on_srs = lambda x: non_uniform_savgol(x.index, x.values,
-                                            window, degree)
-        smoothed_tracks = clean.loc[linear].apply(savgol_on_srs,1)
+    linear = set([k for v in contig.values for i in v for j in i for k in j])
+    if smooth:  # Apply savgol filter TODO fix nans affecting edge placing
+        savgol_on_srs = lambda x: non_uniform_savgol(x.index, x.values, window, degree)
+        smoothed_tracks = clean.loc[linear].apply(savgol_on_srs, 1)
     else:
         smoothed_tracks = clean.loc[linear].apply(lambda x: np.array(x.values), axis=1)
 
     # fetch edges from ids TODO (IF necessary, here we can compare growth rates)
-    idx_to_edge = lambda preposts: [([get_val(smoothed_tracks.loc[pre],-1) for pre in pres],
-                            [get_val(smoothed_tracks.loc[post],0) for post in posts])
-                            for pres, posts in preposts]
+    idx_to_edge = lambda preposts: [
+        (
+            [get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
+            [get_val(smoothed_tracks.loc[post], 0) for post in posts],
+        )
+        for pres, posts in preposts
+    ]
     edges = contig.apply(idx_to_edge)
 
     closest_pairs = edges.apply(get_vec_closest_pairs, tol=tol)
 
-    #match local with global ids
-    joinable_ids = [localid_to_idx(closest_pairs.loc[i], contig.loc[i])\
-                    for i in closest_pairs.index]
+    # match local with global ids
+    joinable_ids = [
+        localid_to_idx(closest_pairs.loc[i], contig.loc[i]) for i in closest_pairs.index
+    ]
 
     return [pair for pairset in joinable_ids for pair in pairset]
 
+
 get_val = lambda x, n: x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
 
 
@@ -234,17 +258,18 @@ def localid_to_idx(local_ids, contig_trap):
     list of pairs with (experiment-level) ids to be joint
     """
     lin_pairs = []
-    for i,pairs in enumerate(local_ids):
+    for i, pairs in enumerate(local_ids):
         if len(pairs):
-            for left,right in pairs:
-                lin_pairs.append((contig_trap[i][0][left],
-                                  contig_trap[i][1][right]))
+            for left, right in pairs:
+                lin_pairs.append((contig_trap[i][0][left], contig_trap[i][1][right]))
     return lin_pairs
 
-def get_vec_closest_pairs(lst:List, **kwags):
+
+def get_vec_closest_pairs(lst: List, **kwags):
     return [get_closest_pairs(*l, **kwags) for l in lst]
 
-def get_closest_pairs(pre:List[float], post:List[float], tol:Union[float,int]=1):
+
+def get_closest_pairs(pre: List[float], post: List[float], tol: Union[float, int] = 1):
     """Calculate a cost matrix the Hungarian algorithm to pick the best set of
     options
 
@@ -258,22 +283,25 @@ def get_closest_pairs(pre:List[float], post:List[float], tol:Union[float,int]=1)
 
     """
     if len(pre) > len(post):
-        dMetric = np.abs(np.subtract.outer(post,pre))
+        dMetric = np.abs(np.subtract.outer(post, pre))
     else:
-        dMetric = np.abs(np.subtract.outer(pre,post))
+        dMetric = np.abs(np.subtract.outer(pre, post))
     # dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered
     # ids = linear_sum_assignment(dMetric)
-    dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric) # nans will be filtered
+    dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric)  # nans will be filtered
 
     ids = solve_matrix(dMetric)
     if not len(ids[0]):
         return []
 
-    norm = np.array(pre)[ids[len(pre)>len(post)]] if tol<1 else 1 # relative or absolute tol
-    result = dMetric[ids]/norm
-    ids = ids if len(pre)<len(post) else ids[::-1]
+    norm = (
+        np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
+    )  # relative or absolute tol
+    result = dMetric[ids] / norm
+    ids = ids if len(pre) < len(post) else ids[::-1]
+
+    return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
 
-    return [idx for idx,res in zip(zip(*ids), result) if res < tol]
 
 def solve_matrix(dMetric):
     """
@@ -288,70 +316,72 @@ def solve_matrix(dMetric):
     glob_is = []
     glob_js = []
     if (~np.isnan(dMetric)).any():
-        tmp = copy(dMetric )
+        tmp = copy(dMetric)
         std = sorted(tmp[~np.isnan(tmp)])
         while (~np.isnan(std)).any():
             v = std[0]
-            i_s,j_s = np.where(tmp==v)
+            i_s, j_s = np.where(tmp == v)
             i = i_s[0]
             j = j_s[0]
-            tmp[i,:]+= np.nan
-            tmp[:,j]+= np.nan
+            tmp[i, :] += np.nan
+            tmp[:, j] += np.nan
             glob_is.append(i)
             glob_js.append(j)
 
             std = sorted(tmp[~np.isnan(tmp)])
 
+    return (np.array(glob_is), np.array(glob_js))
 
-    return (np.array( glob_is ), np.array( glob_js ))
 
 def plot_joinable(tracks, joinable_pairs, max=64):
     """
     Convenience plotting function for debugging and data vis
     """
 
-    nx=8
-    ny=8
-    _, axes = plt.subplots(nx,ny)
+    nx = 8
+    ny = 8
+    _, axes = plt.subplots(nx, ny)
     for i in range(nx):
         for j in range(ny):
-            if i*ny+j < len(joinable_pairs):
+            if i * ny + j < len(joinable_pairs):
                 ax = axes[i, j]
                 pre, post = joinable_pairs[i * ny + j]
                 pre_srs = tracks.loc[pre].dropna()
                 post_srs = tracks.loc[post].dropna()
-                ax.plot(pre_srs.index, pre_srs.values , 'b')
+                ax.plot(pre_srs.index, pre_srs.values, "b")
                 # try:
                 #     totrange = np.arange(pre_srs.index[0],post_srs.index[-1])
                 #     ax.plot(totrange, interpolate(pre_srs, totrange), 'r-')
                 # except:
                 #     pass
-                ax.plot(post_srs.index, post_srs.values, 'g')
+                ax.plot(post_srs.index, post_srs.values, "g")
 
     plt.show()
 
+
 def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
-    '''
+    """
     Get all pair of contiguous track ids from a tracks dataframe.
 
     :param tracks: (m x n) dataframe where rows are cell tracks and
         columns are timepoints
     :param min_dgr: float minimum difference in growth rate from the interpolation
-    '''
+    """
     # indices = np.where(tracks.notna())
 
-
-    mins, maxes = [tracks.notna().apply(np.where, axis=1).apply(fn)
-                   for fn in (np.min, np.max)]
+    mins, maxes = [
+        tracks.notna().apply(np.where, axis=1).apply(fn) for fn in (np.min, np.max)
+    ]
 
     mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
-    mins_d.index = mins_d.index - 1 # make indices equal
+    mins_d.index = mins_d.index - 1  # make indices equal
     maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist())
 
     common = sorted(set(mins_d.index).intersection(maxes_d.index), reverse=True)
 
     return [(maxes_d[t], mins_d[t]) for t in common]
 
+
 # def fit_track(track: pd.Series, obj=None):
 #     if obj is None:
 #         obj = objective
@@ -416,7 +446,7 @@ def non_uniform_savgol(x, y, window, polynom):
         raise ValueError('"x" and "y" must be of the same size')
 
     if len(x) < window:
-        raise ValueError('The data size must be larger than the window size')
+        raise ValueError("The data size must be larger than the window size")
 
     if type(window) is not int:
         raise TypeError('"window" must be an integer')
@@ -434,9 +464,9 @@ def non_uniform_savgol(x, y, window, polynom):
     polynom += 1
 
     # Initialize variables
-    A = np.empty((window, polynom))     # Matrix
-    tA = np.empty((polynom, window))    # Transposed matrix
-    t = np.empty(window)                # Local x variables
+    A = np.empty((window, polynom))  # Matrix
+    tA = np.empty((polynom, window))  # Transposed matrix
+    t = np.empty(window)  # Local x variables
     y_smoothed = np.full(len(y), np.nan)
 
     # Start smoothing
diff --git a/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc b/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc
deleted file mode 100644
index 6d598779a3e3d03ea6335aa26a3c1b73d713e53e..0000000000000000000000000000000000000000
Binary files a/tests/extraction/__pycache__/log_test.cpython-37-pytest-6.2.5.pyc and /dev/null differ
diff --git a/tests/extraction/test_tracks.py b/tests/extraction/test_tracks.py
index 40976fcca875972cf01d5018b0ce64205ce13c46..c51fd9e2394a376f3b4530848f7124284d1948d4 100644
--- a/tests/extraction/test_tracks.py
+++ b/tests/extraction/test_tracks.py
@@ -1,6 +1,6 @@
-
 from extraction.core.tracks import load_test_dset, clean_tracks, merge_tracks
 
+
 def test_clean_tracks():
     tracks = load_test_dset()
     clean = clean_tracks(tracks, min_len=3)
@@ -8,24 +8,26 @@ def test_clean_tracks():
     assert len(clean) < len(tracks)
     pass
 
+
 def test_merge_tracks_drop():
     tracks = load_test_dset()
 
-    joint_tracks,joint_ids = merge_tracks(tracks,window=3, degree=2, drop=True)
+    joint_tracks = merge_tracks(tracks, window=3, degree=2, drop=True, tol=1)
 
-    assert len(joint_tracks)<len(tracks), 'Error when merging'
-
-    assert len(joint_ids), 'No joint ids found'
+    assert len(joint_tracks) < len(tracks), "Error when merging"
 
     pass
 
+
 def test_merge_tracks_nodrop():
     tracks = load_test_dset()
 
-    joint_tracks,joint_ids = merge_tracks(tracks,window=3, degree=2, drop=False)
+    joint_tracks, joint_ids = merge_tracks(
+        tracks, window=3, degree=2, drop=False, tol=1
+    )
 
-    assert len(joint_tracks)==len(tracks), 'Error when merging'
+    assert len(joint_tracks) == len(tracks), "Error when merging"
 
-    assert len(joint_ids), 'No joint ids found'
+    assert len(joint_ids), "No joint ids found"
 
     pass