diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index b58302c3e69d64b3b84c0fd0a389425eb5c5dbf4..e67c6e6f659fee169c8325a125d0f23a6a11b460 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -141,7 +141,7 @@ class PipelineParameters(ParametersABC): # define defaults and update with any inputs defaults["tiler"] = TilerParameters.default(**tiler).to_dict() - # Generate a backup channel, for when logfile meta is available + # generate a backup channel, for when logfile meta is available # but not image metadata. backup_ref_channel = None if "channels" in meta_d and isinstance( @@ -384,7 +384,7 @@ class Pipeline(ProcessABC): session = None filename = None # - run_kwargs = {"extraction": {"labels": None, "masks": None}} + run_kwargs = {"extraction": {"cell_labels": None, "masks": None}} try: ( filename, @@ -507,7 +507,7 @@ class Pipeline(ProcessABC): ) elif step == "extraction": # remove masks and labels after extraction - for k in ["masks", "labels"]: + for k in ["masks", "cell_labels"]: run_kwargs[step][k] = None # check and report clogging frac_clogged_traps = self.check_earlystop( diff --git a/src/aliby/tile/tiler.py b/src/aliby/tile/tiler.py index 27c1e814d4682228149fc0343d9a1a564cc7ad14..be01900ee28c521e4f9f421dedf4d1474e39594c 100644 --- a/src/aliby/tile/tiler.py +++ b/src/aliby/tile/tiler.py @@ -344,7 +344,7 @@ class Tiler(StepABC): return tiler @lru_cache(maxsize=2) - def get_tc(self, t: int, c: int) -> np.ndarray: + def get_tc(self, tp: int, c: int) -> np.ndarray: """ Load image using dask. @@ -357,7 +357,7 @@ class Tiler(StepABC): Parameters ---------- - t: integer + tp: integer An index for a time point c: integer An index for a channel @@ -366,7 +366,7 @@ class Tiler(StepABC): ------- full: an array of images """ - full = self.image[t, c] + full = self.image[tp, c] if hasattr(full, "compute"): # if using dask fetch images full = full.compute(scheduler="synchronous") @@ -570,9 +570,8 @@ class Tiler(StepABC): Returns ------- res: array - Data arranged as (tiles, channels, time points, X, Y, Z) + Data arranged as (tiles, channels, Z, X, Y) """ - # FIXME add support for sub-tiling a tile # FIXME can we ignore z if channels is None: channels = [0] @@ -583,8 +582,7 @@ class Tiler(StepABC): for c in channels: # only return requested z val = self.get_tp_data(tp, c)[:, z] - # starts with the order: tiles, z, y, x - # returns the order: tiles, C, T, Z, X, Y + # starts with the order: tiles, Z, Y, X val = np.expand_dims(val, axis=1) res.append(val) if tile_shape is not None: @@ -596,7 +594,10 @@ class Tiler(StepABC): for tile_size, ax in zip(tile_shape, res[0].shape[-3:-2]) ] ) - return np.stack(res, axis=1) + # convert to array with channels as first column + # final has dimensions (tiles, channels, 1, Z, X, Y) + final = np.stack(res, axis=1) + return final @property def ref_channel_index(self): diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index cebb0b2a1d9229174c18f1ea9a9f5cb0a80383cf..ad42f9b9620e5d711d234c298d718a504d9206ad 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -26,14 +26,14 @@ extraction_result = t.Dict[ str, t.Dict[reduction_method, t.Dict[str, t.Dict[str, pd.Series]]] ] -# Global variables used to load functions that either analyse cells or their background. These global variables both allow the functions to be stored in a dictionary for access only on demand and to be defined simply in extraction/core/functions. +# Global variables used to load functions that either analyse cells +# or their background. These global variables both allow the functions +# to be stored in a dictionary for access only on demand and to be +# defined simply in extraction/core/functions. CELL_FUNS, TRAPFUNS, FUNS = load_funs() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() RED_FUNS = load_redfuns() -# Assign datatype depending on the metric used -# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte} - class ExtractorParameters(ParametersABC): """Base class to define parameters for extraction.""" @@ -74,13 +74,19 @@ class Extractor(StepABC): """ Apply a metric to cells identified in the tiles. - Using the cell masks, the Extractor applies a metric, such as area or median, to cells identified in the image tiles. + Using the cell masks, the Extractor applies a metric, such as + area or median, to cells identified in the image tiles. Its methods require both tile images and masks. - Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile. + Usually the metric is applied to only a tile's masked area, but + some metrics depend on the whole tile. - Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third or leaf level. + Extraction follows a three-level tree structure. Channels, such + as GFP, are the root level; the reduction algorithm, such as + maximum projection, is the second level; the specific metric, + or operation, to apply to the masks, such as mean, is the third + or leaf level. """ # TODO Alan: Move this to a location with the SwainLab defaults @@ -107,7 +113,8 @@ class Extractor(StepABC): store: str Path to the h5 file containing the cell masks. tiler: pipeline-core.core.segmentation tiler - Class that contains or fetches the images used for segmentation. + Class that contains or fetches the images used for + segmentation. """ self.params = parameters if store: @@ -161,13 +168,16 @@ class Extractor(StepABC): def load_custom_funs(self): """ - Incorporate the extra arguments of custom functions into their definitions. + Incorporate the extra arguments of custom functions into their + definitions. Normal functions only have cell_masks and trap_image as their arguments, and here custom functions are made the same by setting the values of their extra arguments. - Any other parameters are taken from the experiment's metadata and automatically applied. These parameters therefore must be loaded within an Extractor instance. + Any other parameters are taken from the experiment's metadata + and automatically applied. These parameters therefore must be + loaded within an Extractor instance. """ # find functions specified in params.tree funs = set( @@ -222,7 +232,8 @@ class Extractor(StepABC): """ Find tiles for a given time point, channels, and z-stacks. - Any additional keyword arguments are passed to tiler.get_tiles_timepoint + Any additional keyword arguments are passed to + tiler.get_tiles_timepoint Parameters ---------- @@ -243,8 +254,9 @@ class Extractor(StepABC): # a list of the indices of the z stacks channel_ids = None if z is None: - # gets the tiles data via tiler + # include all Z channels z = list(range(self.tiler.shape[-3])) + # get the image data via tiler res = ( self.tiler.get_tiles_timepoint( tp, channels=channel_ids, z=z, **kwargs @@ -252,7 +264,7 @@ class Extractor(StepABC): if channel_ids else None ) - # data arranged as (tiles, channels, time points, X, Y, Z) + # res has dimensions (tiles, channels, 1, Z, X, Y) return res def extract_traps( @@ -260,7 +272,7 @@ class Extractor(StepABC): traps: t.List[np.ndarray], masks: t.List[np.ndarray], metric: str, - labels: t.Dict[int, t.List[int]], + cell_labels: t.Dict[int, t.List[int]], ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]: """ Apply a function to a whole position. @@ -273,23 +285,25 @@ class Extractor(StepABC): t.List of masks. metric: str Metric to extract. - labels: dict - A dict of cell labels with trap_ids as keys and a list of cell labels as values. + cell_labels: dict + A dict of cell labels with trap_ids as keys and a list + of cell labels as values. pos_info: bool Whether to add the position as an index or not. Returns ------- res_idx: a tuple of tuples - A two-tuple comprising a tuple of results and a tuple of the tile_id and cell labels + A two-tuple comprising a tuple of results and a tuple of + the tile_id and cell labels """ - if labels is None: - self._log("No labels given. Sorting cells using index.") + if cell_labels is None: + self._log("No cell labels given. Sorting cells using index.") cell_fun = True if metric in self._all_cell_funs else False idx = [] results = [] for trap_id, (mask_set, trap, lbl_set) in enumerate( - zip(masks, traps, labels.values()) + zip(masks, traps, cell_labels.values()) ): # ignore empty traps if len(mask_set): @@ -344,7 +358,9 @@ class Extractor(StepABC): masks: list of arrays An array of masks for each trap: one per cell at the trap red_metrics: dict - dict for which keys are reduction functions and values are either a list or a set of strings giving the metric functions. + dict for which keys are reduction functions and values are + either a list or a set of strings giving the metric + functions. For example: {'np_max': {'max5px', 'mean', 'median'}} **kwargs: dict All other arguments passed to Extractor.extract_funs. @@ -353,7 +369,8 @@ class Extractor(StepABC): ------ Dict of dataframes with the corresponding reductions and metrics nested. """ - # create dict with keys naming the reduction in the z-direction and the reduced data as values + # create dict with keys naming the reduction in the z-direction + # and the reduced data as values reduced_tiles_data = {} if traps is not None: for red_fun in red_metrics.keys(): @@ -392,64 +409,24 @@ class Extractor(StepABC): reduced = reduce_z(img, method) return reduced - def extract_tp( - self, - tp: int, - tree: t.Optional[extraction_tree] = None, - tile_size: int = 117, - masks: t.Optional[t.List[np.ndarray]] = None, - labels: t.Optional[t.List[int]] = None, - **kwargs, - ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]: - """ - Extract for an individual time point. - - Parameters - ---------- - tp : int - Time point being analysed. - tree : dict - Nested dictionary indicating channels, reduction functions and - metrics to be used. - For example: {'general': {'None': ['area', 'volume', 'eccentricity']}} - tile_size : int - Size of the tile to be extracted. - masks : list of arrays - A list of masks per trap with each mask having dimensions (ncells, tile_size, - tile_size). - labels : dict - A dictionary with trap_ids as keys and cell_labels as values. - **kwargs : keyword arguments - Passed to extractor.reduce_extract. - - Returns - ------- - d: dict - Dictionary of the results with three levels of dictionaries. - The first level has channels as keys. - The second level has reduction metrics as keys. - The third level has cell or background metrics as keys and a two-tuple as values. - The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap. - - An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells. - """ - # TODO Can we split the different extraction types into sub-methods to make this easier to read? + def make_tree_bits(self, tree): + """Put extraction tree and information for the channels into a dict.""" if tree is None: # use default tree: extraction_tree = self.params.tree - # dictionary with channel: {reduction algorithm : metric} - ch_tree = {ch: v for ch, v in tree.items() if ch != "general"} - # tuple of the channels - tree_chs = (*ch_tree,) - # create a Cells object to extract information from the h5 file - cells = Cells(self.local) - # find the cell labels and store as dict with trap_ids as keys - if labels is None: - raw_labels = cells.labels_at_time(tp) - labels = { - trap_id: raw_labels.get(trap_id, []) - for trap_id in range(cells.ntraps) - } + tree_bits = { + "tree": tree, + # dictionary with channel: {reduction algorithm : metric} + "channel_tree": { + ch: v for ch, v in tree.items() if ch != "general" + }, + } + # tuple of the fluorescence channels + tree_bits["tree_channels"] = (*tree_bits["channel_tree"],) + return tree_bits + + def get_masks(self, tp, masks, cells): + """Get the masks as a list with an array of masks for each trap.""" # find the cell masks for a given trap as a dict with trap_ids as keys if masks is None: raw_masks = cells.at_time(tp, kind="mask") @@ -458,16 +435,31 @@ class Extractor(StepABC): if len(cells): masks[trap_id] = np.stack(np.array(cells)).astype(bool) # convert to a list of masks + # one array of size (no cells, tile_size, tile_size) per trap masks = [np.array(v) for v in masks.values()] - # find image data at the time point - # stored as an array arranged as (traps, channels, time points, X, Y, Z) - tiles = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs) - # generate boolean masks for background as a list with one mask per trap - bgs = np.array([]) + return masks + + def get_cell_labels(self, tp, cell_labels, cells): + """Get the cell labels per trap as a dict with trap_ids as keys.""" + if cell_labels is None: + raw_cell_labels = cells.labels_at_time(tp) + cell_labels = { + trap_id: raw_cell_labels.get(trap_id, []) + for trap_id in range(cells.ntraps) + } + return cell_labels + + def get_background_masks(self, masks, tile_size): + """ + Generate boolean background masks. + + Combine masks per trap and then take the logical inverse. + """ if self.params.sub_bg: bgs = ~np.array( list( map( + # sum over masks for each cell lambda x: np.sum(x, axis=0) if np.any(x) else np.zeros((tile_size, tile_size)), @@ -475,15 +467,29 @@ class Extractor(StepABC): ) ) ).astype(bool) - # perform extraction by applying metrics + else: + bgs = np.array([]) + return bgs + + def extract_one_channel( + self, tree_bits, cell_labels, tiles, masks, bgs, **kwargs + ): + """ + Extract using all metrics requiring a single channel. + + Apply first without and then with background subtraction. + + Return the extraction results and a dict of background corrected images. + """ d = {} - self.img_bgsub = {} - for ch, red_metrics in tree.items(): + img_bgsub = {} + for ch, red_metrics in tree_bits["tree"].items(): # NB ch != is necessary for threading if ch != "general" and tiles is not None and len(tiles): - # image data for all traps and z sections for a particular channel - # as an array arranged as (tiles, Z, X, Y, ) - img = tiles[:, tree_chs.index(ch), 0] + # image data for all traps for a particular channel and time point + # arranged as (traps, Z, X, Y) + # we use 0 here to access the single time point available + img = tiles[:, tree_bits["tree_channels"].index(ch), 0] else: img = None # apply metrics to image data @@ -491,37 +497,43 @@ class Extractor(StepABC): traps=img, masks=masks, red_metrics=red_metrics, - labels=labels, + cell_labels=cell_labels, **kwargs, ) # apply metrics to image data with the background subtracted if bgs.any() and ch in self.params.sub_bg and img is not None: - # calculate metrics with subtracted bg + # calculate metrics with background subtracted ch_bs = ch + "_bgsub" # subtract median background - self.img_bgsub[ch_bs] = np.moveaxis( - np.stack( - list( - map( - lambda tile, mask: np.moveaxis(tile, 0, -1) - - bn.median(tile[:, mask], axis=1), - img, - bgs, - ) - ) - ), - -1, - 1, - ) # End with tiles, z, y, x + bgsub_mapping = map( + # move time to last column to allow subtraction + lambda img, bgs: np.moveaxis(img, 0, -1) + # median of background over all pixels for each time point + - bn.median(img[:, bgs], axis=1), + img, + bgs, + ) + # apply map and convert to array + mapping_result = np.stack(list(bgsub_mapping)) + # move time axis back to the second column + img_bgsub[ch_bs] = np.moveaxis(mapping_result, -1, 1) # apply metrics to background-corrected data d[ch_bs] = self.reduce_extract( - red_metrics=ch_tree[ch], - traps=self.img_bgsub[ch_bs], + red_metrics=tree_bits["channel_tree"][ch], + traps=img_bgsub[ch_bs], masks=masks, - labels=labels, + cell_labels=cell_labels, **kwargs, ) - # apply any metrics using multiple channels, such as pH calculations + return d, img_bgsub + + def extract_multiple_channels( + self, tree_bits, cell_labels, tiles, masks, **kwargs + ): + """ + Extract using all metrics requiring multiple channels. + """ + d = {} for name, ( chs, merge_fun, @@ -529,22 +541,99 @@ class Extractor(StepABC): ) in self.params.multichannel_ops.items(): if len( set(chs).intersection( - set(self.img_bgsub.keys()).union(tree_chs) + set(self.img_bgsub.keys()).union( + tree_bits["tree_channels"] + ) ) ) == len(chs): channels_stack = np.stack( - [self.get_imgs(ch, tiles, tree_chs) for ch in chs], axis=-1 + [ + self.get_imgs(ch, tiles, tree_bits["tree_channels"]) + for ch in chs + ], + axis=-1, ) merged = RED_FUNS[merge_fun](channels_stack, axis=-1) d[name] = self.reduce_extract( red_metrics=red_metrics, traps=merged, masks=masks, - labels=labels, + cell_labels=cell_labels, **kwargs, ) return d + def extract_tp( + self, + tp: int, + tree: t.Optional[extraction_tree] = None, + tile_size: int = 117, + masks: t.Optional[t.List[np.ndarray]] = None, + cell_labels: t.Optional[t.List[int]] = None, + **kwargs, + ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]: + """ + Extract for an individual time point. + + Parameters + ---------- + tp : int + Time point being analysed. + tree : dict + Nested dictionary indicating channels, reduction functions + and metrics to be used. + For example: {'general': {'None': ['area', 'volume', 'eccentricity']}} + tile_size : int + Size of the tile to be extracted. + masks : list of arrays + A list of masks per trap with each mask having dimensions + (ncells, tile_size, tile_size) and with one mask per cell. + cell_labels : dict + A dictionary with trap_ids as keys and cell_labels as values. + **kwargs : keyword arguments + Passed to extractor.reduce_extract. + + Returns + ------- + d: dict + Dictionary of the results with three levels of dictionaries. + The first level has channels as keys. + The second level has reduction metrics as keys. + The third level has cell or background metrics as keys and a + two-tuple as values. + The first tuple is the result of applying the metrics to a + particular cell or trap; the second tuple is either + (trap_id, cell_label) for a metric applied to a cell or a + trap_id for a metric applied to a trap. + + An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple + of the calculated mean GFP fluorescence for all cells. + """ + # dict of information from extraction tree + tree_bits = self.make_tree_bits(tree) + # create a Cells object to extract information from the h5 file + cells = Cells(self.local) + # find the cell labels as dict with trap_ids as keys + cell_labels = self.get_cell_labels(tp, cell_labels, cells) + # get masks one per cell per trap + masks = self.get_masks(tp, masks, cells) + # find image data at the time point + # stored as an array arranged as (traps, channels, 1, Z, X, Y) + tiles = self.get_tiles( + tp, tile_shape=tile_size, channels=tree_bits["tree_channels"] + ) + # generate boolean masks for background for each trap + bgs = self.get_background_masks(masks, tile_size) + # perform extraction + res_one, self.img_bgsub = self.extract_one_channel( + tree_bits, cell_labels, tiles, masks, bgs, **kwargs + ) + res_two = self.extract_multiple_channels( + tree_bits, cell_labels, tiles, masks, **kwargs + ) + res = {**res_one, **res_two} + return res + def get_imgs(self, channel: t.Optional[str], tiles, channels=None): """ Return image from a correct source, either raw or bgsub. @@ -554,14 +643,16 @@ class Extractor(StepABC): channel: str Name of channel to get. tiles: ndarray - An array of the image data having dimensions of (tile_id, channel, tp, tile_size, tile_size, n_zstacks). + An array of the image data having dimensions of + (tile_id, channel, tp, tile_size, tile_size, n_zstacks). channels: list of str (optional) t.List of available channels. Returns ------- img: ndarray - An array of image data with dimensions (no tiles, X, Y, no Z channels) + An array of image data with dimensions + (no tiles, X, Y, no Z channels) """ if channels is None: channels = (*self.params.tree,) @@ -598,7 +689,9 @@ class Extractor(StepABC): Returns ------- d: dict - A dict of the extracted data for one position with a concatenated string of channel, reduction metric, and cell metric as keys and pd.DataFrame of the extracted data for all time points as values. + A dict of the extracted data for one position with a concatenated + string of channel, reduction metric, and cell metric as keys and + pd.DataFrame of the extracted data for all time points as values. """ if tree is None: tree = self.params.tree @@ -673,14 +766,17 @@ def flatten_nesteddict( nest: dict of dicts Contains the nested results of extraction. to: str (optional) - Specifies the format of the output, either pd.Series (default) or a list + Specifies the format of the output, either pd.Series (default) + or a list tp: int Time point used to name the pd.Series Returns ------- d: dict - A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values. + A dict with a concatenated string of channel, reduction metric, + and cell metric as keys and either a pd.Series or a list of the + corresponding extracted data as values. """ d = {} for k0, v0 in nest.items(): @@ -690,14 +786,3 @@ def flatten_nesteddict( pd.Series(*v2, name=tp) if to == "series" else v2 ) return d - - -class hollowExtractor(Extractor): - """ - Extractor that only cares about receiving images and masks. - - Used for testing. - """ - - def __init__(self, parameters): - self.params = parameters diff --git a/src/extraction/core/functions/distributors.py b/src/extraction/core/functions/distributors.py index e9b5265f55373af6acd409d4a018d9b6341dbd7b..90838e61b21ef505f87df784fbe82da6781efce1 100644 --- a/src/extraction/core/functions/distributors.py +++ b/src/extraction/core/functions/distributors.py @@ -44,5 +44,6 @@ def reduce_z(trap_image: np.ndarray, fun: t.Callable, axis: int = 0): elif isinstance(fun, np.ufunc): # optimise the reduction function if possible return fun.reduce(trap_image, axis=axis) - else: # WARNING: Very slow, only use when no alternatives exist + else: + # WARNING: Very slow, only use when no alternatives exist return np.apply_along_axis(fun, axis, trap_image)