diff --git a/src/agora/io/decorators.py b/src/agora/io/decorators.py index f0eea7e0e363e8a9d7075b6b3a6e236e08fccf31..f4d8d023cae59c45c736142c63b66ef8672441c8 100644 --- a/src/agora/io/decorators.py +++ b/src/agora/io/decorators.py @@ -1,6 +1,6 @@ #!/usr/bin/env jupyter """ -Convenience decorators that extend commonly-used methods or functions. +Convenience decorators to extend commonly-used methods or functions. """ import typing as t from functools import wraps @@ -9,14 +9,14 @@ from functools import wraps def _first_arg_str_to_df( fn: t.Callable, ): - """Ensures Signal-like classes convert strings to datasets when calling them""" - + """Enable Signal-like classes to convert strings to data sets.""" @wraps(fn) def format_input(*args, **kwargs): cls = args[0] data = args[1] if isinstance(data, str): + # get data from h5 file data = cls.get_raw(data) + # replace path in the undecorated function with data return fn(cls, data, *args[2:], **kwargs) - return format_input diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 0210ef3427552c8d7fa62a9f366976459d112e51..dc951d87565acdadc2bcd58dda964897993009d5 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -10,23 +10,21 @@ import pandas as pd from agora.io.bridge import BridgeH5 from agora.io.decorators import _first_arg_str_to_df -from agora.utils.merge import apply_merges from agora.utils.association import validate_association from agora.utils.kymograph import add_index_levels +from agora.utils.merge import apply_merges class Signal(BridgeH5): """ - Class that fetches data from the hdf5 storage for post-processing + Fetch data from h5 files for post-processing. - Signal is works under the assumption that metadata and data are - accessible, to perform time-adjustments and apply previously-recorded - postprocesses. + Signal assumes that the metadata and data are accessible to perform time-adjustments and apply previously recorded post-processes. """ def __init__(self, file: t.Union[str, PosixPath]): + """Define index_names for dataframes, candidate fluorescence channels, and composite statistics.""" super().__init__(file, flag=None) - self.index_names = ( "experiment", "position", @@ -34,7 +32,6 @@ class Signal(BridgeH5): "cell_label", "mother_label", ) - self.candidate_channels = ( "GFP", "GFPFast", @@ -45,40 +42,44 @@ class Signal(BridgeH5): "Cy5", "pHluorin405", ) - + # Alan: why "equivalences"? this variable is unused. equivalences = { "m5m": ("extraction/GFP/max/max5px", "extraction/GFP/max/median") } def __getitem__(self, dsets: t.Union[str, t.Collection]): - - if isinstance( - dsets, str - ): # or isinstance(Dsets,dsets.endswith("imBackground"): + """Get and potentially pre-process data from h5 file and return as a dataframe.""" + if isinstance(dsets, str): + # no pre-processing df = self.get_raw(dsets) - - # elif isinstance(dsets, str): - # df = self.apply_prepost(dsets) - + return self.add_name(df, dsets) elif isinstance(dsets, list): + # pre-processing is_bgd = [dset.endswith("imBackground") for dset in dsets] + # Alan: what does this error message mean? assert sum(is_bgd) == 0 or sum(is_bgd) == len( dsets - ), "Trap data and cell data can't be mixed" + ), "Tile data and cell data can't be mixed" return [ self.add_name(self.apply_prepost(dset), dset) for dset in dsets ] else: raise Exception(f"Invalid type {type(dsets)} to get datasets") - # return self.cols_in_mins(self.add_name(df, dsets)) - return self.add_name(df, dsets) - @staticmethod def add_name(df, name): + """Add column of identical strings to a dataframe.""" df.name = name return df + # def cols_in_mins_old(self, df: pd.DataFrame): + # """Convert numerical columns in a dataframe to minutes.""" + # try: + # df.columns = (df.columns * self.tinterval // 60).astype(int) + # except Exception as e: + # print(f"Warning:Signal: Unable to convert columns to minutes: {e}") + # return df + def cols_in_mins(self, df: pd.DataFrame): # Convert numerical columns in a dataframe to minutes try: @@ -89,57 +90,68 @@ class Signal(BridgeH5): @cached_property def ntimepoints(self): + """Find the number of time points for one position, or one h5 file.""" with h5py.File(self.filename, "r") as f: return f["extraction/general/None/area/timepoint"][-1] + 1 @cached_property def tinterval(self) -> int: + """Find the interval between time points (minutes).""" tinterval_location = "time_settings/timeinterval" with h5py.File(self.filename, "r") as f: - return f.attrs[tinterval_location][0] + if tinterval_location in f: + return f.attrs[tinterval_location][0] + else: + print( + f"{str(self.filename).split('/')[-1]}: using default time interval of 5 minutes" + ) + return 5 @staticmethod def get_retained(df, cutoff): + """Return a fraction of the df, one without later time points.""" return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff] @property - def channels(self): + def channels(self) -> t.Collection[str]: + """Get channels as an array of strings.""" with h5py.File(self.filename, "r") as f: - return f.attrs["channels"] + return list(f.attrs["channels"]) @_first_arg_str_to_df def retained(self, signal, cutoff=0.8): + """ + Load data (via decorator) and reduce the resulting dataframe. - df = signal - # df = self[signal] - if isinstance(df, pd.DataFrame): - return self.get_retained(df, cutoff) - - elif isinstance(df, list): - return [self.get_retained(d, cutoff=cutoff) for d in df] + Load data for a signal or a list of signals and reduce the resulting + dataframes to a fraction of their original size, losing late time + points. + """ + if isinstance(signal, pd.DataFrame): + return self.get_retained(signal, cutoff) + elif isinstance(signal, list): + return [self.get_retained(d, cutoff=cutoff) for d in signal] @lru_cache(2) def lineage( self, lineage_location: t.Optional[str] = None, merged: bool = False ) -> np.ndarray: """ - Return lineage data from a given location as a matrix where - the first column is the trap id, - the second column is the mother label and - the third column is the daughter label. + Get lineage data from a given location in the h5 file. + + Returns an array with three columns: the tile id, the mother label, and the daughter label. """ if lineage_location is None: lineage_location = "postprocessing/lineage" if merged: lineage_location += "_merged" - with h5py.File(self.filename, "r") as f: - trap_mo_da = f[lineage_location] + tile_mo_da = f[lineage_location] lineage = np.array( ( - trap_mo_da["trap"], - trap_mo_da["mother_label"], - trap_mo_da["daughter_label"], + tile_mo_da["trap"], + tile_mo_da["mother_label"], + tile_mo_da["daughter_label"], ) ).T return lineage @@ -151,22 +163,20 @@ class Signal(BridgeH5): merges: t.Union[np.ndarray, bool] = True, picks: t.Union[t.Collection, bool] = True, ): - """Apply modifier operations (picker, merger) to a given dataframe. - + """ + Apply modifier operations (picker or merger) to a dataframe. Parameters ---------- data : t.Union[str, pd.DataFrame] - DataFrame or url to one. + DataFrame or path to one. merges : t.Union[np.ndarray, bool] - (optional) 2-D array with three columns and variable length. The - first column is the trap id, second is mother label and third one is - daughter id. - If it is True it fetches merges from file, if false it skips merging step. + (optional) 2-D array with three columns: the tile id, the mother label, and the daughter id. + If True, fetch merges from file. picks : t.Union[np.ndarray, bool] - (optional) 2-D ndarray where first column is traps and second column - is cell labels. - If it is True it fetches picks from file, if false it skips picking step. + (optional) 2-D array with two columns: the tiles and + the cell labels. + If True, fetch picks from file. Examples -------- @@ -175,31 +185,24 @@ class Signal(BridgeH5): """ if isinstance(merges, bool): merges: np.ndarray = self.get_merges() if merges else np.array([]) - - merged = copy(data) - if merges.any(): merged = apply_merges(data, merges) - + else: + merged = copy(data) if isinstance(picks, bool): picks = ( self.get_picks(names=merged.index.names) if picks else set(merged.index) ) - with h5py.File(self.filename, "r") as f: if "modifiers/picks" in f and picks: - # missing_cells = [i for i in picks if tuple(i) not in - # set(merged.index)] - if picks: return merged.loc[ set(picks).intersection( [tuple(x) for x in merged.index] ) ] - else: if isinstance(merged.index, pd.MultiIndex): empty_lvls = [[] for i in merged.index.names] @@ -213,59 +216,75 @@ class Signal(BridgeH5): merged = pd.DataFrame([], index=index) return merged + # Alan: do we need two similar properties - see below? @property def datasets(self): + """Print data sets available in h5 file.""" if not hasattr(self, "_available"): self._available = [] - with h5py.File(self.filename, "r") as f: - f.visititems(self.store_signal_url) - + f.visititems(self.store_signal_path) for sig in self._available: print(sig) @cached_property def p_available(self): - """Print signal list""" + """Print data sets available in h5 file.""" self.datasets @cached_property def available(self): - """Return list of available signals""" + """Get data sets available in h5 file.""" try: if not hasattr(self, "_available"): self._available = [] - with h5py.File(self.filename, "r") as f: - f.visititems(self.store_signal_url) - + f.visititems(self.store_signal_path) except Exception as e: self._log("Exception when visiting h5: {}".format(e), "exception") return self._available def get_merged(self, dataset): + """Run preprocessing for merges.""" return self.apply_prepost(dataset, picks=False) @cached_property - def merges(self): + def merges(self) -> np.ndarray: + """Get merges.""" with h5py.File(self.filename, "r") as f: dsets = f.visititems(self._if_merges) return dsets @cached_property def n_merges(self): + """Get number of merges.""" return len(self.merges) @cached_property - def picks(self): + def picks(self) -> np.ndarray: + """Get picks.""" with h5py.File(self.filename, "r") as f: dsets = f.visititems(self._if_picks) return dsets def get_raw( - self, dataset: str, in_minutes: bool = True, lineage: bool = False - ): + self, + dataset: str, + in_minutes: bool = True, + lineage: bool = False, + ) -> pd.DataFrame: + """ + Load data from a h5 file and return as a dataframe. + + Parameters + ---------- + dataset: str or list of strs + The name of the h5 file or a list of h5 file names + in_minutes: boolean + If True, + lineage: boolean + """ try: if isinstance(dataset, str): with h5py.File(self.filename, "r") as f: @@ -273,9 +292,10 @@ class Signal(BridgeH5): if in_minutes: df = self.cols_in_mins(df) elif isinstance(dataset, list): + # Alan: no mother_labels in this case? return [self.get_raw(dset) for dset in dataset] - - if lineage: # This assumes that df is sorted + if lineage: + # assumes that df is sorted mother_label = np.zeros(len(df), dtype=int) lineage = self.lineage() a, b = validate_association( @@ -285,20 +305,17 @@ class Signal(BridgeH5): ) mother_label[b] = lineage[a, 1] df = add_index_levels(df, {"mother_label": mother_label}) - return df - except Exception as e: self._log(f"Could not fetch dataset {dataset}", "error") raise e def get_merges(self): - # fetch merge events going up to the first level + """Get merge events going up to the first level.""" with h5py.File(self.filename, "r") as f: merges = f.get("modifiers/merges", np.array([])) if not isinstance(merges, np.ndarray): merges = merges[()] - return merges def get_picks( @@ -306,58 +323,42 @@ class Signal(BridgeH5): names: t.Tuple[str, ...] = ("trap", "cell_label"), path: str = "modifiers/picks/", ) -> t.Set[t.Tuple[int, str]]: - """ - Return the relevant picks based on names - """ + """Get the relevant picks based on names.""" with h5py.File(self.filename, "r") as f: picks = set() if path in f: picks = set(zip(*[f[path + name] for name in names])) - return picks def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: - """ - Fetch DataFrame from results storage file. - """ - + """Get data from h5 file as a dataframe.""" assert path in f, f"{path} not in {f}" - dset = f[path] - - values, index, columns = ([], [], []) - + values, index, columns = [], [], [] index_names = copy(self.index_names) valid_names = [lbl for lbl in index_names if lbl in dset.keys()] if valid_names: - index = pd.MultiIndex.from_arrays( [dset[lbl] for lbl in valid_names], names=valid_names ) - - columns = dset.attrs.get("columns", None) # dset.attrs["columns"] + columns = dset.attrs.get("columns", None) if "timepoint" in dset: columns = f[path + "/timepoint"][()] - values = f[path + "/values"][()] - - return pd.DataFrame( - values, - index=index, - columns=columns, - ) + df = pd.DataFrame(values, index=index, columns=columns) + return df @property def stem(self): + """Get name of h5 file.""" return self.filename.stem - def store_signal_url( - self, fullname: str, node: t.Union[h5py.Dataset, h5py.Group] + def store_signal_path( + self, + fullname: str, + node: t.Union[h5py.Dataset, h5py.Group], ): - """ - Store the name of a signal it is a leaf node (a group with no more groups inside) - and starts with extraction - """ + """Store the name of a signal if it is a leaf node (a group with no more groups inside) and if it starts with extraction.""" if isinstance(node, h5py.Group) and np.all( [isinstance(x, h5py.Dataset) for x in node.values()] ): @@ -383,19 +384,15 @@ class Signal(BridgeH5): # TODO FUTURE add stages support to fluigent system @property def ntps(self) -> int: - # Return number of time-points according to the metadata + """Get number of time points from the metadata.""" return self.meta_h5["time_settings/ntimepoints"][0] @property def stages(self) -> t.List[str]: - """ - Return the contents of the pump with highest flow rate - at each stage. - """ + """Get the contents of the pump with highest flow rate at each stage.""" flowrate_name = "pumpinit/flowrate" pumprate_name = "pumprate" switchtimes_name = "switchtimes" - main_pump_id = np.concatenate( ( (np.argmax(self.meta_h5[flowrate_name]),), @@ -418,7 +415,6 @@ class Signal(BridgeH5): def switch_times(self) -> t.List[int]: switchtimes_name = "switchtimes" switches_minutes = self.meta_h5[switchtimes_name] - return [ t_min for t_min in switches_minutes @@ -427,7 +423,7 @@ class Signal(BridgeH5): @property def stages_span(self) -> t.Tuple[t.Tuple[str, int], ...]: - # Return consecutive stages and their corresponding number of time-points + """Get consecutive stages and their corresponding number of time points.""" transition_tps = (0, *self.switch_times, self.max_span) spans = [ end - start diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py index e6af6f647f5c2575b9b235086c0b78b11417f324..71411e1dab85442fd986dd9c925a11f50785e3c2 100644 --- a/src/agora/utils/kymograph.py +++ b/src/agora/utils/kymograph.py @@ -23,23 +23,21 @@ def add_index_levels( def drop_level( - df: pd.DataFrame, name: str = "mother_label", as_list: bool = True + df: pd.DataFrame, + name: str = "mother_label", + as_list: bool = True, ) -> t.Union[t.List[index_row], pd.Index]: - """Drop index level + """ + Drop index level. Parameters ---------- df : pd.DataFrame - dataframe whose multiindex we will drop + Dataframe whose multiindex we will drop name : str - name of index level to drop + Mame of index level to drop as_list : bool Whether to return as a list instead of an index - - Examples - -------- - FIXME: Add docs. - """ short_index = df.index.droplevel(name) if as_list: @@ -50,22 +48,17 @@ def drop_level( def intersection_matrix( index1: pd.MultiIndex, index2: pd.MultiIndex ) -> np.ndarray: - """ - Use casting to obtain the boolean mask of the intersection of two multiindices - """ + """Use casting to obtain the boolean mask of the intersection of two multi-indices.""" indices = [index1, index2] for i in range(2): if hasattr(indices[i], "to_list"): indices[i]: t.List = indices[i].to_list() indices[i]: np.ndarray = np.array(indices[i]) - return (indices[0][..., None] == indices[1].T).all(axis=1) def get_mother_ilocs_from_daughters(df: pd.DataFrame) -> np.ndarray: - """ - Fetch mother locations in the index of df for all daughters in df. - """ + """Fetch mother locations in the index of df for all daughters in df.""" daughter_ids = df.index[df.index.get_level_values("mother_label") > 0] mother_ilocs = intersection_matrix( daughter_ids.droplevel("cell_label"), @@ -86,36 +79,41 @@ def get_mothers_from_another_df(whole_df: pd.DataFrame, da_df: pd.DataFrame): def bidirectional_retainment_filter( - df: pd.DataFrame, mothers_thresh: float = 0.8, daughters_thresh: int = 7 -): + df: pd.DataFrame, + mothers_thresh: float = 0.8, + daughters_thresh: int = 7, +) -> pd.DataFrame: """ Retrieve families where mothers are present for more than a fraction of the experiment, and daughters for longer than some number of time-points. + + Parameters + ---------- + df: pd.DataFrame + Data + mothers_thresh: float + Minimum fraction of experiment's total duration for which mothers must be present. + daughters_thresh: int + Minimum number of time points for which daughters must be observed """ + # daughters all_daughters = df.loc[df.index.get_level_values("mother_label") > 0] - - # Filter daughters + # keep daughters observed sufficiently often retained_daughters = all_daughters.loc[ all_daughters.notna().sum(axis=1) > daughters_thresh ] - - # Fectch mother using existing daughters + # fetch mother using existing daughters mothers = df.loc[get_mothers_from_another_df(df, retained_daughters)] - - # Get mothers + # keep mothers present for at least a fraction of the experiment's duration retained_mothers = mothers.loc[ mothers.notna().sum(axis=1) > mothers.shape[1] * mothers_thresh ] - - # Filter-out daughters with no valid mothers + # drop daughters with no valid mothers final_da_mask = intersection_matrix( drop_level(retained_daughters, "cell_label", as_list=False), drop_level(retained_mothers, "mother_label", as_list=False), ) - final_daughters = retained_daughters.loc[final_da_mask.any(axis=1)] - - # Join mothers and daughters and sort index - # + # join mothers and daughters and sort index return pd.concat((final_daughters, retained_mothers), axis=0).sort_index() diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index 36d03cb8c814de84e1fec6cc8a4e667733abf980..eee3185a613eb19af8faa5b88566d76d1f197d53 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -38,7 +38,7 @@ RED_FUNS = load_redfuns() class ExtractorParameters(ParametersABC): """ - Base class to define parameters for extraction + Base class to define parameters for extraction. """ def __init__( @@ -54,7 +54,7 @@ class ExtractorParameters(ParametersABC): Nested dictionary indicating channels, reduction functions and metrics to be used. str channel -> U(function,None) reduction -> str metric - If not of depth three, tree will be filled with Nones. + If not of depth three, tree will be filled with None. sub_bg: set multichannel_ops: dict """ @@ -65,7 +65,7 @@ class ExtractorParameters(ParametersABC): @staticmethod def guess_from_meta(store_name: str, suffix="fast"): """ - Find the microscope used from the h5 metadata + Find the microscope used from the h5 metadata. Parameters ---------- @@ -90,25 +90,28 @@ class ExtractorParameters(ParametersABC): class Extractor(StepABC): """ - The Extractor applies a metric, such as area or median, to cells identified in the image tiles using the cell masks. + Apply a metric to cells identified in the tiles. - Its methods therefore require both tile images and masks. + Using the cell masks, the Extractor applies a metric, such as area or median, to cells identified in the image tiles. - Usually one metric is applied to the masked area in a tile, but there are metrics that depend on the whole tile. + Its methods require both tile images and masks. - Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the second level is the reduction algorithm, such as maximum projection; the last level is the metric - the specific operation to apply to the cells in the image identified by the mask, such as median, which is the median value of the pixels in each cell. + Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile. + + Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks is the third level. Parameters ---------- parameters: core.extractor Parameters - Parameters that include with channels, reduction and - extraction functions to use. + Parameters that include the channels, and reduction and + extraction functions. store: str - Path to hdf5 storage file. Must contain cell outlines. + Path to the h5 file, which must contain the cell masks. tiler: pipeline-core.core.segmentation tiler - Class that contains or fetches the image to be used for segmentation. + Class that contains or fetches the images used for segmentation. """ + # Alan: should this data be stored here or all such data in a separate file default_meta = { "pixel_size": 0.236, "z_size": 0.6, @@ -149,7 +152,7 @@ class Extractor(StepABC): store: str, tiler: Tiler, ): - # initate from tiler + """Initiate from a tiler instance.""" return cls(parameters, store=store, tiler=tiler) @classmethod @@ -159,12 +162,12 @@ class Extractor(StepABC): store: str, img_meta: tuple, ): - # initiate from image + """Initiate from images.""" return cls(parameters, store=store, tiler=Tiler(*img_meta)) @property def channels(self): - # returns a tuple of strings of the available channels + """Get a tuple of the available channels.""" if not hasattr(self, "_channels"): if type(self.params.tree) is dict: self._channels = tuple(self.params.tree.keys()) @@ -225,7 +228,7 @@ class Extractor(StepABC): self._all_funs = {**self._custom_funs, **FUNS} def load_meta(self): - # load metadata from h5 file whose name is given by self.local + """Load metadata from h5 file.""" self.meta = load_attributes(self.local) def get_tiles( @@ -236,8 +239,9 @@ class Extractor(StepABC): **kwargs, ) -> t.Optional[np.ndarray]: """ - Finds traps for a given time point and given channels and z-stacks. - Returns None if no traps are found. + Find tiles for a given time point and given channels and z-stacks. + + Returns None if no tiles are found. Any additional keyword arguments are passed to tiler.get_tiles_timepoint @@ -377,7 +381,6 @@ class Extractor(StepABC): self.reduce_dims(trap, method=RED_FUNS[red_fun]) for trap in traps ] - d = { red_fun: self.extract_funs( metrics=metrics, @@ -394,6 +397,7 @@ class Extractor(StepABC): ) -> np.ndarray: """ Collapse a z-stack into 2d array using method. + If method is None, return the original data. Parameters @@ -418,7 +422,7 @@ class Extractor(StepABC): **kwargs, ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]: """ - Core extraction method for an individual time-point. + Extract for an individual time-point. Parameters ---------- @@ -561,7 +565,7 @@ class Extractor(StepABC): def get_imgs(self, channel: t.Optional[str], traps, channels=None): """ - Returns the image from a correct source, either raw or bgsub + Return image from a correct source, either raw or bgsub. Parameters ---------- @@ -592,6 +596,8 @@ class Extractor(StepABC): **kwargs, ) -> dict: """ + Wrapper to add compatibility with other steps of the pipeline. + Parameters ---------- tps: list of int (optional) @@ -677,7 +683,7 @@ def flatten_nesteddict( nest: dict, to="series", tp: int = None ) -> t.Dict[str, pd.Series]: """ - Converts a nested extraction dict into a dict of pd.Series + Convert a nested extraction dict into a dict of pd.Series. Parameters ---------- @@ -706,6 +712,7 @@ def flatten_nesteddict( class hollowExtractor(Extractor): """ Extractor that only cares about receiving images and masks. + Used for testing. """ diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py index b9c43b0c1f0cedc63bd16296b58a32e25c45aa5f..a5387ccb1138357e54c46fb3e2cf6abadc5b6424 100644 --- a/src/postprocessor/chainer.py +++ b/src/postprocessor/chainer.py @@ -16,26 +16,33 @@ from postprocessor.core.lineageprocess import LineageProcessParameters class Chainer(Signal): """ - Class that extends signal by applying postprocesess. + Extend Signal by applying post-processes and allowing composite signals that combine basic signals. + Instead of reading processes previously applied, it executes them when called. """ - process_types = ("multisignal", "processes", "reshapers") - common_chains = {} + # these no longer seem to be used + #process_types = ("multisignal", "processes", "reshapers") + #common_chains = {} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for channel in self.candidate_channels: + # find first channel in h5 file that corresponds to a candidate_channel + # but channel is redefined. why is there a loop over candidate channels? + # what about capitals? try: channel = [ ch for ch in self.channels if re.match("channel", ch) ][0] break except: + # is this still a good idea? pass - try: + # what's this? + # composite statistic comprising the quotient of two others equivalences = { "m5m": ( f"extraction/{channel}/max/max5px", @@ -43,24 +50,27 @@ class Chainer(Signal): ), } - def replace_url(url: str, bgsub: str = ""): - # return pattern with bgsub - channel = url.split("/")[1] + # function to add bgsub to paths + def replace_path(path: str, bgsub: str = ""): + channel = path.split("/")[1] if "bgsub" in bgsub: - url = re.sub(channel, f"{channel}_bgsub", url) - return url + # add bgsub to path + path = re.sub(channel, f"{channel}_bgsub", path) + return path + # for composite statistics + # add chain with and without bgsub self.common_chains = { alias + bgsub: lambda **kwargs: self.get( replace_url(denominator, alias + bgsub), **kwargs ) - / self.get(replace_url(numerator, alias + bgsub), **kwargs) + / self.get(replace_path(numerator, alias + bgsub), **kwargs) for alias, (denominator, numerator) in equivalences.items() for bgsub in ("", "_bgsub") } - except: + # Is this still a good idea? pass def get( @@ -72,20 +82,22 @@ class Chainer(Signal): retain: t.Optional[float] = None, **kwargs, ): - if dataset in self.common_chains: # Produce dataset on the fly + """Load data from an h5 file.""" + 1/0 + if dataset in self.common_chains: + # get dataset for composite chains data = self.common_chains[dataset](**kwargs) else: + # use Signal's get_raw data = self.get_raw(dataset, in_minutes=in_minutes) if chain: data = self.apply_chain(data, chain, **kwargs) - if retain: - data = data.loc[data.notna().sum(axis=1) > data.shape[1] * retain] - - if ( - stages and "stage" not in data.columns.names - ): # Return stages as additional column level - + # keep data only from early time points + data = self.get_retained(data, retain) + # data = data.loc[data.notna().sum(axis=1) > data.shape[1] * retain] + if (stages and "stage" not in data.columns.names): + # return stages as additional column level stages_index = [ x for i, (name, span) in enumerate(self.stages_span_tp) @@ -95,25 +107,26 @@ class Chainer(Signal): zip(stages_index, data.columns), names=("stage", "time"), ) - return data def apply_chain( self, input_data: pd.DataFrame, chain: t.Tuple[str, ...], **kwargs ): - """Apply a series of processes to a dataset. + """ + Apply a series of processes to a data set. + + Like postprocessing, Chainer consecutively applies processes. - In a similar fashion to how postprocessing works, Chainer allows the - consecutive application of processes to a dataset. Parameters can be - passed as kwargs. It does not support the same process multiple times - with different parameters. + Parameters can be passed as kwargs. + + Chainer does not support applying the same process multiple times with different parameters. Parameters ---------- input_data : pd.DataFrame - Input data to iteratively process. + Input data to process. chain : t.Tuple[str, ...] - Tuple of strings with the name of processes. + Tuple of strings with the names of the processes **kwargs : kwargs Arguments passed on to Process.as_function() method to modify the parameters. @@ -138,6 +151,5 @@ class Chainer(Signal): raise (NotImplementedError) merges = process.as_function(result, **params) result = self.apply_merges(result, merges) - self._intermediate_steps.append(result) return result diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index ec04a6f5b64a783884807e783f641dc43622497b..09bff52e5c22b54f5b6346a17a3d7e1a6cafc0af 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -21,55 +21,56 @@ from postprocessor.chainer import Chainer class Grouper(ABC): """Base grouper class.""" - files = [] - def __init__(self, dir: Union[str, PosixPath]): + """Find h5 files and load a chain for each one.""" path = Path(dir) - self.name = path.name assert path.exists(), "Dir does not exist" + self.name = path.name self.files = list(path.glob("*.h5")) assert len(self.files), "No valid h5 files in dir" self.load_chains() def load_chains(self) -> None: - # Sets self.chainers + """Load a chain for each position, or h5 file.""" self.chainers = {f.name[:-3]: Chainer(f) for f in self.files} @property def fsignal(self) -> Chainer: - # Returns first signal + """Get first chain.""" return list(self.chainers.values())[0] @property def ntimepoints(self) -> int: + """Find number of time points.""" return max([s.ntimepoints for s in self.chainers.values()]) @property def tintervals(self) -> float: + """Find the maximum time interval for all chains.""" tintervals = set([s.tinterval / 60 for s in self.chainers.values()]) assert ( len(tintervals) == 1 ), "Not all chains have the same time interval" - return max(tintervals) @property - def available(self) -> None: + def available(self) -> t.Collection[str]: + """Generate list of available signals in the first chain.""" return self.fsignal.available @property def available_grouped(self) -> None: + """Display available signals and the number of chains for each.""" if not hasattr(self, "_available_grouped"): self._available_grouped = Counter( [x for s in self.chainers.values() for x in s.available] ) - for s, n in self._available_grouped.items(): print(f"{s} - {n}") @property def datasets(self) -> None: - """Print available datasets in first Chainer instance.""" + """Print available data sets in the first chain.""" return self.fsignal.datasets @abstractproperty @@ -84,113 +85,142 @@ class Grouper(ABC): standard: t.Optional[bool] = False, **kwargs, ): - """Concatenate multiple signals - - Parameters - ---------- - path : str - signal address within h5py file. - reduce_cols : bool - Whether or not to collapse columns into a single one. - axis : int - Concatenation axis. - pool : int - Number of threads used. If 0 or None only one core is used. - **kwargs : key, value pairings - Named arguments to pass to concat_ind_function. - - Examples - -------- - FIXME: Add docs. + """ + Concatenate data for one signal from different h5 files, with + one h5 file per position, into a dataframe. + + Parameters + ---------- + path : str + Signal location within h5py file + pool : int + Number of threads used; if 0 or None only one core is used + mode: str + standard: boolean + **kwargs : key, value pairings + Named arguments for concat_ind_function + + Examples + -------- + >>> record = grouper.concat_signal("extraction/GFP/max/median") """ if path.startswith("/"): path = path.strip("/") - - sitems = self.filter_path(path) + good_chains = self.filter_chains(path) if standard: fn_pos = concat_standard else: fn_pos = concat_signal_ind kwargs["mode"] = mode - - kymographs = self.pool_function( + records = self.pool_function( path=path, f=fn_pos, pool=pool, - chainers=sitems, + chainers=good_chains, **kwargs, ) - + # check for errors errors = [ - k - for kymo, k in zip(kymographs, self.chainers.keys()) - if kymo is None + k for kymo, k in zip(records, self.chainers.keys()) if kymo is None ] - kymographs = [kymo for kymo in kymographs if kymo is not None] + records = [record for record in records if record is not None] if len(errors): print("Warning: Positions contain errors {errors}") - - assert len(kymographs), "All datasets contain errors" - - concat = pd.concat(kymographs, axis=0) - - if ( - len(concat.index.names) > 4 - ): # Reorder levels when mother_label is present + assert len(records), "All data sets contain errors" + # combine into one dataframe + concat = pd.concat(records, axis=0) + if len(concat.index.names) > 4: + # reorder levels in the multi-index dataframe when mother_label is present concat = concat.reorder_levels( ("group", "position", "trap", "cell_label", "mother_label") ) - concat_sorted = concat.sort_index() - return concat_sorted - def filter_path(self, path: str) -> t.Dict[str, Chainer]: - # Check the path is in a given signal - sitems = { + def filter_chains(self, path: str) -> t.Dict[str, Chainer]: + """Filter chains to those whose data is available in the h5 file.""" + good_chains = { k: v for k, v in self.chainers.items() if path in [*v.available, *v.common_chains] } - nchains_dif = len(self.chainers) - len(sitems) + nchains_dif = len(self.chainers) - len(good_chains) if nchains_dif: print( f"Grouper:Warning: {nchains_dif} chains do not contain" f" channel {path}" ) - assert len( - sitems + good_chains ), f"No valid dataset to use. Valid datasets are {self.available}" + return good_chains - return sitems + def pool_function( + self, + path: str, + f: t.Callable, + pool: t.Optional[int] = None, + chainers: t.Dict[str, Chainer] = None, + **kwargs, + ): + """Enable different threads for independent chains, particularly useful when aggregating multiple elements.""" + if pool is None: + # Alan: why is None changed to 8 + # pool = 8 + pass + chainers = chainers or self.chainers + if pool: + with Pool(pool) as p: + records = p.map( + lambda x: f( + path=path, + chainer=x[1], + group=self.positions_groups[x[0]], + position=x[0], + **kwargs, + ), + chainers.items(), + ) + else: + records = [ + f( + path=path, + chainer=chainer, + group=self.positions_groups[name], + position=name, + **kwargs, + ) + for name, chainer in self.chainers.items() + ] + return records @property def nmembers(self) -> t.Dict[str, int]: - # Return the number of positions belonging to each group + """Get the number of positions belonging to each group.""" return Counter(self.positions_groups.values()) @property - def ntraps(self): + def ntiles(self): + """Get total number of tiles per position (h5 file).""" for pos, s in self.chainers.items(): with h5py.File(s.filename, "r") as f: print(pos, f["/trap_info/trap_locations"].shape[0]) @property - def ntraps_by_pos(self) -> t.Dict[str, int]: - # Return total number of traps grouped - ntraps = {} + def ntiles_by_group(self) -> t.Dict[str, int]: + """Get total number of tiles per group.""" + ntiles = {} for pos, s in self.chainers.items(): with h5py.File(s.filename, "r") as f: - ntraps[pos] = f["/trap_info/trap_locations"].shape[0] - - ntraps_by_pos = {k: 0 for k in self.groups} - for posname, vals in ntraps.items(): - ntraps_by_pos[self.positions_groups[posname]] += vals + ntiles[pos] = f["/trap_info/trap_locations"].shape[0] + ntiles_by_group = {k: 0 for k in self.groups} + for posname, vals in ntiles.items(): + ntiles_by_group[self.positions_groups[posname]] += vals + return ntiles_by_group - return ntraps_by_pos - - def traplocs(self): + @property + def tilelocs(self) -> t.Dict[str, np.ndarray]: + """Get the locations of the tiles for each position as a dictionary.""" d = {} for pos, s in self.chainers.items(): with h5py.File(s.filename, "r") as f: @@ -199,20 +229,21 @@ class Grouper(ABC): @property def groups(self) -> t.Tuple[str]: - # Return groups sorted alphabetically + """Get groups, sorted alphabetically.""" return tuple(sorted(set(self.positions_groups.values()))) @property def positions(self) -> t.Tuple[str]: - # Return positions sorted alphabetically + """Get positions, sorted alphabetically.""" return tuple(sorted(set(self.positions_groups.keys()))) def ncells( - self, path="extraction/general/None/area", mode="retained", **kwargs + self, + path="extraction/general/None/area", + mode="retained", + **kwargs, ) -> t.Dict[str, int]: - """ - Returns number of cells retained per position in base channel - """ + """Get number of cells retained per position in base channel as a dictionary.""" return ( self.concat_signal(path=path, mode=mode, **kwargs) .groupby("group") @@ -222,53 +253,12 @@ class Grouper(ABC): @property def nretained(self) -> t.Dict[str, int]: + """Get number of cells retained per position in base channel as a dictionary.""" return self.ncells() - def pool_function( - self, - path: str, - f: t.Callable, - pool: t.Optional[int] = None, - chainers: t.Dict[str, Chainer] = None, - **kwargs, - ): - """ - Wrapper to add support for threading to process independent chains. - Particularly useful when aggregating multiple elements. - """ - if pool is None: - pool = 8 - chainers = chainers or self.chainers - - if pool: - - with Pool(pool) as p: - kymographs = p.map( - lambda x: f( - path=path, - chainer=x[1], - group=self.positions_groups[x[0]], - position=x[0], - **kwargs, - ), - chainers.items(), - ) - else: - kymographs = [ - f( - path=path, - chainer=chainer, - group=self.positions_groups[name], - position=name, - **kwargs, - ) - for name, chainer in self.chainers.items() - ] - - return kymographs - @property def channels(self): + """Get unique channels for all chains as a set.""" return set( [ channel @@ -279,20 +269,24 @@ class Grouper(ABC): @property def stages_span(self): + # FAILS on my example return self.fsignal.stages_span @property def max_span(self): + # FAILS on my example return self.fsignal.max_span - @property - def tinterval(self): - return self.fsignal.tinterval - @property def stages(self): + # FAILS on my example return self.fsignal.stages + @property + def tinterval(self): + """Get interval between time points.""" + return self.fsignal.tinterval + class MetaGrouper(Grouper): """Group positions using metadata's 'group' number.""" @@ -301,51 +295,49 @@ class MetaGrouper(Grouper): class NameGrouper(Grouper): - """Group a set of positions using a subsection of the name.""" + """Group a set of positions with a shorter version of the group's name.""" - def __init__(self, dir, criteria=None): + def __init__(self, dir, name_inds=(0, -4)): + """Define the indices to slice names.""" super().__init__(dir=dir) - - if criteria is None: - criteria = (0, -4) - self.criteria = criteria + self.name_inds = name_inds @property def positions_groups(self) -> t.Dict[str, str]: + """Get a dictionary with the positions as keys and groups as items.""" if not hasattr(self, "_positions_groups"): self._positions_groups = {} for name in self.chainers.keys(): self._positions_groups[name] = name[ - self.criteria[0] : self.criteria[1] + self.name_inds[0] : self.name_inds[1] ] - return self._positions_groups class phGrouper(NameGrouper): - """Grouper for pH calibration experiments where all surveyed media pH - values are within a single experiment.""" + """Grouper for pH calibration experiments where all surveyed media pH values are within a single experiment.""" - def __init__(self, dir, criteria=(3, 7)): - super().__init__(dir=dir, criteria=criteria) + def __init__(self, dir, name_inds=(3, 7)): + """Initialise via NameGrouper.""" + super().__init__(dir=dir, name_inds=name_inds) def get_ph(self) -> None: + """Find the pH from the group names and store as a dictionary.""" self.ph = {gn: self.ph_from_group(gn) for gn in self.positions_groups} @staticmethod - def ph_from_group(group_name: int) -> float: - if group_name.startswith("ph_"): + def ph_from_group(group_name: str) -> float: + """Find the pH from the name of a group.""" + if group_name.startswith("ph_") or group_name.startswith("pH_"): group_name = group_name[3:] - return float(group_name.replace("_", ".")) - def aggregate_multichains(self, paths: list) -> pd.DataFrame: - """Accumulate multiple chains.""" - + def aggregate_multichains(self, signals: list) -> pd.DataFrame: + """Get data from a list of signals and combine into one multi-index dataframe with 'media-pH' included.""" aggregated = pd.concat( [ - self.concat_signal(path, reduce_cols=np.nanmean) - for path in paths + self.concat_signal(signal, reduce_cols=np.nanmean) + for signal in signals ], axis=1, ) @@ -360,10 +352,10 @@ class phGrouper(NameGrouper): name="media_pH", ) aggregated = pd.concat((aggregated, ph), axis=1) - return aggregated +# Alan: why are these separate functions? def concat_standard( path: str, chainer: Chainer, @@ -371,7 +363,6 @@ def concat_standard( position: t.Optional[str] = None, **kwargs, ) -> pd.DataFrame: - combined = chainer.get(path, **kwargs).copy() combined["position"] = position combined["group"] = group @@ -379,10 +370,10 @@ def concat_standard( combined.index = combined.index.reorder_levels( ("group", "position", "trap", "cell_label", "mother_label") ) - return combined +# why _ind ? def concat_signal_ind( path: str, chainer: Chainer, @@ -391,33 +382,34 @@ def concat_signal_ind( position=None, **kwargs, ) -> pd.DataFrame: - """Core function that handles retrieval of an individual signal, applies - filtering if requested and adjusts indices.""" + """ + Retrieve an individual signal. + Applies filtering if requested and adjusts indices. + """ if position is None: + # name of h5 file position = chainer.stem - if mode == "retained": combined = chainer.retained(path, **kwargs) - elif mode == "raw": combined = chainer.get_raw(path, **kwargs) - elif mode == "daughters": combined = chainer.get_raw(path, **kwargs) combined = combined.loc[ combined.index.get_level_values("mother_label") > 0 ] - elif mode == "families": combined = chainer[path] - + else: + raise Exception(f"{mode} not recognised.") if combined is not None: + # adjust indices combined["position"] = position combined["group"] = group combined.set_index(["group", "position"], inplace=True, append=True) combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) - + # should there be an error message if None is returned? return combined @@ -426,6 +418,14 @@ class MultiGrouper: folder.""" def __init__(self, source: Union[str, list]): + """ + Create NameGroupers for each experiment. + + Parameters + ---------- + source: list of str + List of folders, one per experiment, containing h5 files. + """ if isinstance(source, str): source = Path(source) self.exp_dirs = list(source.glob("*")) @@ -437,43 +437,38 @@ class MultiGrouper: @property def available(self) -> None: + """Print available signals and number of chains, one per position, for each Grouper.""" for gpr in self.groupers: print(gpr.available_grouped) @property def sigtable(self) -> pd.DataFrame: - """Generate a matrix containing the number of datasets for each signal - and experiment.""" + """Generate a table showing the number of positions, or h5 files, available for each signal with one column per experiment.""" def regex_cleanup(x): x = re.sub(r"extraction\/", "", x) x = re.sub(r"postprocessing\/", "", x) x = re.sub(r"\/max", "", x) - return x if not hasattr(self, "_sigtable"): raw_mat = [ - [s.available for s in gpr.chains.values()] + [s.available for s in gpr.chainers.values()] for gpr in self.groupers ] available_grouped = [ Counter([x for y in grp for x in y]) for grp in raw_mat ] - nexps = len(available_grouped) sigs_idx = list( set([y for x in available_grouped for y in x.keys()]) ) sigs_idx = [regex_cleanup(x) for x in sigs_idx] - nsigs = len(sigs_idx) - sig_matrix = np.zeros((nsigs, nexps)) for i, c in enumerate(available_grouped): for k, v in c.items(): sig_matrix[sigs_idx.index(regex_cleanup(k)), i] = v - sig_matrix[sig_matrix == 0] = np.nan self._sigtable = pd.DataFrame( sig_matrix, @@ -482,8 +477,11 @@ class MultiGrouper: ) return self._sigtable + # Alan: function seems out of place + # seaborn is not in pyproject.toml def sigtable_plot(self) -> None: - """Plot number of chains for all available experiments. + """ + Plot number of chains for all available experiments. Examples -------- @@ -503,27 +501,30 @@ class MultiGrouper: path: Union[str, list], **kwargs, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: - """Aggregate chains from multiple Groupers (and thus experiments) + """ + Aggregate chains, one per position, from multiple Groupers, one per experiment. Parameters ---------- - chains : Union[str, list] - string or list of strings indicating the signal(s) to fetch. - **kwargs : keyword arguments to pass to Grouper.concat_signal - Customise the filters and format to fetch chains. + path : Union[str, list] + String or list of strings indicating the signal(s) to fetch. + **kwargs : + Passed to Grouper.concat_signal. Returns ------- - Union[pd.DataFrame, Dict[str, pd.DataFrame]] - DataFrame or list of DataFrames + concatenated: Union[pd.DataFrame, Dict[str, pd.DataFrame]] + A multi-index dataFrame or a dictionary of multi-index dataframes, one per signal Examples -------- - FIXME: Add docs. + >>> mg = MultiGrouper(["pHCalibrate7_24", "pHCalibrate6_7"]) + >>> p405 = mg.aggregate_signal("extraction/pHluorin405_0_4/max/median") + >>> p588 = mg.aggregate_signal("extraction/pHluorin488_0_4/max/median") + >>> ratio = p405 / p488 """ if isinstance(path, str): path = [path] - sigs = {s: [] for s in path} for s in path: for grp in self.groupers: @@ -533,18 +534,14 @@ class MultiGrouper: [(grp.name, *x) for x in sigset.index], names=("experiment", *sigset.index.names), ) - sigset.index = new_idx sigs[s].append(sigset) except Exception as e: print("Grouper {} failed: {}".format(grp.name, e)) - # raise (e) - - concated = { + concatenated = { name: pd.concat(multiexp_sig) for name, multiexp_sig in sigs.items() } - if len(concated) == 1: - concated = list(concated.values())[0] - - return concated + if len(concatenated) == 1: + concatenated = list(concatenated.values())[0] + return concatenated