import logging import typing as t from time import perf_counter from typing import Callable, Dict, List import h5py import numpy as np import pandas as pd from agora.abc import ParametersABC, ProcessABC from agora.io.cells import Cells from agora.io.writer import Writer, load_attributes from aliby.tile.tiler import Tiler from extraction.core.functions.defaults import exparams_from_meta from extraction.core.functions.distributors import reduce_z, trap_apply from extraction.core.functions.loaders import ( load_custom_args, load_funs, load_mergefuns, load_redfuns, ) from extraction.core.functions.utils import depth # Global parameters used to load functions that either analyse cells or their background. These global parameters both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions. CELL_FUNS, TRAPFUNS, FUNS = load_funs() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() RED_FUNS = load_redfuns() MERGE_FUNS = load_mergefuns() # Assign datatype depending on the metric used # m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte} class ExtractorParameters(ParametersABC): """ Base class to define parameters for extraction """ def __init__( self, tree: Dict[str, Dict[Callable, List[str]]] = None, sub_bg: set = set(), multichannel_ops: Dict = {}, ): """ Parameters ---------- tree: dict Nested dictionary indicating channels, reduction functions and metrics to be used. str channel -> U(function,None) reduction -> str metric If not of depth three, tree will be filled with Nones. sub_bg: set multichannel_ops: dict """ self.tree = fill_tree(tree) self.sub_bg = sub_bg self.multichannel_ops = multichannel_ops @staticmethod def guess_from_meta(store_name: str, suffix="fast"): """ Find the microscope used from the h5 metadata Parameters ---------- store_name : str or Path For a h5 file suffix : str Added at the end of the predicted parameter set """ with h5py.File(store_name, "r") as f: microscope = f["/"].attrs.get("microscope") assert microscope, "No metadata found" return "_".join((microscope, suffix)) @classmethod def default(cls): return cls({}) @classmethod def from_meta(cls, meta): return cls(**exparams_from_meta(meta)) class Extractor(ProcessABC): """ The Extractor applies a metric, such as area or median, to cells identified in the image tiles using the cell masks. Its methods therefore require both tile images and masks. Usually one metric is applied per mask, but there are tile-specific backgrounds (Alan), which apply one metric per tile. Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the second level is the reduction algorithm, such as maximum projection; the last level is the metric - the specific operation to apply to the cells in the image identified by the mask, such as median, which is the median value of the pixels in each cell. Parameters ---------- parameters: core.extractor Parameters Parameters that include with channels, reduction and extraction functions to use. store: str Path to hdf5 storage file. Must contain cell outlines. tiler: pipeline-core.core.segmentation tiler Class that contains or fetches the image to be used for segmentation. """ default_meta = { "pixel_size": 0.236, "z_size": 0.6, "spacing": 0.6, } def __init__( self, parameters: ExtractorParameters, store: str = None, tiler: Tiler = None, ): """ Initialise Extractor. Parameters ---------- parameters: ExtractorParameters object store: str Name of h5 file tiler: Tiler object """ self.params = parameters if store: self.local = store self.load_meta() else: # if no h5 file, use the parameters directly self.meta = {"channel": parameters.to_dict()["tree"].keys()} if tiler: self.tiler = tiler self.load_funs() @classmethod def from_tiler( cls, parameters: ExtractorParameters, store: str, tiler: Tiler, ): # initate from tiler return cls(parameters, store=store, tiler=tiler) @classmethod def from_img( cls, parameters: ExtractorParameters, store: str, img_meta: tuple, ): # initiate from image return cls(parameters, store=store, tiler=Tiler(*img_meta)) @property def channels(self): # returns a tuple of strings of the available channels if not hasattr(self, "_channels"): if type(self.params.tree) is dict: self._channels = tuple(self.params.tree.keys()) return self._channels @property def current_position(self): return self.local.split("/")[-1][:-3] @property def group(self): # returns path within h5 file if not hasattr(self, "_out_path"): self._group = "/extraction/" return self._group def load_custom_funs(self): """ Define any custom functions to be functions of cell_masks and trap_image only. Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance. """ # find functions specified in params.tree funs = set( [ fun for ch in self.params.tree.values() for red in ch.values() for fun in red ] ) # consider only those already loaded from CUSTOM_FUNS funs = funs.intersection(CUSTOM_FUNS.keys()) # find their arguments ARG_VALS = { k: {k2: self.get_meta(k2) for k2 in v} for k, v in CUSTOM_ARGS.items() } # define custom functions - those with extra arguments other than cell_masks and trap_image - as functions of two variables self._custom_funs = {} for k, f in CUSTOM_FUNS.items(): def tmp(f): # pass extra arguments to custom function return lambda cell_masks, trap_image: trap_apply( f, cell_masks, trap_image, **ARG_VALS.get(k, {}) ) self._custom_funs[k] = tmp(f) def load_funs(self): self.load_custom_funs() self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS) # merge the two dicts self._all_funs = {**self._custom_funs, **FUNS} def load_meta(self): # load metadata from h5 file whose name is given by self.local self.meta = load_attributes(self.local) def get_traps( self, tp: int, channels: list = None, z: list = None, **kwargs ) -> tuple: """ Finds traps for a given time point and given channels and z-stacks. Returns None if no traps are found. Any additional keyword arguments are passed to tiler.get_traps_timepoint Parameters ---------- tp: int Time point of interest channels: list of strings (optional) Channels of interest z: list of integers (optional) Indices for the z-stacks of interest """ if channels is None: # find channels from tiler channel_ids = list(range(len(self.tiler.channels))) elif len(channels): # a subset of channels was specified channel_ids = [self.tiler.get_channel_index(ch) for ch in channels] else: # oh oh channel_ids = None # a list of the indices of the z stacks if z is None: z = list(range(self.tiler.shape[-1])) # find the appropiate traps from tiler traps = ( self.tiler.get_traps_timepoint( tp, channels=channel_ids, z=z, **kwargs ) if channel_ids else None ) return traps def extract_traps( self, traps: List[np.array], masks: List[np.array], metric: str, labels: List[int] = None, ) -> dict: """ Apply a function for a whole position. :traps: List[np.array] list of images :masks: List[np.array] list of masks :metric:str metric to extract :labels: List[int] cell Labels to use as indices for output DataFrame :pos_info: bool Whether to add the position as index or not returns :d: Dictionary of dataframe """ if labels is None: raise Warning("No labels given. Sorting cells using index.") cell_fun = True if metric in self._all_cell_funs else False idx = [] results = [] for trap_id, (mask_set, trap, lbl_set) in enumerate( zip(masks, traps, labels.values()) ): if len(mask_set): # ignore empty traps result = self._all_funs[metric](mask_set, trap) if cell_fun: for lbl, val in zip(lbl_set, result): results.append(val) idx.append((trap_id, lbl)) else: results.append(result) idx.append(trap_id) return (tuple(results), tuple(idx)) def extract_funs( self, traps: List[np.array], masks: List[np.array], metrics: List[str], **kwargs, ) -> dict: """ Extract multiple metrics from a timepoint """ d = { metric: self.extract_traps( traps=traps, masks=masks, metric=metric, **kwargs ) for metric in metrics } return d def reduce_extract( self, traps: np.array, masks: list, red_metrics: dict, **kwargs, ) -> dict: """ Wrapper to apply reduction and then extraction. Parameters ---------- param red_metrics: dict dict for which keys are reduction funcions and values are strings indicating the metric function **kwargs: dict All other arguments and must include masks and traps. Returns ------ Dictionary of dataframes with the corresponding reductions and metrics nested. """ # create dict of traps with reduction in the z-direction reduced_traps = {} if traps is not None: for red_fun in red_metrics.keys(): reduced_traps[red_fun] = [ self.reduce_dims(trap, method=RED_FUNS[red_fun]) for trap in traps ] d = { red_fun: self.extract_funs( metrics=metrics, traps=reduced_traps.get(red_fun, [None for _ in masks]), masks=masks, **kwargs, ) for red_fun, metrics in red_metrics.items() } return d def reduce_dims(self, img: np.array, method=None) -> np.array: """ Collapse a z-stack into 2d array. It may perform a null operation. """ if method is None: return img else: return reduce_z(img, method) def extract_tp( self, tp: int, tree: dict = None, tile_size: int = 117, masks=None, labels=None, **kwargs, ) -> t.Dict[str, t.Dict[str, pd.Series]]: """ Core extraction method for an individual time-point. Parameters ---------- tp : int Time point being analysed. tree : dict Nested dictionary indicating channels, reduction functions and metrics to be used. tile_size : int size of the tile to be extracted. masks : np.ndarray A 3-D boolean numpy array with dimensions (ncells, tile_size, tile_size). labels : t.List[t.List[int]] List of lists of ints indicating the ids of masks. **kwargs : Additional keyword arguments to be passed to extractor.reduce_extract. Returns ------- dict """ if tree is None: # use default tree = self.params.tree # dictionary with channel: {reduction algorithm : metric} ch_tree = {ch: v for ch, v in tree.items() if ch != "general"} # tuple of the channels tree_chs = (*ch_tree,) # create a Cells object to extract information from the h5 file cells = Cells(self.local) # find the cell labels and store as dict with trap_ids as keys if labels is None: raw_labels = cells.labels_at_time(tp) labels = { trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps) } # find the cell masks as a dict with trap_ids as keys if masks is None: raw_masks = cells.at_time(tp, kind="mask") masks = {trap_id: [] for trap_id in range(cells.ntraps)} for trap_id, cells in raw_masks.items(): if len(cells): masks[trap_id] = np.dstack(np.array(cells)).astype(bool) # convert to a list of masks masks = [np.array(v) for v in masks.values()] # find traps at the time point traps = self.get_traps(tp, tile_shape=tile_size, channels=tree_chs) self.img_bgsub = {} if self.params.sub_bg: # generate boolean masks for background bg = [ ~np.sum(m, axis=2).astype(bool) if np.any(m) else np.zeros((tile_size, tile_size)) for m in masks ] d = {} for ch, red_metrics in tree.items(): img = None # ch != is necessary for threading if ch != "general" and traps is not None and len(traps): img = traps[:, tree_chs.index(ch), 0] d[ch] = self.reduce_extract( red_metrics=red_metrics, traps=img, masks=masks, labels=labels, **kwargs, ) if ( ch in self.params.sub_bg and img is not None ): # Calculate metrics with subtracted bg ch_bs = ch + "_bgsub" self.img_bgsub[ch_bs] = [] for trap, maskset in zip(img, bg): cells_fl = np.zeros_like(trap) is_cell = np.where(maskset) if len(is_cell[0]): # skip calculation for empty traps cells_fl = np.median(trap[is_cell], axis=0) self.img_bgsub[ch_bs].append(trap - cells_fl) d[ch_bs] = self.reduce_extract( red_metrics=ch_tree[ch], traps=self.img_bgsub[ch_bs], masks=masks, labels=labels, **kwargs, ) # Additional operations between multiple channels (e.g. pH calculations) for name, ( chs, merge_fun, red_metrics, ) in self.params.multichannel_ops.items(): if len( set(chs).intersection( set(self.img_bgsub.keys()).union(tree_chs) ) ) == len(chs): imgs = [self.get_imgs(ch, traps, tree_chs) for ch in chs] merged = MERGE_FUNS[merge_fun](*imgs) d[name] = self.reduce_extract( red_metrics=red_metrics, traps=merged, masks=masks, labels=labels, **kwargs, ) return d def get_imgs(self, channel, traps, channels=None): """ Returns the image from a correct source, either raw or bgsub :channel: str name of channel to get :img: ndarray (trap_id, channel, tp, tile_size, tile_size, n_zstacks) of standard channels :channels: List of channels """ if channels is None: channels = (*self.params.tree,) if channel in channels: return traps[:, channels.index(channel), 0] elif channel in self.img_bgsub: return self.img_bgsub[channel] def run_tp(self, tp, **kwargs): """ Wrapper to add compatiblibility with other pipeline steps """ return self.run(tps=[tp], **kwargs) def run( self, tree=None, tps: List[int] = None, save=True, **kwargs ) -> dict: if tree is None: tree = self.params.tree if tps is None: tps = list(range(self.meta["time_settings/ntimepoints"][0])) d = {} for tp in tps: new = flatten_nest( self.extract_tp(tp=tp, tree=tree, **kwargs), to="series", tp=tp, ) for k in new.keys(): n = new[k] d[k] = pd.concat((d.get(k, None), n), axis=1) for k in d.keys(): indices = ["experiment", "position", "trap", "cell_label"] idx = ( indices[-d[k].index.nlevels :] if d[k].index.nlevels > 1 else [indices[-2]] ) d[k].index.names = idx toreturn = d if save: self.save_to_hdf(toreturn) return toreturn def extract_pos( self, tree=None, tps: List[int] = None, save=True, **kwargs ) -> dict: if tree is None: tree = self.params.tree if tps is None: tps = list(range(self.meta["time_settings/ntimepoints"])) d = {} for tp in tps: new = flatten_nest( self.extract_tp(tp=tp, tree=tree, **kwargs), to="series", tp=tp, ) for k in new.keys(): n = new[k] d[k] = pd.concat((d.get(k, None), n), axis=1) for k in d.keys(): indices = ["experiment", "position", "trap", "cell_label"] idx = ( indices[-d[k].index.nlevels :] if d[k].index.nlevels > 1 else [indices[-2]] ) d[k].index.names = idx toreturn = d if save: self.save_to_hdf(toreturn) return toreturn def save_to_hdf(self, group_df, path=None): if path is None: path = self.local self.writer = Writer(path) for path, df in group_df.items(): dset_path = "/extraction/" + path self.writer.write(dset_path, df) self.writer.id_cache.clear() def get_meta(self, flds): if not hasattr(flds, "__iter__"): # make flds a list flds = [flds] meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()} return { f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds } ### Helpers def flatten_nest(nest: dict, to="series", tp: int = None) -> dict: """ Convert a nested extraction dict into a dict of series :param nest: dict contained the nested results of extraction :param to: str = 'series' Determine output format, either list or pd.Series :param tp: int timepoint used to name the series """ d = {} for k0, v0 in nest.items(): for k1, v1 in v0.items(): for k2, v2 in v1.items(): d["/".join((k0, k1, k2))] = ( pd.Series(*v2, name=tp) if to == "series" else v2 ) return d def fill_tree(tree): if tree is None: return None tree_depth = depth(tree) if depth(tree) < 3: d = {None: {None: {None: []}}} for _ in range(2 - tree_depth): d = d[None] d[None] = tree tree = d return tree class hollowExtractor(Extractor): """Extractor that only cares about receiving image and masks, used for testing. """ def __init__(self, parameters): self.params = parameters