From a348bc3c04c56fb3ca6662a49330541a51a91aa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Mon, 5 Jul 2021 21:28:55 +0100 Subject: [PATCH] bugfixed and add npairs function Former-commit-id: e14a74604bb44168ab2212ea787a9d4095a9f3c0 --- core/io/base.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/core/io/base.py b/core/io/base.py index cb771ced..181f23b0 100644 --- a/core/io/base.py +++ b/core/io/base.py @@ -49,9 +49,11 @@ class BridgeH5: ] # get indices of consecutive elements return where_consec - def get_npairs(self, nstepsback=2): - tree = b.cell_tree - consecutive = b.get_consecutives(tree, nstepsback=nstepsback) + def get_npairs(self, nstepsback=2, tree=None): + if tree is None: + tree = self.cell_tree + + consecutive = self.get_consecutives(tree, nstepsback=nstepsback) flat_tree = flatten(tree) n_predictions = 0 @@ -59,10 +61,23 @@ class BridgeH5: flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) pairs = [(f, (f[0], f[1] + i)) for f in flat] for p in pairs: - n_predictions += len(flat_tree[p[0]]) * len(flat_tree[p[1]]) + n_predictions += len(flat_tree.get(p[0], [])) * len( + flat_tree.get(p[1], []) + ) return n_predictions + def get_npairs_over_time(self, nstepsback=2): + tree = self.cell_tree + npairs = [] + for t in self._hdf["cell_info"]["processed_timepoints"][()]: + tmp_tree = { + k: {k2: v2 for k2, v2 in v.items() if k2 <= t} for k, v in tree.items() + } + npairs.append(self.get_npairs(tree=tmp_tree)) + + return np.diff(npairs) + def get_info_tree( self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") ): -- GitLab