From 614aa578b5975643e0db3f07223646fbbb082e4d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Thu, 20 Jan 2022 11:06:26 +0000
Subject: [PATCH] pass segmentation values to extractor directly

---
 aliby/pipeline.py            | 42 ++++++++++++++++++++---
 extraction/core/extractor.py | 65 +++++++++++++++++++++++-------------
 2 files changed, 80 insertions(+), 27 deletions(-)

diff --git a/aliby/pipeline.py b/aliby/pipeline.py
index 35d5f070..ab695174 100644
--- a/aliby/pipeline.py
+++ b/aliby/pipeline.py
@@ -8,14 +8,15 @@ from typing import List
 from pathlib import Path
 import traceback
 
-import itertools
+from itertools import groupby
 import yaml
 from tqdm import tqdm
 from time import perf_counter
+from pathos.multiprocessing import Pool
 
 import numpy as np
 import pandas as pd
-from pathos.multiprocessing import Pool
+from scipy import ndimage
 
 from aliby.experiment import MetaData
 from aliby.haystack import initialise_tf
@@ -29,6 +30,13 @@ from extraction.core.extractor import Extractor, ExtractorParameters
 from extraction.core.functions.defaults import exparams_from_meta
 from postprocessor.core.processor import PostProcessor, PostProcessorParameters
 
+logging.basicConfig(
+    filename="aliby.log",
+    filemode="w",
+    format="%(name)s - %(levelname)s - %(message)s",
+    level=logging.DEBUG,
+)
+
 
 class PipelineParameters(ParametersABC):
     def __init__(self, general, tiler, baby, extraction, postprocessing):
@@ -184,6 +192,7 @@ class Pipeline(ProcessABC):
                 # if True:  # not Path(filename).exists():
                 meta = MetaData(directory, filename)
                 meta.run()
+                meta.add_omero_id(config["general"]["id"])
                 tiler = Tiler.from_image(
                     image, TilerParameters.from_dict(config["tiler"])
                 )
@@ -256,9 +265,15 @@ class Pipeline(ProcessABC):
                         t = perf_counter()
                         bwriter.write(seg, overwrite=["mother_assign"])
                         logging.debug(f"Timing:Writing-baby:{perf_counter() - t}s")
-                        t = perf_counter()
 
-                        tmp = ext.run(tps=[i])
+                        t = perf_counter()
+                        labels, masks = groupby_traps(
+                            seg["trap"],
+                            seg["cell_label"],
+                            seg["edgemasks"],
+                            tiler.n_traps,
+                        )
+                        tmp = ext.run(tps=[i], masks=masks, labels=labels)
                         logging.debug(f"Timing:Extraction:{perf_counter() - t}s")
                     else:  # Stop if more than X% traps are clogged
                         logging.debug(
@@ -283,6 +298,9 @@ class Pipeline(ProcessABC):
                 PostProcessor(filename, post_proc_params).run()
                 return True
         except Exception as e:  # bug in the trap getting
+            logging.exception(
+                f"Caught exception in worker thread (x = {name}):", exc_info=True
+            )
             print(f"Caught exception in worker thread (x = {name}):")
             # This prints the type, value, and stack trace of the
             # current exception being handled.
@@ -306,3 +324,19 @@ class Pipeline(ProcessABC):
             > es_parameters["thresh_trap_clogged"]
         ).mean()
         return frac_clogged_traps
+
+
+def groupby_traps(traps, labels, edgemasks, ntraps):
+    # Group data by traps to pass onto extractor without re-reading hdf5
+    iterators = [
+        groupby(zip(traps, dset), lambda x: x[0]) for dset in (labels, edgemasks)
+    ]
+    label_d = {key: [x[1] for x in group] for key, group in iterators[0]}
+    mask_d = {
+        key: np.dstack([ndimage.morphology.binary_fill_holes(x[1]) for x in group])
+        for key, group in iterators[1]
+    }
+    labels = {i: label_d.get(i, []) for i in range(ntraps)}
+    masks = {i: mask_d.get(i, []) for i in range(ntraps)}
+
+    return labels, masks
diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py
index b447f04a..b3418224 100644
--- a/extraction/core/extractor.py
+++ b/extraction/core/extractor.py
@@ -1,7 +1,9 @@
 import os
 from pathlib import Path, PosixPath
-import pkg_resources
+
 from collections.abc import Iterable
+import logging
+from time import perf_counter
 
 # from copy import copy
 from typing import Union, List, Dict, Callable
@@ -28,6 +30,8 @@ from agora.io.writer import Writer, load_attributes
 from agora.io.cells import Cells
 from aliby.tile.tiler import Tiler
 
+import matplotlib.pyplot as plt
+
 CELL_FUNS, TRAPFUNS, FUNS = load_funs()
 CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args()
 RED_FUNS = load_redfuns()
@@ -87,17 +91,17 @@ class ExtractorParameters(ParametersABC):
 
 class Extractor(ProcessABC):
     """
-        Base class to perform feature extraction.
-
-        Parameters
-        ----------
-        parameters: core.extractor Parameters
-            Parameters that include with channels, reduction and
-    b            extraction functions to use.
-        store: str
-            Path to hdf5 storage file. Must contain cell outlines.
-        tiler: pipeline-core.core.segmentation tiler
-            Class that contains or fetches the image to be used for segmentation.
+    Base class to perform feature extraction.
+
+    Parameters
+    ----------
+    parameters: core.extractor Parameters
+        Parameters that include with channels, reduction and
+            extraction functions to use.
+    store: str
+        Path to hdf5 storage file. Must contain cell outlines.
+    tiler: pipeline-core.core.segmentation tiler
+        Class that contains or fetches the image to be used for segmentation.
     """
 
     default_meta = {"pixel_size": 0.236, "z_size": 0.6, "spacing": 0.6}
@@ -300,7 +304,13 @@ class Extractor(ProcessABC):
         return reduce_z(img, method)
 
     def extract_tp(
-        self, tp: int, tree: dict = None, tile_size: int = 117, **kwargs
+        self,
+        tp: int,
+        tree: dict = None,
+        tile_size: int = 117,
+        masks=None,
+        labels=None,
+        **kwargs,
     ) -> dict:
         """
         :param tp: int timepoint from which to extract results
@@ -317,18 +327,27 @@ class Extractor(ProcessABC):
         cells = Cells.hdf(self.local)
 
         # labels
-        raw_labels = cells.labels_at_time(tp)
-        labels = {
-            trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps)
-        }
+        if labels is None:
+            raw_labels = cells.labels_at_time(tp)
+            labels = {
+                trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps)
+            }
 
         # masks
-        raw_masks = cells.at_time(tp, kind="mask")
-
-        masks = {trap_id: [] for trap_id in range(cells.ntraps)}
-        for trap_id, cells in raw_masks.items():
-            if len(cells):
-                masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
+        t = perf_counter()
+        if masks is None:
+            raw_masks = cells.at_time(tp, kind="mask")
+            nmasks = len([y.shape for x in raw_masks.values() for y in x])
+            # plt.imshow(np.dstack(raw_masks.get(1, [[]])).sum(axis=2))
+            # plt.savefig(f"{tp}.png")
+            # plt.close()
+            logging.debug(f"Timing:nmasks:{nmasks}")
+            logging.debug(f"Timing:MasksFetch:TP_{tp}:{perf_counter() - t}s")
+
+            masks = {trap_id: [] for trap_id in range(cells.ntraps)}
+            for trap_id, cells in raw_masks.items():
+                if len(cells):
+                    masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
 
         masks = [np.array(v) for v in masks.values()]
 
-- 
GitLab