diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index 1b3aa6451b34152dc13779dbf172647ad54cfc63..ecc34f3f232f7e3b87c18225d56360fb6eb4a38c 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -16,7 +16,6 @@ from extraction.core.functions.distributors import reduce_z, trap_apply from extraction.core.functions.loaders import ( load_custom_args, load_funs, - load_mergefuns, load_redfuns, ) @@ -33,7 +32,6 @@ extraction_result = t.Dict[ CELL_FUNS, TRAPFUNS, FUNS = load_funs() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() RED_FUNS = load_redfuns() -MERGE_FUNS = load_mergefuns() # Assign datatype depending on the metric used # m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte} @@ -406,10 +404,10 @@ class Extractor(ProcessABC): method: function The reduction function """ - if method is None: - return img - else: - return reduce_z(img, method) + reduced = img + if method is not None: + reduced = reduce_z(img, method) + return reduced def extract_tp( self, @@ -547,8 +545,10 @@ class Extractor(ProcessABC): set(self.img_bgsub.keys()).union(tree_chs) ) ) == len(chs): - imgs = [self.get_imgs(ch, tiles, tree_chs) for ch in chs] - merged = MERGE_FUNS[merge_fun](*imgs) + channels_stack = np.stack( + [self.get_imgs(ch, tiles, tree_chs) for ch in chs] + ) + merged = RED_FUNS[merge_fun](channels_stack, axis=-1) d[name] = self.reduce_extract( red_metrics=red_metrics, traps=merged, diff --git a/src/extraction/core/functions/defaults.py b/src/extraction/core/functions/defaults.py index 4d60c701c95062784c4853e55ed11b27a398d1f4..ad95d5ff9917f68dd9bd4b810657c9556744f4bb 100644 --- a/src/extraction/core/functions/defaults.py +++ b/src/extraction/core/functions/defaults.py @@ -28,7 +28,7 @@ def exparams_from_meta( "mKO2", } - default_reductions = {"np_max"} + default_reductions = {"max"} default_metrics = { "mean", "median", @@ -62,7 +62,7 @@ def exparams_from_meta( ["GFPFast_bgsub", "pHluorin405_bgsub"], ), ) - for b, y in zip(["em_ratio", "gsum"], ["div0", "np_add"]) + for b, y in zip(["em_ratio", "gsum"], ["div0", "add"]) } for i, v in sets.items(): base["multichannel_ops"][i] = [ diff --git a/src/extraction/core/functions/distributors.py b/src/extraction/core/functions/distributors.py index 23d1ff09bc21a167a3b644a397122f3eaa42efec..40e8fff96b4da6fe417a7ead89d9eb6b54993137 100644 --- a/src/extraction/core/functions/distributors.py +++ b/src/extraction/core/functions/distributors.py @@ -1,3 +1,6 @@ +import typing as t + +import bottleneck as bn import numpy as np @@ -22,7 +25,7 @@ def trap_apply(cell_fun, cell_masks, *args, **kwargs): return [cell_fun(cell_masks[..., i], *args, **kwargs) for i in cells_iter] -def reduce_z(trap_image, fun): +def reduce_z(trap_image: np.ndarray, fun: t.Callable): """ Reduce the trap_image to 2d. @@ -32,9 +35,15 @@ def reduce_z(trap_image, fun): Images for all the channels associated with a trap fun: function Function to execute the reduction + """ - if isinstance(fun, np.ufunc): + # FUTURE replace with py3.10's match-case. + if ( + hasattr(fun, "__module__") and fun.__module__[:10] == "bottleneck" + ): # Bottleneck type + return getattr(bn.reduce, fun.__name__)(trap_image, axis=2) + elif isinstance(fun, np.ufunc): # optimise the reduction function if possible return fun.reduce(trap_image, axis=2) - else: + else: # WARNING: Very slow, only use when no alternatives exist return np.apply_along_axis(fun, 2, trap_image) diff --git a/src/extraction/core/functions/loaders.py b/src/extraction/core/functions/loaders.py index cb26d47ea832a6365bd655d0b3e29e3cc73114cf..9d2e9c479428db9db1189726a7ba8478b9923646 100644 --- a/src/extraction/core/functions/loaders.py +++ b/src/extraction/core/functions/loaders.py @@ -1,7 +1,8 @@ import typing as t -from inspect import getfullargspec, getmembers, isfunction +from types import FunctionType +from inspect import getfullargspec, getmembers, isfunction, isbuiltin -import numpy as np +import bottleneck as bn from extraction.core.functions import cell, trap from extraction.core.functions.custom import localisation @@ -102,22 +103,30 @@ def load_funs(): return CELLFUNS, TRAPFUNS, {**TRAPFUNS, **CELLFUNS} -def load_redfuns(): # TODO make defining reduction functions more flexible +def load_redfuns( + additional_reducers: t.Optional[ + t.Union[t.Dict[str, t.Callable], t.Callable] + ] = None, +) -> t.Dict[str, t.Callable]: """ - Load functions to reduce the z-stack to two dimensions. + Load functions to reduce a multidimensional image by one dimension. + + It can take custom functions as arguments. """ RED_FUNS = { - "np_max": np.maximum, - "np_mean": np.mean, - "np_median": np.median, + "max": bn.nanmax, + "mean": bn.nanmean, + "median": bn.nanmedian, + "div0": div0, + "add": bn.nansum, "None": None, } - return RED_FUNS + if additional_reducers is not None: + if isinstance(additional_reducers, FunctionType): + additional_reducers = [ + (additional_reducers.__name__, additional_reducers) + ] + RED_FUNS.update(name, fun) -def load_mergefuns(): - """ - Load functions to merge multiple channels - """ - MERGE_FUNS = {"div0": div0, "np_add": np.add} - return MERGE_FUNS + return RED_FUNS