import typing as t from pathlib import Path import bottleneck as bn import h5py import numpy as np import pandas as pd import aliby.global_parameters as global_parameters from agora.abc import ParametersABC, StepABC from agora.io.cells import Cells from agora.io.writer import Writer, load_meta from aliby.tile.tiler import Tiler, find_channel_name from extraction.core.functions.distributors import reduce_z, trap_apply from extraction.core.functions.loaders import ( load_custom_args, load_funs, load_redfuns, ) # define types reduction_method = t.Union[t.Callable, str, None] extraction_tree = t.Dict[ str, t.Dict[reduction_method, t.Dict[str, t.Collection]] ] extraction_result = t.Dict[ str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]] ] # Global variables used to load functions that either analyse cells # or their background. These global variables both allow the functions # to be stored in a dictionary for access only on demand and to be # defined simply in extraction/core/functions. CELL_FUNS, TRAP_FUNS, ALL_FUNS = load_funs() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() REDUCTION_FUNS = load_redfuns() def extraction_params_from_meta( meta: t.Union[dict, Path, str], extras: t.Collection[str] = ["ph"] ): """Obtain parameters for extraction from meta data.""" if not isinstance(meta, dict): # load meta data with h5py.File(meta, "r") as f: meta = dict(f["/"].attrs.items()) base = { "tree": {"general": {"None": ["area", "volume", "eccentricity"]}}, "multichannel_ops": {}, } candidate_channels = set(global_parameters.possible_imaging_channels) default_reductions = {"max"} default_metrics = set(global_parameters.fluorescence_functions) default_reduction_metrics = { r: default_metrics for r in default_reductions } # default_rm["None"] = ["nuc_conv_3d"] # Uncomment this to add nuc_conv_3d (slow) extant_fluorescence_ch = [] for av_channel in candidate_channels: # find matching channels in metadata found_channel = find_channel_name(meta.get("channels", []), av_channel) if found_channel is not None: extant_fluorescence_ch.append(found_channel) for ch in extant_fluorescence_ch: base["tree"][ch] = default_reduction_metrics base["sub_bg"] = extant_fluorescence_ch return base class ExtractorParameters(ParametersABC): """Base class to define parameters for extraction.""" def __init__( self, tree: extraction_tree, sub_bg: set = set(), multichannel_ops: t.Dict = {}, ): """ Initialise. Parameters ---------- tree: dict Nested dictionary indicating channels, reduction functions and metrics to be used. str channel -> U(function, None) reduction -> str metric If not of depth three, tree will be filled with None. sub_bg: set multichannel_ops: dict """ self.tree = tree self.sub_bg = sub_bg self.multichannel_ops = multichannel_ops @classmethod def default(cls): return cls({}) @classmethod def from_meta(cls, meta): """Instantiate from the meta data; used by Pipeline.""" return cls(**extraction_params_from_meta(meta)) class Extractor(StepABC): """ Apply a metric to cells identified in the tiles. Using the cell masks, the Extractor applies a metric, such as area or median, to cells identified in the image tiles. Its methods require both tile images and masks. Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile. Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third or leaf level. """ default_meta = global_parameters.imaging_specifications def __init__( self, parameters: ExtractorParameters, store: t.Optional[str] = None, tiler: t.Optional[Tiler] = None, ): """ Initialise Extractor. Parameters ---------- parameters: core.extractor Parameters Parameters that include the channels, reduction and extraction functions. store: str Path to the h5 file containing the cell masks. tiler: pipeline-core.core.segmentation tiler Class that contains or fetches the images used for segmentation. """ self.params = parameters if store: self.h5path = store self.meta = load_meta(self.h5path) else: # if no h5 file, use the parameters directly self.meta = {"channel": parameters.to_dict()["tree"].keys()} if tiler: self.tiler = tiler available_channels = set((*tiler.channels, "general")) # only extract for channels available self.params.tree = { k: v for k, v in self.params.tree.items() if k in available_channels } self.params.sub_bg = available_channels.intersection( self.params.sub_bg ) # add background subtracted channels to those available available_channels_bgsub = available_channels.union( [c + "_bgsub" for c in self.params.sub_bg] ) # remove any multichannel operations requiring a missing channel for op, (input_ch, _, _) in self.params.multichannel_ops.items(): if not set(input_ch).issubset(available_channels_bgsub): self.params.multichannel_ops.pop(op) self.load_funs() @classmethod def from_tiler( cls, parameters: ExtractorParameters, store: str, tiler: Tiler, ): """Initiate from a tiler instance.""" return cls(parameters, store=store, tiler=tiler) @classmethod def from_img( cls, parameters: ExtractorParameters, store: str, img_meta: tuple, ): """Initiate from images.""" return cls(parameters, store=store, tiler=Tiler(*img_meta)) @property def channels(self): """Get a tuple of the available channels.""" if not hasattr(self, "_channels"): if type(self.params.tree) is dict: self._channels = tuple(self.params.tree.keys()) return self._channels @property def current_position(self): """Return position being analysed.""" return str(self.h5path).split("/")[-1][:-3] @property def group(self): """Return out path to write in the h5 file.""" if not hasattr(self, "_out_path"): self._group = "/extraction/" return self._group def load_funs(self): """Define all functions, including custom ones.""" self.load_custom_funs() self.all_cell_funs = set(self.custom_funs.keys()).union(CELL_FUNS) # merge the two dicts self.all_funs = {**self.custom_funs, **ALL_FUNS} def load_custom_funs(self): """ Incorporate extra arguments of custom functions into their definitions. Normal functions only have cell_masks and trap_image as their arguments, and here custom functions are made the same by setting the values of their extra arguments. Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance. """ # find functions specified in params.tree funs = set( [ fun for channel in self.params.tree.values() for reduction in channel.values() for fun in reduction ] ) # consider only those already loaded from CUSTOM_FUNS funs = funs.intersection(CUSTOM_FUNS.keys()) # find their arguments self.custom_arg_vals = { k: {k2: self.get_meta(k2) for k2 in v} for k, v in CUSTOM_ARGS.items() } # define custom functions self.custom_funs = {} for k, f in CUSTOM_FUNS.items(): def tmp(f): # pass extra arguments to custom function # return a function of cell_masks and trap_image return lambda cell_masks, trap_image: trap_apply( f, cell_masks, trap_image, **self.custom_arg_vals.get(k, {}), ) self.custom_funs[k] = tmp(f) def get_tiles( self, tp: int, channels: t.Optional[t.List[t.Union[str, int]]] = None, z: t.Optional[t.List[str]] = None, ) -> t.Optional[np.ndarray]: """ Find tiles for a given time point, channels, and z-stacks. Any additional keyword arguments are passed to tiler.get_tiles_timepoint Parameters ---------- tp: int Time point of interest. channels: list of strings (optional) Channels of interest. z: list of integers (optional) Indices for the z-stacks of interest. """ if channels is None: # find channels from tiler channel_ids = list(range(len(self.tiler.channels))) elif len(channels): # a subset of channels was specified channel_ids = [self.tiler.get_channel_index(ch) for ch in channels] else: # a list of the indices of the z stacks channel_ids = None if z is None: # include all Z channels z = list(range(self.tiler.shape[-3])) # get the image data via tiler tiles = ( self.tiler.get_tiles_timepoint(tp, channels=channel_ids, z=z) if channel_ids else None ) # tiles has dimensions (tiles, channels, 1, Z, X, Y) return tiles def apply_cell_function( self, traps: t.List[np.ndarray], masks: t.List[np.ndarray], cell_function: str, cell_labels: t.Dict[int, t.List[int]], ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]: """ Apply a cell function to all cells at all traps for one time point. Parameters ---------- traps: list of arrays t.List of images. masks: list of arrays t.List of masks. cell_function: str Function to apply. cell_labels: dict A dict with trap_ids as keys and a list of cell labels as values. Returns ------- res_idx: a tuple of tuples A two-tuple comprising a tuple of results and a tuple of the tile_id and cell labels """ if cell_labels is None: self._log("No cell labels given. Sorting cells using index.") cell_fun = True if cell_function in self.all_cell_funs else False idx = [] results = [] for trap_id, (mask_set, trap, local_cell_labels) in enumerate( zip(masks, traps, cell_labels.values()) ): # ignore empty traps if len(mask_set): # find property from the tile result = self.all_funs[cell_function](mask_set, trap) if cell_fun: # store results for each cell separately for cell_label, val in zip(local_cell_labels, result): results.append(val) idx.append((trap_id, cell_label)) else: # background (trap) function results.append(result) idx.append(trap_id) res_idx = (tuple(results), tuple(idx)) return res_idx def apply_cell_funs( self, tiles: t.List[np.array], masks: t.List[np.array], cell_funs: t.List[str], **kwargs, ) -> t.Dict[str, pd.Series]: """ Return dict with cell_funs as keys and the corresponding results as values. Data from one time point is used. """ d = { cell_fun: self.apply_cell_function( traps=tiles, masks=masks, cell_function=cell_fun, **kwargs ) for cell_fun in cell_funs } return d def reduce_extract( self, tiles: np.ndarray, masks: t.List[np.ndarray], reduction_cell_funs: t.Dict[reduction_method, t.Collection[str]], **kwargs, ) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]: """ Reduce to a 2D image and then extract. Parameters ---------- tiles: array An array of image data arranged as (tiles, X, Y, Z) masks: list of arrays An array of masks for each trap: one per cell at the trap reduction_cell_funs: dict An upper branch of the extraction tree: a dict for which keys are reduction functions and values are either a list or a set of strings giving the cell functions to apply. For example: {'np_max': {'max5px', 'mean', 'median'}} **kwargs: dict All other arguments passed to Extractor.apply_cell_funs. Returns ------ Dict of dataframes with the corresponding reductions and metrics nested. """ # create dict with keys naming the reduction in the z-direction # and the reduced data as values reduced_tiles = {} if tiles is not None: for reduction in reduction_cell_funs.keys(): reduced_tiles[reduction] = [ self.reduce_dims( tile_data, method=REDUCTION_FUNS[reduction] ) for tile_data in tiles ] # calculate cell and tile properties d = { reduction: self.apply_cell_funs( tiles=reduced_tiles.get(reduction, [None for _ in masks]), masks=masks, cell_funs=cell_funs, **kwargs, ) for reduction, cell_funs in reduction_cell_funs.items() } return d def reduce_dims( self, img: np.ndarray, method: reduction_method = None ) -> np.ndarray: """ Collapse a z-stack into 2d array using method. If method is None, return the original data. Parameters ---------- img: array An array of the image data arranged as (X, Y, Z). method: function The reduction function. """ reduced = img if method is not None: reduced = reduce_z(img, method) return reduced def make_tree_dict(self, tree: extraction_tree): """Put extraction tree into a dict.""" if tree is None: # use default tree = self.params.tree tree_dict = { # the whole extraction tree "tree": tree, # the extraction tree for fluorescence channels "channels_tree": { ch: v for ch, v in tree.items() if ch != "general" }, } # tuple of the fluorescence channels tree_dict["channels"] = (*tree_dict["channels_tree"],) return tree_dict def get_masks(self, tp, masks, cells): """Get the masks as a list with an array of masks for each trap.""" # find the cell masks for a given trap as a dict with trap_ids as keys if masks is None: raw_masks = cells.at_time(tp, kind="mask") masks = {trap_id: [] for trap_id in range(cells.ntraps)} for trap_id, cells in raw_masks.items(): if len(cells): masks[trap_id] = np.stack(np.array(cells)).astype(bool) # convert to a list of masks # one array of size (no cells, tile_size, tile_size) per trap masks = [np.array(v) for v in masks.values()] return masks def get_cell_labels(self, tp, cell_labels, cells): """Get the cell labels per trap as a dict with trap_ids as keys.""" if cell_labels is None: raw_cell_labels = cells.labels_at_time(tp) cell_labels = { trap_id: raw_cell_labels.get(trap_id, []) for trap_id in range(cells.ntraps) } return cell_labels def get_background_masks(self, masks, tile_size): """ Generate boolean background masks. Combine masks per trap and then take the logical inverse. """ if self.params.sub_bg: bgs = ~np.array( list( map( # sum over masks for each cell lambda x: np.sum(x, axis=0) if np.any(x) else np.zeros((tile_size, tile_size)), masks, ) ) ).astype(bool) else: bgs = np.array([]) return bgs def extract_one_channel( self, tree_dict, cell_labels, img, img_bgsub, masks, **kwargs ): """Extract as dict all metrics requiring only a single channel.""" d = {} for ch, reduction_cell_funs in tree_dict["tree"].items(): # extract from all images including bright field d[ch] = self.reduce_extract( # use None for "general"; no fluorescence image tiles=img.get(ch, None), masks=masks, reduction_cell_funs=reduction_cell_funs, cell_labels=cell_labels, **kwargs, ) if ch != "general": # extract from background-corrected fluorescence images d[ch + "_bgsub"] = self.reduce_extract( tiles=img_bgsub[ch + "_bgsub"], masks=masks, reduction_cell_funs=reduction_cell_funs, cell_labels=cell_labels, **kwargs, ) return d def extract_multiple_channels(self, cell_labels, img, img_bgsub, masks): """Extract as a dict all metrics requiring multiple channels.""" # NB multichannel functions do not use tree_dict available_channels = set(list(img.keys()) + list(img_bgsub.keys())) d = {} for multichannel_fun_name, ( channels, reduction, multichannel_function, ) in self.params.multichannel_ops.items(): common_channels = set(channels).intersection(available_channels) # all required channels should be available if len(common_channels) == len(channels): for images, suffix in zip([img, img_bgsub], ["", "_bgsub"]): # channels channels_stack = np.stack( [images[ch + suffix] for ch in channels], axis=-1, ) # reduce in Z tiles = REDUCTION_FUNS[reduction](channels_stack, axis=1) # set up dict if multichannel_fun_name not in d: d[multichannel_fun_name] = {} if reduction not in d[multichannel_fun_name]: d[multichannel_fun_name][reduction] = {} # apply multichannel function d[multichannel_fun_name][reduction][ multichannel_function + suffix ] = self.apply_cell_function( tiles, masks, multichannel_function, cell_labels, ) return d def extract_tp( self, tp: int, tree: t.Optional[extraction_tree] = None, tile_size: int = 117, masks: t.Optional[t.List[np.ndarray]] = None, cell_labels: t.Optional[t.List[int]] = None, **kwargs, ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]: """ Extract for an individual time point. Parameters ---------- tp : int Time point being analysed. tree : dict Nested dictionary indicating channels, reduction functions and metrics to be used. For example: {'general': {'None': ['area', 'volume', 'eccentricity']}} tile_size : int Size of the tile to be extracted. masks : list of arrays A list of masks per trap with each mask having dimensions (ncells, tile_size, tile_size) and with one mask per cell. cell_labels : dict A dictionary with trap_ids as keys and cell_labels as values. **kwargs : keyword arguments Passed to extractor.reduce_extract. Returns ------- d: dict Dictionary of the results with three levels of dictionaries. The first level has channels as keys. The second level has reduction metrics as keys. The third level has cell or background metrics as keys and a two-tuple as values. The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap. An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells. """ # dict of information from extraction tree tree_dict = self.make_tree_dict(tree) # create a Cells object to extract information from the h5 file cells = Cells(self.h5path) # find the cell labels as dict with trap_ids as keys cell_labels = self.get_cell_labels(tp, cell_labels, cells) # get masks one per cell per trap masks = self.get_masks(tp, masks, cells) # find image data for all traps at the time point # stored as an array arranged as (traps, channels, 1, Z, X, Y) tiles = self.get_tiles(tp, channels=tree_dict["channels"]) # generate boolean masks for background for each trap bgs = self.get_background_masks(masks, tile_size) # get images and background corrected images as dicts # with fluorescnce channels as keys img, img_bgsub = self.get_imgs_background_subtract( tree_dict, tiles, bgs ) # perform extraction res_one = self.extract_one_channel( tree_dict, cell_labels, img, img_bgsub, masks, **kwargs ) res_multiple = self.extract_multiple_channels( cell_labels, img, img_bgsub, masks ) res = {**res_one, **res_multiple} return res def get_imgs_background_subtract(self, tree_dict, tiles, bgs): """ Get two dicts of fluorescence images. Return images and background subtracted image for all traps for one time point. """ img = {} img_bgsub = {} for ch, _ in tree_dict["channels_tree"].items(): # NB ch != is necessary for threading if tiles is not None and len(tiles): # image data for all traps for a particular channel and # time point arranged as (traps, Z, X, Y) # we use 0 here to access the single time point available img[ch] = tiles[:, tree_dict["channels"].index(ch), 0] if ( bgs.any() and ch in self.params.sub_bg and img[ch] is not None ): # subtract median background bgsub_mapping = map( # move Z to last column to allow subtraction lambda img, bgs: np.moveaxis(img, 0, -1) # median of background over all pixels for each Z section - bn.median(img[:, bgs], axis=1), img[ch], bgs, ) # apply map and convert to array mapping_result = np.stack(list(bgsub_mapping)) # move Z axis back to the second column img_bgsub[ch + "_bgsub"] = np.moveaxis( mapping_result, -1, 1 ) else: img[ch] = None img_bgsub[ch] = None return img, img_bgsub def get_imgs_old(self, channel: t.Optional[str], tiles, channels=None): """ Return image from a correct source, either raw or bgsub. Parameters ---------- channel: str Name of channel to get. tiles: ndarray An array of the image data having dimensions of (tile_id, channel, tp, tile_size, tile_size, n_zstacks). channels: list of str (optional) t.List of available channels. Returns ------- img: ndarray An array of image data with dimensions (no tiles, X, Y, no Z channels) """ if channels is None: channels = (*self.params.tree,) if channel in channels: # TODO start here to fetch channel using regex return tiles[:, channels.index(channel), 0] elif channel in self.img_bgsub: return self.img_bgsub[channel] def _run_tp( self, tps: t.List[int] = None, tree=None, save=True, **kwargs, ) -> dict: """ Run extraction for one position and for the specified time points. Save the results to a h5 file. Parameters ---------- tps: list of int (optional) Time points to include. tree: dict (optional) Nested dictionary indicating channels, reduction functions and metrics to be used. For example: {'general': {'None': ['area', 'volume', 'eccentricity']}} save: boolean (optional) If True, save results to h5 file. kwargs: keyword arguments (optional) Passed to extract_tp. Returns ------- d: dict A dict of the extracted data for one position with a concatenated string of channel, reduction metric, and cell metric as keys and pd.DataFrame of the extracted data for all time points as values. """ if tree is None: tree = self.params.tree if tps is None: tps = list(range(self.meta["time_settings/ntimepoints"][0])) elif isinstance(tps, int): tps = [tps] # store results in dict d = {} for tp in tps: # extract for each time point and convert to dict of pd.Series new = flatten_nesteddict( self.extract_tp(tp=tp, tree=tree, **kwargs), to="series", tp=tp, ) # concatenate with data extracted from earlier time points for k in new.keys(): d[k] = pd.concat((d.get(k, None), new[k]), axis=1) # add indices to pd.Series containing the extracted data for k in d.keys(): indices = ["experiment", "position", "trap", "cell_label"] idx = ( indices[-d[k].index.nlevels :] if d[k].index.nlevels > 1 else [indices[-2]] ) d[k].index.names = idx # save if save: self.save_to_h5(d) return d def save_to_h5(self, dict_series, path=None): """ Save the extracted data for one position to the h5 file. Parameters ---------- dict_series: dict A dictionary of the extracted data, created by run. path: Path (optional) To the h5 file. """ if path is None: path = self.h5path self.writer = Writer(path) for extract_name, series in dict_series.items(): dset_path = "/extraction/" + extract_name self.writer.write(dset_path, series) self.writer.id_cache.clear() def get_meta(self, flds: t.Union[str, t.Collection]): """Obtain metadata for one or multiple fields.""" if isinstance(flds, str): flds = [flds] meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()} return { f: meta_short.get(f, self.default_meta.get(f, None)) for f in flds } def flatten_nesteddict( nest: dict, to="series", tp: int = None ) -> t.Dict[str, pd.Series]: """ Convert a nested extraction dict into a dict of pd.Series. Parameters ---------- nest: dict of dicts Contains the nested results of extraction. to: str (optional) Specifies the format of the output, either pd.Series (default) or a list tp: int Time point used to name the pd.Series Returns ------- d: dict A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values. """ d = {} for k0, v0 in nest.items(): for k1, v1 in v0.items(): for k2, v2 in v1.items(): d["/".join((k0, k1, k2))] = ( pd.Series(*v2, name=tp) if to == "series" else v2 ) return d