Skip to content
Snippets Groups Projects
Commit a348bc3c authored by Alán Muñoz's avatar Alán Muñoz
Browse files

bugfixed and add npairs function

Former-commit-id: e14a74604bb44168ab2212ea787a9d4095a9f3c0
parent dca52155
No related branches found
No related tags found
No related merge requests found
......@@ -49,9 +49,11 @@ class BridgeH5:
] # get indices of consecutive elements
return where_consec
def get_npairs(self, nstepsback=2):
tree = b.cell_tree
consecutive = b.get_consecutives(tree, nstepsback=nstepsback)
def get_npairs(self, nstepsback=2, tree=None):
if tree is None:
tree = self.cell_tree
consecutive = self.get_consecutives(tree, nstepsback=nstepsback)
flat_tree = flatten(tree)
n_predictions = 0
......@@ -59,10 +61,23 @@ class BridgeH5:
flat = list(chain(*[product([k], list(v)) for k, v in d.items()]))
pairs = [(f, (f[0], f[1] + i)) for f in flat]
for p in pairs:
n_predictions += len(flat_tree[p[0]]) * len(flat_tree[p[1]])
n_predictions += len(flat_tree.get(p[0], [])) * len(
flat_tree.get(p[1], [])
)
return n_predictions
def get_npairs_over_time(self, nstepsback=2):
tree = self.cell_tree
npairs = []
for t in self._hdf["cell_info"]["processed_timepoints"][()]:
tmp_tree = {
k: {k2: v2 for k2, v2 in v.items() if k2 <= t} for k, v in tree.items()
}
npairs.append(self.get_npairs(tree=tmp_tree))
return np.diff(npairs)
def get_info_tree(
self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label")
):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment