From 109b10f2938ec2ff03cad7ad3e61120ac7a5b0d9 Mon Sep 17 00:00:00 2001
From: Swainlab <peter.swain@ed.ac.uk>
Date: Sat, 12 Aug 2023 15:56:54 +0100
Subject: [PATCH] renamed variables in extractor

---
 src/extraction/core/extractor.py         | 68 ++++++++++++------------
 src/extraction/core/functions/loaders.py | 22 ++++----
 2 files changed, 47 insertions(+), 43 deletions(-)

diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index 4fbd5e7f..24cf7171 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -30,7 +30,7 @@ extraction_result = t.Dict[
 # or their background. These global variables both allow the functions
 # to be stored in a dictionary for access only on demand and to be
 # defined simply in extraction/core/functions.
-CELL_FUNS, TRAPFUNS, FUNS = load_funs()
+CELL_FUNS, TRAP_FUNS, ALL_FUNS = load_funs()
 CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
 RED_FUNS = load_redfuns()
 
@@ -234,7 +234,7 @@ class Extractor(StepABC):
         self.load_custom_funs()
         self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS)
         # merge the two dicts
-        self._all_funs = {**self._custom_funs, **FUNS}
+        self._all_funs = {**self._custom_funs, **ALL_FUNS}
 
     def load_meta(self):
         """Load metadata from h5 file."""
@@ -289,7 +289,7 @@ class Extractor(StepABC):
         self,
         traps: t.List[np.ndarray],
         masks: t.List[np.ndarray],
-        metric: str,
+        cell_property: str,
         cell_labels: t.Dict[int, t.List[int]],
     ) -> t.Tuple[t.Union[t.Tuple[float], t.Tuple[t.Tuple[int]]]]:
         """
@@ -301,8 +301,8 @@ class Extractor(StepABC):
             t.List of images.
         masks: list of arrays
             t.List of masks.
-        metric: str
-            Metric to extract.
+        cell_property: str
+            Property to extract, including imBackground.
         cell_labels: dict
             A dict of cell labels with trap_ids as keys and a list
             of cell labels as values.
@@ -317,21 +317,21 @@ class Extractor(StepABC):
         """
         if cell_labels is None:
             self._log("No cell labels given. Sorting cells using index.")
-        cell_fun = True if metric in self._all_cell_funs else False
+        cell_fun = True if cell_property in self._all_cell_funs else False
         idx = []
         results = []
-        for trap_id, (mask_set, trap, lbl_set) in enumerate(
+        for trap_id, (mask_set, trap, local_cell_labels) in enumerate(
             zip(masks, traps, cell_labels.values())
         ):
             # ignore empty traps
             if len(mask_set):
-                # apply metric either a cell function or otherwise
-                result = self._all_funs[metric](mask_set, trap)
+                # find property from the tile
+                result = self._all_funs[cell_property](mask_set, trap)
                 if cell_fun:
                     # store results for each cell separately
-                    for lbl, val in zip(lbl_set, result):
+                    for cell_label, val in zip(local_cell_labels, result):
                         results.append(val)
-                        idx.append((trap_id, lbl))
+                        idx.append((trap_id, cell_label))
                 else:
                     # background (trap) function
                     results.append(result)
@@ -343,19 +343,19 @@ class Extractor(StepABC):
         self,
         traps: t.List[np.array],
         masks: t.List[np.array],
-        metrics: t.List[str],
+        cell_properties: t.List[str],
         **kwargs,
     ) -> t.Dict[str, pd.Series]:
         """
-        Return dict with metrics as key and metrics applied to data as values.
+        Return dict with metrics as key and cell_properties as values.
 
         Data from one time point is used.
         """
         d = {
-            metric: self.extract_traps(
-                traps=traps, masks=masks, metric=metric, **kwargs
+            cell_property: self.extract_traps(
+                traps=traps, masks=masks, cell_property=cell_property, **kwargs
             )
-            for metric in metrics
+            for cell_property in cell_properties
         }
         return d
 
@@ -363,7 +363,7 @@ class Extractor(StepABC):
         self,
         traps: np.ndarray,
         masks: t.List[np.ndarray],
-        red_metrics: t.Dict[reduction_method, t.Collection[str]],
+        tree_branch: t.Dict[reduction_method, t.Collection[str]],
         **kwargs,
     ) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]:
         """
@@ -375,10 +375,10 @@ class Extractor(StepABC):
             An array of image data arranged as (tiles, X, Y, Z)
         masks: list of arrays
             An array of masks for each trap: one per cell at the trap
-        red_metrics: dict
-            dict for which keys are reduction functions and values are
-            either a list or a set of strings giving the metric
-            functions.
+        tree_branch: dict
+            An upper branch of the extraction tree: a dict for which
+            keys are reduction functions and values are either a list
+            or a set of strings giving the cell properties to be found.
             For example: {'np_max': {'max5px', 'mean', 'median'}}
         **kwargs: dict
             All other arguments passed to Extractor.extract_funs.
@@ -392,21 +392,22 @@ class Extractor(StepABC):
             kwargs["cell_labels"] = kwargs.pop("labels")
         # create dict with keys naming the reduction in the z-direction
         # and the reduced data as values
-        reduced_tiles_data = {}
+        reduced_tiles = {}
         if traps is not None:
-            for red_fun in red_metrics.keys():
-                reduced_tiles_data[red_fun] = [
+            for red_fun in tree_branch.keys():
+                reduced_tiles[red_fun] = [
                     self.reduce_dims(tile_data, method=RED_FUNS[red_fun])
                     for tile_data in traps
                 ]
+        # calculate cell and tile properties
         d = {
             red_fun: self.extract_funs(
-                metrics=metrics,
-                traps=reduced_tiles_data.get(red_fun, [None for _ in masks]),
+                cell_properties=cell_properties,
+                traps=reduced_tiles.get(red_fun, [None for _ in masks]),
                 masks=masks,
                 **kwargs,
             )
-            for red_fun, metrics in red_metrics.items()
+            for red_fun, cell_properties in tree_branch.items()
         }
         return d
 
@@ -505,7 +506,7 @@ class Extractor(StepABC):
         """
         d = {}
         img_bgsub = {}
-        for ch, red_metrics in tree_bits["tree"].items():
+        for ch, tree_branch in tree_bits["tree"].items():
             # NB ch != is necessary for threading
             if ch != "general" and tiles is not None and len(tiles):
                 # image data for all traps for a particular channel and time point
@@ -513,12 +514,13 @@ class Extractor(StepABC):
                 # we use 0 here to access the single time point available
                 img = tiles[:, tree_bits["tree_channels"].index(ch), 0]
             else:
+                # no reduction applied to bright-field images
                 img = None
             # apply metrics to image data
             d[ch] = self.reduce_extract(
                 traps=img,
                 masks=masks,
-                red_metrics=red_metrics,
+                tree_branch=tree_branch,
                 cell_labels=cell_labels,
                 **kwargs,
             )
@@ -528,20 +530,20 @@ class Extractor(StepABC):
                 ch_bs = ch + "_bgsub"
                 # subtract median background
                 bgsub_mapping = map(
-                    # move time to last column to allow subtraction
+                    # move Z to last column to allow subtraction
                     lambda img, bgs: np.moveaxis(img, 0, -1)
-                    # median of background over all pixels for each time point
+                    # median of background over all pixels for each Z section
                     - bn.median(img[:, bgs], axis=1),
                     img,
                     bgs,
                 )
                 # apply map and convert to array
                 mapping_result = np.stack(list(bgsub_mapping))
-                # move time axis back to the second column
+                # move Z axis back to the second column
                 img_bgsub[ch_bs] = np.moveaxis(mapping_result, -1, 1)
                 # apply metrics to background-corrected data
                 d[ch_bs] = self.reduce_extract(
-                    red_metrics=tree_bits["channel_tree"][ch],
+                    tree_branch=tree_bits["channel_tree"][ch],
                     traps=img_bgsub[ch_bs],
                     masks=masks,
                     cell_labels=cell_labels,
diff --git a/src/extraction/core/functions/loaders.py b/src/extraction/core/functions/loaders.py
index ce33f845..0547e4f7 100644
--- a/src/extraction/core/functions/loaders.py
+++ b/src/extraction/core/functions/loaders.py
@@ -11,8 +11,10 @@ from extraction.core.functions.math_utils import div0
 
 """
 Load functions for analysing cells and their background.
-Note that inspect.getmembers returns a list of function names and
-functions, and inspect.getfullargspec returns a function's arguments.
+
+Note that inspect.getmembers returns a list of function names
+and functions, and inspect.getfullargspec returns a
+function's arguments.
 """
 
 
@@ -66,7 +68,7 @@ def load_cellfuns():
     # create dict of the core functions from cell.py - these functions apply to a single mask
     cell_funs = load_cellfuns_core()
     # create a dict of functions that apply the core functions to an array of cell_masks
-    CELLFUNS = {}
+    CELL_FUNS = {}
     for f_name, f in cell_funs.items():
         if isfunction(f):
 
@@ -79,27 +81,27 @@ def load_cellfuns():
                     # function that applies f to m and img, the trap_image
                     return lambda m, img: trap_apply(f, m, img)
 
-            CELLFUNS[f_name] = tmp(f)
-    return CELLFUNS
+            CELL_FUNS[f_name] = tmp(f)
+    return CELL_FUNS
 
 
 def load_trapfuns():
     """Load functions that are applied to an entire tile."""
-    TRAPFUNS = {
+    TRAP_FUNS = {
         f[0]: f[1]
         for f in getmembers(trap)
         if isfunction(f[1])
         and f[1].__module__.startswith("extraction.core.functions")
     }
-    return TRAPFUNS
+    return TRAP_FUNS
 
 
 def load_funs():
     """Combine all automatically loaded functions."""
-    CELLFUNS = load_cellfuns()
-    TRAPFUNS = load_trapfuns()
+    CELL_FUNS = load_cellfuns()
+    TRAP_FUNS = load_trapfuns()
     # return dict of cell funs, dict of trap funs, and dict of both
-    return CELLFUNS, TRAPFUNS, {**TRAPFUNS, **CELLFUNS}
+    return CELL_FUNS, TRAP_FUNS, {**TRAP_FUNS, **CELL_FUNS}
 
 
 def load_redfuns(
-- 
GitLab