From 614aa578b5975643e0db3f07223646fbbb082e4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Thu, 20 Jan 2022 11:06:26 +0000 Subject: [PATCH] pass segmentation values to extractor directly --- aliby/pipeline.py | 42 ++++++++++++++++++++--- extraction/core/extractor.py | 65 +++++++++++++++++++++++------------- 2 files changed, 80 insertions(+), 27 deletions(-) diff --git a/aliby/pipeline.py b/aliby/pipeline.py index 35d5f070..ab695174 100644 --- a/aliby/pipeline.py +++ b/aliby/pipeline.py @@ -8,14 +8,15 @@ from typing import List from pathlib import Path import traceback -import itertools +from itertools import groupby import yaml from tqdm import tqdm from time import perf_counter +from pathos.multiprocessing import Pool import numpy as np import pandas as pd -from pathos.multiprocessing import Pool +from scipy import ndimage from aliby.experiment import MetaData from aliby.haystack import initialise_tf @@ -29,6 +30,13 @@ from extraction.core.extractor import Extractor, ExtractorParameters from extraction.core.functions.defaults import exparams_from_meta from postprocessor.core.processor import PostProcessor, PostProcessorParameters +logging.basicConfig( + filename="aliby.log", + filemode="w", + format="%(name)s - %(levelname)s - %(message)s", + level=logging.DEBUG, +) + class PipelineParameters(ParametersABC): def __init__(self, general, tiler, baby, extraction, postprocessing): @@ -184,6 +192,7 @@ class Pipeline(ProcessABC): # if True: # not Path(filename).exists(): meta = MetaData(directory, filename) meta.run() + meta.add_omero_id(config["general"]["id"]) tiler = Tiler.from_image( image, TilerParameters.from_dict(config["tiler"]) ) @@ -256,9 +265,15 @@ class Pipeline(ProcessABC): t = perf_counter() bwriter.write(seg, overwrite=["mother_assign"]) logging.debug(f"Timing:Writing-baby:{perf_counter() - t}s") - t = perf_counter() - tmp = ext.run(tps=[i]) + t = perf_counter() + labels, masks = groupby_traps( + seg["trap"], + seg["cell_label"], + seg["edgemasks"], + tiler.n_traps, + ) + tmp = ext.run(tps=[i], masks=masks, labels=labels) logging.debug(f"Timing:Extraction:{perf_counter() - t}s") else: # Stop if more than X% traps are clogged logging.debug( @@ -283,6 +298,9 @@ class Pipeline(ProcessABC): PostProcessor(filename, post_proc_params).run() return True except Exception as e: # bug in the trap getting + logging.exception( + f"Caught exception in worker thread (x = {name}):", exc_info=True + ) print(f"Caught exception in worker thread (x = {name}):") # This prints the type, value, and stack trace of the # current exception being handled. @@ -306,3 +324,19 @@ class Pipeline(ProcessABC): > es_parameters["thresh_trap_clogged"] ).mean() return frac_clogged_traps + + +def groupby_traps(traps, labels, edgemasks, ntraps): + # Group data by traps to pass onto extractor without re-reading hdf5 + iterators = [ + groupby(zip(traps, dset), lambda x: x[0]) for dset in (labels, edgemasks) + ] + label_d = {key: [x[1] for x in group] for key, group in iterators[0]} + mask_d = { + key: np.dstack([ndimage.morphology.binary_fill_holes(x[1]) for x in group]) + for key, group in iterators[1] + } + labels = {i: label_d.get(i, []) for i in range(ntraps)} + masks = {i: mask_d.get(i, []) for i in range(ntraps)} + + return labels, masks diff --git a/extraction/core/extractor.py b/extraction/core/extractor.py index b447f04a..b3418224 100644 --- a/extraction/core/extractor.py +++ b/extraction/core/extractor.py @@ -1,7 +1,9 @@ import os from pathlib import Path, PosixPath -import pkg_resources + from collections.abc import Iterable +import logging +from time import perf_counter # from copy import copy from typing import Union, List, Dict, Callable @@ -28,6 +30,8 @@ from agora.io.writer import Writer, load_attributes from agora.io.cells import Cells from aliby.tile.tiler import Tiler +import matplotlib.pyplot as plt + CELL_FUNS, TRAPFUNS, FUNS = load_funs() CUSTOM_FUNS, CUSTOM_ARGS = load_custom_args() RED_FUNS = load_redfuns() @@ -87,17 +91,17 @@ class ExtractorParameters(ParametersABC): class Extractor(ProcessABC): """ - Base class to perform feature extraction. - - Parameters - ---------- - parameters: core.extractor Parameters - Parameters that include with channels, reduction and - b extraction functions to use. - store: str - Path to hdf5 storage file. Must contain cell outlines. - tiler: pipeline-core.core.segmentation tiler - Class that contains or fetches the image to be used for segmentation. + Base class to perform feature extraction. + + Parameters + ---------- + parameters: core.extractor Parameters + Parameters that include with channels, reduction and + extraction functions to use. + store: str + Path to hdf5 storage file. Must contain cell outlines. + tiler: pipeline-core.core.segmentation tiler + Class that contains or fetches the image to be used for segmentation. """ default_meta = {"pixel_size": 0.236, "z_size": 0.6, "spacing": 0.6} @@ -300,7 +304,13 @@ class Extractor(ProcessABC): return reduce_z(img, method) def extract_tp( - self, tp: int, tree: dict = None, tile_size: int = 117, **kwargs + self, + tp: int, + tree: dict = None, + tile_size: int = 117, + masks=None, + labels=None, + **kwargs, ) -> dict: """ :param tp: int timepoint from which to extract results @@ -317,18 +327,27 @@ class Extractor(ProcessABC): cells = Cells.hdf(self.local) # labels - raw_labels = cells.labels_at_time(tp) - labels = { - trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps) - } + if labels is None: + raw_labels = cells.labels_at_time(tp) + labels = { + trap_id: raw_labels.get(trap_id, []) for trap_id in range(cells.ntraps) + } # masks - raw_masks = cells.at_time(tp, kind="mask") - - masks = {trap_id: [] for trap_id in range(cells.ntraps)} - for trap_id, cells in raw_masks.items(): - if len(cells): - masks[trap_id] = np.dstack(np.array(cells)).astype(bool) + t = perf_counter() + if masks is None: + raw_masks = cells.at_time(tp, kind="mask") + nmasks = len([y.shape for x in raw_masks.values() for y in x]) + # plt.imshow(np.dstack(raw_masks.get(1, [[]])).sum(axis=2)) + # plt.savefig(f"{tp}.png") + # plt.close() + logging.debug(f"Timing:nmasks:{nmasks}") + logging.debug(f"Timing:MasksFetch:TP_{tp}:{perf_counter() - t}s") + + masks = {trap_id: [] for trap_id in range(cells.ntraps)} + for trap_id, cells in raw_masks.items(): + if len(cells): + masks[trap_id] = np.dstack(np.array(cells)).astype(bool) masks = [np.array(v) for v in masks.values()] -- GitLab