From 43b390905538c514f425d9b6be4758d55158c954 Mon Sep 17 00:00:00 2001
From: Swainlab <peter.swain@ed.ac.uk>
Date: Mon, 14 Aug 2023 16:26:46 +0100
Subject: [PATCH] partway through add multichannel op to extractor

---
 src/extraction/core/extractor.py            | 44 +++++++++++----------
 src/extraction/core/functions/cell.py       | 11 ++++++
 tests/aliby/network/extraction/test_base.py |  6 +--
 3 files changed, 37 insertions(+), 24 deletions(-)

diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index 24cf7171..6d025188 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -140,7 +140,7 @@ class Extractor(StepABC):
                 [c + "_bgsub" for c in self.params.sub_bg]
             )
             # remove any multichannel operations requiring a missing channel
-            for op, (input_ch, _, _) in dict(self.params.multichannel_ops):
+            for op, (input_ch, _, _) in self.params.multichannel_ops.items():
                 if not set(input_ch).issubset(available_channels_bgsub):
                     self.params.multichannel_ops.pop(op)
         self.load_funs()
@@ -306,8 +306,6 @@ class Extractor(StepABC):
         cell_labels: dict
             A dict of cell labels with trap_ids as keys and a list
             of cell labels as values.
-        pos_info: bool
-            Whether to add the position as an index or not.
 
         Returns
         -------
@@ -557,19 +555,18 @@ class Extractor(StepABC):
         """
         Extract using all metrics requiring multiple channels.
         """
+        available_chs = set(self.img_bgsub.keys()).union(
+            tree_bits["tree_channels"]
+        )
         d = {}
         for name, (
             chs,
-            merge_fun,
-            red_metrics,
+            reduction_fun,
+            op,
         ) in self.params.multichannel_ops.items():
-            if len(
-                set(chs).intersection(
-                    set(self.img_bgsub.keys()).union(
-                        tree_bits["tree_channels"]
-                    )
-                )
-            ) == len(chs):
+            common_chs = set(chs).intersection(available_chs)
+            # all required channels should be available
+            if len(common_chs) == len(chs):
                 channels_stack = np.stack(
                     [
                         self.get_imgs(ch, tiles, tree_bits["tree_channels"])
@@ -577,13 +574,18 @@ class Extractor(StepABC):
                     ],
                     axis=-1,
                 )
-                merged = RED_FUNS[merge_fun](channels_stack, axis=-1)
-                d[name] = self.reduce_extract(
-                    red_metrics=red_metrics,
-                    traps=merged,
-                    masks=masks,
-                    cell_labels=cell_labels,
-                    **kwargs,
+                # reduce in Z
+                traps = RED_FUNS[reduction_fun](channels_stack, axis=1)
+                # evaluate multichannel op
+                if name not in d:
+                    d[name] = {}
+                if reduction_fun not in d[name]:
+                    d[name][reduction_fun] = {}
+                d[name][reduction_fun][op] = self.extract_traps(
+                    traps,
+                    masks,
+                    op,
+                    cell_labels,
                 )
         return d
 
@@ -652,10 +654,10 @@ class Extractor(StepABC):
         res_one, self.img_bgsub = self.extract_one_channel(
             tree_bits, cell_labels, tiles, masks, bgs, **kwargs
         )
-        res_two = self.extract_multiple_channels(
+        res_multiple = self.extract_multiple_channels(
             tree_bits, cell_labels, tiles, masks, **kwargs
         )
-        res = {**res_one, **res_two}
+        res = {**res_one, **res_multiple}
         return res
 
     def get_imgs(self, channel: t.Optional[str], tiles, channels=None):
diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py
index 4d97f23a..f668c567 100644
--- a/src/extraction/core/functions/cell.py
+++ b/src/extraction/core/functions/cell.py
@@ -229,3 +229,14 @@ def moment_of_inertia(cell_mask, trap_image):
         return moi
     else:
         return np.nan
+
+
+def ratio(cell_mask, trap_image):
+    """Find the median ratio between two fluorescence channels."""
+    if trap_image.ndim == 3 and trap_image.shape[-1] == 2:
+        fl_1 = trap_image[..., 0][cell_mask]
+        fl_2 = trap_image[..., 1][cell_mask]
+        div = np.median(fl_1 / fl_2)
+    else:
+        div = np.nan
+    return div
diff --git a/tests/aliby/network/extraction/test_base.py b/tests/aliby/network/extraction/test_base.py
index 14209dbd..1ba8b452 100644
--- a/tests/aliby/network/extraction/test_base.py
+++ b/tests/aliby/network/extraction/test_base.py
@@ -56,9 +56,9 @@ def test_extractor(imgs, masks, tree):
     for ch_branches in extractor.params.tree.values():
         print(
             extractor.reduce_extract(
-                red_metrics=ch_branches,
-                traps=[traps],
-                masks=[masks],
+                [traps],
+                [masks],
+                ch_branches,
                 labels={0: labels},
             )
         )
-- 
GitLab