From 54a7513159f0a8f264faa870167766fbf677e6a7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Wed, 7 Sep 2022 11:38:41 +0100
Subject: [PATCH] style(extractor): Improve typing

---
 extraction/core/extractor.py | 21 ++++++++++-----------
 1 file changed, 10 insertions(+), 11 deletions(-)

diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py
index bce0ef85..fca1e1e8 100644
--- a/extraction/core/extractor.py
+++ b/extraction/core/extractor.py
@@ -46,7 +46,7 @@ class ExtractorParameters(ParametersABC):
 
     def __init__(
         self,
-        tree: t.Dict[str, t.Dict[reduction_method, t.Collection[str]]],
+        tree: extraction_tree,
         sub_bg: set = set(),
         multichannel_ops: t.Dict = {},
     ):
@@ -260,12 +260,11 @@ class Extractor(ProcessABC):
             # a subset of channels was specified
             channel_ids = [self.tiler.get_channel_index(ch) for ch in channels]
         else:
-            # oh oh
+            # a list of the indices of the z stacks
             channel_ids = None
-        # a list of the indices of the z stacks
         if z is None:
-            z = list(range(self.tiler.shape[-1]))
-        # gets the data via tiler
+            # gets the tiles data via tiler
+            z: t.List[int] = list(range(self.tiler.shape[-1]))
         tiles = (
             self.tiler.get_tiles_timepoint(
                 tp, channels=channel_ids, z=z, **kwargs
@@ -456,7 +455,7 @@ class Extractor(ProcessABC):
         """
         if tree is None:
             # use default
-            tree = self.params.tree
+            tree: extraction_tree = self.params.tree
         # dictionary with channel: {reduction algorithm : metric}
         ch_tree = {ch: v for ch, v in tree.items() if ch != "general"}
         # tuple of the channels
@@ -484,10 +483,10 @@ class Extractor(ProcessABC):
 
         # find image data at the time point
         # stored as an array arranged as (traps, channels, timepoints, X, Y, Z)
-        # Alan: traps does not appear the best name here!
-        traps = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs)
+        tiles = self.get_tiles(tp, tile_shape=tile_size, channels=tree_chs)
 
         # generate boolean masks for background as a list with one mask per trap
+        bgs = []
         if self.params.sub_bg:
             bgs = [
                 ~np.sum(m, axis=2).astype(bool)
@@ -501,10 +500,10 @@ class Extractor(ProcessABC):
         self.img_bgsub = {}
         for ch, red_metrics in tree.items():
             # NB ch != is necessary for threading
-            if ch != "general" and traps is not None and len(traps):
+            if ch != "general" and tiles is not None and len(traps):
                 # image data for all traps and z sections for a particular channel
                 # as an array arranged as (no traps, X, Y, no Z channels)
-                img = traps[:, tree_chs.index(ch), 0]
+                img = tiles[:, tree_chs.index(ch), 0]
             else:
                 img = None
             # apply metrics to image data
@@ -516,7 +515,7 @@ class Extractor(ProcessABC):
                 **kwargs,
             )
             # apply metrics to image data with the background subtracted
-            if ch in self.params.sub_bg and img is not None:
+            if bgs and ch in self.params.sub_bg and img is not None:
                 # calculate metrics with subtracted bg
                 ch_bs = ch + "_bgsub"
                 self.img_bgsub[ch_bs] = []
-- 
GitLab