Skip to content
Snippets Groups Projects
Commit dca52155 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

bugfix and add benchmarking functions

Former-commit-id: e4d098d75fa4a9f91744c2f984ab9d2bee897e64
parent 942f6a60
No related branches found
No related tags found
No related merge requests found
from typing import Union
from itertools import groupby
import collections
from itertools import groupby, chain, product
import numpy as np
import h5py
......@@ -29,6 +31,38 @@ class BridgeH5:
def cell_tree(self):
return self.get_info_tree()
def get_n_cellpairs(self, nstepsback=2):
cell_tree = self.cell_tree
# get pair of consecutive trap-time points
pass
@staticmethod
def get_consecutives(tree, nstepsback):
# Receives a sorted tree and returns the keys of consecutive elements
vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level
where_consec = [
{
k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0]
for k, v in vals.items()
}
for n in range(nstepsback)
] # get indices of consecutive elements
return where_consec
def get_npairs(self, nstepsback=2):
tree = b.cell_tree
consecutive = b.get_consecutives(tree, nstepsback=nstepsback)
flat_tree = flatten(tree)
n_predictions = 0
for i, d in enumerate(consecutive, 1):
flat = list(chain(*[product([k], list(v)) for k, v in d.items()]))
pairs = [(f, (f[0], f[1] + i)) for f in flat]
for p in pairs:
n_predictions += len(flat_tree[p[0]]) * len(flat_tree[p[1]])
return n_predictions
def get_info_tree(
self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label")
):
......@@ -59,11 +93,11 @@ class BridgeH5:
def groupsort(iterable: Union[tuple, list]):
# Groups a list or tuple by the first element and returns
# a dictionary that follows {v[0]:sorted(v[1:]) for v in iterable}.
# Sorted by the first element in the remaining values
# Sorts iterable and returns a dictionary where the values are grouped by the first element.
return {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])}
iterable = sorted(iterable, key=lambda x: x[0])
grouped = {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])}
return grouped
def recursive_groupsort(iterable):
......@@ -72,3 +106,15 @@ def recursive_groupsort(iterable):
return {k: recursive_groupsort(v) for k, v in groupsort(iterable).items()}
else: # Only two elements in list
return [x[0] for x in iterable]
def flatten(d, parent_key="", sep="_"):
"""Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615"""
items = []
for k, v in d.items():
new_key = parent_key + (k,) if parent_key else (k,)
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment