From 18dc1a37f1b580647db6b0f495e8df28f4bebc86 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Wed, 28 Sep 2022 13:35:36 +0100
Subject: [PATCH] fix(extraction): pass axis to div0 mergefun

---
 src/extraction/core/extractor.py            |  2 +-
 src/extraction/core/functions/math_utils.py | 14 ++++++++++++--
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index ecc34f3f..e94bdc4b 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -546,7 +546,7 @@ class Extractor(ProcessABC):
                 )
             ) == len(chs):
                 channels_stack = np.stack(
-                    [self.get_imgs(ch, tiles, tree_chs) for ch in chs]
+                    [self.get_imgs(ch, tiles, tree_chs) for ch in chs], axis=-1
                 )
                 merged = RED_FUNS[merge_fun](channels_stack, axis=-1)
                 d[name] = self.reduce_extract(
diff --git a/src/extraction/core/functions/math_utils.py b/src/extraction/core/functions/math_utils.py
index b94a0897..eeae8e0c 100644
--- a/src/extraction/core/functions/math_utils.py
+++ b/src/extraction/core/functions/math_utils.py
@@ -1,7 +1,7 @@
 import numpy as np
 
 
-def div0(a, b, fill=0):
+def div0(array, fill=0, axis=-1):
     """
     Divide array a by array b.
 
@@ -13,9 +13,19 @@ def div0(a, b, fill=0):
     ----------
     a: array
     b: array
+    fill: float
+    **kwargs: kwargs
     """
+    assert array.shape[axis] == 2, f"Array has the wrong shape in axis {axis}"
+    slices_0, slices_1 = [[slice(None)] * len(array.shape)] * 2
+    slices_0[axis] = 0
+    slices_1[axis] = 1
+
     with np.errstate(divide="ignore", invalid="ignore"):
-        c = np.true_divide(a, b)
+        c = np.true_divide(
+            array[tuple(slices_0)],
+            array[tuple(slices_1)],
+        )
     if np.isscalar(c):
         return c if np.isfinite(c) else fill
     else:
-- 
GitLab