From 4fc31051cbd791378bcbbdd941173b3f93ff4a04 Mon Sep 17 00:00:00 2001 From: Peter Swain <peter.swain@ed.ac.uk> Date: Sun, 29 Oct 2023 16:09:44 +0000 Subject: [PATCH] extended bgsub to multichannel functions --- src/extraction/core/extractor.py | 159 ++++++++++++-------------- src/extraction/core/functions/cell.py | 9 +- 2 files changed, 80 insertions(+), 88 deletions(-) diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index ba1d491..44677b3 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -306,15 +306,15 @@ class Extractor(StepABC): # tiles has dimensions (tiles, channels, 1, Z, X, Y) return tiles - def extract_traps( + def apply_cell_function( self, traps: t.List[np.ndarray], masks: t.List[np.ndarray], - cell_property: str, + cell_function: str, cell_labels: t.Dict[int, t.List[int]], ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]: """ - Apply a function to a whole position. + Apply a cell function to all cells at all traps for one time point. Parameters ---------- @@ -322,11 +322,11 @@ class Extractor(StepABC): t.List of images. masks: list of arrays t.List of masks. - cell_property: str - Property to extract, including imBackground. + cell_function: str + Function to apply. cell_labels: dict - A dict of cell labels with trap_ids as keys and a list - of cell labels as values. + A dict with trap_ids as keys and a list of cell labels as + values. Returns ------- @@ -336,7 +336,7 @@ class Extractor(StepABC): """ if cell_labels is None: self._log("No cell labels given. Sorting cells using index.") - cell_fun = True if cell_property in self.all_cell_funs else False + cell_fun = True if cell_function in self.all_cell_funs else False idx = [] results = [] for trap_id, (mask_set, trap, local_cell_labels) in enumerate( @@ -345,7 +345,7 @@ class Extractor(StepABC): # ignore empty traps if len(mask_set): # find property from the tile - result = self.all_funs[cell_property](mask_set, trap) + result = self.all_funs[cell_function](mask_set, trap) if cell_fun: # store results for each cell separately for cell_label, val in zip(local_cell_labels, result): @@ -371,8 +371,8 @@ class Extractor(StepABC): Data from one time point is used. """ d = { - cell_fun: self.extract_traps( - traps=tiles, masks=masks, cell_property=cell_fun, **kwargs + cell_fun: self.apply_cell_function( + traps=tiles, masks=masks, cell_function=cell_fun, **kwargs ) for cell_fun in cell_funs } @@ -513,88 +513,66 @@ class Extractor(StepABC): return bgs def extract_one_channel( - self, tree_dict, cell_labels, tiles, masks, bgs, **kwargs + self, tree_dict, cell_labels, img, img_bgsub, masks, **kwargs ): - """ - Extract all metrics requiring only a single channel. - - Apply first without and then with background subtraction. - - Return the extraction results and a dict of background - corrected images. - """ + """Extract as dict all metrics requiring only a single channel.""" d = {} for ch, reduction_cell_funs in tree_dict["tree"].items(): - # NB ch != is necessary for threading - if ch != "general" and tiles is not None and len(tiles): - # image data for all traps for a particular channel and - # time point arranged as (traps, Z, X, Y) - # we use 0 here to access the single time point available - channel_tile = tiles[:, tree_dict["channels"].index(ch), 0] - else: - # no reduction applied to "general" - bright-field images - channel_tile = None - # apply metrics to image data + # extract from all images including bright field d[ch] = self.reduce_extract( - tiles=channel_tile, + # use None for "general"; no fluorescence image + tiles=img.get(ch, None), masks=masks, reduction_cell_funs=reduction_cell_funs, cell_labels=cell_labels, **kwargs, ) - # apply metrics to image data with the background subtracted - if ( - bgs.any() - and ch in self.params.sub_bg - and channel_tile is not None - ): - # apply metrics to background-corrected data + if ch != "general": + # extract from background-corrected fluorescence images d[ch + "_bgsub"] = self.reduce_extract( - tiles=self.img_bgsub[ch + "_bgsub"], + tiles=img_bgsub[ch + "_bgsub"], masks=masks, - reduction_cell_funs=tree_dict["channels_tree"][ch], + reduction_cell_funs=reduction_cell_funs, cell_labels=cell_labels, **kwargs, ) return d - def extract_multiple_channels( - self, tree_dict, cell_labels, tiles, masks, **kwargs - ): - """ - Extract all metrics requiring multiple channels. - """ - # channels and background corrected channels - available_chs = set(self.img_bgsub.keys()).union(tree_dict["channels"]) + def extract_multiple_channels(self, cell_labels, img, img_bgsub, masks): + """Extract as a dict all metrics requiring multiple channels.""" + # NB multichannel functions do not use tree_dict + available_channels = set(list(img.keys()) + list(img_bgsub.keys())) d = {} - for name, ( - chs, - reduction_fun, - op, + for multichannel_fun_name, ( + channels, + reduction, + multichannel_function, ) in self.params.multichannel_ops.items(): - common_chs = set(chs).intersection(available_chs) + common_channels = set(channels).intersection(available_channels) # all required channels should be available - if len(common_chs) == len(chs): - channels_stack = np.stack( - [ - self.get_imgs(ch, tiles, tree_dict["channels"]) - for ch in chs - ], - axis=-1, - ) - # reduce in Z - traps = REDUCTION_FUNS[reduction_fun](channels_stack, axis=1) - # evaluate multichannel op - if name not in d: - d[name] = {} - if reduction_fun not in d[name]: - d[name][reduction_fun] = {} - d[name][reduction_fun][op] = self.extract_traps( - traps, - masks, - op, - cell_labels, - ) + if len(common_channels) == len(channels): + for images, suffix in zip([img, img_bgsub], ["", "_bgsub"]): + # channels + channels_stack = np.stack( + [images[ch + suffix] for ch in channels], + axis=-1, + ) + # reduce in Z + tiles = REDUCTION_FUNS[reduction](channels_stack, axis=1) + # set up dict + if multichannel_fun_name not in d: + d[multichannel_fun_name] = {} + if reduction not in d[multichannel_fun_name]: + d[multichannel_fun_name][reduction] = {} + # apply multichannel function + d[multichannel_fun_name][reduction][ + multichannel_function + suffix + ] = self.apply_cell_function( + tiles, + masks, + multichannel_function, + cell_labels, + ) return d def extract_tp( @@ -656,33 +634,41 @@ class Extractor(StepABC): tiles = self.get_tiles(tp, channels=tree_dict["channels"]) # generate boolean masks for background for each trap bgs = self.get_background_masks(masks, tile_size) - # perform background subtraction for all traps at this time point - self.img_bgsub = self.perform_background_subtraction( + # get images and background corrected images as dicts + # with fluorescnce channels as keys + img, img_bgsub = self.get_imgs_background_subtract( tree_dict, tiles, bgs ) # perform extraction res_one = self.extract_one_channel( - tree_dict, cell_labels, tiles, masks, bgs, **kwargs + tree_dict, cell_labels, img, img_bgsub, masks, **kwargs ) res_multiple = self.extract_multiple_channels( - tree_dict, cell_labels, tiles, masks, **kwargs + cell_labels, img, img_bgsub, masks ) res = {**res_one, **res_multiple} return res - def perform_background_subtraction(self, tree_dict, tiles, bgs): - """Subtract background for fluorescence channels.""" + def get_imgs_background_subtract(self, tree_dict, tiles, bgs): + """ + Get two dicts of fluorescence images. + + Return images and background subtracted image for all traps + for one time point. + """ + img = {} img_bgsub = {} for ch, _ in tree_dict["channels_tree"].items(): + # NB ch != is necessary for threading if tiles is not None and len(tiles): # image data for all traps for a particular channel and # time point arranged as (traps, Z, X, Y) # we use 0 here to access the single time point available - channel_tile = tiles[:, tree_dict["channels"].index(ch), 0] + img[ch] = tiles[:, tree_dict["channels"].index(ch), 0] if ( bgs.any() and ch in self.params.sub_bg - and channel_tile is not None + and img[ch] is not None ): # subtract median background bgsub_mapping = map( @@ -690,7 +676,7 @@ class Extractor(StepABC): lambda img, bgs: np.moveaxis(img, 0, -1) # median of background over all pixels for each Z section - bn.median(img[:, bgs], axis=1), - channel_tile, + img[ch], bgs, ) # apply map and convert to array @@ -699,9 +685,12 @@ class Extractor(StepABC): img_bgsub[ch + "_bgsub"] = np.moveaxis( mapping_result, -1, 1 ) - return img_bgsub + else: + img[ch] = None + img_bgsub[ch] = None + return img, img_bgsub - def get_imgs(self, channel: t.Optional[str], tiles, channels=None): + def get_imgs_old(self, channel: t.Optional[str], tiles, channels=None): """ Return image from a correct source, either raw or bgsub. diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py index 21f9545..5de868e 100644 --- a/src/extraction/core/functions/cell.py +++ b/src/extraction/core/functions/cell.py @@ -239,9 +239,12 @@ def moment_of_inertia(cell_mask, trap_image): def ratio(cell_mask, trap_image): """Find the median ratio between two fluorescence channels.""" if trap_image.ndim == 3 and trap_image.shape[-1] == 2: - fl_1 = trap_image[..., 0][cell_mask] - fl_2 = trap_image[..., 1][cell_mask] - div = np.median(fl_1 / fl_2) + fl_0 = trap_image[..., 0][cell_mask] + fl_1 = trap_image[..., 1][cell_mask] + if np.any(fl_1 == 0): + div = np.nan + else: + div = np.median(fl_0 / fl_1) else: div = np.nan return div -- GitLab