From 9c3bee4232c1c1751908fef130d23932544aa18e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Fri, 7 Oct 2022 18:37:34 +0100 Subject: [PATCH] refactor(grouper): to use chainer --- src/postprocessor/grouper.py | 207 ++++++++++++++++++++--------------- 1 file changed, 118 insertions(+), 89 deletions(-) diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index e8a6f3ce..daaec4aa 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns -from agora.io.signal import Signal +from postprocessor.chainer import Chainer from pathos.multiprocessing import Pool @@ -28,27 +28,27 @@ class Grouper(ABC): assert path.exists(), "Dir does not exist" self.files = list(path.glob("*.h5")) assert len(self.files), "No valid h5 files in dir" - self.load_signals() + self.load_chains() - def load_signals(self) -> None: - # Sets self.signals - self.signals = {f.name[:-3]: Signal(f) for f in self.files} + def load_chains(self) -> None: + # Sets self.chainers + self.chainers = {f.name[:-3]: Chainer(f) for f in self.files} @property - def fsignal(self) -> Signal: + def fsignal(self) -> Chainer: # Returns first signal - return list(self.signals.values())[0] + return list(self.chainers.values())[0] @property def ntimepoints(self) -> int: - return max([s.ntimepoints for s in self.signals.values()]) + return max([s.ntimepoints for s in self.chainers.values()]) @property def tintervals(self) -> float: - tintervals = set([s.tinterval / 60 for s in self.signals.values()]) + tintervals = set([s.tinterval / 60 for s in self.chainers.values()]) assert ( len(tintervals) == 1 - ), "Not all signals have the same time interval" + ), "Not all chains have the same time interval" return max(tintervals) @@ -60,7 +60,7 @@ class Grouper(ABC): def available_grouped(self) -> None: if not hasattr(self, "_available_grouped"): self._available_grouped = Counter( - [x for s in self.signals.values() for x in s.available] + [x for s in self.chainers.values() for x in s.available] ) for s, n in self._available_grouped.items(): @@ -68,7 +68,7 @@ class Grouper(ABC): @property def datasets(self) -> None: - """Print available datasets in first Signal instance.""" + """Print available datasets in first Chainer instance.""" return self.fsignal.datasets @abstractproperty @@ -78,13 +78,12 @@ class Grouper(ABC): def concat_signal( self, path: str, - reduce_cols: t.Optional[bool] = None, - axis: int = 0, - mode: str = "retained", pool: t.Optional[int] = None, + mode: str = "retained", + standard: t.Optional[bool] = False, **kwargs, ): - """Concatenate a single signal. + """Concate Parameters ---------- @@ -106,36 +105,58 @@ class Grouper(ABC): if path.startswith("/"): path = path.strip("/") - # Check the path is in a given signal - sitems = {k: v for k, v in self.signals.items() if path in v.available} - nsignals_dif = len(self.signals) - len(sitems) - if nsignals_dif: - print( - f"Grouper:Warning: {nsignals_dif} signals do not contain" - f" channel {path}" - ) + sitems = self.filter_path(path) + if standard: + fn_pos = concat_standard + else: + fn_pos = concat_signal_ind + kwargs["mode"] = mode - signals = self.pool_function( + kymographs = self.pool_function( path=path, - f=concat_signal_ind, - mode=mode, + f=fn_pos, pool=pool, - signals=sitems, + chainers=sitems, **kwargs, ) - errors = [k for s, k in zip(signals, self.signals.keys()) if s is None] - signals = [s for s in signals if s is not None] + errors = [ + k + for kymo, k in zip(kymographs, self.chainers.keys()) + if kymo is None + ] + kymographs = [kymo for kymo in kymographs if kymo is not None] if len(errors): print("Warning: Positions contain errors {errors}") - assert len(signals), "All datasets contain errors" - sorted = pd.concat(signals, axis=axis).sort_index() - if reduce_cols: - sorted = sorted.apply(np.nanmean, axis=1) - spath = path.split("/") - sorted.name = "_".join([spath[1], spath[-1]]) - return sorted + assert len(kymographs), "All datasets contain errors" + + concat_sorted = ( + pd.concat(kymographs, axis=0) + .reorder_levels( + ("group", "position", "trap", "cell_label", "mother_label") + ) + .sort_index() + ) + return concat_sorted + + def filter_path(self, path: str) -> t.Dict[str, Chainer]: + # Check the path is in a given signal + sitems = { + k: v + for k, v in self.chainers.items() + if path in [*v.available, *v.common_chains] + } + nchains_dif = len(self.chainers) - len(sitems) + if nchains_dif: + print( + f"Grouper:Warning: {nchains_dif} chains do not contain" + f" channel {path}" + ) + + assert len(sitems), "No valid dataset to use" + + return sitems @property def nmembers(self) -> t.Dict[str, int]: @@ -144,7 +165,7 @@ class Grouper(ABC): @property def ntraps(self): - for pos, s in self.signals.items(): + for pos, s in self.chainers.items(): with h5py.File(s.filename, "r") as f: print(pos, f["/trap_info/trap_locations"].shape[0]) @@ -152,7 +173,7 @@ class Grouper(ABC): def ntraps_by_pos(self) -> t.Dict[str, int]: # Return total number of traps grouped ntraps = {} - for pos, s in self.signals.items(): + for pos, s in self.chainers.items(): with h5py.File(s.filename, "r") as f: ntraps[pos] = f["/trap_info/trap_locations"].shape[0] @@ -164,7 +185,7 @@ class Grouper(ABC): def traplocs(self): d = {} - for pos, s in self.signals.items(): + for pos, s in self.chainers.items(): with h5py.File(s.filename, "r") as f: d[pos] = f["/trap_info/trap_locations"][()] return d @@ -201,15 +222,16 @@ class Grouper(ABC): path: str, f: t.Callable, pool: t.Optional[int] = None, - signals: t.Dict[str, Signal] = None, + chainers: t.Dict[str, Chainer] = None, **kwargs, ): """ - Wrapper to add support for threading to process independent signals. + Wrapper to add support for threading to process independent chains. Particularly useful when aggregating multiple elements. """ - pool = pool or 8 - signals = signals or self.signals + if pool is None: + pool = 8 + chainers = chainers or self.chainers if pool: @@ -217,25 +239,37 @@ class Grouper(ABC): kymographs = p.map( lambda x: f( path=path, - signal=x[1], + chainer=x[1], group=self.positions_groups[x[0]], + position=x[0], **kwargs, ), - signals.items(), + chainers.items(), ) else: kymographs = [ f( path=path, - signal=signal, + chainer=chainer, group=self.positions_groups[name], + position=name, **kwargs, ) - for name, signal in self.signals.items() + for name, chainer in self.chainers.items() ] return kymographs + @property + def channels(self): + return set( + [ + channel + for chainer in self.chainers.values() + for channel in chainer.channels + ] + ) + @property def stages_span(self): return self.fsignal.stages_span @@ -273,33 +307,13 @@ class NameGrouper(Grouper): def positions_groups(self) -> t.Dict[str, str]: if not hasattr(self, "_positions_groups"): self._positions_groups = {} - for name in self.signals.keys(): + for name in self.chainers.keys(): self._positions_groups[name] = name[ self.criteria[0] : self.criteria[1] ] return self._positions_groups - # def aggregate_multisignals(self, paths=None, **kwargs): - # aggregated = pd.concat( - # [ - # self.concat_signal(path, reduce_cols=np.nanmean, **kwargs) - # for path in paths - # ], - # axis=1, - # ) - # # ph = pd.Series( - # # [ - # # self.ph_from_group(x[list(aggregated.index.names).index("group")]) - # # for x in aggregated.index - # # ], - # # index=aggregated.index, - # # name="media_pH", - # # ) - # # self.aggregated = pd.concat((aggregated, ph), axis=1) - - # return aggregated - class phGrouper(NameGrouper): """Grouper for pH calibration experiments where all surveyed media pH @@ -318,8 +332,8 @@ class phGrouper(NameGrouper): return float(group_name.replace("_", ".")) - def aggregate_multisignals(self, paths: list) -> pd.DataFrame: - """Accumulate multiple signals.""" + def aggregate_multichains(self, paths: list) -> pd.DataFrame: + """Accumulate multiple chains.""" aggregated = pd.concat( [ @@ -343,9 +357,26 @@ class phGrouper(NameGrouper): return aggregated +def concat_standard( + path: str, + chainer: Chainer, + group: str, + position: t.Optional[str] = None, + **kwargs, +) -> pd.DataFrame: + + combined = chainer.get(path, **kwargs).copy() + combined["position"] = position + combined["group"] = group + combined.set_index(["group", "position"], inplace=True, append=True) + combined.index = combined.index.copy().swaplevel(-2, 0).swaplevel(-1, 1) + + return combined + + def concat_signal_ind( path: str, - signal: Signal, + chainer: Chainer, group: str, mode: str = "retained", position=None, @@ -354,23 +385,21 @@ def concat_signal_ind( """Core function that handles retrieval of an individual signal, applies filtering if requested and adjusts indices.""" if position is None: - position = signal.stem + position = chainer.stem if mode == "retained": - combined = signal.retained(path, **kwargs) + combined = chainer.retained(path, **kwargs) if mode == "mothers": raise (NotImplementedError) elif mode == "raw": - combined = signal.get_raw(path, **kwargs) + combined = chainer.get_raw(path, **kwargs) elif mode == "families": - combined = signal[path] + combined = chainer[path] combined["position"] = position combined["group"] = group combined.set_index(["group", "position"], inplace=True, append=True) combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) return combined - # except: - # return None class MultiGrouper: @@ -385,7 +414,7 @@ class MultiGrouper: self.exp_dirs = [Path(x) for x in source] self.groupers = [NameGrouper(d) for d in self.exp_dirs] for group in self.groupers: - group.load_signals() + group.load_chains() @property def available(self) -> None: @@ -406,7 +435,7 @@ class MultiGrouper: if not hasattr(self, "_sigtable"): raw_mat = [ - [s.available for s in gpr.signals.values()] + [s.available for s in gpr.chains.values()] for gpr in self.groupers ] available_grouped = [ @@ -435,7 +464,7 @@ class MultiGrouper: return self._sigtable def sigtable_plot(self) -> None: - """Plot number of signals for all available experiments. + """Plot number of chains for all available experiments. Examples -------- @@ -452,17 +481,17 @@ class MultiGrouper: def aggregate_signal( self, - signals: Union[str, list], + path: Union[str, list], **kwargs, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: - """Aggregate signals from multiple Groupers (and thus experiments) + """Aggregate chains from multiple Groupers (and thus experiments) Parameters ---------- - signals : Union[str, list] + chains : Union[str, list] string or list of strings indicating the signal(s) to fetch. **kwargs : keyword arguments to pass to Grouper.concat_signal - Customise the filters and format to fetch signals. + Customise the filters and format to fetch chains. Returns ------- @@ -473,11 +502,11 @@ class MultiGrouper: -------- FIXME: Add docs. """ - if isinstance(signals, str): - signals = [signals] + if isinstance(path, str): + path = [path] - sigs = {s: [] for s in signals} - for s in signals: + sigs = {s: [] for s in path} + for s in path: for grp in self.groupers: try: sigset = grp.concat_signal(s, **kwargs) -- GitLab