From e55aed3dc4af23e4097738ae8b12fcd485260606 Mon Sep 17 00:00:00 2001
From: Peter Swain <peter.swain@ed.ac.uk>
Date: Sat, 3 Sep 2022 18:16:51 +0100
Subject: [PATCH] more extractor

---
 extraction/core/extractor.py | 271 +++++++++++++++++++++--------------
 1 file changed, 164 insertions(+), 107 deletions(-)

diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py
index 773da141..c2406a16 100644
--- a/extraction/core/extractor.py
+++ b/extraction/core/extractor.py
@@ -166,6 +166,7 @@ class Extractor(ProcessABC):
         return self._channels
 
     @property
+    # Alan: does this work. local is not a string.
     def current_position(self):
         return self.local.split("/")[-1][:-3]
 
@@ -270,52 +271,53 @@ class Extractor(ProcessABC):
         traps: List[np.array],
         masks: List[np.array],
         metric: str,
-        labels: List[int] = None,
+        labels: Dict = None,
     ) -> dict:
         """
-        Apply a function for a whole position.
+        Apply a function to a whole position.
 
         Parameters
         ----------
-        traps: List[np.array]
-            List of images
-        masks: List[np.array]
-            List of masks
+        traps: list of arrays
+            List of images.
+        masks: list of arrays
+            List of masks.
         metric: str
-            Metric to extract
-        labels: List[int]
-            Cell labels to use as indices in output dataFrame
+            Metric to extract.
+        labels: dict
+            A dict of cell labels with trap_ids as keys and a list of cell labels as values.
         pos_info: bool
-            Whether to add the position as an index or not
+            Whether to add the position as an index or not.
 
         Returns
         -------
-        d: dict
-            A dictionary of dataframes
+        res_idx: a tuple of tuples
+            A two-tuple of a tuple of results and a tuple with the corresponding trap_id and cell labels
         """
-
         if labels is None:
+            # Alan: it looks like this will crash if Labels is None
             raise Warning("No labels given. Sorting cells using index.")
-
         cell_fun = True if metric in self._all_cell_funs else False
-
         idx = []
         results = []
-
         for trap_id, (mask_set, trap, lbl_set) in enumerate(
             zip(masks, traps, labels.values())
         ):
-            if len(mask_set):  # ignore empty traps
+            # ignore empty traps
+            if len(mask_set):
+                # apply metric either a cell function or otherwise
                 result = self._all_funs[metric](mask_set, trap)
                 if cell_fun:
+                    # store results for each cell separately
                     for lbl, val in zip(lbl_set, result):
                         results.append(val)
                         idx.append((trap_id, lbl))
                 else:
+                    # background (trap) function
                     results.append(result)
                     idx.append(trap_id)
-
-        return (tuple(results), tuple(idx))
+        res_idx = (tuple(results), tuple(idx))
+        return res_idx
 
     def extract_funs(
         self,
@@ -406,7 +408,7 @@ class Extractor(ProcessABC):
         masks=None,
         labels=None,
         **kwargs,
-    ) -> t.Dict[str, t.Dict[str, pd.Series]]:
+    ) -> t.Dict[str, t.Dict[str, t.Dict[str, tuple]]]:
         """
         Core extraction method for an individual time-point.
 
@@ -420,17 +422,25 @@ class Extractor(ProcessABC):
             For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
         tile_size : int
             Size of the tile to be extracted.
-        masks : np.ndarray
-            A 3d boolean numpy array with dimensions (ncells, tile_size,
+        masks : list of arrays
+            A list of masks per trap with each mask having dimensions (ncells, tile_size,
             tile_size).
-        labels : t.List[t.List[int]]
-            List of lists of ints indicating the ids of masks.
+        labels : dict
+            A dictionary with trap_ids as keys and cell_labels as values.
         **kwargs : keyword arguments
             Passed to extractor.reduce_extract.
 
         Returns
         -------
-        dict
+        d: dict
+            Dictionary of the results with three levels of dictionaries.
+            The first level has channels as keys.
+            The second level has reduction metrics as keys.
+            The third level has cell or background metrics as keys and a two-tuple as values.
+            The first tuple is the result of applying the metrics to a particular cell or trap; the second tuple is either (trap_id, cell_label) for a metric applied to a cell or a trap_id for a metric applied to a trap.
+
+            An example is d["GFP"]["np_max"]["mean"][0], which gives a tuple of the calculated mean GFP fluorescence for all cells.
+
         """
         if tree is None:
             # use default
@@ -465,26 +475,27 @@ class Extractor(ProcessABC):
         # Alan: traps does not appear the best name here!
         traps = self.get_traps(tp, tile_shape=tile_size, channels=tree_chs)
 
-        self.img_bgsub = {}
+        # generate boolean masks for background as a list with one mask per trap
         if self.params.sub_bg:
-            # generate boolean masks for background as a list with one mask per trap
-            bg = [
+            bgs = [
                 ~np.sum(m, axis=2).astype(bool)
                 if np.any(m)
                 else np.zeros((tile_size, tile_size))
                 for m in masks
             ]
 
+        # perform extraction by applying metrics
         d = {}
+        self.img_bgsub = {}
         for ch, red_metrics in tree.items():
-            # image data for all traps and z sections for a particular channel
-            # as an array arranged as (traps, X, Y, Z)
             # NB ch != is necessary for threading
             if ch != "general" and traps is not None and len(traps):
+                # image data for all traps and z sections for a particular channel
+                # as an array arranged as (no traps, X, Y, no Z channels)
                 img = traps[:, tree_chs.index(ch), 0]
             else:
                 img = None
-
+            # apply metrics to image data
             d[ch] = self.reduce_extract(
                 red_metrics=red_metrics,
                 traps=img,
@@ -492,23 +503,21 @@ class Extractor(ProcessABC):
                 labels=labels,
                 **kwargs,
             )
-
-            if (
-                ch in self.params.sub_bg and img is not None
-            ):  # Calculate metrics with subtracted bg
+            # apply metrics to image data with the background subtracted
+            if ch in self.params.sub_bg and img is not None:
+                # calculate metrics with subtracted bg
                 ch_bs = ch + "_bgsub"
-
                 self.img_bgsub[ch_bs] = []
-                for trap, maskset in zip(img, bg):
-
+                for trap, bg in zip(img, bgs):
                     cells_fl = np.zeros_like(trap)
-
-                    is_cell = np.where(maskset)
-                    if len(is_cell[0]):  # skip calculation for empty traps
+                    # Alan: should this not be is_not_cell?
+                    is_cell = np.where(bg)
+                    # skip for empty traps
+                    if len(is_cell[0]):
                         cells_fl = np.median(trap[is_cell], axis=0)
-
+                    # subtract median background
                     self.img_bgsub[ch_bs].append(trap - cells_fl)
-
+                # apply metrics to background-corrected data
                 d[ch_bs] = self.reduce_extract(
                     red_metrics=ch_tree[ch],
                     traps=self.img_bgsub[ch_bs],
@@ -517,7 +526,7 @@ class Extractor(ProcessABC):
                     **kwargs,
                 )
 
-        # Additional operations between multiple channels (e.g. pH calculations)
+        # apply any metrics that use multiple channels (eg pH calculations)
         for name, (
             chs,
             merge_fun,
@@ -544,14 +553,22 @@ class Extractor(ProcessABC):
         """
         Returns the image from a correct source, either raw or bgsub
 
-        :channel: str name of channel to get
-        :img: ndarray (trap_id, channel, tp, tile_size, tile_size, n_zstacks) of standard channels
-        :channels: List of channels
-        """
+        Parameters
+        ----------
+        channel: str
+            Name of channel to get.
+        traps: ndarray
+            An array of the image data having dimensions of (trap_id, channel, tp, tile_size, tile_size, n_zstacks).
+        channels: list of str (optional)
+            List of available channels.
 
+        Returns
+        -------
+        img: ndarray
+            An array of image data with dimensions (no traps, X, Y, no Z channels)
+        """
         if channels is None:
             channels = (*self.params.tree,)
-
         if channel in channels:
             return traps[:, channels.index(channel), 0]
         elif channel in self.img_bgsub:
@@ -559,32 +576,53 @@ class Extractor(ProcessABC):
 
     def run_tp(self, tp, **kwargs):
         """
-        Wrapper to add compatiblibility with other pipeline steps
+        Wrapper to add compatiblibility with other steps of the pipeline.
         """
         return self.run(tps=[tp], **kwargs)
 
     def run(
-        self, tree=None, tps: List[int] = None, save=True, **kwargs
+        self,
+        tree=None,
+        tps: List[int] = None,
+        save=True,
+        **kwargs,
     ) -> dict:
+        """
+        Parameters
+        ----------
+        tree: dict
+            Nested dictionary indicating channels, reduction functions and
+            metrics to be used.
+            For example: {'general': {'None': ['area', 'volume', 'eccentricity']}}
+        tps: list of int (optional)
+            Time points to include.
+        save: boolean (optional)
+            If True, save results to h5 file.
+        kwargs: keyword arguments (optional)
+            Passed to extract_tp.
 
+        Returns
+        -------
+        d: dict
+            A dict of the extracted data with a concatenated string of channel, reduction metric, and cell metric as keys and pd.Series of the extracted data as values.
+        """
         if tree is None:
             tree = self.params.tree
-
         if tps is None:
             tps = list(range(self.meta["time_settings/ntimepoints"][0]))
-
+        # store results in dict
         d = {}
         for tp in tps:
-            new = flatten_nest(
+            # extract for each time point and convert to dict of pd.Series
+            new = flatten_nesteddict(
                 self.extract_tp(tp=tp, tree=tree, **kwargs),
                 to="series",
                 tp=tp,
             )
-
+            # concatenate with data extracted from early time points
             for k in new.keys():
-                n = new[k]
-                d[k] = pd.concat((d.get(k, None), n), axis=1)
-
+                d[k] = pd.concat((d.get(k, None), new[k]), axis=1)
+        # add indices to pd.Series containing the extracted data
         for k in d.keys():
             indices = ["experiment", "position", "trap", "cell_label"]
             idx = (
@@ -593,65 +631,73 @@ class Extractor(ProcessABC):
                 else [indices[-2]]
             )
             d[k].index.names = idx
-
-            toreturn = d
-
+        # save
         if save:
-            self.save_to_hdf(toreturn)
+            self.save_to_hdf(d)
+        return d
 
-        return toreturn
+    # Alan: isn't this identical to run?
+    # def extract_pos(
+    #     self, tree=None, tps: List[int] = None, save=True, **kwargs
+    # ) -> dict:
 
-    def extract_pos(
-        self, tree=None, tps: List[int] = None, save=True, **kwargs
-    ) -> dict:
+    #     if tree is None:
+    #         tree = self.params.tree
 
-        if tree is None:
-            tree = self.params.tree
+    #     if tps is None:
+    #         tps = list(range(self.meta["time_settings/ntimepoints"]))
 
-        if tps is None:
-            tps = list(range(self.meta["time_settings/ntimepoints"]))
+    #     d = {}
+    #     for tp in tps:
+    #         new = flatten_nest(
+    #             self.extract_tp(tp=tp, tree=tree, **kwargs),
+    #             to="series",
+    #             tp=tp,
+    #         )
 
-        d = {}
-        for tp in tps:
-            new = flatten_nest(
-                self.extract_tp(tp=tp, tree=tree, **kwargs),
-                to="series",
-                tp=tp,
-            )
+    #         for k in new.keys():
+    #             n = new[k]
+    #             d[k] = pd.concat((d.get(k, None), n), axis=1)
 
-            for k in new.keys():
-                n = new[k]
-                d[k] = pd.concat((d.get(k, None), n), axis=1)
+    #     for k in d.keys():
+    #         indices = ["experiment", "position", "trap", "cell_label"]
+    #         idx = (
+    #             indices[-d[k].index.nlevels :]
+    #             if d[k].index.nlevels > 1
+    #             else [indices[-2]]
+    #         )
+    #         d[k].index.names = idx
 
-        for k in d.keys():
-            indices = ["experiment", "position", "trap", "cell_label"]
-            idx = (
-                indices[-d[k].index.nlevels :]
-                if d[k].index.nlevels > 1
-                else [indices[-2]]
-            )
-            d[k].index.names = idx
+    #         toreturn = d
 
-            toreturn = d
+    #     if save:
+    #         self.save_to_hdf(toreturn)
 
-        if save:
-            self.save_to_hdf(toreturn)
+    #     return toreturn
 
-        return toreturn
+    def save_to_hdf(self, dict_series, path=None):
+        """
+        Save the extracted data to the h5 file.
 
-    def save_to_hdf(self, group_df, path=None):
+        Parameters
+        ----------
+        dict_series: dict
+            A dictionary of the extracted data, created by run.
+        path: Path (optional)
+            To the h5 file.
+        """
         if path is None:
             path = self.local
-
         self.writer = Writer(path)
-        for path, df in group_df.items():
-            dset_path = "/extraction/" + path
-            self.writer.write(dset_path, df)
+        for extract_name, series in dict_series.items():
+            dset_path = "/extraction/" + extract_name
+            self.writer.write(dset_path, series)
         self.writer.id_cache.clear()
 
     def get_meta(self, flds):
+        # Alan: unsure what this is doing. seems to break for "nuc_conv_3d"
+        # make flds a list
         if not hasattr(flds, "__iter__"):
-            # make flds a list
             flds = [flds]
         meta_short = {k.split("/")[-1]: v for k, v in self.meta.items()}
         return {
@@ -660,14 +706,24 @@ class Extractor(ProcessABC):
 
 
 ### Helpers
-def flatten_nest(nest: dict, to="series", tp: int = None) -> dict:
-    """
-    Convert a nested extraction dict into a dict of series
-    :param nest: dict contained the nested results of extraction
-    :param to: str = 'series' Determine output format, either list or  pd.Series
-    :param tp: int timepoint used to name the series
+def flatten_nesteddict(nest: dict, to="series", tp: int = None) -> dict:
     """
+    Converts a nested extraction dict into a dict of pd.Series
 
+    Parameters
+    ----------
+    nest: dict of dicts
+        Contains the nested results of extraction.
+    to: str (optional)
+        Specifies the format of the output, either pd.Series (default) or a list
+    tp: int
+        Timepoint used to name the pd.Series
+
+    Returns
+    -------
+    d: dict
+        A dict with a concatenated string of channel, reduction metric, and cell metric as keys and either a pd.Series or a list of the corresponding extracted data as values.
+    """
     d = {}
     for k0, v0 in nest.items():
         for k1, v1 in v0.items():
@@ -675,10 +731,10 @@ def flatten_nest(nest: dict, to="series", tp: int = None) -> dict:
                 d["/".join((k0, k1, k2))] = (
                     pd.Series(*v2, name=tp) if to == "series" else v2
                 )
-
     return d
 
 
+# Alan: this no longer seems to be used
 def fill_tree(tree):
     if tree is None:
         return None
@@ -693,8 +749,9 @@ def fill_tree(tree):
 
 
 class hollowExtractor(Extractor):
-    """Extractor that only cares about receiving image and masks,
-    used for testing.
+    """
+    Extractor that only cares about receiving images and masks.
+    Used for testing.
     """
 
     def __init__(self, parameters):
-- 
GitLab