diff --git a/src/agora/abc.py b/src/agora/abc.py index 0aaefdfd8a4d315079eb1a952715e0f327667834..0a0407b9687fc070f9db8ee05801bd422bddf19a 100644 --- a/src/agora/abc.py +++ b/src/agora/abc.py @@ -118,7 +118,7 @@ class ParametersABC(ABC): assert name not in ( "parameters", "params", - ), "Attribute can't be named params or parameters" + ), "Attribute cannot be named params or parameters." if name in self.__dict__: if check_type_recursive(getattr(self, name), new_value): diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index 61fb89a906c5b42260ac18e78c41b1447b06eea0..ba1d491e86dc20edea9e42ac39cbcdc5139f77a5 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -33,7 +33,38 @@ extraction_result = t.Dict[ # defined simply in extraction/core/functions. CELL_FUNS, TRAP_FUNS, ALL_FUNS = load_funs() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() -RED_FUNS = load_redfuns() +REDUCTION_FUNS = load_redfuns() + + +def extraction_params_from_meta( + meta: t.Union[dict, Path, str], extras: t.Collection[str] = ["ph"] +): + """Obtain parameters for extraction from meta data.""" + if not isinstance(meta, dict): + # load meta data + with h5py.File(meta, "r") as f: + meta = dict(f["/"].attrs.items()) + base = { + "tree": {"general": {"None": ["area", "volume", "eccentricity"]}}, + "multichannel_ops": {}, + } + candidate_channels = set(global_parameters.possible_imaging_channels) + default_reductions = {"max"} + default_metrics = set(global_parameters.fluorescence_functions) + default_reduction_metrics = { + r: default_metrics for r in default_reductions + } + # default_rm["None"] = ["nuc_conv_3d"] # Uncomment this to add nuc_conv_3d (slow) + extant_fluorescence_ch = [] + for av_channel in candidate_channels: + # find matching channels in metadata + found_channel = find_channel_name(meta.get("channels", []), av_channel) + if found_channel is not None: + extant_fluorescence_ch.append(found_channel) + for ch in extant_fluorescence_ch: + base["tree"][ch] = default_reduction_metrics + base["sub_bg"] = extant_fluorescence_ch + return base class ExtractorParameters(ParametersABC): @@ -68,7 +99,7 @@ class ExtractorParameters(ParametersABC): @classmethod def from_meta(cls, meta): - """Instantiate using meta data.""" + """Instantiate from the meta data; used by Pipeline.""" return cls(**extraction_params_from_meta(meta)) @@ -115,8 +146,8 @@ class Extractor(StepABC): """ self.params = parameters if store: - self.local = store - self.meta = load_meta(self.local) + self.h5path = store + self.meta = load_meta(self.h5path) else: # if no h5 file, use the parameters directly self.meta = {"channel": parameters.to_dict()["tree"].keys()} @@ -172,18 +203,26 @@ class Extractor(StepABC): @property def current_position(self): - return str(self.local).split("/")[-1][:-3] + """Return position being analysed.""" + return str(self.h5path).split("/")[-1][:-3] @property def group(self): - """Return path within the h5 file.""" + """Return out path to write in the h5 file.""" if not hasattr(self, "_out_path"): self._group = "/extraction/" return self._group + def load_funs(self): + """Define all functions, including custom ones.""" + self.load_custom_funs() + self.all_cell_funs = set(self.custom_funs.keys()).union(CELL_FUNS) + # merge the two dicts + self.all_funs = {**self.custom_funs, **ALL_FUNS} + def load_custom_funs(self): """ - Incorporate the extra arguments of custom functions into their definitions. + Incorporate extra arguments of custom functions into their definitions. Normal functions only have cell_masks and trap_image as their arguments, and here custom functions are made the same by @@ -197,20 +236,20 @@ class Extractor(StepABC): funs = set( [ fun - for ch in self.params.tree.values() - for red in ch.values() - for fun in red + for channel in self.params.tree.values() + for reduction in channel.values() + for fun in reduction ] ) # consider only those already loaded from CUSTOM_FUNS funs = funs.intersection(CUSTOM_FUNS.keys()) # find their arguments - self._custom_arg_vals = { + self.custom_arg_vals = { k: {k2: self.get_meta(k2) for k2 in v} for k, v in CUSTOM_ARGS.items() } # define custom functions - self._custom_funs = {} + self.custom_funs = {} for k, f in CUSTOM_FUNS.items(): def tmp(f): @@ -220,17 +259,10 @@ class Extractor(StepABC): f, cell_masks, trap_image, - **self._custom_arg_vals.get(k, {}), + **self.custom_arg_vals.get(k, {}), ) - self._custom_funs[k] = tmp(f) - - def load_funs(self): - """Define all functions, including custom ones.""" - self.load_custom_funs() - self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS) - # merge the two dicts - self._all_funs = {**self._custom_funs, **ALL_FUNS} + self.custom_funs[k] = tmp(f) def get_tiles( self, @@ -266,13 +298,13 @@ class Extractor(StepABC): # include all Z channels z = list(range(self.tiler.shape[-3])) # get the image data via tiler - res = ( + tiles = ( self.tiler.get_tiles_timepoint(tp, channels=channel_ids, z=z) if channel_ids else None ) - # res has dimensions (tiles, channels, 1, Z, X, Y) - return res + # tiles has dimensions (tiles, channels, 1, Z, X, Y) + return tiles def extract_traps( self, @@ -304,7 +336,7 @@ class Extractor(StepABC): """ if cell_labels is None: self._log("No cell labels given. Sorting cells using index.") - cell_fun = True if cell_property in self._all_cell_funs else False + cell_fun = True if cell_property in self.all_cell_funs else False idx = [] results = [] for trap_id, (mask_set, trap, local_cell_labels) in enumerate( @@ -313,7 +345,7 @@ class Extractor(StepABC): # ignore empty traps if len(mask_set): # find property from the tile - result = self._all_funs[cell_property](mask_set, trap) + result = self.all_funs[cell_property](mask_set, trap) if cell_fun: # store results for each cell separately for cell_label, val in zip(local_cell_labels, result): @@ -326,31 +358,31 @@ class Extractor(StepABC): res_idx = (tuple(results), tuple(idx)) return res_idx - def extract_funs( + def apply_cell_funs( self, - traps: t.List[np.array], + tiles: t.List[np.array], masks: t.List[np.array], - cell_properties: t.List[str], + cell_funs: t.List[str], **kwargs, ) -> t.Dict[str, pd.Series]: """ - Return dict with metrics as key and cell_properties as values. + Return dict with cell_funs as keys and the corresponding results as values. Data from one time point is used. """ d = { - cell_property: self.extract_traps( - traps=traps, masks=masks, cell_property=cell_property, **kwargs + cell_fun: self.extract_traps( + traps=tiles, masks=masks, cell_property=cell_fun, **kwargs ) - for cell_property in cell_properties + for cell_fun in cell_funs } return d def reduce_extract( self, - traps: np.ndarray, + tiles: np.ndarray, masks: t.List[np.ndarray], - tree_branch: t.Dict[reduction_method, t.Collection[str]], + reduction_cell_funs: t.Dict[reduction_method, t.Collection[str]], **kwargs, ) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]: """ @@ -358,17 +390,17 @@ class Extractor(StepABC): Parameters ---------- - tiles_data: array + tiles: array An array of image data arranged as (tiles, X, Y, Z) masks: list of arrays An array of masks for each trap: one per cell at the trap - tree_branch: dict + reduction_cell_funs: dict An upper branch of the extraction tree: a dict for which keys are reduction functions and values are either a list - or a set of strings giving the cell properties to be found. + or a set of strings giving the cell functions to apply. For example: {'np_max': {'max5px', 'mean', 'median'}} **kwargs: dict - All other arguments passed to Extractor.extract_funs. + All other arguments passed to Extractor.apply_cell_funs. Returns ------ @@ -377,21 +409,23 @@ class Extractor(StepABC): # create dict with keys naming the reduction in the z-direction # and the reduced data as values reduced_tiles = {} - if traps is not None: - for red_fun in tree_branch.keys(): - reduced_tiles[red_fun] = [ - self.reduce_dims(tile_data, method=RED_FUNS[red_fun]) - for tile_data in traps + if tiles is not None: + for reduction in reduction_cell_funs.keys(): + reduced_tiles[reduction] = [ + self.reduce_dims( + tile_data, method=REDUCTION_FUNS[reduction] + ) + for tile_data in tiles ] # calculate cell and tile properties d = { - red_fun: self.extract_funs( - cell_properties=cell_properties, - traps=reduced_tiles.get(red_fun, [None for _ in masks]), + reduction: self.apply_cell_funs( + tiles=reduced_tiles.get(reduction, [None for _ in masks]), masks=masks, + cell_funs=cell_funs, **kwargs, ) - for red_fun, cell_properties in tree_branch.items() + for reduction, cell_funs in reduction_cell_funs.items() } return d @@ -415,21 +449,22 @@ class Extractor(StepABC): reduced = reduce_z(img, method) return reduced - def make_tree_bits(self, tree): - """Put extraction tree and information for the channels into a dict.""" + def make_tree_dict(self, tree: extraction_tree): + """Put extraction tree into a dict.""" if tree is None: # use default - tree: extraction_tree = self.params.tree - tree_bits = { + tree = self.params.tree + tree_dict = { + # the whole extraction tree "tree": tree, - # dictionary with channel: {reduction algorithm : metric} - "channel_tree": { + # the extraction tree for fluorescence channels + "channels_tree": { ch: v for ch, v in tree.items() if ch != "general" }, } # tuple of the fluorescence channels - tree_bits["tree_channels"] = (*tree_bits["channel_tree"],) - return tree_bits + tree_dict["channels"] = (*tree_dict["channels_tree"],) + return tree_dict def get_masks(self, tp, masks, cells): """Get the masks as a list with an array of masks for each trap.""" @@ -478,7 +513,7 @@ class Extractor(StepABC): return bgs def extract_one_channel( - self, tree_bits, cell_labels, tiles, masks, bgs, **kwargs + self, tree_dict, cell_labels, tiles, masks, bgs, **kwargs ): """ Extract all metrics requiring only a single channel. @@ -489,62 +524,48 @@ class Extractor(StepABC): corrected images. """ d = {} - img_bgsub = {} - for ch, tree_branch in tree_bits["tree"].items(): + for ch, reduction_cell_funs in tree_dict["tree"].items(): # NB ch != is necessary for threading if ch != "general" and tiles is not None and len(tiles): # image data for all traps for a particular channel and # time point arranged as (traps, Z, X, Y) # we use 0 here to access the single time point available - img = tiles[:, tree_bits["tree_channels"].index(ch), 0] + channel_tile = tiles[:, tree_dict["channels"].index(ch), 0] else: - # no reduction applied to bright-field images - img = None + # no reduction applied to "general" - bright-field images + channel_tile = None # apply metrics to image data d[ch] = self.reduce_extract( - traps=img, + tiles=channel_tile, masks=masks, - tree_branch=tree_branch, + reduction_cell_funs=reduction_cell_funs, cell_labels=cell_labels, **kwargs, ) # apply metrics to image data with the background subtracted - if bgs.any() and ch in self.params.sub_bg and img is not None: - # calculate metrics with background subtracted - ch_bs = ch + "_bgsub" - # subtract median background - bgsub_mapping = map( - # move Z to last column to allow subtraction - lambda img, bgs: np.moveaxis(img, 0, -1) - # median of background over all pixels for each Z section - - bn.median(img[:, bgs], axis=1), - img, - bgs, - ) - # apply map and convert to array - mapping_result = np.stack(list(bgsub_mapping)) - # move Z axis back to the second column - img_bgsub[ch_bs] = np.moveaxis(mapping_result, -1, 1) + if ( + bgs.any() + and ch in self.params.sub_bg + and channel_tile is not None + ): # apply metrics to background-corrected data - d[ch_bs] = self.reduce_extract( - traps=img_bgsub[ch_bs], + d[ch + "_bgsub"] = self.reduce_extract( + tiles=self.img_bgsub[ch + "_bgsub"], masks=masks, - tree_branch=tree_bits["channel_tree"][ch], + reduction_cell_funs=tree_dict["channels_tree"][ch], cell_labels=cell_labels, **kwargs, ) - return d, img_bgsub + return d def extract_multiple_channels( - self, tree_bits, cell_labels, tiles, masks, **kwargs + self, tree_dict, cell_labels, tiles, masks, **kwargs ): """ Extract all metrics requiring multiple channels. """ # channels and background corrected channels - available_chs = set(self.img_bgsub.keys()).union( - tree_bits["tree_channels"] - ) + available_chs = set(self.img_bgsub.keys()).union(tree_dict["channels"]) d = {} for name, ( chs, @@ -556,13 +577,13 @@ class Extractor(StepABC): if len(common_chs) == len(chs): channels_stack = np.stack( [ - self.get_imgs(ch, tiles, tree_bits["tree_channels"]) + self.get_imgs(ch, tiles, tree_dict["channels"]) for ch in chs ], axis=-1, ) # reduce in Z - traps = RED_FUNS[reduction_fun](channels_stack, axis=1) + traps = REDUCTION_FUNS[reduction_fun](channels_stack, axis=1) # evaluate multichannel op if name not in d: d[name] = {} @@ -623,28 +644,63 @@ class Extractor(StepABC): of the calculated mean GFP fluorescence for all cells. """ # dict of information from extraction tree - tree_bits = self.make_tree_bits(tree) + tree_dict = self.make_tree_dict(tree) # create a Cells object to extract information from the h5 file - cells = Cells(self.local) + cells = Cells(self.h5path) # find the cell labels as dict with trap_ids as keys cell_labels = self.get_cell_labels(tp, cell_labels, cells) # get masks one per cell per trap masks = self.get_masks(tp, masks, cells) - # find image data at the time point + # find image data for all traps at the time point # stored as an array arranged as (traps, channels, 1, Z, X, Y) - tiles = self.get_tiles(tp, channels=tree_bits["tree_channels"]) + tiles = self.get_tiles(tp, channels=tree_dict["channels"]) # generate boolean masks for background for each trap bgs = self.get_background_masks(masks, tile_size) + # perform background subtraction for all traps at this time point + self.img_bgsub = self.perform_background_subtraction( + tree_dict, tiles, bgs + ) # perform extraction - res_one, self.img_bgsub = self.extract_one_channel( - tree_bits, cell_labels, tiles, masks, bgs, **kwargs + res_one = self.extract_one_channel( + tree_dict, cell_labels, tiles, masks, bgs, **kwargs ) res_multiple = self.extract_multiple_channels( - tree_bits, cell_labels, tiles, masks, **kwargs + tree_dict, cell_labels, tiles, masks, **kwargs ) res = {**res_one, **res_multiple} return res + def perform_background_subtraction(self, tree_dict, tiles, bgs): + """Subtract background for fluorescence channels.""" + img_bgsub = {} + for ch, _ in tree_dict["channels_tree"].items(): + if tiles is not None and len(tiles): + # image data for all traps for a particular channel and + # time point arranged as (traps, Z, X, Y) + # we use 0 here to access the single time point available + channel_tile = tiles[:, tree_dict["channels"].index(ch), 0] + if ( + bgs.any() + and ch in self.params.sub_bg + and channel_tile is not None + ): + # subtract median background + bgsub_mapping = map( + # move Z to last column to allow subtraction + lambda img, bgs: np.moveaxis(img, 0, -1) + # median of background over all pixels for each Z section + - bn.median(img[:, bgs], axis=1), + channel_tile, + bgs, + ) + # apply map and convert to array + mapping_result = np.stack(list(bgsub_mapping)) + # move Z axis back to the second column + img_bgsub[ch + "_bgsub"] = np.moveaxis( + mapping_result, -1, 1 + ) + return img_bgsub + def get_imgs(self, channel: t.Optional[str], tiles, channels=None): """ Return image from a correct source, either raw or bgsub. @@ -719,7 +775,7 @@ class Extractor(StepABC): to="series", tp=tp, ) - # concatenate with data extracted from early time points + # concatenate with data extracted from earlier time points for k in new.keys(): d[k] = pd.concat((d.get(k, None), new[k]), axis=1) # add indices to pd.Series containing the extracted data @@ -748,7 +804,7 @@ class Extractor(StepABC): To the h5 file. """ if path is None: - path = self.local + path = self.h5path self.writer = Writer(path) for extract_name, series in dict_series.items(): dset_path = "/extraction/" + extract_name @@ -796,43 +852,3 @@ def flatten_nesteddict( pd.Series(*v2, name=tp) if to == "series" else v2 ) return d - - -def extraction_params_from_meta( - meta: t.Union[dict, Path, str], extras: t.Collection[str] = ["ph"] -): - """ - Obtain parameters from metadata of the h5 file. - - Compares a list of candidate channels using case-insensitive - regex to identify valid channels. - """ - meta = meta if isinstance(meta, dict) else load_metadata(meta) - base = { - "tree": {"general": {"None": ["area", "volume", "eccentricity"]}}, - "multichannel_ops": {}, - } - candidate_channels = set(global_parameters.possible_imaging_channels) - default_reductions = {"max"} - default_metrics = set(global_parameters.fluorescence_functions) - default_reduction_metrics = { - r: default_metrics for r in default_reductions - } - # default_rm["None"] = ["nuc_conv_3d"] # Uncomment this to add nuc_conv_3d (slow) - extant_fluorescence_ch = [] - for av_channel in candidate_channels: - # find matching channels in metadata - found_channel = find_channel_name(meta.get("channels", []), av_channel) - if found_channel is not None: - extant_fluorescence_ch.append(found_channel) - for ch in extant_fluorescence_ch: - base["tree"][ch] = default_reduction_metrics - base["sub_bg"] = extant_fluorescence_ch - return base - - -def load_metadata(file: t.Union[str, Path], group="/"): - """Get meta data from an h5 file.""" - with h5py.File(file, "r") as f: - meta = dict(f[group].attrs.items()) - return meta