diff --git a/core/io/base.py b/core/io/base.py index b2314d649681dabcc2c73a917f8f412108ab58c8..cb771cedd8eb2cb0af575e705317c3cf9e90221e 100644 --- a/core/io/base.py +++ b/core/io/base.py @@ -1,6 +1,8 @@ from typing import Union -from itertools import groupby +import collections +from itertools import groupby, chain, product +import numpy as np import h5py @@ -29,6 +31,38 @@ class BridgeH5: def cell_tree(self): return self.get_info_tree() + def get_n_cellpairs(self, nstepsback=2): + cell_tree = self.cell_tree + # get pair of consecutive trap-time points + pass + + @staticmethod + def get_consecutives(tree, nstepsback): + # Receives a sorted tree and returns the keys of consecutive elements + vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level + where_consec = [ + { + k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] + for k, v in vals.items() + } + for n in range(nstepsback) + ] # get indices of consecutive elements + return where_consec + + def get_npairs(self, nstepsback=2): + tree = b.cell_tree + consecutive = b.get_consecutives(tree, nstepsback=nstepsback) + flat_tree = flatten(tree) + + n_predictions = 0 + for i, d in enumerate(consecutive, 1): + flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) + pairs = [(f, (f[0], f[1] + i)) for f in flat] + for p in pairs: + n_predictions += len(flat_tree[p[0]]) * len(flat_tree[p[1]]) + + return n_predictions + def get_info_tree( self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") ): @@ -59,11 +93,11 @@ class BridgeH5: def groupsort(iterable: Union[tuple, list]): - # Groups a list or tuple by the first element and returns - # a dictionary that follows {v[0]:sorted(v[1:]) for v in iterable}. - # Sorted by the first element in the remaining values + # Sorts iterable and returns a dictionary where the values are grouped by the first element. - return {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])} + iterable = sorted(iterable, key=lambda x: x[0]) + grouped = {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])} + return grouped def recursive_groupsort(iterable): @@ -72,3 +106,15 @@ def recursive_groupsort(iterable): return {k: recursive_groupsort(v) for k, v in groupsort(iterable).items()} else: # Only two elements in list return [x[0] for x in iterable] + + +def flatten(d, parent_key="", sep="_"): + """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615""" + items = [] + for k, v in d.items(): + new_key = parent_key + (k,) if parent_key else (k,) + if isinstance(v, collections.MutableMapping): + items.extend(flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items)