diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py deleted file mode 100644 index 9fc68c63a7dfb4f9f8a9460d03d6bfee232b327f..0000000000000000000000000000000000000000 --- a/src/postprocessor/chainer.py +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env jupyter - -import re -import typing as t -from copy import copy - -import pandas as pd - -from agora.io.signal import Signal -from agora.utils.kymograph import bidirectional_retainment_filter -from postprocessor.core.abc import get_process - - -class Chainer(Signal): - """ - Extend Signal by applying post-processes and allowing composite signals that combine basic signals. - - Chainer "chains" multiple processes upon fetching a dataset. - - Instead of reading processes previously applied, Chainer executes - them when called. - """ - - _synonyms = {} - - def __init__(self, *args, **kwargs): - """Initialise chainer.""" - super().__init__(*args, **kwargs) - - def replace_path(path: str, bgsub: bool = ""): - # function to add bgsub to paths - channel = path.split("/")[1] - suffix = "_bgsub" if bgsub else "" - path = re.sub(channel, f"{channel}{suffix}", path) - return path - - # add chain with and without bgsub for composite statistics - self.common_chains = { - alias - + bgsub: lambda **kwargs: self.get( - replace_path(denominator, alias + bgsub), **kwargs - ) - / self.get(replace_path(numerator, alias + bgsub), **kwargs) - for alias, (denominator, numerator) in self._synonyms.items() - for bgsub in ("", "_bgsub") - } - - def get( - self, - dataset: str, - chain: t.Collection[str] = ("standard", "interpolate", "savgol"), - in_minutes: bool = True, - stages: bool = True, - retain: t.Optional[float] = None, - **kwargs, - ): - """Load data from an h5 file.""" - if dataset in self.common_chains: - # get dataset for composite chains - data = self.common_chains[dataset](**kwargs) - else: - # use Signal's get_raw - data = self.get_raw(dataset, in_minutes=in_minutes, lineage=True) - if chain: - data = self.apply_chain(data, chain, **kwargs) - if retain: - # keep data only from early time points - data = self.get_retained(data, retain) - if stages and "stage" not in data.columns.names: - # return stages as additional column level - stages_index = [ - x - for i, (name, span) in enumerate(self.stages_span_tp) - for x in (f"{i} { name }",) * span - ] - data.columns = pd.MultiIndex.from_tuples( - zip(stages_index, data.columns), - names=("stage", "time"), - ) - return data - - def apply_chain( - self, input_data: pd.DataFrame, chain: t.Tuple[str, ...], **kwargs - ): - """ - Apply a series of processes to a data set. - - Like postprocessing, Chainer consecutively applies processes. - - Parameters can be passed as kwargs. - - Chainer does not support applying the same process multiple times with different parameters. - - Parameters - ---------- - input_data : pd.DataFrame - Input data to process. - chain : t.Tuple[str, ...] - Tuple of strings with the names of the processes - **kwargs : kwargs - Arguments passed on to Process.as_function() method to modify the parameters. - - Examples - -------- - FIXME: Add docs. - - - """ - result = copy(input_data) - self._intermediate_steps = [] - for process in chain: - if process == "standard": - result = bidirectional_retainment_filter(result) - else: - params = kwargs.get(process, {}) - process_cls = get_process(process) - result = process_cls.as_function(result, **params) - process_type = process_cls.__module__.split(".")[-2] - if process_type == "reshapers": - if process == "merger": - raise (NotImplementedError) - self._intermediate_steps.append(result) - return result diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index da6cbd67483be37f2d676f45cdc189e25228d8d6..e9abac873eddc8fe82e5c3a6a81a34debfef5288 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -13,14 +13,13 @@ import pandas as pd from pathos.multiprocessing import Pool from agora.io.signal import Signal -from postprocessor.chainer import Chainer class Grouper(ABC): """Base grouper class.""" def __init__(self, dir: Union[str, Path], name_inds=(0, -4)): - """Find h5 files and load a chain for each one.""" + """Find h5 files and load each one.""" path = Path(dir) assert path.exists(), f"{str(dir)} does not exist" self.name = path.name @@ -29,31 +28,33 @@ class Grouper(ABC): self.load_positions() self.positions_groups = { name: name[name_inds[0] : name_inds[1]] - for name in self.chainers.keys() + for name in self.positions.keys() } def load_positions(self) -> None: - """Load a chain for each position, or h5 file.""" - self.chainers = {f.name[:-3]: Signal(f) for f in self.files} + """Load a Signal for each position, or h5 file.""" + self.positions = {f.name[:-3]: Signal(f) for f in self.files} @property def first_signal(self) -> Signal: """Get Signal for the first position.""" - return list(self.chainers.values())[0] + return list(self.positions.values())[0] @property def ntimepoints(self) -> int: """Find number of time points.""" - return max([s.ntimepoints for s in self.chainers.values()]) + return max([s.ntimepoints for s in self.positions.values()]) @property - def max_tinterval(self) -> float: - """Find the maximum time interval for all chains.""" - tintervals = set([s.tinterval / 60 for s in self.chainers.values()]) + def tinterval(self) -> float: + """Find the time interval for all positions.""" + tintervals = list( + set([s.tinterval / 60 for s in self.positions.values()]) + ) assert ( len(tintervals) == 1 - ), "Not all chains have the same time interval" - return max(tintervals) + ), "Not all positions have the same time interval." + return tintervals[0] @property def available(self) -> t.Collection[str]: @@ -65,7 +66,7 @@ class Grouper(ABC): """Display available signals and the number of positions with these signals.""" if not hasattr(self, "available_grouped"): self._available_grouped = Counter( - [x for s in self.chainers.values() for x in s.available] + [x for s in self.positions.values() for x in s.available] ) for s, n in self._available_grouped.items(): print(f"{s} - {n}") @@ -78,7 +79,7 @@ class Grouper(ABC): **kwargs, ): """ - Concatenate data for one signal from different h5 files into a data frame. + Concatenate data for one signal from different h5 files. Each h5 files corresponds to a different position. @@ -89,7 +90,6 @@ class Grouper(ABC): pool : int Number of threads used; if 0 or None only one core is used. mode: str - standard: boolean **kwargs : key, value pairings Named arguments for concat_ind_function @@ -101,19 +101,18 @@ class Grouper(ABC): path = path.strip("/") good_positions = self.filter_positions(path) if good_positions: - fn_pos = concat_one_signal kwargs["mode"] = mode records = self.pool_function( path=path, - f=fn_pos, + f=concat_one_signal, pool=pool, - chainers=good_positions, + positions=good_positions, **kwargs, ) # check for errors errors = [ k - for kymo, k in zip(records, self.chainers.keys()) + for kymo, k in zip(records, self.positions.keys()) if kymo is None ] records = [record for record in records if record is not None] @@ -130,12 +129,12 @@ class Grouper(ABC): concat_sorted = concat.sort_index() return concat_sorted - def filter_positions(self, path: str) -> t.Dict[str, Chainer]: - """Filter chains to those whose data is available in the h5 file.""" + def filter_positions(self, path: str) -> t.Dict[str, Signal]: + """Filter positions to those whose data is available in the h5 file.""" good_positions = { - k: v for k, v in self.chainers.items() if path in [*v.available] + k: v for k, v in self.positions.items() if path in [*v.available] } - no_positions_dif = len(self.chainers) - len(good_positions) + no_positions_dif = len(self.positions) - len(good_positions) if no_positions_dif: print( f"Grouper: Warning: {no_positions_dif} positions do not contain" @@ -148,52 +147,52 @@ class Grouper(ABC): path: str, f: t.Callable, pool: t.Optional[int] = None, - chainers: t.Dict[str, Chainer] = None, + positions: t.Dict[str, Signal] = None, **kwargs, ): """ - Enable different threads for independent chains. + Enable different threads for different positions. Particularly useful when aggregating multiple elements. """ - chainers = chainers or self.chainers + positions = positions or self.positions if pool: with Pool(pool) as p: records = p.map( lambda x: f( path=path, - chainer=x[1], + position=x[1], group=self.positions_groups[x[0]], - position=x[0], + position_name=x[0], **kwargs, ), - chainers.items(), + positions.items(), ) else: records = [ f( path=path, - chainer=chainer, + position=position, group=self.positions_groups[name], - position=name, + position_name=name, **kwargs, ) - for name, chainer in self.chainers.items() + for name, position in self.positions.items() ] return records @property def no_tiles(self): """Get total number of tiles per position (h5 file).""" - for pos, s in self.chainers.items(): + for pos, s in self.positions.items(): with h5py.File(s.filename, "r") as f: print(pos, f["/trap_info/trap_locations"].shape[0]) @property - def tilelocs(self) -> t.Dict[str, np.ndarray]: + def tile_locs(self) -> t.Dict[str, np.ndarray]: """Get the locations of the tiles for each position as a dictionary.""" d = {} - for pos, s in self.chainers.items(): + for pos, s in self.positions.items(): with h5py.File(s.filename, "r") as f: d[pos] = f["/trap_info/trap_locations"][()] return d @@ -223,16 +222,11 @@ class Grouper(ABC): return set( [ channel - for position in self.chainers.values() + for position in self.positions.values() for channel in position.channels ] ) - @property - def tinterval(self): - """Get interval between time points in seconds.""" - return self.first_signal.tinterval - @property def no_members(self) -> t.Dict[str, int]: """Get the number of positions belonging to each group.""" @@ -242,7 +236,7 @@ class Grouper(ABC): def no_tiles_by_group(self) -> t.Dict[str, int]: """Get total number of tiles per group.""" no_tiles = {} - for pos, s in self.chainers.items(): + for pos, s in self.positions.items(): with h5py.File(s.filename, "r") as f: no_tiles[pos] = f["/trap_info/trap_locations"].shape[0] no_tiles_by_group = {k: 0 for k in self.groups} @@ -261,21 +255,12 @@ class Grouper(ABC): return tuple(sorted(set(self.positions_groups.keys()))) -class NameGrouper(Grouper): - """Group a set of positions with a shorter version of the group's name.""" - - def __init__(self, dir, name_inds=(0, -4)): - """Define the indices to slice names.""" - super().__init__(dir=dir) - self.name_inds = name_inds - - def concat_one_signal( path: str, - chainer: Chainer, + position: Signal, group: str, mode: str = "retained", - position=None, + position_name=None, **kwargs, ) -> pd.DataFrame: """ @@ -283,27 +268,26 @@ def concat_one_signal( Applies filtering if requested and adjusts indices. """ - if position is None: + if position_name is None: # name of h5 file - position = chainer.stem + position_name = position.stem if mode == "retained": - combined = chainer.retained(path, **kwargs) + combined = position.retained(path, **kwargs) elif mode == "raw": - combined = chainer.get_raw(path, **kwargs) + combined = position.get_raw(path, **kwargs) elif mode == "daughters": - combined = chainer.get_raw(path, **kwargs) + combined = position.get_raw(path, **kwargs) combined = combined.loc[ combined.index.get_level_values("mother_label") > 0 ] elif mode == "families": - combined = chainer[path] + combined = position[path] else: - raise Exception(f"{mode} not recognised.") + raise Exception(f"concat_one_signal: {mode} not recognised.") if combined is not None: # adjust indices - combined["position"] = position + combined["position"] = position_name combined["group"] = group combined.set_index(["group", "position"], inplace=True, append=True) combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) - # should there be an error message if None is returned? return combined