diff --git a/aliby/pipeline.py b/aliby/pipeline.py index 7f02704b4d8ed8bbb912a1c4ae2218c3ee4bb575..be6936ed2b04e04c368177ede873d6b2e62d07b0 100644 --- a/aliby/pipeline.py +++ b/aliby/pipeline.py @@ -2,17 +2,14 @@ Pipeline and chaining elements. """ import logging +import typing as t import os import re import traceback from copy import copy from itertools import groupby from pathlib import Path, PosixPath - -# from p_tqdm import p_map from time import perf_counter - -# from abc import ABC, abstractmethod from typing import Union import h5py @@ -54,6 +51,8 @@ logging.basicConfig( class PipelineParameters(ParametersABC): + _pool_index = None + def __init__( self, general, tiler, baby, extraction, postprocessing, reporting ): @@ -72,7 +71,7 @@ class PipelineParameters(ParametersABC): baby={}, extraction={}, postprocessing={}, - reporting={}, + # reporting={}, ): """ Load unit test experiment @@ -137,14 +136,16 @@ class PipelineParameters(ParametersABC): defaults["general"][k] = v defaults["tiler"] = TilerParameters.default(**tiler).to_dict() - defaults["baby"] = BabyParameters.default(**extraction).to_dict() - defaults["extraction"] = exparams_from_meta(meta) + defaults["baby"] = BabyParameters.default(**baby).to_dict() + defaults["extraction"] = ( + exparams_from_meta(meta) + or BabyParameters.default(**extraction).to_dict() + ) defaults["postprocessing"] = PostProcessorParameters.default( **postprocessing ).to_dict() defaults["reporting"] = {} - # for k in defaults.keys(): - # exec("defaults[k].update(" + k + ")") + return cls(**{k: v for k, v in defaults.items()}) def load_logs(self): @@ -161,12 +162,15 @@ class Pipeline(ProcessABC): """ iterative_steps = ["tiler", "baby", "extraction"] + step_sequence = [ "tiler", "baby", "extraction", "postprocessing", ] + + # Indicate groupings to perform special operations during step iteration writer_groups = { "tiler": ["trap_info"], "baby": ["cell_info"], @@ -244,7 +248,6 @@ class Pipeline(ProcessABC): pipeline_parameters.general["directory"] = directory pipeline_parameters.general["filter"] = Path(fpath).stem - # Fix legacy postprocessing parameters post_process_params = pipeline_parameters.postprocessing.get( "parameters", None ) @@ -322,13 +325,6 @@ class Pipeline(ProcessABC): # position=0, ) - # results = p_map( - # lambda x: self.create_pipeline(*x), - # [(k, i) for i, k in enumerate(image_ids.items())], - # num_cpus=distributed, - # position=0, - # ) - else: # Sequential results = [] for k, v in tqdm(image_ids.items()): @@ -338,142 +334,42 @@ class Pipeline(ProcessABC): return results def create_pipeline(self, image_id, index=None): - config = self.parameters.to_dict() - pparams = config + self._pool_index = index name, image_id = image_id - general_config = config["general"] session = None - earlystop = general_config.get("earlystop", None) - process_from = {k: 0 for k in self.iterative_steps} - steps = {} - ow = {k: 0 for k in self.step_sequence} - - # check overwriting - ow_id = general_config.get("overwrite", 0) - ow = {step: True for step in self.step_sequence} - if ow_id and ow_id != True: - ow = { - step: self.step_sequence.index(ow_id) < i - for i, step in enumerate(self.step_sequence, 1) - } - + filename = None + run_kwargs = {"extraction": {"labels": None, "masks": None}} try: - # Set up - directory = general_config["directory"] + ( + filename, + meta, + config, + process_from, + tps, + steps, + earlystop, + session, + trackers_state, + ) = self._setup_pipeline(image_id) + + loaded_writers = { + name: writer(filename) + for k in self.step_sequence + if k in self.writers + for name, writer in self.writers[k] + } + writer_ow_kwargs = { + "state": loaded_writers["state"].datatypes.keys(), + "baby": ["mother_assign"], + } - image_wrapper = get_image_class(image_id) + # START PIPELINE + frac_clogged_traps = 0 + min_process_from = min(process_from.values()) - with image_wrapper( + with get_image_class(image_id)( image_id, **self.general.get("server_info", {}) ) as image: - filename = f"{directory}/{image.name}.h5" - meta = MetaData(directory, filename) - - from_start = True - trackers_state = None - if Path(filename).exists(): # If no previous segmentation - - if not ow["tiler"]: # Try to load config from file - try: - with h5py.File(filename, "r") as f: - steps["tiler"] = Tiler.from_hdf5( - image, filename - ) - - legacy_get_last_tp = { # Function to support seg in ver < 0.24 - "tiler": lambda f: f[ - "trap_info/drifts" - ].shape[0] - - 1, - "baby": lambda f: f["cell_info/timepoint"][ - -1 - ], - "extraction": lambda f: f[ - "extraction/general/None/area/timepoint" - ][-1], - } - for k, v in process_from.items(): - if not ow[k]: - process_from[k] = legacy_get_last_tp[ - k - ](f) - # process_from[k] = f[ - # self.writer_groups[k][-1] - # ].attrs.get( - # "last_processed", - # legacy_get_last_tp[k](f), - # ) - process_from[k] += 1 - # get state array - if not ow["baby"]: - trackers_state = StateReader( - filename - ).get_formatted_states() - - config["tiler"] = steps[ - "tiler" - ].parameters.to_dict() - if not np.any(ow.values()): - from_start = False - print( - f"Existing file {filename} will be used." - ) - except Exception as e: - print(e) - - # Delete datasets to overwrite and update pipeline data - with h5py.File(filename, "r") as f: - pparams = PipelineParameters.from_yaml( - f.attrs["parameters"] - ).to_dict() - - with h5py.File(filename, "a") as f: - for k, v in ow.items(): - if v: - for gname in self.writer_groups[k]: - if gname in f: - del f[gname] - - pparams[k] = config[k] - meta.add_fields( - { - "parameters": PipelineParameters.from_dict( - pparams - ).to_yaml() - }, - overwrite=True, - ) - - if from_start: # New experiment or overwriting - if config.get("overwrite", False) == True or np.all( - list(ow.values()) - ): - if Path(filename).exists(): - os.remove(filename) - meta.run() - meta.add_fields( # Add non-logfile metadata - { - "omero_id,": config["general"]["id"], - "image_id": image_id, - "parameters": PipelineParameters.from_dict( - pparams - ).to_yaml(), - } - ) - - tps = min(general_config["tps"], image.data.shape[0]) - - run_kwargs = {"extraction": {"labels": None, "masks": None}} - loaded_writers = { - name: writer(filename) - for k in self.step_sequence - if k in self.writers - for name, writer in self.writers[k] - } - writer_ow_kwargs = { - "state": loaded_writers["state"].datatypes.keys(), - "baby": ["mother_assign"], - } # Initialise Steps if "tiler" not in steps: @@ -508,7 +404,7 @@ class Pipeline(ProcessABC): [c + "_bgsub" for c in config["extraction"]["sub_bg"]] ) tmp = copy(config["extraction"]["multichannel_ops"]) - for op, (input_ch, op_id, red_ext) in tmp.items(): + for op, (input_ch, _, _) in tmp.items(): if not set(input_ch).issubset(av_channels_wsub): del config["extraction"]["multichannel_ops"][op] @@ -518,112 +414,101 @@ class Pipeline(ProcessABC): steps["extraction"] = Extractor.from_tiler( exparams, store=filename, tiler=steps["tiler"] ) + pbar = tqdm( + range(min_process_from, tps), + desc=image.name, + initial=min_process_from, + total=tps, + # position=index + 1, + ) + for i in pbar: - # RUN - # Adjust tps based on how many tps are available on the server - frac_clogged_traps = 0 - # print(f"Processing from {process_from}") - min_process_from = min(process_from.values()) - pbar = tqdm( - range(min_process_from, tps), - desc=image.name, - initial=min_process_from, - total=tps, - # position=index + 1, - ) - for i in pbar: - - if ( - frac_clogged_traps < earlystop["thresh_pos_clogged"] - or i < earlystop["min_tp"] - ): - - for step in self.iterative_steps: - if i >= process_from[step]: - t = perf_counter() - result = steps[step].run_tp( - i, **run_kwargs.get(step, {}) - ) - logging.debug( - f"Timing:{step}:{perf_counter() - t}s" - ) - if step in loaded_writers: + if ( + frac_clogged_traps + < earlystop["thresh_pos_clogged"] + or i < earlystop["min_tp"] + ): + + for step in self.iterative_steps: + if i >= process_from[step]: t = perf_counter() - loaded_writers[step].write( - data=result, - overwrite=writer_ow_kwargs.get( - step, [] - ), - tp=i, - meta={"last_processed": i}, + result = steps[step].run_tp( + i, **run_kwargs.get(step, {}) ) logging.debug( - f"Timing:Writing-{step}:{perf_counter() - t}s" + f"Timing:{step}:{perf_counter() - t}s" ) + if step in loaded_writers: + t = perf_counter() + loaded_writers[step].write( + data=result, + overwrite=writer_ow_kwargs.get( + step, [] + ), + tp=i, + meta={"last_processed": i}, + ) + logging.debug( + f"Timing:Writing-{step}:{perf_counter() - t}s" + ) - # Step-specific actions - - if step == "tiler": - if i == min_process_from: + # Step-specific actions + if ( + step == "tiler" + and i == min_process_from + ): print( f"Found {steps['tiler'].n_traps} traps in {image.name}" ) - elif ( - step == "baby" - ): # Write state and pass info to ext - loaded_writers["state"].write( - data=steps[ - step - ].crawler.tracker_states, - overwrite=loaded_writers[ - "state" - ].datatypes.keys(), - tp=i, - ) - labels, masks = groupby_traps( - result["trap"], - result["cell_label"], - result["edgemasks"], - steps["tiler"].n_traps, - ) - - elif ( - step == "extraction" - ): # Remove mask/label after ext - for k in ["masks", "labels"]: - run_kwargs[step][k] = None - - frac_clogged_traps = self.check_earlystop( - filename, earlystop, steps["tiler"].tile_size - ) - logging.debug( - f"Quality:Clogged_traps:{frac_clogged_traps}" - ) - - frac = np.round(frac_clogged_traps * 100) - pbar.set_postfix_str(f"{frac} Clogged") - else: # Stop if more than X% traps are clogged - logging.debug( - f"EarlyStop:{earlystop['thresh_pos_clogged']*100}% traps clogged at time point {i}" - ) - print( - f"Stopping analysis at time {i} with {frac_clogged_traps} clogged traps" - ) - meta.add_fields({"end_status": "Clogged"}) - break - - meta.add_fields({"last_processed": i}) - # Run post processing - - meta.add_fields({"end_status": "Success"}) - post_proc_params = PostProcessorParameters.from_dict( - config["postprocessing"] - ) - PostProcessor(filename, post_proc_params).run() + elif ( + step == "baby" + ): # Write state and pass info to ext + loaded_writers["state"].write( + data=steps[ + step + ].crawler.tracker_states, + overwrite=loaded_writers[ + "state" + ].datatypes.keys(), + tp=i, + ) + elif ( + step == "extraction" + ): # Remove mask/label after ext + for k in ["masks", "labels"]: + run_kwargs[step][k] = None + + frac_clogged_traps = self.check_earlystop( + filename, earlystop, steps["tiler"].tile_size + ) + logging.debug( + f"Quality:Clogged_traps:{frac_clogged_traps}" + ) + + frac = np.round(frac_clogged_traps * 100) + pbar.set_postfix_str(f"{frac} Clogged") + else: # Stop if more than X% traps are clogged + logging.debug( + f"EarlyStop:{earlystop['thresh_pos_clogged']*100}% traps clogged at time point {i}" + ) + print( + f"Stopping analysis at time {i} with {frac_clogged_traps} clogged traps" + ) + meta.add_fields({"end_status": "Clogged"}) + break + + meta.add_fields({"last_processed": i}) + # Run post processing + + meta.add_fields({"end_status": "Success"}) + post_proc_params = PostProcessorParameters.from_dict( + config["postprocessing"] + ) + PostProcessor(filename, post_proc_params).run() - # return 1 + return 1 - except Exception as e: # bug in the trap getting + except Exception as e: # bug during setup or runtime logging.exception( f"Caught exception in worker thread (x = {name}):", exc_info=True, @@ -632,11 +517,9 @@ class Pipeline(ProcessABC): # This prints the type, value, and stack trace of the # current exception being handled. traceback.print_exc() - print() raise e finally: - if session: - session.close() + _close_session(session) # try: # compiler = ExperimentCompiler(None, filepath) @@ -666,6 +549,146 @@ class Pipeline(ProcessABC): return (traps_above_nthresh & traps_above_athresh).mean() + def _load_config_from_file( + self, + filename: PosixPath, + process_from: t.Dict[str, int], + trackers_state: t.List, + overwrite: t.Dict[str, bool], + ): + with h5py.File(filename, "r") as f: + for k in process_from.keys(): + if not overwrite[k]: + process_from[k] = self.legacy_get_last_tp[k](f) + process_from[k] += 1 + return process_from, trackers_state, overwrite + + @staticmethod + def legacy_get_last_tp(step: str) -> t.Callable: + """Get last time-point in different ways depending + on which step we are using + + To support segmentation in aliby < v0.24 + TODO Deprecate and replace with State method + """ + switch_case = { + "tiler": lambda f: f["trap_info/drifts"].shape[0] - 1, + "baby": lambda f: f["cell_info/timepoint"][-1], + "extraction": lambda f: f[ + "extraction/general/None/area/timepoint" + ][-1], + } + return switch_case[step] + + def _setup_pipeline(self, image_id: int): + config = self.parameters.to_dict() + pparams = config + image_id = image_id + general_config = config["general"] + session = None + earlystop = general_config.get("earlystop", None) + process_from = {k: 0 for k in self.iterative_steps} + steps = {} + ow = {k: 0 for k in self.step_sequence} + + # check overwriting + ow_id = general_config.get("overwrite", 0) + ow = {step: True for step in self.step_sequence} + if ow_id and ow_id is not True: + ow = { + step: self.step_sequence.index(ow_id) < i + for i, step in enumerate(self.step_sequence, 1) + } + + # Set up + directory = general_config["directory"] + + with get_image_class(image_id)( + image_id, **self.general.get("server_info", {}) + ) as image: + filename = Path(f"{directory}/{image.name}.h5") + meta = MetaData(directory, filename) + + from_start = False if np.any(ow.values()) else True + trackers_state = [] + # If no previous segmentation and keep tiler + if filename.exists(): + if not ow["tiler"]: + steps["tiler"] = Tiler.from_hdf5(image, filename) + try: + ( + process_from, + trackers_state, + ow, + ) = self._load_config_from_file( + filename, process_from, trackers_state, ow + ) + # get state array + trackers_state = ( + [] + if ow["baby"] + else StateReader(filename).get_formatted_states() + ) + + config["tiler"] = steps["tiler"].parameters.to_dict() + except Exception: + pass + + # Delete datasets to overwrite and update pipeline data + # Use existing parameters + with h5py.File(filename, "a") as f: + pparams = PipelineParameters.from_yaml( + f.attrs["parameters"] + ).to_dict() + + for k, v in ow.items(): + if v: + for gname in self.writer_groups[k]: + if gname in f: + del f[gname] + + pparams[k] = config[k] + meta.add_fields( + { + "parameters": PipelineParameters.from_dict( + pparams + ).to_yaml() + }, + overwrite=True, + ) + + if from_start: # New experiment or overwriting + if ( + config.get("overwrite", False) is True + or np.all(list(ow.values())) + ) and filename.exists(): + os.remove(filename) + + meta.run() + meta.add_fields( # Add non-logfile metadata + { + "omero_id,": config["general"]["id"], + "image_id": image_id, + "parameters": PipelineParameters.from_dict( + pparams + ).to_yaml(), + } + ) + + tps = min(general_config["tps"], image.data.shape[0]) + + return ( + filename, + meta, + config, + process_from, + tps, + steps, + earlystop, + session, + trackers_state, + ) + def groupby_traps(traps, labels, edgemasks, ntraps): # Group data by traps to pass onto extractor without re-reading hdf5 @@ -685,3 +708,8 @@ def groupby_traps(traps, labels, edgemasks, ntraps): masks = {i: mask_d.get(i, []) for i in range(ntraps)} return labels, masks + + +def _close_session(session): + if session: + session.close()