From a348bc3c04c56fb3ca6662a49330541a51a91aa1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Mon, 5 Jul 2021 21:28:55 +0100
Subject: [PATCH] bugfixed and add npairs function

Former-commit-id: e14a74604bb44168ab2212ea787a9d4095a9f3c0
---
 core/io/base.py | 23 +++++++++++++++++++----
 1 file changed, 19 insertions(+), 4 deletions(-)

diff --git a/core/io/base.py b/core/io/base.py
index cb771ced..181f23b0 100644
--- a/core/io/base.py
+++ b/core/io/base.py
@@ -49,9 +49,11 @@ class BridgeH5:
         ]  # get indices of consecutive elements
         return where_consec
 
-    def get_npairs(self, nstepsback=2):
-        tree = b.cell_tree
-        consecutive = b.get_consecutives(tree, nstepsback=nstepsback)
+    def get_npairs(self, nstepsback=2, tree=None):
+        if tree is None:
+            tree = self.cell_tree
+
+        consecutive = self.get_consecutives(tree, nstepsback=nstepsback)
         flat_tree = flatten(tree)
 
         n_predictions = 0
@@ -59,10 +61,23 @@ class BridgeH5:
             flat = list(chain(*[product([k], list(v)) for k, v in d.items()]))
             pairs = [(f, (f[0], f[1] + i)) for f in flat]
             for p in pairs:
-                n_predictions += len(flat_tree[p[0]]) * len(flat_tree[p[1]])
+                n_predictions += len(flat_tree.get(p[0], [])) * len(
+                    flat_tree.get(p[1], [])
+                )
 
         return n_predictions
 
+    def get_npairs_over_time(self, nstepsback=2):
+        tree = self.cell_tree
+        npairs = []
+        for t in self._hdf["cell_info"]["processed_timepoints"][()]:
+            tmp_tree = {
+                k: {k2: v2 for k2, v2 in v.items() if k2 <= t} for k, v in tree.items()
+            }
+            npairs.append(self.get_npairs(tree=tmp_tree))
+
+        return np.diff(npairs)
+
     def get_info_tree(
         self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label")
     ):
-- 
GitLab