diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index 1769dc506d70259407384b4dea7e179400920966..b0cc310b7b29d6b380b013b9eedb5bc555571fb6 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -2,7 +2,6 @@ import logging import os import re -import traceback import typing as t from pathlib import Path from pprint import pprint @@ -329,117 +328,99 @@ class Pipeline(ProcessABC): index: t.Optional[int] = None, ): """Run a pipeline for one position.""" - session = None # required for multiprocessing name, image_id = name_image_id config = self.parameters.to_dict() config["tiler"]["position_name"] = name.split(".")[0] earlystop = config["general"].get("earlystop", None) - try: - out_file = self.generate_h5file(image_id) - # instantiate writers - tiler_writer = TilerWriter(out_file) - baby_writer = LinearBabyWriter(out_file) - babystate_writer = StateWriter(out_file) - # start pipeline - frac_clogged_traps = 0.0 - with dispatch_image(image_id)( - image_id, **self.server_info - ) as image: - # initialise tiler; load local meta data from image - tiler = Tiler.from_image( - image, - TilerParameters.from_dict(config["tiler"]), - ) - # initialise Baby - session = initialise_tf(2) - babyrunner = BabyRunner.from_tiler( - BabyParameters.from_dict(config["baby"]), - tiler=tiler, - ) - # initialise extraction - extraction = Extractor.from_tiler( - ExtractorParameters.from_dict(config["extraction"]), - store=out_file, - tiler=tiler, - ) - # initialise progress bar - tps = min(config["general"]["tps"], image.data.shape[0]) - progress_bar = tqdm(range(tps), desc=image.name) - # run through time points - for i in progress_bar: - if ( - frac_clogged_traps < earlystop["thresh_pos_clogged"] - or i < earlystop["min_tp"] - ): - # run tiler - result = tiler.run_tp(i) - tiler_writer.write( - data=result, - overwrite=[], - tp=i, - meta={"last_processed:": i}, - ) - if i == 0: - logging.getLogger("aliby").info( - f"Found {tiler.no_tiles} traps in {image.name}" - ) - # run Baby - try: - result = babyrunner.run_tp(i) - except baby.errors.Clogging: - logging.getLogger("aliby").warn( - "WARNING:Clogging threshold exceeded in BABY." - ) - baby_writer.write( - data=result, - tp=i, - overwrite="mother_assign", - meta={"last_processed": i}, - ) - babystate_writer.write( - data=babyrunner.crawler.tracker_states, - overwrite=babystate_writer.datatypes.keys(), - tp=i, - ) - # run extraction - result = extraction.run_tp( - i, cell_labels=None, masks=None - ) - # check and report clogging - frac_clogged_traps = check_earlystop( - out_file, - earlystop, - tiler.tile_size, + out_file = self.generate_h5file(image_id) + # instantiate writers + tiler_writer = TilerWriter(out_file) + baby_writer = LinearBabyWriter(out_file) + babystate_writer = StateWriter(out_file) + # start pipeline + frac_clogged_traps = 0.0 + with dispatch_image(image_id)(image_id, **self.server_info) as image: + # initialise tiler; load local meta data from image + tiler = Tiler.from_image( + image, + TilerParameters.from_dict(config["tiler"]), + ) + # initialise Baby + initialise_tf(2) + babyrunner = BabyRunner.from_tiler( + BabyParameters.from_dict(config["baby"]), + tiler=tiler, + ) + # initialise extraction + extraction = Extractor.from_tiler( + ExtractorParameters.from_dict(config["extraction"]), + store=out_file, + tiler=tiler, + ) + # initialise progress bar + tps = min(config["general"]["tps"], image.data.shape[0]) + progress_bar = tqdm(range(tps), desc=image.name) + # run through time points + for i in progress_bar: + if ( + frac_clogged_traps < earlystop["thresh_pos_clogged"] + or i < earlystop["min_tp"] + ): + # run tiler + result = tiler.run_tp(i) + tiler_writer.write( + data=result, + overwrite=[], + tp=i, + meta={"last_processed:": i}, + ) + if i == 0: + logging.getLogger("aliby").info( + f"Found {tiler.no_tiles} traps in {image.name}" ) - if frac_clogged_traps > 0.3: - self._log( - f"{name}:Clogged_traps:{frac_clogged_traps}" - ) - frac = np.round(frac_clogged_traps * 100) - progress_bar.set_postfix_str(f"{frac} Clogged") - else: - # stop if too many traps are clogged - self._log( - f"{name}: Stopped early at time {i} with {frac_clogged_traps} clogged traps." + # run Baby + try: + result = babyrunner.run_tp(i) + except baby.errors.Clogging: + logging.getLogger("aliby").warn( + "WARNING:Clogging threshold exceeded in BABY." ) - break - # run post-processing - PostProcessor( - out_file, - PostProcessorParameters.from_dict( - config["postprocessing"] - ), - ).run() - self._log("Analysis finished successfully.", "info") - return 1 - except Exception as e: - logging.exception(f"{name}: Exception caught.", exc_info=True) - # print the type, value, and stack trace of the exception - traceback.print_exc() - raise e - finally: - # for multiprocessing - close_session(session) + baby_writer.write( + data=result, + tp=i, + overwrite="mother_assign", + meta={"last_processed": i}, + ) + babystate_writer.write( + data=babyrunner.crawler.tracker_states, + overwrite=babystate_writer.datatypes.keys(), + tp=i, + ) + # run extraction + result = extraction.run_tp(i, cell_labels=None, masks=None) + # check and report clogging + frac_clogged_traps = check_earlystop( + out_file, + earlystop, + tiler.tile_size, + ) + if frac_clogged_traps > 0.3: + self._log(f"{name}:Clogged_traps:{frac_clogged_traps}") + frac = np.round(frac_clogged_traps * 100) + progress_bar.set_postfix_str(f"{frac} Clogged") + else: + # stop if too many traps are clogged + self._log( + f"{name}: Stopped early at time {i} with {frac_clogged_traps} clogged traps." + ) + break + # run post-processing + PostProcessor( + out_file, + PostProcessorParameters.from_dict(config["postprocessing"]), + ).run() + self._log("Analysis finished successfully.", "info") + return 1 def check_earlystop(out_file: str, es_parameters: dict, tile_size: int): @@ -485,19 +466,8 @@ def check_earlystop(out_file: str, es_parameters: dict, tile_size: int): return (traps_above_nthresh & traps_above_athresh).mean() -def close_session(session): - """Close session for multiprocessing.""" - if session: - session.close() - - def initialise_tf(version): """Initialise tensorflow.""" - if version == 1: - core_config = tf.ConfigProto() - core_config.gpu_options.allow_growth = True - session = tf.Session(config=core_config) - return session if version == 2: gpus = tf.config.experimental.list_physical_devices("GPU") if gpus: @@ -507,4 +477,3 @@ def initialise_tf(version): print( len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs" ) - return None