From 747cecdc13a1ec52121bf5c2145360dabf209520 Mon Sep 17 00:00:00 2001
From: Peter Swain <peter.swain@ed.ac.uk>
Date: Fri, 2 Sep 2022 20:45:22 +0100
Subject: [PATCH] more extractor

---
 aliby/tile/tiler.py          | 29 ++++++++++---
 extraction/core/extractor.py | 80 +++++++++++++++++++++++++-----------
 2 files changed, 81 insertions(+), 28 deletions(-)

diff --git a/aliby/tile/tiler.py b/aliby/tile/tiler.py
index a5b5f4fb..8b5a679a 100644
--- a/aliby/tile/tiler.py
+++ b/aliby/tile/tiler.py
@@ -15,7 +15,7 @@ One key method is Tiler.run.
 
 The image-processing is performed by traps/segment_traps.
 
-The experiment is stored as an array wuth a standard indexing order of (Time, Channels, Z-stack, Y, X).
+The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, Y, X).
 """
 import warnings
 from functools import lru_cache
@@ -519,6 +519,24 @@ class Tiler(ProcessABC):
         """
         Get a multidimensional array with all tiles for a set of channels
         and z-stacks.
+
+        Used by extractor.
+
+        Parameters
+        ---------
+        tp: int
+            Index of time point
+        tile_shape: int or tuple of two ints
+            Size of tile in x and y dimensions
+        channels: string or list of strings
+            Names of channels of interest
+        z: int
+            Index of z-channel of interest
+
+        Returns
+        -------
+        res: array
+            Data arranged as (traps, channels, timepoints, X, Y, Z)
         """
         # FIXME add support for subtiling trap
         # FIXME can we ignore z(always  give)
@@ -526,12 +544,13 @@ class Tiler(ProcessABC):
             channels = [0]
         elif isinstance(channels, str):
             channels = [channels]
+        # get the data
         res = []
         for c in channels:
-            val = self.get_tp_data(tp, c)[:, z]  # Only return requested z
-            # positions
-            # Starts at traps, z, y, x
-            # Turn to Trap, C, T, X, Y, Z order
+            # only return requested z
+            val = self.get_tp_data(tp, c)[:, z]
+            # starts with the order: traps, z, y, x
+            # returns the order: trap, C, T, X, Y, Z
             val = val.swapaxes(1, 3).swapaxes(1, 2)
             val = np.expand_dims(val, axis=1)
             res.append(val)
diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py
index 527267e7..773da141 100644
--- a/extraction/core/extractor.py
+++ b/extraction/core/extractor.py
@@ -221,7 +221,11 @@ class Extractor(ProcessABC):
         self.meta = load_attributes(self.local)
 
     def get_traps(
-        self, tp: int, channels: list = None, z: list = None, **kwargs
+        self,
+        tp: int,
+        channels: list = None,
+        z: list = None,
+        **kwargs,
     ) -> tuple:
         """
         Finds traps for a given time point and given channels and z-stacks.
@@ -250,7 +254,7 @@ class Extractor(ProcessABC):
         # a list of the indices of the z stacks
         if z is None:
             z = list(range(self.tiler.shape[-1]))
-        # find the appropiate traps from tiler
+        # gets the data via tiler
         traps = (
             self.tiler.get_traps_timepoint(
                 tp, channels=channel_ids, z=z, **kwargs
@@ -258,6 +262,7 @@ class Extractor(ProcessABC):
             if channel_ids
             else None
         )
+        # data arranged as (traps, channels, timepoints, X, Y, Z)
         return traps
 
     def extract_traps(
@@ -270,14 +275,23 @@ class Extractor(ProcessABC):
         """
         Apply a function for a whole position.
 
-        :traps: List[np.array] list of images
-        :masks: List[np.array] list of masks
-        :metric:str metric to extract
-        :labels: List[int] cell Labels to use as indices for output DataFrame
-        :pos_info: bool Whether to add the position as index or not
+        Parameters
+        ----------
+        traps: List[np.array]
+            List of images
+        masks: List[np.array]
+            List of masks
+        metric: str
+            Metric to extract
+        labels: List[int]
+            Cell labels to use as indices in output dataFrame
+        pos_info: bool
+            Whether to add the position as an index or not
 
-        returns
-        :d: Dictionary of dataframe
+        Returns
+        -------
+        d: dict
+            A dictionary of dataframes
         """
 
         if labels is None:
@@ -311,7 +325,7 @@ class Extractor(ProcessABC):
         **kwargs,
     ) -> dict:
         """
-        Extract multiple metrics from a timepoint
+        Returns dict with metrics as key and metrics applied to data as values for data from one timepoint.
         """
         d = {
             metric: self.extract_traps(
@@ -333,16 +347,21 @@ class Extractor(ProcessABC):
 
         Parameters
         ----------
-        param red_metrics: dict
-            dict for which keys are reduction funcions and values are strings indicating the metric function
+        traps: array
+            An array of image data arranged as (traps, X, Y, Z)
+        masks: list of arrays
+            An array of masks for each trap: one per cell at the trap
+        red_metrics: dict
+            dict for which keys are reduction functions and values are either a list or a set of strings giving the metric functions.
+            For example: {'np_max': {'max5px', 'mean', 'median'}}
         **kwargs: dict
-            All other arguments and must include masks and traps.
+            All other arguments and must include masks and traps. Alan: stll true?
 
         Returns
         ------
         Dictionary of dataframes with the corresponding reductions and metrics nested.
         """
-        # create dict of traps with reduction in the z-direction
+        # create dict with keys naming the reduction in the z-direction and the reduced data as values
         reduced_traps = {}
         if traps is not None:
             for red_fun in red_metrics.keys():
@@ -364,7 +383,15 @@ class Extractor(ProcessABC):
 
     def reduce_dims(self, img: np.array, method=None) -> np.array:
         """
-        Collapse a z-stack into 2d array. It may perform a null operation.
+        Collapse a z-stack into 2d array using method.
+        If method is None, return the original data.
+
+        Parameters
+        ----------
+        img: array
+            An array of the image data arranged as (X, Y, Z)
+        method: function
+            The reduction function
         """
         if method is None:
             return img
@@ -390,14 +417,16 @@ class Extractor(ProcessABC):
         tree : dict
             Nested dictionary indicating channels, reduction functions and
             metrics to be used.
+            For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
         tile_size : int
-            size of the tile to be extracted.
+            Size of the tile to be extracted.
         masks : np.ndarray
-            A 3-D boolean numpy array with dimensions (ncells, tile_size,
+            A 3d boolean numpy array with dimensions (ncells, tile_size,
             tile_size).
         labels : t.List[t.List[int]]
             List of lists of ints indicating the ids of masks.
-        **kwargs : Additional keyword arguments to be passed to extractor.reduce_extract.
+        **kwargs : keyword arguments
+            Passed to extractor.reduce_extract.
 
         Returns
         -------
@@ -421,7 +450,7 @@ class Extractor(ProcessABC):
                 for trap_id in range(cells.ntraps)
             }
 
-        # find the cell masks as a dict with trap_ids as keys
+        # find the cell masks for a given trap as a dict with trap_ids as keys
         if masks is None:
             raw_masks = cells.at_time(tp, kind="mask")
             masks = {trap_id: [] for trap_id in range(cells.ntraps)}
@@ -431,12 +460,14 @@ class Extractor(ProcessABC):
         # convert to a list of masks
         masks = [np.array(v) for v in masks.values()]
 
-        # find traps at the time point
+        # find image data at the time point
+        # stored as an array arranged as (traps, channels, timepoints, X, Y, Z)
+        # Alan: traps does not appear the best name here!
         traps = self.get_traps(tp, tile_shape=tile_size, channels=tree_chs)
 
         self.img_bgsub = {}
         if self.params.sub_bg:
-            # generate boolean masks for background
+            # generate boolean masks for background as a list with one mask per trap
             bg = [
                 ~np.sum(m, axis=2).astype(bool)
                 if np.any(m)
@@ -446,10 +477,13 @@ class Extractor(ProcessABC):
 
         d = {}
         for ch, red_metrics in tree.items():
-            img = None
-            # ch != is necessary for threading
+            # image data for all traps and z sections for a particular channel
+            # as an array arranged as (traps, X, Y, Z)
+            # NB ch != is necessary for threading
             if ch != "general" and traps is not None and len(traps):
                 img = traps[:, tree_chs.index(ch), 0]
+            else:
+                img = None
 
             d[ch] = self.reduce_extract(
                 red_metrics=red_metrics,
-- 
GitLab