Skip to content
Snippets Groups Projects
Commit 496c2d40 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(extraction): simplify loaders and np->bn

parent 8614df0f
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,6 @@ from extraction.core.functions.distributors import reduce_z, trap_apply
from extraction.core.functions.loaders import (
load_custom_args,
load_funs,
load_mergefuns,
load_redfuns,
)
......@@ -33,7 +32,6 @@ extraction_result = t.Dict[
CELL_FUNS, TRAPFUNS, FUNS = load_funs()
CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
RED_FUNS = load_redfuns()
MERGE_FUNS = load_mergefuns()
# Assign datatype depending on the metric used
# m2type = {"mean": np.float32, "median": np.ubyte, "imBackground": np.ubyte}
......@@ -406,10 +404,10 @@ class Extractor(ProcessABC):
method: function
The reduction function
"""
if method is None:
return img
else:
return reduce_z(img, method)
reduced = img
if method is not None:
reduced = reduce_z(img, method)
return reduced
def extract_tp(
self,
......@@ -547,8 +545,10 @@ class Extractor(ProcessABC):
set(self.img_bgsub.keys()).union(tree_chs)
)
) == len(chs):
imgs = [self.get_imgs(ch, tiles, tree_chs) for ch in chs]
merged = MERGE_FUNS[merge_fun](*imgs)
channels_stack = np.stack(
[self.get_imgs(ch, tiles, tree_chs) for ch in chs]
)
merged = RED_FUNS[merge_fun](channels_stack, axis=-1)
d[name] = self.reduce_extract(
red_metrics=red_metrics,
traps=merged,
......
......@@ -28,7 +28,7 @@ def exparams_from_meta(
"mKO2",
}
default_reductions = {"np_max"}
default_reductions = {"max"}
default_metrics = {
"mean",
"median",
......@@ -62,7 +62,7 @@ def exparams_from_meta(
["GFPFast_bgsub", "pHluorin405_bgsub"],
),
)
for b, y in zip(["em_ratio", "gsum"], ["div0", "np_add"])
for b, y in zip(["em_ratio", "gsum"], ["div0", "add"])
}
for i, v in sets.items():
base["multichannel_ops"][i] = [
......
import typing as t
import bottleneck as bn
import numpy as np
......@@ -22,7 +25,7 @@ def trap_apply(cell_fun, cell_masks, *args, **kwargs):
return [cell_fun(cell_masks[..., i], *args, **kwargs) for i in cells_iter]
def reduce_z(trap_image, fun):
def reduce_z(trap_image: np.ndarray, fun: t.Callable):
"""
Reduce the trap_image to 2d.
......@@ -32,9 +35,15 @@ def reduce_z(trap_image, fun):
Images for all the channels associated with a trap
fun: function
Function to execute the reduction
"""
if isinstance(fun, np.ufunc):
# FUTURE replace with py3.10's match-case.
if (
hasattr(fun, "__module__") and fun.__module__[:10] == "bottleneck"
): # Bottleneck type
return getattr(bn.reduce, fun.__name__)(trap_image, axis=2)
elif isinstance(fun, np.ufunc):
# optimise the reduction function if possible
return fun.reduce(trap_image, axis=2)
else:
else: # WARNING: Very slow, only use when no alternatives exist
return np.apply_along_axis(fun, 2, trap_image)
import typing as t
from inspect import getfullargspec, getmembers, isfunction
from types import FunctionType
from inspect import getfullargspec, getmembers, isfunction, isbuiltin
import numpy as np
import bottleneck as bn
from extraction.core.functions import cell, trap
from extraction.core.functions.custom import localisation
......@@ -102,22 +103,30 @@ def load_funs():
return CELLFUNS, TRAPFUNS, {**TRAPFUNS, **CELLFUNS}
def load_redfuns(): # TODO make defining reduction functions more flexible
def load_redfuns(
additional_reducers: t.Optional[
t.Union[t.Dict[str, t.Callable], t.Callable]
] = None,
) -> t.Dict[str, t.Callable]:
"""
Load functions to reduce the z-stack to two dimensions.
Load functions to reduce a multidimensional image by one dimension.
It can take custom functions as arguments.
"""
RED_FUNS = {
"np_max": np.maximum,
"np_mean": np.mean,
"np_median": np.median,
"max": bn.nanmax,
"mean": bn.nanmean,
"median": bn.nanmedian,
"div0": div0,
"add": bn.nansum,
"None": None,
}
return RED_FUNS
if additional_reducers is not None:
if isinstance(additional_reducers, FunctionType):
additional_reducers = [
(additional_reducers.__name__, additional_reducers)
]
RED_FUNS.update(name, fun)
def load_mergefuns():
"""
Load functions to merge multiple channels
"""
MERGE_FUNS = {"div0": div0, "np_add": np.add}
return MERGE_FUNS
return RED_FUNS
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment