diff --git a/core/io/base.py b/core/io/base.py
index b2314d649681dabcc2c73a917f8f412108ab58c8..cb771cedd8eb2cb0af575e705317c3cf9e90221e 100644
--- a/core/io/base.py
+++ b/core/io/base.py
@@ -1,6 +1,8 @@
 from typing import Union
-from itertools import groupby
+import collections
+from itertools import groupby, chain, product
 
+import numpy as np
 import h5py
 
 
@@ -29,6 +31,38 @@ class BridgeH5:
     def cell_tree(self):
         return self.get_info_tree()
 
+    def get_n_cellpairs(self, nstepsback=2):
+        cell_tree = self.cell_tree
+        # get pair of consecutive trap-time points
+        pass
+
+    @staticmethod
+    def get_consecutives(tree, nstepsback):
+        # Receives a sorted tree and returns the keys of consecutive elements
+        vals = {k: np.array(list(v)) for k, v in tree.items()}  # get tp level
+        where_consec = [
+            {
+                k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0]
+                for k, v in vals.items()
+            }
+            for n in range(nstepsback)
+        ]  # get indices of consecutive elements
+        return where_consec
+
+    def get_npairs(self, nstepsback=2):
+        tree = b.cell_tree
+        consecutive = b.get_consecutives(tree, nstepsback=nstepsback)
+        flat_tree = flatten(tree)
+
+        n_predictions = 0
+        for i, d in enumerate(consecutive, 1):
+            flat = list(chain(*[product([k], list(v)) for k, v in d.items()]))
+            pairs = [(f, (f[0], f[1] + i)) for f in flat]
+            for p in pairs:
+                n_predictions += len(flat_tree[p[0]]) * len(flat_tree[p[1]])
+
+        return n_predictions
+
     def get_info_tree(
         self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label")
     ):
@@ -59,11 +93,11 @@ class BridgeH5:
 
 
 def groupsort(iterable: Union[tuple, list]):
-    # Groups a list or tuple by the first element and returns
-    # a dictionary that follows {v[0]:sorted(v[1:]) for v in iterable}.
-    # Sorted by the first element in the remaining values
+    # Sorts iterable and returns a dictionary where the values are grouped by the first element.
 
-    return {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])}
+    iterable = sorted(iterable, key=lambda x: x[0])
+    grouped = {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])}
+    return grouped
 
 
 def recursive_groupsort(iterable):
@@ -72,3 +106,15 @@ def recursive_groupsort(iterable):
         return {k: recursive_groupsort(v) for k, v in groupsort(iterable).items()}
     else:  # Only two elements in list
         return [x[0] for x in iterable]
+
+
+def flatten(d, parent_key="", sep="_"):
+    """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615"""
+    items = []
+    for k, v in d.items():
+        new_key = parent_key + (k,) if parent_key else (k,)
+        if isinstance(v, collections.MutableMapping):
+            items.extend(flatten(v, new_key, sep=sep).items())
+        else:
+            items.append((new_key, v))
+    return dict(items)