Skip to content
Snippets Groups Projects
Commit c74e9e88 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

encapsulated signal and writer

Former-commit-id: dce960267879405009096e4df3485425999d7655
parent db1193e7
No related branches found
No related tags found
No related merge requests found
......@@ -216,11 +216,6 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
"""
tracks.index.names = [
"trap",
"cell",
] # TODO remove this once it is integrated in the tracker
# contig=tracks.groupby(['pos','trap']).apply(tracks2contig)
clean = clean_tracks(tracks, min_len=window + 1, min_gr=0.9) # get useful tracks
contig = clean.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
......
from typing import Union
import collections
from itertools import groupby, chain, product
import numpy as np
import h5py
class BridgeH5:
"""
Base class to interact with h5 data stores.
It also contains functions useful to predict how long should segmentation take.
"""
def __init__(self, file):
self.filename = file
self._hdf = h5py.File(file, "r")
self._filecheck()
def _filecheck(self):
assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found."
def close(self):
self._hdf.close()
def max_ncellpairs(self, nstepsback):
"""
Get maximum number of cell pairs to be calculated
"""
dset = self._hdf["cell_info"][()]
# attrs = self._hdf[dataset].attrs
pass
@property
def cell_tree(self):
return self.get_info_tree()
def get_n_cellpairs(self, nstepsback=2):
cell_tree = self.cell_tree
# get pair of consecutive trap-time points
pass
@staticmethod
def get_consecutives(tree, nstepsback):
# Receives a sorted tree and returns the keys of consecutive elements
vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level
where_consec = [
{
k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0]
for k, v in vals.items()
}
for n in range(nstepsback)
] # get indices of consecutive elements
return where_consec
def get_npairs(self, nstepsback=2, tree=None):
if tree is None:
tree = self.cell_tree
consecutive = self.get_consecutives(tree, nstepsback=nstepsback)
flat_tree = flatten(tree)
n_predictions = 0
for i, d in enumerate(consecutive, 1):
flat = list(chain(*[product([k], list(v)) for k, v in d.items()]))
pairs = [(f, (f[0], f[1] + i)) for f in flat]
for p in pairs:
n_predictions += len(flat_tree.get(p[0], [])) * len(
flat_tree.get(p[1], [])
)
return n_predictions
def get_npairs_over_time(self, nstepsback=2):
tree = self.cell_tree
npairs = []
for t in self._hdf["cell_info"]["processed_timepoints"][()]:
tmp_tree = {
k: {k2: v2 for k2, v2 in v.items() if k2 <= t} for k, v in tree.items()
}
npairs.append(self.get_npairs(tree=tmp_tree))
return np.diff(npairs)
def get_info_tree(
self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label")
):
"""
Returns traps, time points and labels for this position in form of a tree
in the hierarchy determined by the argument fields. Note that it is
compressed to non-empty elements and timepoints.
Default hierarchy is:
- trap
- time point
- cell label
This function currently produces trees of depth 3, but it can easily be
extended for deeper trees if needed (e.g. considering groups,
chambers and/or positions).
input
:fields: Fields to fetch from 'cell_info' inside the hdf5 storage
returns
:tree: Nested dictionary where keys (or branches) are the upper levels
and the leaves are the last element of :fields:.
"""
zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),)
return recursive_groupsort(zipped_info)
def groupsort(iterable: Union[tuple, list]):
# Sorts iterable and returns a dictionary where the values are grouped by the first element.
iterable = sorted(iterable, key=lambda x: x[0])
grouped = {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])}
return grouped
def recursive_groupsort(iterable):
# Recursive extension of groupsort
if len(iterable[0]) > 1:
return {k: recursive_groupsort(v) for k, v in groupsort(iterable).items()}
else: # Only two elements in list
return [x[0] for x in iterable]
def flatten(d, parent_key="", sep="_"):
"""Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615"""
items = []
for k, v in d.items():
new_key = parent_key + (k,) if parent_key else (k,)
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
import pandas as pd
from postprocessor.core.io.base import BridgeH5
class Signal(BridgeH5):
"""
Class that fetches data from the hdf5 storage for post-processing
"""
def __init__(self, file):
super().__init__(file)
def __getitem__(self, dataset):
dset = self._hdf[dataset]
index = pd.MultiIndex.from_arrays(
[dset[lbl][()] for lbl in dset.keys() if "axis1_label" in lbl]
)
columns = dset["axis0"][()]
return pd.DataFrame(dset[("block0_values")][()], index=index, columns=columns)
@staticmethod
def _if_ext_or_post(name):
if name.startswith("extraction") or name.startswith("postprocessing"):
if len(name.split("/")) > 3:
return name
@property
def datasets(self):
return self._hdf.visit(self._if_ext_or_post)
from itertools import accumulate
import h5py
import pandas as pd
from postprocessor.core.io.base import BridgeH5
def Writer(BridgeH5):
"""
Class in charge of transforming data into compatible formats
Decoupling interface from implementation!
:hdfname: Name of file to write into
"""
def __init__(self, hdfname):
self._hdf = h5py.Hdf(hdfname, "a")
def write(self, address, data):
self._hdf.add_group(address)
if type(data) is pd.DataFrame:
self.write_df(address, data)
elif type(data) is np.array:
self.write_np(address, data)
def write_np(self, address, array):
pass
def write_df(self, df, tps, path):
print("writing to ", path)
for item in accummulate(path.split("/")[:-2]):
if item not in self._hdf:
self._hdf.create_group(item)
pos_group = f[path.split("/")[1]]
if path not in pos_group:
pos_group.create_dataset(name=path, shape=df.shape, dtype=df.dtypes[0])
new_dset = f[path]
new_dset[()] = df.values
if len(df.index.names) > 1:
trap, cell_label = zip(*list(df.index.values))
new_dset.attrs["trap"] = trap
new_dset.attrs["cell_label"] = cell_label
new_dset.attrs["idnames"] = ["trap", "cell_label"]
else:
new_dset.attrs["trap"] = list(df.index.values)
new_dset.attrs["idnames"] = ["trap"]
pos_group.attrs["processed_timepoints"] = tps
......@@ -52,4 +52,7 @@ class Merger(ProcessABC):
super().__init__(parameters)
def run(self, signal):
merged, joint_pairs = merge_tracks(signal) # , min_len=self.window + 1)
merged, _ = merge_tracks(signal) # , min_len=self.window + 1)
indices = (*zip(*merged.index.tolist()),)
names = merged.index.names
return {name: ids for name, ids in zip(names, indices)}
......@@ -4,8 +4,8 @@ import pandas as pd
from postprocessor.core.processes.base import ParametersABC
from postprocessor.core.processes.merger import MergerParameters, Merger
from postprocessor.core.processes.picker import PickerParameters, Picker
from postprocessor.core.io.writer import Writer
from postprocessor.core.io.signal import Signal
from core.io.writer import Writer
from core.io.signal import Signal
from core.cells import Cells
......@@ -19,10 +19,13 @@ class PostProcessorParameters(ParametersABC):
:datasets: Dictionary
"""
def __init__(self, merger=None, picker=None, processes=[], datasets=None):
def __init__(
self, merger=None, picker=None, processes=[], datasets=[], outpaths=[]
):
self.merger: MergerParameters = merger
self.picker: PickerParameters = picker
self.processes: List = processes
self.outpaths = outpaths
self.datasets: Dict = datasets
......@@ -49,10 +52,11 @@ class PostProcessorParameters(ParametersABC):
class PostProcessor:
def __init__(self, filename, parameters):
self.parameters = parameters
self._signals = Signal(filename)
# self._signals = Signal(filename)
self._writer = Writer(filename)
self.datasets = parameters["datasets"]
self.outpaths = parameters["outpaths"]
self.merger = Merger(parameters["merger"])
self.picker = Picker(
parameters=parameters["picker"], cells=Cells.from_source(filename)
......@@ -62,11 +66,17 @@ class PostProcessor:
]
def run(self):
self.merger.run(self._signals[self.datasets["merger"]])
self.picker.run(self._signals[self.datasets["picker"]])
for process, dataset in zip(self.processes, self.datasets["processes"]):
processed_result = process.run(self._signals.get_dataset(dataset))
self.writer.write(processed_result, dataset, outpath)
new_ids = self.merger.run(self._signals[self.datasets["merger"]])
for name, ids in new_ids.items():
self._writer.write(ids, "/postprocessing/cell_info/")
picks = self.picker.run(self._writer[self.datasets["picker"]])
return merge
# print(merge, picks)
# for process, dataset, outpath in zip(
# self.processes, self.datasets["processes"], self.outpaths
# ):
# processed_result = process.run(self._signals.get_dataset(dataset))
# self.writer.write(processed_result, dataset, outpath)
def _if_dict(item):
......
......@@ -2,7 +2,14 @@ from postprocessor.core.processor import PostProcessor, PostProcessorParameters
params = PostProcessorParameters.default()
pp = PostProcessor(
"/shared_libs/pipeline-core/scripts/data/ph_calibration_dual_phl_ura8_5_04_5_83_7_69_7_13_6_59__01/ph_5_04_005store.h5",
"/shared_libs/pipeline-core/scripts/pH_calibration_dual_phl__ura8__by4741__01/ph_5_29_025store.h5",
params,
)
tmp = pp.run()
# tmp = pp.run()
import h5py
f = h5py.File(
"/shared_libs/pipeline-core/scripts/pH_calibration_dual_phl__ura8__by4741__01/ph_5_29_025store.h5",
"a",
)
from postprocessor.core.io.signal import Signal
signal = Signal(
"/shared_libs/pipeline-core/scripts/data/ph_calibration_dual_phl_ura8_5_04_5_83_7_69_7_13_6_59__01/ph_5_04_001store.h5"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment