From 9550a53ef152a6c0214b58fe7cf172982d520d05 Mon Sep 17 00:00:00 2001
From: pswain <peter.swain@ed.ac.uk>
Date: Thu, 15 Feb 2024 18:01:14 +0000
Subject: [PATCH] fix(GPU issue): bug fixed in calling babywriter

babywriter's overwrite variable must be a list.
---
 src/aliby/new_pipeline.py | 475 ----------------------------------
 src/aliby/pipeline.py     | 520 ++++++++++++--------------------------
 2 files changed, 158 insertions(+), 837 deletions(-)
 delete mode 100644 src/aliby/new_pipeline.py

diff --git a/src/aliby/new_pipeline.py b/src/aliby/new_pipeline.py
deleted file mode 100644
index f58c75b..0000000
--- a/src/aliby/new_pipeline.py
+++ /dev/null
@@ -1,475 +0,0 @@
-"""Set up and run pipelines: tiling, segmentation, extraction, and then post-processing."""
-import logging
-import os
-import re
-import typing as t
-from pathlib import Path
-from pprint import pprint
-
-import baby
-import baby.errors
-import numpy as np
-import tensorflow as tf
-from pathos.multiprocessing import Pool
-from tqdm import tqdm
-
-try:
-    if baby.__version__ == "v0.30.1":
-        from aliby.baby_sitter import BabyParameters, BabyRunner
-except AttributeError:
-    from aliby.baby_client import BabyParameters, BabyRunner
-
-import aliby.global_parameters as global_parameters
-from agora.abc import ParametersABC, ProcessABC
-from agora.io.metadata import MetaData
-from agora.io.signal import Signal
-from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter
-from aliby.io.dataset import dispatch_dataset
-from aliby.io.image import dispatch_image
-from aliby.tile.tiler import Tiler, TilerParameters
-from extraction.core.extractor import (
-    Extractor,
-    ExtractorParameters,
-    extraction_params_from_meta,
-)
-from postprocessor.core.postprocessing import (
-    PostProcessor,
-    PostProcessorParameters,
-)
-
-# stop warnings from TensorFlow
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
-logging.getLogger("tensorflow").setLevel(logging.ERROR)
-
-
-class PipelineParameters(ParametersABC):
-    """Define parameters for the steps of the pipeline."""
-
-    def __init__(
-        self,
-        general,
-        tiler,
-        baby,
-        extraction,
-        postprocessing,
-    ):
-        """Initialise parameter sets using passed dictionaries."""
-        self.general = general
-        self.tiler = tiler
-        self.baby = baby
-        self.extraction = extraction
-        self.postprocessing = postprocessing
-
-    @classmethod
-    def default(
-        cls,
-        general={},
-        tiler={},
-        baby={},
-        extraction={},
-        postprocessing={},
-    ):
-        """
-        Initialise parameters for steps of the pipeline.
-
-        Some parameters are extracted from the log files.
-
-        Parameters
-        ---------
-        general: dict
-            Parameters to set up the pipeline.
-        tiler: dict
-            Parameters for tiler.
-        baby: dict (optional)
-            Parameters for Baby.
-        extraction: dict (optional)
-            Parameters for extraction.
-        postprocessing: dict (optional)
-            Parameters for post-processing.
-        """
-        if (
-            isinstance(general["expt_id"], Path)
-            and general["expt_id"].exists()
-        ):
-            expt_id = str(general["expt_id"])
-        else:
-            expt_id = general["expt_id"]
-        directory = Path(general["directory"])
-        # get metadata from log files either locally or via OMERO
-        with dispatch_dataset(
-            expt_id,
-            **{k: general.get(k) for k in ("host", "username", "password")},
-        ) as conn:
-            directory = directory / conn.unique_name
-            if not directory.exists():
-                directory.mkdir(parents=True)
-            # download logs for metadata
-            conn.cache_logs(directory)
-        try:
-            meta_d = MetaData(directory, None).load_logs()
-        except Exception as e:
-            logging.getLogger("aliby").warn(
-                f"WARNING:Metadata: error when loading: {e}"
-            )
-            minimal_default_meta = {
-                "channels": ["Brightfield"],
-                "ntps": [2000],
-            }
-            # set minimal metadata
-            meta_d = minimal_default_meta
-        # define default values for general parameters
-        tps = meta_d.get("ntps", 2000)
-        defaults = {
-            "general": dict(
-                id=expt_id,
-                distributed=False,
-                tps=tps,
-                directory=str(directory.parent),
-                filter="",
-                earlystop=global_parameters.earlystop,
-                logfile_level="INFO",
-                use_explog=True,
-            )
-        }
-        # update default values for general using inputs
-        for k, v in general.items():
-            if k not in defaults["general"]:
-                defaults["general"][k] = v
-            elif isinstance(v, dict):
-                for k2, v2 in v.items():
-                    defaults["general"][k][k2] = v2
-            else:
-                defaults["general"][k] = v
-        # default Tiler parameters
-        defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
-        # generate a backup channel for when logfile meta is available
-        # but not image metadata.
-        backup_ref_channel = None
-        if "channels" in meta_d and isinstance(
-            defaults["tiler"]["ref_channel"], str
-        ):
-            backup_ref_channel = meta_d["channels"].index(
-                defaults["tiler"]["ref_channel"]
-            )
-        defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
-        # default parameters
-        defaults["baby"] = BabyParameters.default(**baby).to_dict()
-        defaults["extraction"] = extraction_params_from_meta(meta_d)
-        defaults["postprocessing"] = PostProcessorParameters.default(
-            **postprocessing
-        ).to_dict()
-        return cls(**{k: v for k, v in defaults.items()})
-
-
-class Pipeline(ProcessABC):
-    """
-    Initialise and run tiling, segmentation, extraction, and post-processing.
-
-    To customise parameters for any step use the PipelineParameters class.
-    """
-
-    def __init__(self, parameters: PipelineParameters, store=None):
-        """Initialise using Pipeline parameters."""
-        super().__init__(parameters)
-        if store is not None:
-            store = Path(store)
-        # h5 file
-        self.store = store
-
-    @staticmethod
-    def setLogger(
-        folder, file_level: str = "INFO", stream_level: str = "WARNING"
-    ):
-        """Initialise and format logger."""
-        logger = logging.getLogger("aliby")
-        logger.setLevel(getattr(logging, file_level))
-        formatter = logging.Formatter(
-            "%(asctime)s - %(levelname)s:%(message)s",
-            datefmt="%Y-%m-%dT%H:%M:%S%z",
-        )
-        # for streams - stdout, files, etc.
-        ch = logging.StreamHandler()
-        ch.setLevel(getattr(logging, stream_level))
-        ch.setFormatter(formatter)
-        logger.addHandler(ch)
-        # create file handler that logs even debug messages
-        fh = logging.FileHandler(Path(folder) / "aliby.log", "w+")
-        fh.setLevel(getattr(logging, file_level))
-        fh.setFormatter(formatter)
-        logger.addHandler(fh)
-
-    def setup(self):
-        """Get meta data and identify each position."""
-        config = self.parameters.to_dict()
-        # print configuration
-        print("\nalibylite\n")
-        try:
-            logging.getLogger("aliby").info(f"Using Baby {baby.__version__}.")
-        except AttributeError:
-            logging.getLogger("aliby").info("Using original Baby.")
-        for step in config:
-            print("\n---\n" + step + "\n---")
-            pprint(config[step])
-        print()
-        # extract from configuration
-        expt_id = config["general"]["id"]
-        root_dir = Path(config["general"]["directory"])
-        self.server_info = {
-            k: config["general"].get(k)
-            for k in ("host", "username", "password")
-        }
-        dispatcher = dispatch_dataset(expt_id, **self.server_info)
-        logging.getLogger("aliby").info(
-            f"Fetching data using {dispatcher.__class__.__name__}."
-        )
-        # get log files, either locally or via OMERO
-        with dispatcher as conn:
-            position_ids = conn.get_images()
-            directory = self.store or root_dir / conn.unique_name
-            if not directory.exists():
-                directory.mkdir(parents=True)
-            # get logs to use for metadata
-            conn.cache_logs(directory)
-        print("Positions available:")
-        for i, pos in enumerate(position_ids.keys()):
-            print("\t" + f"{i}: " + pos.split(".")[0])
-        # add directory to configuration
-        self.parameters.general["directory"] = str(directory)
-        self.setLogger(directory)
-        return position_ids
-
-    def filter_positions(self, position_filter, position_ids):
-        """Select particular positions."""
-        if isinstance(position_filter, list):
-            selected_ids = {
-                k: v
-                for filt in position_filter
-                for k, v in self.apply_filter(position_ids, filt).items()
-            }
-        else:
-            selected_ids = self.apply_filter(position_ids, position_filter)
-        return selected_ids
-
-    def apply_filter(self, position_ids: dict, position_filter: int or str):
-        """
-        Select positions.
-
-        Either pick a particular position or use a regular expression
-        to parse their file names.
-        """
-        if isinstance(position_filter, str):
-            # pick positions using a regular expression
-            position_ids = {
-                k: v
-                for k, v in position_ids.items()
-                if re.search(position_filter, k)
-            }
-        elif isinstance(position_filter, int):
-            # pick a particular position
-            position_ids = {
-                k: v
-                for i, (k, v) in enumerate(position_ids.items())
-                if i == position_filter
-            }
-        return position_ids
-
-    def run(self):
-        """Run separate pipelines for all positions in an experiment."""
-        initialise_tf(2)
-        config = self.parameters.to_dict()
-        position_ids = self.setup()
-        # pick particular positions if desired
-        position_filter = config["general"]["filter"]
-        if position_filter is not None:
-            position_ids = self.filter_positions(position_filter, position_ids)
-        if not len(position_ids):
-            raise Exception("No images to segment.")
-        else:
-            print("\nPositions selected:")
-            for pos in position_ids:
-                print("\t" + pos.split(".")[0])
-        # create and run pipelines
-        distributed = config["general"]["distributed"]
-        if distributed:
-            # multiple cores
-            with Pool(distributed) as p:
-                results = p.map(
-                    self.run_one_position,
-                    [position_id for position_id in position_ids.items()],
-                )
-        else:
-            # single core
-            results = [
-                self.run_one_position(position_id)
-                for position_id in position_ids.items()
-            ]
-        # results is binary giving the success for each position
-        return results
-
-    def generate_h5file(self, image_id):
-        """Delete any existing and then create h5file for one position."""
-        config = self.parameters.to_dict()
-        out_dir = config["general"]["directory"]
-        with dispatch_image(image_id)(image_id, **self.server_info) as image:
-            out_file = Path(f"{out_dir}/{image.name}.h5")
-        # remove existing h5 file
-        if out_file.exists():
-            os.remove(out_file)
-        meta = MetaData(out_dir, out_file)
-        # generate h5 file using meta data from logs
-        if config["general"]["use_explog"]:
-            meta.run()
-        return out_file
-
-    def run_one_position(
-        self, name_image_id: t.Tuple[str, str or Path or int]
-    ):
-        """Run a pipeline for one position."""
-        name, image_id = name_image_id
-        config = self.parameters.to_dict()
-        config["tiler"]["position_name"] = name.split(".")[0]
-        earlystop = config["general"].get("earlystop", None)
-        out_file = self.generate_h5file(image_id)
-        # instantiate writers
-        tiler_writer = TilerWriter(out_file)
-        baby_writer = LinearBabyWriter(out_file)
-        babystate_writer = StateWriter(out_file)
-        # start pipeline
-        frac_clogged_traps = 0.0
-        with dispatch_image(image_id)(image_id, **self.server_info) as image:
-            # initialise tiler; load local meta data from image
-            tiler = Tiler.from_image(
-                image,
-                TilerParameters.from_dict(config["tiler"]),
-            )
-            # initialise Baby
-            babyrunner = BabyRunner.from_tiler(
-                BabyParameters.from_dict(config["baby"]),
-                tiler=tiler,
-            )
-            # initialise extraction
-            extraction = Extractor.from_tiler(
-                ExtractorParameters.from_dict(config["extraction"]),
-                store=out_file,
-                tiler=tiler,
-            )
-            # initialise progress bar
-            tps = min(config["general"]["tps"], image.data.shape[0])
-            progress_bar = tqdm(range(tps), desc=image.name)
-            # run through time points
-            for i in progress_bar:
-                if (
-                    frac_clogged_traps < earlystop["thresh_pos_clogged"]
-                    or i < earlystop["min_tp"]
-                ):
-                    # run tiler
-                    result = tiler.run_tp(i)
-                    tiler_writer.write(
-                        data=result,
-                        overwrite=[],
-                        tp=i,
-                        meta={"last_processed:": i},
-                    )
-                    if i == 0:
-                        logging.getLogger("aliby").info(
-                            f"Found {tiler.no_tiles} traps in {image.name}"
-                        )
-                    # run Baby
-                    try:
-                        result = babyrunner.run_tp(i)
-                    except baby.errors.Clogging:
-                        logging.getLogger("aliby").warn(
-                            "WARNING:Clogging threshold exceeded in BABY."
-                        )
-                    baby_writer.write(
-                        data=result,
-                        tp=i,
-                        overwrite="mother_assign",
-                        meta={"last_processed": i},
-                    )
-                    babystate_writer.write(
-                        data=babyrunner.crawler.tracker_states,
-                        overwrite=babystate_writer.datatypes.keys(),
-                        tp=i,
-                    )
-                    # run extraction
-                    result = extraction.run_tp(i, cell_labels=None, masks=None)
-                    # check and report clogging
-                    frac_clogged_traps = check_earlystop(
-                        out_file,
-                        earlystop,
-                        tiler.tile_size,
-                    )
-                    if frac_clogged_traps > 0.3:
-                        self._log(f"{name}:Clogged_traps:{frac_clogged_traps}")
-                        frac = np.round(frac_clogged_traps * 100)
-                        progress_bar.set_postfix_str(f"{frac} Clogged")
-                else:
-                    # stop if too many traps are clogged
-                    self._log(
-                        f"{name}: Stopped early at time {i} with {frac_clogged_traps} clogged traps."
-                    )
-                    break
-            # run post-processing
-            PostProcessor(
-                out_file,
-                PostProcessorParameters.from_dict(config["postprocessing"]),
-            ).run()
-            self._log("Analysis finished successfully.", "info")
-            return 1
-
-
-def check_earlystop(out_file: str, es_parameters: dict, tile_size: int):
-    """
-    Check recent time points for tiles with too many cells.
-
-    Returns the fraction of clogged tiles, where clogged tiles have
-    too many cells or too much of their area covered by cells.
-
-    Parameters
-    ----------
-    filename: str
-        Name of h5 file.
-    es_parameters: dict
-        Parameters defining when early stopping should happen.
-        For example:
-                {'min_tp': 100,
-                'thresh_pos_clogged': 0.4,
-                'thresh_trap_ncells': 8,
-                'thresh_trap_area': 0.9,
-                'ntps_to_eval': 5}
-    tile_size: int
-        Size of tile.
-    """
-    # get the area of the cells organised by trap and cell number
-    s = Signal(out_file)
-    df = s.get_raw("/extraction/general/None/area")
-    # check the latest time points only
-    cells_used = df[
-        df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
-    ].dropna(how="all")
-    # find tiles with too many cells
-    traps_above_nthresh = (
-        cells_used.groupby("trap").count().apply(np.mean, axis=1)
-        > es_parameters["thresh_trap_ncells"]
-    )
-    # find tiles with cells covering too great a fraction of the tiles' area
-    traps_above_athresh = (
-        cells_used.groupby("trap").sum().apply(np.mean, axis=1)
-        / tile_size**2
-        > es_parameters["thresh_trap_area"]
-    )
-    return (traps_above_nthresh & traps_above_athresh).mean()
-
-
-def initialise_tf(version):
-    """Initialise tensorflow."""
-    if version == 2:
-        gpus = tf.config.experimental.list_physical_devices("GPU")
-        if gpus:
-            for gpu in gpus:
-                tf.config.experimental.set_memory_growth(gpu, True)
-            logical_gpus = tf.config.experimental.list_logical_devices("GPU")
-            print(
-                len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
-            )
diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py
index 2ee0888..1086318 100644
--- a/src/aliby/pipeline.py
+++ b/src/aliby/pipeline.py
@@ -2,17 +2,14 @@
 import logging
 import os
 import re
-import traceback
 import typing as t
-from copy import copy
-from importlib.metadata import version
 from pathlib import Path
 from pprint import pprint
 
 import baby
 import baby.errors
-import h5py
 import numpy as np
+import multiprocessing
 import tensorflow as tf
 from pathos.multiprocessing import Pool
 from tqdm import tqdm
@@ -26,7 +23,6 @@ except AttributeError:
 import aliby.global_parameters as global_parameters
 from agora.abc import ParametersABC, ProcessABC
 from agora.io.metadata import MetaData
-from agora.io.reader import StateReader
 from agora.io.signal import Signal
 from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter
 from aliby.io.dataset import dispatch_dataset
@@ -50,8 +46,6 @@ logging.getLogger("tensorflow").setLevel(logging.ERROR)
 class PipelineParameters(ParametersABC):
     """Define parameters for the steps of the pipeline."""
 
-    _pool_index = None
-
     def __init__(
         self,
         general,
@@ -60,7 +54,7 @@ class PipelineParameters(ParametersABC):
         extraction,
         postprocessing,
     ):
-        """Initialise, but called by a class method - not directly."""
+        """Initialise parameter sets using passed dictionaries."""
         self.general = general
         self.tiler = tiler
         self.baby = baby
@@ -159,11 +153,9 @@ class PipelineParameters(ParametersABC):
                 defaults["tiler"]["ref_channel"]
             )
         defaults["tiler"]["backup_ref_channel"] = backup_ref_channel
-        # default BABY parameters
+        # default parameters
         defaults["baby"] = BabyParameters.default(**baby).to_dict()
-        # default Extraction parmeters
         defaults["extraction"] = extraction_params_from_meta(meta_d)
-        # default PostProcessing parameters
         defaults["postprocessing"] = PostProcessorParameters.default(
             **postprocessing
         ).to_dict()
@@ -171,39 +163,14 @@ class PipelineParameters(ParametersABC):
 
 
 class Pipeline(ProcessABC):
-    """
-    Initialise and run tiling, segmentation, extraction and post-processing.
-
-    Each step feeds the next one.
-
-    To customise parameters for any step use the PipelineParameters class.stem
-    """
-
-    pipeline_steps = ["tiler", "baby", "extraction"]
-    step_sequence = [
-        "tiler",
-        "baby",
-        "extraction",
-        "postprocessing",
-    ]
-
-    # specify the group in the h5 files written by each step
-    writer_groups = {
-        "tiler": ["trap_info"],
-        "baby": ["cell_info"],
-        "extraction": ["extraction"],
-        "postprocessing": ["postprocessing", "modifiers"],
-    }
-    writers = {
-        "tiler": [("tiler", TilerWriter)],
-        "baby": [("baby", LinearBabyWriter), ("state", StateWriter)],
-    }
+    """Initialise and run tiling, segmentation, extraction and post-processing."""
 
     def __init__(self, parameters: PipelineParameters, store=None):
-        """Initialise - not usually called directly."""
+        """Initialise using Pipeline parameters."""
         super().__init__(parameters)
         if store is not None:
             store = Path(store)
+        # h5 file
         self.store = store
 
     @staticmethod
@@ -228,11 +195,10 @@ class Pipeline(ProcessABC):
         fh.setFormatter(formatter)
         logger.addHandler(fh)
 
-
-    def run(self):
-        """Run separate pipelines for all positions in an experiment."""
-        # display configuration
+    def setup(self):
+        """Get meta data and identify each position."""
         config = self.parameters.to_dict()
+        # print configuration
         print("\nalibylite\n")
         try:
             logging.getLogger("aliby").info(f"Using Baby {baby.__version__}.")
@@ -244,8 +210,6 @@ class Pipeline(ProcessABC):
         print()
         # extract from configuration
         expt_id = config["general"]["id"]
-        distributed = config["general"]["distributed"]
-        position_filter = config["general"]["filter"]
         root_dir = Path(config["general"]["directory"])
         self.server_info = {
             k: config["general"].get(k)
@@ -266,45 +230,22 @@ class Pipeline(ProcessABC):
         print("Positions available:")
         for i, pos in enumerate(position_ids.keys()):
             print("\t" + f"{i}: " + pos.split(".")[0])
-        # update configuration
+        # add directory to configuration
         self.parameters.general["directory"] = str(directory)
-        config["general"]["directory"] = directory
         self.setLogger(directory)
-        # pick particular positions if desired
-        if position_filter is not None:
-            if isinstance(position_filter, list):
-                position_ids = {
-                    k: v
-                    for filt in position_filter
-                    for k, v in self.apply_filter(position_ids, filt).items()
-                }
-            else:
-                position_ids = self.apply_filter(position_ids, position_filter)
-        if not len(position_ids):
-            raise Exception("No images to segment.")
-        else:
-            print("\nPositions selected:")
-            for pos in position_ids:
-                print("\t" + pos.split(".")[0])
-        # create and run pipelines
-        if distributed != 0:
-            # multiple cores
-            with Pool(distributed) as p:
-                results = p.map(self.run_one_position, [position_id for position_id in position_ids.items()])
-                # results = p.map(
-                #     lambda x: self.run_one_position(*x),
-                #     [
-                #         (position_id, i)
-                #         for i, position_id in enumerate(position_ids.items())
-                #     ],
-                # )
+        return position_ids
+
+    def filter_positions(self, position_filter, position_ids):
+        """Select particular positions."""
+        if isinstance(position_filter, list):
+            selected_ids = {
+                k: v
+                for filt in position_filter
+                for k, v in self.apply_filter(position_ids, filt).items()
+            }
         else:
-            # single core
-            results = [
-                self.run_one_position((position_id, position_id_path), 1)
-                for position_id, position_id_path in tqdm(position_ids.items())
-            ]
-        return results
+            selected_ids = self.apply_filter(position_ids, position_filter)
+        return selected_ids
 
     def apply_filter(self, position_ids: dict, position_filter: int or str):
         """
@@ -329,292 +270,148 @@ class Pipeline(ProcessABC):
             }
         return position_ids
 
+    def run(self):
+        """Run separate pipelines for all positions in an experiment."""
+        config = self.parameters.to_dict()
+        position_ids = self.setup()
+        # pick particular positions if desired
+        position_filter = config["general"]["filter"]
+        if position_filter is not None:
+            position_ids = self.filter_positions(position_filter, position_ids)
+        if not len(position_ids):
+            raise Exception("No images to segment.")
+        else:
+            print("\nPositions selected:")
+            for pos in position_ids:
+                print("\t" + pos.split(".")[0])
+        print(f"Number of CPU cores available: {multiprocessing.cpu_count()}")
+        # create and run pipelines
+        distributed = config["general"]["distributed"]
+        if distributed != 0:
+            # multiple cores
+            with Pool(distributed) as p:
+                results = p.map(
+                    self.run_one_position,
+                    [position_id for position_id in position_ids.items()],
+                )
+        else:
+            # single core
+            results = [
+                self.run_one_position(position_id)
+                for position_id in position_ids.items()
+            ]
+        return results
+
+    def generate_h5file(self, image_id):
+        """Delete any existing and then create h5file for one position."""
+        config = self.parameters.to_dict()
+        out_dir = config["general"]["directory"]
+        with dispatch_image(image_id)(image_id, **self.server_info) as image:
+            out_file = Path(f"{out_dir}/{image.name}.h5")
+        # remove existing h5 file
+        if out_file.exists():
+            os.remove(out_file)
+        meta = MetaData(out_dir, out_file)
+        # generate h5 file using meta data from logs
+        if config["general"]["use_explog"]:
+            meta.run()
+        return out_file
+
     def run_one_position(
-        self,
-        name_image_id: t.Tuple[str, str or Path or int],
+        self, name_image_id: t.Tuple[str, str or Path or int]
     ):
-        """Set up and run a pipeline for one position."""
+        """Run a pipeline for one position."""
         name, image_id = name_image_id
-        run_kwargs = {"extraction": {"cell_labels": None, "masks": None}}
-        try:
-            pipe= self.setup_pipeline(image_id, name)
-            loaded_writers = {
-                name: writer(pipe["filename"])
-                for k in self.step_sequence
-                if k in self.writers
-                for name, writer in self.writers[k]
-            }
-            writer_overwrite_kwargs = {
-                "state": loaded_writers["state"].datatypes.keys(),
-                "baby": ["mother_assign"],
-            }
-
-            # START PIPELINE
-            frac_clogged_traps = 0.0
-            min_process_from = min(pipe["process_from"].values())
-            with dispatch_image(image_id)(
-                image_id, **self.server_info
-            ) as image:
-                # initialise steps
-                if "tiler" not in pipe["steps"]:
-                    pipe["config"]["tiler"]["position_name"] = name.split(".")[
-                        0
-                    ]
-                    # loads local meta data from image
-                    pipe["steps"]["tiler"] = Tiler.from_image(
-                        image,
-                        TilerParameters.from_dict(pipe["config"]["tiler"]),
-                    )
-                if pipe["process_from"]["baby"] < pipe["tps"]:
-                    initialise_tf(2)
-                    pipe["steps"]["baby"] = BabyRunner.from_tiler(
-                        BabyParameters.from_dict(pipe["config"]["baby"]),
-                        pipe["steps"]["tiler"],
-                    )
-                    if pipe["trackers_state"]:
-                        pipe["steps"]["baby"].crawler.tracker_states = pipe[
-                            "trackers_state"
-                        ]
-                if pipe["process_from"]["extraction"] < pipe["tps"]:
-                    exparams = ExtractorParameters.from_dict(
-                        pipe["config"]["extraction"]
-                    )
-                    pipe["steps"]["extraction"] = Extractor.from_tiler(
-                        exparams,
-                        store=pipe["filename"],
-                        tiler=pipe["steps"]["tiler"],
-                    )
-                    # initiate progress bar
-                    progress_bar = tqdm(
-                        range(min_process_from, pipe["tps"]),
-                        desc=image.name,
-                        initial=min_process_from,
-                        total=pipe["tps"],
-                    )
-                    # run through time points
-                    for i in progress_bar:
-                        if (
-                            frac_clogged_traps
-                            < pipe["earlystop"]["thresh_pos_clogged"]
-                            or i < pipe["earlystop"]["min_tp"]
-                        ):
-                            # run through steps
-                            for step in self.pipeline_steps:
-                                if i >= pipe["process_from"][step]:
-                                    # perform step
-                                    try:
-                                        result = pipe["steps"][step].run_tp(
-                                            i, **run_kwargs.get(step, {})
-                                        )
-                                    except baby.errors.Clogging:
-                                        logging.getLogger("aliby").warn(
-                                            "WARNING:Clogging threshold exceeded in BABY."
-                                        )
-                                    # write result to h5 file using writers
-                                    # extractor writes to h5 itself
-                                    if step in loaded_writers:
-                                        loaded_writers[step].write(
-                                            data=result,
-                                            overwrite=writer_overwrite_kwargs.get(
-                                                step, []
-                                            ),
-                                            tp=i,
-                                            meta={"last_processed": i},
-                                        )
-                                    # clean up
-                                    if (
-                                        step == "tiler"
-                                        and i == min_process_from
-                                    ):
-                                        logging.getLogger("aliby").info(
-                                            f"Found {pipe['steps']['tiler'].no_tiles} traps in {image.name}"
-                                        )
-                                    elif step == "baby":
-                                        # write state
-                                        loaded_writers["state"].write(
-                                            data=pipe["steps"][
-                                                step
-                                            ].crawler.tracker_states,
-                                            overwrite=loaded_writers[
-                                                "state"
-                                            ].datatypes.keys(),
-                                            tp=i,
-                                        )
-                                    elif step == "extraction":
-                                        # remove masks and labels after extraction
-                                        for k in ["masks", "cell_labels"]:
-                                            run_kwargs[step][k] = None
-                            # check and report clogging
-                            frac_clogged_traps = check_earlystop(
-                                pipe["filename"],
-                                pipe["earlystop"],
-                                pipe["steps"]["tiler"].tile_size,
-                            )
-                            if frac_clogged_traps > 0.3:
-                                self._log(
-                                    f"{name}:Clogged_traps:{frac_clogged_traps}"
-                                )
-                                frac = np.round(frac_clogged_traps * 100)
-                                progress_bar.set_postfix_str(f"{frac} Clogged")
-                        else:
-                            # stop if too many traps are clogged
-                            self._log(
-                                f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
-                            )
-                            pipe["meta"].add_fields({"end_status": "Clogged"})
-                            break
-                        pipe["meta"].add_fields({"last_processed": i})
-                    pipe["meta"].add_fields({"end_status": "Success"})
-                    # run post-processing
-                    post_proc_params = PostProcessorParameters.from_dict(
-                        pipe["config"]["postprocessing"]
-                    )
-                    PostProcessor(pipe["filename"], post_proc_params).run()
-                    self._log("Analysis finished successfully.", "info")
-                    return 1
-        except Exception as e:
-            # catch bugs during setup or run time
-            logging.exception(
-                f"{name}: Exception caught.",
-                exc_info=True,
-            )
-            # print the type, value, and stack trace of the exception
-            traceback.print_exc()
-            raise e
-        finally:
-            pass
-
-    def setup_pipeline(
-        self,
-        image_id: int,
-        name: str,
-    ) -> t.Tuple[
-        Path,
-        MetaData,
-        t.Dict,
-        int,
-        t.Dict,
-        t.Dict,
-        t.Optional[int],
-        t.List[np.ndarray],
-    ]:
-        """
-        Initialise steps in a pipeline.
-
-        If necessary use a file to re-start experiments already partly run.
-
-        Parameters
-        ----------
-        image_id : int or str
-            Identifier of a data set in an OMERO server or a filename.
-
-        Returns
-        -------
-        pipe: dict
-            With keys
-                filename: str
-                    Path to a h5 file to write to.
-                meta: object
-                    agora.io.metadata.MetaData object
-                config: dict
-                    Configuration parameters.
-                process_from: dict
-                    Gives time points from which each step of the
-                    pipeline should start.
-                tps: int
-                    Number of time points.
-                steps: dict
-                earlystop: dict
-                    Parameters to check whether the pipeline should
-                    be stopped.
-                trackers_state: list
-                    States of any trackers from earlier runs.
-        """
-        pipe = {}
         config = self.parameters.to_dict()
-        pipe["earlystop"] = config["general"].get("earlystop", None)
-        pipe["process_from"] = {k: 0 for k in self.pipeline_steps}
-        pipe["steps"] = {}
-        # check overwriting
-        overwrite_id = config["general"].get("overwrite", 0)
-        overwrite = {step: True for step in self.step_sequence}
-        if overwrite_id and overwrite_id is not True:
-            overwrite = {
-                step: self.step_sequence.index(overwrite_id) < i
-                for i, step in enumerate(self.step_sequence, 1)
-            }
-        # set up
-        directory = config["general"]["directory"]
-        pipe["trackers_state"] = []
+        config["tiler"]["position_name"] = name.split(".")[0]
+        earlystop = config["general"].get("earlystop", None)
+        out_file = self.generate_h5file(image_id)
+        # instantiate writers
+        tiler_writer = TilerWriter(out_file)
+        baby_writer = LinearBabyWriter(out_file)
+        babystate_writer = StateWriter(out_file)
+        # start pipeline
+        initialise_tensorflow()
+        frac_clogged_traps = 0.0
         with dispatch_image(image_id)(image_id, **self.server_info) as image:
-            pipe["filename"] = Path(f"{directory}/{image.name}.h5")
-            # load metadata from h5 file
-            pipe["meta"] = MetaData(directory, pipe["filename"])
-            from_start = True if np.any(overwrite.values()) else False
-            # remove existing h5 file if overwriting
-            if (
-                from_start
-                and (
-                    config["general"].get("overwrite", False)
-                    or np.all(list(overwrite.values()))
-                )
-                and pipe["filename"].exists()
-            ):
-                os.remove(pipe["filename"])
-            # if the file exists with no previous segmentation use its tiler
-            if pipe["filename"].exists():
-                self._log("Result file exists.", "info")
-                if not overwrite["tiler"]:
-                    tiler_params_dict = TilerParameters.default().to_dict()
-                    tiler_params_dict["position_name"] = name.split(".")[0]
-                    tiler_params = TilerParameters.from_dict(tiler_params_dict)
-                    pipe["steps"]["tiler"] = Tiler.from_h5(
-                        image, pipe["filename"], tiler_params
+            # initialise tiler; load local meta data from image
+            tiler = Tiler.from_image(
+                image, TilerParameters.from_dict(config["tiler"])
+            )
+            # initialise Baby
+            babyrunner = BabyRunner.from_tiler(
+                BabyParameters.from_dict(config["baby"]), tiler=tiler
+            )
+            # initialise extraction
+            extraction = Extractor.from_tiler(
+                ExtractorParameters.from_dict(config["extraction"]),
+                store=out_file,
+                tiler=tiler,
+            )
+            # initiate progress bar
+            tps = min(config["general"]["tps"], image.data.shape[0])
+            progress_bar = tqdm(range(tps), desc=image.name)
+            # run through time points
+            for i in progress_bar:
+                if (
+                    frac_clogged_traps < earlystop["thresh_pos_clogged"]
+                    or i < earlystop["min_tp"]
+                ):
+                    # run tiler
+                    result = tiler.run_tp(i)
+                    tiler_writer.write(
+                        data=result,
+                        overwrite=[],
+                        tp=i,
+                        meta={"last_processed:": i},
                     )
-                    try:
-                        (
-                            process_from,
-                            trackers_state,
-                            overwrite,
-                        ) = self._load_config_from_file(
-                            pipe["filename"],
-                            pipe["process_from"],
-                            pipe["trackers_state"],
-                            overwrite,
+                    if i == 0:
+                        logging.getLogger("aliby").info(
+                            f"Found {tiler.no_tiles} traps in {image.name}"
                         )
-                        # get state array
-                        pipe["trackers_state"] = (
-                            []
-                            if overwrite["baby"]
-                            else StateReader(
-                                pipe["filename"]
-                            ).get_formatted_states()
+                    # run Baby
+                    try:
+                        result = babyrunner.run_tp(i)
+                    except baby.errors.Clogging:
+                        logging.getLogger("aliby").warn(
+                            "WARNING:Clogging threshold exceeded in BABY."
                         )
-                        config["tiler"] = pipe["steps"][
-                            "tiler"
-                        ].parameters.to_dict()
-                    except Exception:
-                        self._log("Overwriting tiling data")
-
-            if config["general"]["use_explog"]:
-                pipe["meta"].run()
-            pipe["config"] = config
-            # add metadata not in the log file
-            pipe["meta"].add_fields(
-                {
-                    "aliby_version": version("aliby"),
-                    "baby_version": version("aliby-baby"),
-                    "omero_id": config["general"]["id"],
-                    "image_id": image_id
-                    if isinstance(image_id, int)
-                    else str(image_id),
-                    "parameters": PipelineParameters.from_dict(
-                        config
-                    ).to_yaml(),
-                }
-            )
-            pipe["tps"] = min(config["general"]["tps"], image.data.shape[0])
-            return pipe
+                    baby_writer.write(
+                        data=result,
+                        tp=i,
+                        overwrite=["mother_assign"],
+                        meta={"last_processed": i},
+                    )
+                    babystate_writer.write(
+                        data=babyrunner.crawler.tracker_states,
+                        overwrite=babystate_writer.datatypes.keys(),
+                        tp=i,
+                    )
+                    # run extraction
+                    result = extraction.run_tp(i, cell_labels=None, masks=None)
+                    # check and report clogging
+                    frac_clogged_traps = check_earlystop(
+                        out_file,
+                        earlystop,
+                        tiler.tile_size,
+                    )
+                    if frac_clogged_traps > 0.3:
+                        self._log(f"{name}:Clogged_traps:{frac_clogged_traps}")
+                        frac = np.round(frac_clogged_traps * 100)
+                        progress_bar.set_postfix_str(f"{frac} Clogged")
+                else:
+                    # stop if too many clogged traps
+                    self._log(
+                        f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
+                    )
+                    break
+            # run post-processing
+            PostProcessor(
+                out_file,
+                PostProcessorParameters.from_dict(config["postprocessing"]),
+            ).run()
+            self._log("Analysis finished successfully.", "info")
+            return 1
 
 
 def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
@@ -660,8 +457,7 @@ def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
     return (traps_above_nthresh & traps_above_athresh).mean()
 
 
-
-def initialise_tf(version):
+def initialise_tensorflow(version=2):
     """Initialise tensorflow."""
     if version == 2:
         gpus = tf.config.experimental.list_physical_devices("GPU")
@@ -670,5 +466,5 @@ def initialise_tf(version):
                 tf.config.experimental.set_memory_growth(gpu, True)
             logical_gpus = tf.config.experimental.list_logical_devices("GPU")
             print(
-                len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
+                len(gpus), "physical GPUs,", len(logical_gpus), "logical GPUs"
             )
-- 
GitLab