diff --git a/core/functions/tracks.py b/core/functions/tracks.py index 7523213520890d73bd074f3bd0ff304cd96077f3..c446c6ee76dd7b02cad45a5135fc47a98d400783 100644 --- a/core/functions/tracks.py +++ b/core/functions/tracks.py @@ -216,11 +216,6 @@ def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: """ - tracks.index.names = [ - "trap", - "cell", - ] # TODO remove this once it is integrated in the tracker - # contig=tracks.groupby(['pos','trap']).apply(tracks2contig) clean = clean_tracks(tracks, min_len=window + 1, min_gr=0.9) # get useful tracks contig = clean.groupby(["trap"]).apply(get_contiguous_pairs) contig = contig.loc[contig.apply(len) > 0] diff --git a/core/io/base.py b/core/io/base.py deleted file mode 100644 index a05aa0354697d4a3a4a81f8edcf0a6091ea72a8a..0000000000000000000000000000000000000000 --- a/core/io/base.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing import Union -import collections -from itertools import groupby, chain, product - -import numpy as np -import h5py - - -class BridgeH5: - """ - Base class to interact with h5 data stores. - It also contains functions useful to predict how long should segmentation take. - """ - - def __init__(self, file): - self.filename = file - self._hdf = h5py.File(file, "r") - - self._filecheck() - - def _filecheck(self): - assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found." - - def close(self): - self._hdf.close() - - def max_ncellpairs(self, nstepsback): - """ - Get maximum number of cell pairs to be calculated - """ - - dset = self._hdf["cell_info"][()] - # attrs = self._hdf[dataset].attrs - pass - - @property - def cell_tree(self): - return self.get_info_tree() - - def get_n_cellpairs(self, nstepsback=2): - cell_tree = self.cell_tree - # get pair of consecutive trap-time points - pass - - @staticmethod - def get_consecutives(tree, nstepsback): - # Receives a sorted tree and returns the keys of consecutive elements - vals = {k: np.array(list(v)) for k, v in tree.items()} # get tp level - where_consec = [ - { - k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] - for k, v in vals.items() - } - for n in range(nstepsback) - ] # get indices of consecutive elements - return where_consec - - def get_npairs(self, nstepsback=2, tree=None): - if tree is None: - tree = self.cell_tree - - consecutive = self.get_consecutives(tree, nstepsback=nstepsback) - flat_tree = flatten(tree) - - n_predictions = 0 - for i, d in enumerate(consecutive, 1): - flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) - pairs = [(f, (f[0], f[1] + i)) for f in flat] - for p in pairs: - n_predictions += len(flat_tree.get(p[0], [])) * len( - flat_tree.get(p[1], []) - ) - - return n_predictions - - def get_npairs_over_time(self, nstepsback=2): - tree = self.cell_tree - npairs = [] - for t in self._hdf["cell_info"]["processed_timepoints"][()]: - tmp_tree = { - k: {k2: v2 for k2, v2 in v.items() if k2 <= t} for k, v in tree.items() - } - npairs.append(self.get_npairs(tree=tmp_tree)) - - return np.diff(npairs) - - def get_info_tree( - self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") - ): - """ - Returns traps, time points and labels for this position in form of a tree - in the hierarchy determined by the argument fields. Note that it is - compressed to non-empty elements and timepoints. - - Default hierarchy is: - - trap - - time point - - cell label - - This function currently produces trees of depth 3, but it can easily be - extended for deeper trees if needed (e.g. considering groups, - chambers and/or positions). - - input - :fields: Fields to fetch from 'cell_info' inside the hdf5 storage - - returns - :tree: Nested dictionary where keys (or branches) are the upper levels - and the leaves are the last element of :fields:. - """ - zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) - - return recursive_groupsort(zipped_info) - - -def groupsort(iterable: Union[tuple, list]): - # Sorts iterable and returns a dictionary where the values are grouped by the first element. - - iterable = sorted(iterable, key=lambda x: x[0]) - grouped = {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])} - return grouped - - -def recursive_groupsort(iterable): - # Recursive extension of groupsort - if len(iterable[0]) > 1: - return {k: recursive_groupsort(v) for k, v in groupsort(iterable).items()} - else: # Only two elements in list - return [x[0] for x in iterable] - - -def flatten(d, parent_key="", sep="_"): - """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615""" - items = [] - for k, v in d.items(): - new_key = parent_key + (k,) if parent_key else (k,) - if isinstance(v, collections.MutableMapping): - items.extend(flatten(v, new_key, sep=sep).items()) - else: - items.append((new_key, v)) - return dict(items) diff --git a/core/io/signal.py b/core/io/signal.py deleted file mode 100644 index da5f73f9dc9cd13fe7cf32fe5240edc35b50d1b9..0000000000000000000000000000000000000000 --- a/core/io/signal.py +++ /dev/null @@ -1,32 +0,0 @@ -import pandas as pd - -from postprocessor.core.io.base import BridgeH5 - - -class Signal(BridgeH5): - """ - Class that fetches data from the hdf5 storage for post-processing - """ - - def __init__(self, file): - super().__init__(file) - - def __getitem__(self, dataset): - dset = self._hdf[dataset] - index = pd.MultiIndex.from_arrays( - [dset[lbl][()] for lbl in dset.keys() if "axis1_label" in lbl] - ) - - columns = dset["axis0"][()] - - return pd.DataFrame(dset[("block0_values")][()], index=index, columns=columns) - - @staticmethod - def _if_ext_or_post(name): - if name.startswith("extraction") or name.startswith("postprocessing"): - if len(name.split("/")) > 3: - return name - - @property - def datasets(self): - return self._hdf.visit(self._if_ext_or_post) diff --git a/core/io/writer.py b/core/io/writer.py deleted file mode 100644 index 240331bb8a9008ae37a128f9d12a437e9937050d..0000000000000000000000000000000000000000 --- a/core/io/writer.py +++ /dev/null @@ -1,50 +0,0 @@ -from itertools import accumulate - -import h5py -import pandas as pd - -from postprocessor.core.io.base import BridgeH5 - - -def Writer(BridgeH5): - """ - Class in charge of transforming data into compatible formats - - Decoupling interface from implementation! - - :hdfname: Name of file to write into - """ - - def __init__(self, hdfname): - self._hdf = h5py.Hdf(hdfname, "a") - - def write(self, address, data): - self._hdf.add_group(address) - if type(data) is pd.DataFrame: - self.write_df(address, data) - elif type(data) is np.array: - self.write_np(address, data) - - def write_np(self, address, array): - pass - - def write_df(self, df, tps, path): - print("writing to ", path) - for item in accummulate(path.split("/")[:-2]): - if item not in self._hdf: - self._hdf.create_group(item) - pos_group = f[path.split("/")[1]] - - if path not in pos_group: - pos_group.create_dataset(name=path, shape=df.shape, dtype=df.dtypes[0]) - new_dset = f[path] - new_dset[()] = df.values - if len(df.index.names) > 1: - trap, cell_label = zip(*list(df.index.values)) - new_dset.attrs["trap"] = trap - new_dset.attrs["cell_label"] = cell_label - new_dset.attrs["idnames"] = ["trap", "cell_label"] - else: - new_dset.attrs["trap"] = list(df.index.values) - new_dset.attrs["idnames"] = ["trap"] - pos_group.attrs["processed_timepoints"] = tps diff --git a/core/processes/merger.py b/core/processes/merger.py index 51a91bf98e22e2c9e4c82bac10830494a839b513..259d752fb1130ee52eaf1e20c106dfb0ccd79594 100644 --- a/core/processes/merger.py +++ b/core/processes/merger.py @@ -52,4 +52,7 @@ class Merger(ProcessABC): super().__init__(parameters) def run(self, signal): - merged, joint_pairs = merge_tracks(signal) # , min_len=self.window + 1) + merged, _ = merge_tracks(signal) # , min_len=self.window + 1) + indices = (*zip(*merged.index.tolist()),) + names = merged.index.names + return {name: ids for name, ids in zip(names, indices)} diff --git a/core/processor.py b/core/processor.py index eeb3cf9af24ab0ebedb215cfef740c553e9edde4..ce910885a251a10458927e5b09adbc8744e031ea 100644 --- a/core/processor.py +++ b/core/processor.py @@ -4,8 +4,8 @@ import pandas as pd from postprocessor.core.processes.base import ParametersABC from postprocessor.core.processes.merger import MergerParameters, Merger from postprocessor.core.processes.picker import PickerParameters, Picker -from postprocessor.core.io.writer import Writer -from postprocessor.core.io.signal import Signal +from core.io.writer import Writer +from core.io.signal import Signal from core.cells import Cells @@ -19,10 +19,13 @@ class PostProcessorParameters(ParametersABC): :datasets: Dictionary """ - def __init__(self, merger=None, picker=None, processes=[], datasets=None): + def __init__( + self, merger=None, picker=None, processes=[], datasets=[], outpaths=[] + ): self.merger: MergerParameters = merger self.picker: PickerParameters = picker self.processes: List = processes + self.outpaths = outpaths self.datasets: Dict = datasets @@ -49,10 +52,11 @@ class PostProcessorParameters(ParametersABC): class PostProcessor: def __init__(self, filename, parameters): self.parameters = parameters - self._signals = Signal(filename) + # self._signals = Signal(filename) self._writer = Writer(filename) self.datasets = parameters["datasets"] + self.outpaths = parameters["outpaths"] self.merger = Merger(parameters["merger"]) self.picker = Picker( parameters=parameters["picker"], cells=Cells.from_source(filename) @@ -62,11 +66,17 @@ class PostProcessor: ] def run(self): - self.merger.run(self._signals[self.datasets["merger"]]) - self.picker.run(self._signals[self.datasets["picker"]]) - for process, dataset in zip(self.processes, self.datasets["processes"]): - processed_result = process.run(self._signals.get_dataset(dataset)) - self.writer.write(processed_result, dataset, outpath) + new_ids = self.merger.run(self._signals[self.datasets["merger"]]) + for name, ids in new_ids.items(): + self._writer.write(ids, "/postprocessing/cell_info/") + picks = self.picker.run(self._writer[self.datasets["picker"]]) + return merge + # print(merge, picks) + # for process, dataset, outpath in zip( + # self.processes, self.datasets["processes"], self.outpaths + # ): + # processed_result = process.run(self._signals.get_dataset(dataset)) + # self.writer.write(processed_result, dataset, outpath) def _if_dict(item): diff --git a/examples/basic_processes.py b/examples/basic_processes.py index 8dd85d5fc83c54c112a67a324ff3e55c5d264cba..b3bf6a2455bfd0ee95a521d2ed1f20f55e584292 100644 --- a/examples/basic_processes.py +++ b/examples/basic_processes.py @@ -2,7 +2,14 @@ from postprocessor.core.processor import PostProcessor, PostProcessorParameters params = PostProcessorParameters.default() pp = PostProcessor( - "/shared_libs/pipeline-core/scripts/data/ph_calibration_dual_phl_ura8_5_04_5_83_7_69_7_13_6_59__01/ph_5_04_005store.h5", + "/shared_libs/pipeline-core/scripts/pH_calibration_dual_phl__ura8__by4741__01/ph_5_29_025store.h5", params, ) -tmp = pp.run() +# tmp = pp.run() + +import h5py + +f = h5py.File( + "/shared_libs/pipeline-core/scripts/pH_calibration_dual_phl__ura8__by4741__01/ph_5_29_025store.h5", + "a", +) diff --git a/examples/signals.py b/examples/signals.py deleted file mode 100644 index 98bdc941eaedaf9b650b13ed17634927fc731ac0..0000000000000000000000000000000000000000 --- a/examples/signals.py +++ /dev/null @@ -1,5 +0,0 @@ -from postprocessor.core.io.signal import Signal - -signal = Signal( - "/shared_libs/pipeline-core/scripts/data/ph_calibration_dual_phl_ura8_5_04_5_83_7_69_7_13_6_59__01/ph_5_04_001store.h5" -)