From dca52155527ab2b15a535ba75b3137db944eea2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Mon, 5 Jul 2021 18:16:46 +0100 Subject: [PATCH] bugfix and add benchmarking functions Former-commit-id: e4d098d75fa4a9f91744c2f984ab9d2bee897e64 --- core/io/base.py | 56 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/core/io/base.py b/core/io/base.py index b2314d64..cb771ced 100644 --- a/core/io/base.py +++ b/core/io/base.py @@ -1,6 +1,8 @@ from typing import Union -from itertools import groupby +import collections +from itertools import groupby, chain, product +import numpy as np import h5py @@ -29,6 +31,38 @@ class BridgeH5: def cell_tree(self): return self.get_info_tree() + def get_n_cellpairs(self, nstepsback=2): + cell_tree = self.cell_tree + # get pair of consecutive trap-time points + pass + + @staticmethod + def get_consecutives(tree, nstepsback): + # Receives a sorted tree and returns the keys of consecutive elements + vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level + where_consec = [ + { + k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] + for k, v in vals.items() + } + for n in range(nstepsback) + ] # get indices of consecutive elements + return where_consec + + def get_npairs(self, nstepsback=2): + tree = b.cell_tree + consecutive = b.get_consecutives(tree, nstepsback=nstepsback) + flat_tree = flatten(tree) + + n_predictions = 0 + for i, d in enumerate(consecutive, 1): + flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) + pairs = [(f, (f[0], f[1] + i)) for f in flat] + for p in pairs: + n_predictions += len(flat_tree[p[0]]) * len(flat_tree[p[1]]) + + return n_predictions + def get_info_tree( self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") ): @@ -59,11 +93,11 @@ class BridgeH5: def groupsort(iterable: Union[tuple, list]): - # Groups a list or tuple by the first element and returns - # a dictionary that follows {v[0]:sorted(v[1:]) for v in iterable}. - # Sorted by the first element in the remaining values + # Sorts iterable and returns a dictionary where the values are grouped by the first element. - return {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])} + iterable = sorted(iterable, key=lambda x: x[0]) + grouped = {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])} + return grouped def recursive_groupsort(iterable): @@ -72,3 +106,15 @@ def recursive_groupsort(iterable): return {k: recursive_groupsort(v) for k, v in groupsort(iterable).items()} else: # Only two elements in list return [x[0] for x in iterable] + + +def flatten(d, parent_key="", sep="_"): + """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615""" + items = [] + for k, v in d.items(): + new_key = parent_key + (k,) if parent_key else (k,) + if isinstance(v, collections.MutableMapping): + items.extend(flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) -- GitLab