From 18dc1a37f1b580647db6b0f495e8df28f4bebc86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Wed, 28 Sep 2022 13:35:36 +0100 Subject: [PATCH] fix(extraction): pass axis to div0 mergefun --- src/extraction/core/extractor.py | 2 +- src/extraction/core/functions/math_utils.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index ecc34f3f..e94bdc4b 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -546,7 +546,7 @@ class Extractor(ProcessABC): ) ) == len(chs): channels_stack = np.stack( - [self.get_imgs(ch, tiles, tree_chs) for ch in chs] + [self.get_imgs(ch, tiles, tree_chs) for ch in chs], axis=-1 ) merged = RED_FUNS[merge_fun](channels_stack, axis=-1) d[name] = self.reduce_extract( diff --git a/src/extraction/core/functions/math_utils.py b/src/extraction/core/functions/math_utils.py index b94a0897..eeae8e0c 100644 --- a/src/extraction/core/functions/math_utils.py +++ b/src/extraction/core/functions/math_utils.py @@ -1,7 +1,7 @@ import numpy as np -def div0(a, b, fill=0): +def div0(array, fill=0, axis=-1): """ Divide array a by array b. @@ -13,9 +13,19 @@ def div0(a, b, fill=0): ---------- a: array b: array + fill: float + **kwargs: kwargs """ + assert array.shape[axis] == 2, f"Array has the wrong shape in axis {axis}" + slices_0, slices_1 = [[slice(None)] * len(array.shape)] * 2 + slices_0[axis] = 0 + slices_1[axis] = 1 + with np.errstate(divide="ignore", invalid="ignore"): - c = np.true_divide(a, b) + c = np.true_divide( + array[tuple(slices_0)], + array[tuple(slices_1)], + ) if np.isscalar(c): return c if np.isfinite(c) else fill else: -- GitLab