diff --git a/src/agora/io/bridge.py b/src/agora/io/bridge.py index 42d9b6c5f51183a4ec38eebaf6cc5b3cbca20b6a..3f44541fc11b33a695d11e9e73793ccbc86e959f 100644 --- a/src/agora/io/bridge.py +++ b/src/agora/io/bridge.py @@ -23,20 +23,19 @@ class BridgeH5: """Initialise with the name of the h5 file.""" self.filename = filename if flag is not None: - self._hdf = h5py.File(filename, flag) - self._filecheck + self.hdf = h5py.File(filename, flag) + assert ( + "cell_info" in self.hdf + ), "Invalid file. No 'cell_info' found." def _log(self, message: str, level: str = "warn"): # Log messages in the corresponding level logger = logging.getLogger("aliby") getattr(logger, level)(f"{self.__class__.__name__}: {message}") - def _filecheck(self): - assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found." - def close(self): """Close the h5 file.""" - self._hdf.close() + self.hdf.close() @property def meta_h5(self) -> t.Dict[str, t.Any]: @@ -83,7 +82,7 @@ class BridgeH5: def get_npairs_over_time(self, nstepsback=2): tree = self.cell_tree npairs = [] - for tp in self._hdf["cell_info"]["processed_timepoints"][()]: + for tp in self.hdf["cell_info"]["processed_timepoints"][()]: tmp_tree = { k: {k2: v2 for k2, v2 in v.items() if k2 <= tp} for k, v in tree.items() @@ -115,7 +114,7 @@ class BridgeH5: ---------- Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:. """ - zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) + zipped_info = (*zip(*[self.hdf["cell_info"][f][()] for f in fields]),) return recursive_groupsort(zipped_info) diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index 9a22c76431c0a7b42b31fc75ba040de24aaa2264..da6cbd67483be37f2d676f45cdc189e25228d8d6 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -1,42 +1,44 @@ #!/usr/bin/env python3 -import re import typing as t -from abc import ABC, abstractproperty +from abc import ABC from collections import Counter from functools import cached_property as property from pathlib import Path -from typing import Dict, List, Union +from typing import Union import h5py -import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns from pathos.multiprocessing import Pool +from agora.io.signal import Signal from postprocessor.chainer import Chainer class Grouper(ABC): """Base grouper class.""" - def __init__(self, dir: Union[str, Path]): + def __init__(self, dir: Union[str, Path], name_inds=(0, -4)): """Find h5 files and load a chain for each one.""" path = Path(dir) assert path.exists(), f"{str(dir)} does not exist" self.name = path.name self.files = list(path.glob("*.h5")) assert len(self.files), "No valid h5 files in dir" - self.load_chains() + self.load_positions() + self.positions_groups = { + name: name[name_inds[0] : name_inds[1]] + for name in self.chainers.keys() + } - def load_chains(self) -> None: + def load_positions(self) -> None: """Load a chain for each position, or h5 file.""" - self.chainers = {f.name[:-3]: Chainer(f) for f in self.files} + self.chainers = {f.name[:-3]: Signal(f) for f in self.files} @property - def fsignal(self) -> Chainer: - """Get first chain.""" + def first_signal(self) -> Signal: + """Get Signal for the first position.""" return list(self.chainers.values())[0] @property @@ -45,7 +47,7 @@ class Grouper(ABC): return max([s.ntimepoints for s in self.chainers.values()]) @property - def tintervals(self) -> float: + def max_tinterval(self) -> float: """Find the maximum time interval for all chains.""" tintervals = set([s.tinterval / 60 for s in self.chainers.values()]) assert ( @@ -55,39 +57,30 @@ class Grouper(ABC): @property def available(self) -> t.Collection[str]: - """Generate list of available signals in the first chain.""" - return self.fsignal.available + """Generate list of available signals from the first position.""" + return self.first_signal.available @property def available_grouped(self) -> None: - """Display available signals and the number of chains for each.""" - if not hasattr(self, "_available_grouped"): + """Display available signals and the number of positions with these signals.""" + if not hasattr(self, "available_grouped"): self._available_grouped = Counter( [x for s in self.chainers.values() for x in s.available] ) for s, n in self._available_grouped.items(): print(f"{s} - {n}") - @property - def datasets(self) -> None: - """Print available data sets in the first chain.""" - return self.fsignal.datasets - - @abstractproperty - def positions_groups(self): - pass - def concat_signal( self, path: str, pool: t.Optional[int] = None, mode: str = "retained", - standard: t.Optional[bool] = False, **kwargs, ): """ - Concatenate data for one signal from different h5 files, one for - each position, into a dataframe. + Concatenate data for one signal from different h5 files into a data frame. + + Each h5 files corresponds to a different position. Parameters ---------- @@ -106,18 +99,15 @@ class Grouper(ABC): """ if path.startswith("/"): path = path.strip("/") - good_chains = self.filter_chains(path) - if good_chains: - if standard: - fn_pos = concat_standard - else: - fn_pos = concat_one_signal - kwargs["mode"] = mode + good_positions = self.filter_positions(path) + if good_positions: + fn_pos = concat_one_signal + kwargs["mode"] = mode records = self.pool_function( path=path, f=fn_pos, pool=pool, - chainers=good_chains, + chainers=good_positions, **kwargs, ) # check for errors @@ -140,20 +130,18 @@ class Grouper(ABC): concat_sorted = concat.sort_index() return concat_sorted - def filter_chains(self, path: str) -> t.Dict[str, Chainer]: + def filter_positions(self, path: str) -> t.Dict[str, Chainer]: """Filter chains to those whose data is available in the h5 file.""" - good_chains = { - k: v - for k, v in self.chainers.items() - if path in [*v.available, *v.common_chains] + good_positions = { + k: v for k, v in self.chainers.items() if path in [*v.available] } - nchains_dif = len(self.chainers) - len(good_chains) - if nchains_dif: + no_positions_dif = len(self.chainers) - len(good_positions) + if no_positions_dif: print( - f"Grouper:Warning: {nchains_dif} chains do not contain" - f" channel {path}" + f"Grouper: Warning: {no_positions_dif} positions do not contain" + f" {path}." ) - return good_chains + return good_positions def pool_function( self, @@ -163,8 +151,11 @@ class Grouper(ABC): chainers: t.Dict[str, Chainer] = None, **kwargs, ): - """Enable different threads for independent chains, particularly - useful when aggregating multiple elements.""" + """ + Enable different threads for independent chains. + + Particularly useful when aggregating multiple elements. + """ chainers = chainers or self.chainers if pool: with Pool(pool) as p: @@ -192,29 +183,12 @@ class Grouper(ABC): return records @property - def nmembers(self) -> t.Dict[str, int]: - """Get the number of positions belonging to each group.""" - return Counter(self.positions_groups.values()) - - @property - def ntiles(self): + def no_tiles(self): """Get total number of tiles per position (h5 file).""" for pos, s in self.chainers.items(): with h5py.File(s.filename, "r") as f: print(pos, f["/trap_info/trap_locations"].shape[0]) - @property - def ntiles_by_group(self) -> t.Dict[str, int]: - """Get total number of tiles per group.""" - ntiles = {} - for pos, s in self.chainers.items(): - with h5py.File(s.filename, "r") as f: - ntiles[pos] = f["/trap_info/trap_locations"].shape[0] - ntiles_by_group = {k: 0 for k in self.groups} - for posname, vals in ntiles.items(): - ntiles_by_group[self.positions_groups[posname]] += vals - return ntiles_by_group - @property def tilelocs(self) -> t.Dict[str, np.ndarray]: """Get the locations of the tiles for each position as a dictionary.""" @@ -224,17 +198,7 @@ class Grouper(ABC): d[pos] = f["/trap_info/trap_locations"][()] return d - @property - def groups(self) -> t.Tuple[str]: - """Get groups, sorted alphabetically.""" - return tuple(sorted(set(self.positions_groups.values()))) - - @property - def positions(self) -> t.Tuple[str]: - """Get positions, sorted alphabetically.""" - return tuple(sorted(set(self.positions_groups.keys()))) - - def ncells( + def no_cells( self, path="extraction/general/None/area", mode="retained", @@ -249,46 +213,52 @@ class Grouper(ABC): ) @property - def nretained(self) -> t.Dict[str, int]: + def no_retained(self) -> t.Dict[str, int]: """Get number of cells retained per position in base channel as a dictionary.""" - return self.ncells() + return self.no_cells() @property def channels(self): - """Get unique channels for all chains as a set.""" + """Get channels available over all positions as a set.""" return set( [ channel - for chainer in self.chainers.values() - for channel in chainer.channels + for position in self.chainers.values() + for channel in position.channels ] ) @property - def stages_span(self): - # TODO: fails on my example - return self.fsignal.stages_span + def tinterval(self): + """Get interval between time points in seconds.""" + return self.first_signal.tinterval @property - def max_span(self): - # TODO: fails on my example - return self.fsignal.max_span + def no_members(self) -> t.Dict[str, int]: + """Get the number of positions belonging to each group.""" + return Counter(self.positions_groups.values()) @property - def stages(self): - # TODO: fails on my example - return self.fsignal.stages + def no_tiles_by_group(self) -> t.Dict[str, int]: + """Get total number of tiles per group.""" + no_tiles = {} + for pos, s in self.chainers.items(): + with h5py.File(s.filename, "r") as f: + no_tiles[pos] = f["/trap_info/trap_locations"].shape[0] + no_tiles_by_group = {k: 0 for k in self.groups} + for posname, vals in no_tiles.items(): + no_tiles_by_group[self.positions_groups[posname]] += vals + return no_tiles_by_group @property - def tinterval(self): - """Get interval between time points.""" - return self.fsignal.tinterval - - -class MetaGrouper(Grouper): - """Group positions using metadata's 'group' number.""" + def groups(self) -> t.Tuple[str]: + """Get groups, sorted alphabetically, as a tuple.""" + return tuple(sorted(set(self.positions_groups.values()))) - pass + @property + def positions(self) -> t.Tuple[str]: + """Get positions, sorted alphabetically, as a tuple.""" + return tuple(sorted(set(self.positions_groups.keys()))) class NameGrouper(Grouper): @@ -299,75 +269,6 @@ class NameGrouper(Grouper): super().__init__(dir=dir) self.name_inds = name_inds - @property - def positions_groups(self) -> t.Dict[str, str]: - """Get a dictionary with the positions as keys and groups as items.""" - if not hasattr(self, "_positions_groups"): - self._positions_groups = {} - for name in self.chainers.keys(): - self._positions_groups[name] = name[ - self.name_inds[0] : self.name_inds[1] - ] - return self._positions_groups - - -class phGrouper(NameGrouper): - """Grouper for pH calibration experiments where all surveyed media pH values are within a single experiment.""" - - def __init__(self, dir, name_inds=(3, 7)): - """Initialise via NameGrouper.""" - super().__init__(dir=dir, name_inds=name_inds) - - def get_ph(self) -> None: - """Find the pH from the group names and store as a dictionary.""" - self.ph = {gn: self.ph_from_group(gn) for gn in self.positions_groups} - - @staticmethod - def ph_from_group(group_name: str) -> float: - """Find the pH from the name of a group.""" - if group_name.startswith("ph_") or group_name.startswith("pH_"): - group_name = group_name[3:] - return float(group_name.replace("_", ".")) - - def aggregate_multichains(self, signals: list) -> pd.DataFrame: - """Get data from a list of signals and combine into one multi-index dataframe with 'media-pH' included.""" - aggregated = pd.concat( - [ - self.concat_signal(signal, reduce_cols=np.nanmean) - for signal in signals - ], - axis=1, - ) - ph = pd.Series( - [ - self.ph_from_group( - x[list(aggregated.index.names).index("group")] - ) - for x in aggregated.index - ], - index=aggregated.index, - name="media_pH", - ) - aggregated = pd.concat((aggregated, ph), axis=1) - return aggregated - - -def concat_standard( - path: str, - chainer: Chainer, - group: str, - position: t.Optional[str] = None, - **kwargs, -) -> pd.DataFrame: - combined = chainer.get(path, **kwargs).copy() - combined["position"] = position - combined["group"] = group - combined.set_index(["group", "position"], inplace=True, append=True) - combined.index = combined.index.reorder_levels( - ("group", "position", "trap", "cell_label", "mother_label") - ) - return combined - def concat_one_signal( path: str, @@ -406,135 +307,3 @@ def concat_one_signal( combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) # should there be an error message if None is returned? return combined - - -class MultiGrouper: - """Wrap results from multiple experiments stored as folders inside a - folder.""" - - def __init__(self, source: Union[str, list]): - """ - Create NameGroupers for each experiment. - - Parameters - ---------- - source: list of str - List of folders, one per experiment, containing h5 files. - """ - if isinstance(source, str): - source = Path(source) - self.exp_dirs = list(source.glob("*")) - else: - self.exp_dirs = [Path(x) for x in source] - self.groupers = [NameGrouper(d) for d in self.exp_dirs] - for group in self.groupers: - group.load_chains() - - @property - def available(self) -> None: - """Print available signals and number of chains, one per position, for each Grouper.""" - for gpr in self.groupers: - print(gpr.available_grouped) - - @property - def sigtable(self) -> pd.DataFrame: - """Generate a table showing the number of positions, or h5 files, available for each signal with one column per experiment.""" - - def regex_cleanup(x): - x = re.sub(r"extraction\/", "", x) - x = re.sub(r"postprocessing\/", "", x) - x = re.sub(r"\/max", "", x) - return x - - if not hasattr(self, "_sigtable"): - raw_mat = [ - [s.available for s in gpr.chainers.values()] - for gpr in self.groupers - ] - available_grouped = [ - Counter([x for y in grp for x in y]) for grp in raw_mat - ] - nexps = len(available_grouped) - sigs_idx = list( - set([y for x in available_grouped for y in x.keys()]) - ) - sigs_idx = [regex_cleanup(x) for x in sigs_idx] - nsigs = len(sigs_idx) - sig_matrix = np.zeros((nsigs, nexps)) - for i, c in enumerate(available_grouped): - for k, v in c.items(): - sig_matrix[sigs_idx.index(regex_cleanup(k)), i] = v - sig_matrix[sig_matrix == 0] = np.nan - self._sigtable = pd.DataFrame( - sig_matrix, - index=sigs_idx, - columns=[x.name for x in self.exp_dirs], - ) - return self._sigtable - - def _sigtable_plot(self) -> None: - """ - Plot number of chains for all available experiments. - - Examples - -------- - FIXME: Add docs. - """ - ax = sns.heatmap(self.sigtable, cmap="viridis") - ax.set_xticklabels( - ax.get_xticklabels(), - rotation=10, - ha="right", - rotation_mode="anchor", - ) - plt.show() - - def aggregate_signal( - self, - path: Union[str, list], - **kwargs, - ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: - """ - Aggregate chains, one per position, from multiple Groupers, one per experiment. - - Parameters - ---------- - path : Union[str, list] - String or list of strings indicating the signal(s) to fetch. - **kwargs : - Passed to Grouper.concat_signal. - - Returns - ------- - concatenated: Union[pd.DataFrame, Dict[str, pd.DataFrame]] - A multi-index dataFrame or a dictionary of multi-index dataframes, one per signal - - Examples - -------- - >>> mg = MultiGrouper(["pHCalibrate7_24", "pHCalibrate6_7"]) - >>> p405 = mg.aggregate_signal("extraction/pHluorin405_0_4/max/median") - >>> p588 = mg.aggregate_signal("extraction/pHluorin488_0_4/max/median") - >>> ratio = p405 / p488 - """ - if isinstance(path, str): - path = [path] - sigs = {s: [] for s in path} - for s in path: - for grp in self.groupers: - try: - sigset = grp.concat_signal(s, **kwargs) - new_idx = pd.MultiIndex.from_tuples( - [(grp.name, *x) for x in sigset.index], - names=("experiment", *sigset.index.names), - ) - sigset.index = new_idx - sigs[s].append(sigset) - except Exception as e: - print("Grouper {} failed: {}".format(grp.name, e)) - concatenated = { - name: pd.concat(multiexp_sig) - for name, multiexp_sig in sigs.items() - } - if len(concatenated) == 1: - concatenated = list(concatenated.values())[0] - return concatenated