From 4fc31051cbd791378bcbbdd941173b3f93ff4a04 Mon Sep 17 00:00:00 2001
From: Peter Swain <peter.swain@ed.ac.uk>
Date: Sun, 29 Oct 2023 16:09:44 +0000
Subject: [PATCH] extended bgsub to multichannel functions

---
 src/extraction/core/extractor.py      | 159 ++++++++++++--------------
 src/extraction/core/functions/cell.py |   9 +-
 2 files changed, 80 insertions(+), 88 deletions(-)

diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index ba1d491..44677b3 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -306,15 +306,15 @@ class Extractor(StepABC):
         # tiles has dimensions (tiles, channels, 1, Z, X, Y)
         return tiles
 
-    def extract_traps(
+    def apply_cell_function(
         self,
         traps: t.List[np.ndarray],
         masks: t.List[np.ndarray],
-        cell_property: str,
+        cell_function: str,
         cell_labels: t.Dict[int, t.List[int]],
     ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
         """
-        Apply a function to a whole position.
+        Apply a cell function to all cells at all traps for one time point.
 
         Parameters
         ----------
@@ -322,11 +322,11 @@ class Extractor(StepABC):
             t.List of images.
         masks: list of arrays
             t.List of masks.
-        cell_property: str
-            Property to extract, including imBackground.
+        cell_function: str
+            Function to apply.
         cell_labels: dict
-            A dict of cell labels with trap_ids as keys and a list
-            of cell labels as values.
+            A dict with trap_ids as keys and a list of cell labels as
+            values.
 
         Returns
         -------
@@ -336,7 +336,7 @@ class Extractor(StepABC):
         """
         if cell_labels is None:
             self._log("No cell labels given. Sorting cells using index.")
-        cell_fun = True if cell_property in self.all_cell_funs else False
+        cell_fun = True if cell_function in self.all_cell_funs else False
         idx = []
         results = []
         for trap_id, (mask_set, trap, local_cell_labels) in enumerate(
@@ -345,7 +345,7 @@ class Extractor(StepABC):
             # ignore empty traps
             if len(mask_set):
                 # find property from the tile
-                result = self.all_funs[cell_property](mask_set, trap)
+                result = self.all_funs[cell_function](mask_set, trap)
                 if cell_fun:
                     # store results for each cell separately
                     for cell_label, val in zip(local_cell_labels, result):
@@ -371,8 +371,8 @@ class Extractor(StepABC):
         Data from one time point is used.
         """
         d = {
-            cell_fun: self.extract_traps(
-                traps=tiles, masks=masks, cell_property=cell_fun, **kwargs
+            cell_fun: self.apply_cell_function(
+                traps=tiles, masks=masks, cell_function=cell_fun, **kwargs
             )
             for cell_fun in cell_funs
         }
@@ -513,88 +513,66 @@ class Extractor(StepABC):
         return bgs
 
     def extract_one_channel(
-        self, tree_dict, cell_labels, tiles, masks, bgs, **kwargs
+        self, tree_dict, cell_labels, img, img_bgsub, masks, **kwargs
     ):
-        """
-        Extract all metrics requiring only a single channel.
-
-        Apply first without and then with background subtraction.
-
-        Return the extraction results and a dict of background
-        corrected images.
-        """
+        """Extract as dict all metrics requiring only a single channel."""
         d = {}
         for ch, reduction_cell_funs in tree_dict["tree"].items():
-            # NB ch != is necessary for threading
-            if ch != "general" and tiles is not None and len(tiles):
-                # image data for all traps for a particular channel and
-                # time point arranged as (traps, Z, X, Y)
-                # we use 0 here to access the single time point available
-                channel_tile = tiles[:, tree_dict["channels"].index(ch), 0]
-            else:
-                # no reduction applied to "general" - bright-field images
-                channel_tile = None
-            # apply metrics to image data
+            # extract from all images including bright field
             d[ch] = self.reduce_extract(
-                tiles=channel_tile,
+                # use None for "general"; no fluorescence image
+                tiles=img.get(ch, None),
                 masks=masks,
                 reduction_cell_funs=reduction_cell_funs,
                 cell_labels=cell_labels,
                 **kwargs,
             )
-            # apply metrics to image data with the background subtracted
-            if (
-                bgs.any()
-                and ch in self.params.sub_bg
-                and channel_tile is not None
-            ):
-                # apply metrics to background-corrected data
+            if ch != "general":
+                # extract from background-corrected fluorescence images
                 d[ch + "_bgsub"] = self.reduce_extract(
-                    tiles=self.img_bgsub[ch + "_bgsub"],
+                    tiles=img_bgsub[ch + "_bgsub"],
                     masks=masks,
-                    reduction_cell_funs=tree_dict["channels_tree"][ch],
+                    reduction_cell_funs=reduction_cell_funs,
                     cell_labels=cell_labels,
                     **kwargs,
                 )
         return d
 
-    def extract_multiple_channels(
-        self, tree_dict, cell_labels, tiles, masks, **kwargs
-    ):
-        """
-        Extract all metrics requiring multiple channels.
-        """
-        # channels and background corrected channels
-        available_chs = set(self.img_bgsub.keys()).union(tree_dict["channels"])
+    def extract_multiple_channels(self, cell_labels, img, img_bgsub, masks):
+        """Extract as a dict all metrics requiring multiple channels."""
+        # NB multichannel functions do not use tree_dict
+        available_channels = set(list(img.keys()) + list(img_bgsub.keys()))
         d = {}
-        for name, (
-            chs,
-            reduction_fun,
-            op,
+        for multichannel_fun_name, (
+            channels,
+            reduction,
+            multichannel_function,
         ) in self.params.multichannel_ops.items():
-            common_chs = set(chs).intersection(available_chs)
+            common_channels = set(channels).intersection(available_channels)
             # all required channels should be available
-            if len(common_chs) == len(chs):
-                channels_stack = np.stack(
-                    [
-                        self.get_imgs(ch, tiles, tree_dict["channels"])
-                        for ch in chs
-                    ],
-                    axis=-1,
-                )
-                # reduce in Z
-                traps = REDUCTION_FUNS[reduction_fun](channels_stack, axis=1)
-                # evaluate multichannel op
-                if name not in d:
-                    d[name] = {}
-                if reduction_fun not in d[name]:
-                    d[name][reduction_fun] = {}
-                d[name][reduction_fun][op] = self.extract_traps(
-                    traps,
-                    masks,
-                    op,
-                    cell_labels,
-                )
+            if len(common_channels) == len(channels):
+                for images, suffix in zip([img, img_bgsub], ["", "_bgsub"]):
+                    # channels
+                    channels_stack = np.stack(
+                        [images[ch + suffix] for ch in channels],
+                        axis=-1,
+                    )
+                    # reduce in Z
+                    tiles = REDUCTION_FUNS[reduction](channels_stack, axis=1)
+                    # set up dict
+                    if multichannel_fun_name not in d:
+                        d[multichannel_fun_name] = {}
+                    if reduction not in d[multichannel_fun_name]:
+                        d[multichannel_fun_name][reduction] = {}
+                    # apply multichannel function
+                    d[multichannel_fun_name][reduction][
+                        multichannel_function + suffix
+                    ] = self.apply_cell_function(
+                        tiles,
+                        masks,
+                        multichannel_function,
+                        cell_labels,
+                    )
         return d
 
     def extract_tp(
@@ -656,33 +634,41 @@ class Extractor(StepABC):
         tiles = self.get_tiles(tp, channels=tree_dict["channels"])
         # generate boolean masks for background for each trap
         bgs = self.get_background_masks(masks, tile_size)
-        # perform background subtraction for all traps at this time point
-        self.img_bgsub = self.perform_background_subtraction(
+        # get images and background corrected images as dicts
+        # with fluorescnce channels as keys
+        img, img_bgsub = self.get_imgs_background_subtract(
             tree_dict, tiles, bgs
         )
         # perform extraction
         res_one = self.extract_one_channel(
-            tree_dict, cell_labels, tiles, masks, bgs, **kwargs
+            tree_dict, cell_labels, img, img_bgsub, masks, **kwargs
         )
         res_multiple = self.extract_multiple_channels(
-            tree_dict, cell_labels, tiles, masks, **kwargs
+            cell_labels, img, img_bgsub, masks
         )
         res = {**res_one, **res_multiple}
         return res
 
-    def perform_background_subtraction(self, tree_dict, tiles, bgs):
-        """Subtract background for fluorescence channels."""
+    def get_imgs_background_subtract(self, tree_dict, tiles, bgs):
+        """
+        Get two dicts of fluorescence images.
+
+        Return images and background subtracted image for all traps
+        for one time point.
+        """
+        img = {}
         img_bgsub = {}
         for ch, _ in tree_dict["channels_tree"].items():
+            # NB ch != is necessary for threading
             if tiles is not None and len(tiles):
                 # image data for all traps for a particular channel and
                 # time point arranged as (traps, Z, X, Y)
                 # we use 0 here to access the single time point available
-                channel_tile = tiles[:, tree_dict["channels"].index(ch), 0]
+                img[ch] = tiles[:, tree_dict["channels"].index(ch), 0]
                 if (
                     bgs.any()
                     and ch in self.params.sub_bg
-                    and channel_tile is not None
+                    and img[ch] is not None
                 ):
                     # subtract median background
                     bgsub_mapping = map(
@@ -690,7 +676,7 @@ class Extractor(StepABC):
                         lambda img, bgs: np.moveaxis(img, 0, -1)
                         # median of background over all pixels for each Z section
                         - bn.median(img[:, bgs], axis=1),
-                        channel_tile,
+                        img[ch],
                         bgs,
                     )
                     # apply map and convert to array
@@ -699,9 +685,12 @@ class Extractor(StepABC):
                     img_bgsub[ch + "_bgsub"] = np.moveaxis(
                         mapping_result, -1, 1
                     )
-        return img_bgsub
+            else:
+                img[ch] = None
+                img_bgsub[ch] = None
+        return img, img_bgsub
 
-    def get_imgs(self, channel: t.Optional[str], tiles, channels=None):
+    def get_imgs_old(self, channel: t.Optional[str], tiles, channels=None):
         """
         Return image from a correct source, either raw or bgsub.
 
diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py
index 21f9545..5de868e 100644
--- a/src/extraction/core/functions/cell.py
+++ b/src/extraction/core/functions/cell.py
@@ -239,9 +239,12 @@ def moment_of_inertia(cell_mask, trap_image):
 def ratio(cell_mask, trap_image):
     """Find the median ratio between two fluorescence channels."""
     if trap_image.ndim == 3 and trap_image.shape[-1] == 2:
-        fl_1 = trap_image[..., 0][cell_mask]
-        fl_2 = trap_image[..., 1][cell_mask]
-        div = np.median(fl_1 / fl_2)
+        fl_0 = trap_image[..., 0][cell_mask]
+        fl_1 = trap_image[..., 1][cell_mask]
+        if np.any(fl_1 == 0):
+            div = np.nan
+        else:
+            div = np.median(fl_0 / fl_1)
     else:
         div = np.nan
     return div
-- 
GitLab