diff --git a/src/aliby/haystack.py b/src/aliby/haystack.py deleted file mode 100644 index d1368ffd7715a4c6285cf6a6b2e2e79edf5d6319..0000000000000000000000000000000000000000 --- a/src/aliby/haystack.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Neural network initialisation. -""" -from pathlib import Path -from time import perf_counter - -import numpy as np -import tensorflow as tf -from agora.io.writer import DynamicWriter - - -def initialise_tf(version): - # Initialise tensorflow - if version == 1: - core_config = tf.ConfigProto() - core_config.gpu_options.allow_growth = True - session = tf.Session(config=core_config) - return session - # TODO this only works for TF2 - if version == 2: - gpus = tf.config.experimental.list_physical_devices("GPU") - if gpus: - for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - logical_gpus = tf.config.experimental.list_logical_devices("GPU") - print( - len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs" - ) - return None - - -def timer(func, *args, **kwargs): - start = perf_counter() - result = func(*args, **kwargs) - print(f"Function {func.__name__}: {perf_counter() - start}s") - return result - - -################## CUSTOM OBJECTS ################################## -class ModelPredictor: - """Generic object that takes a NN and returns the prediction. - - Use for predicting fluorescence/other from bright field. - This does not do instance segmentations of anything. - """ - - def __init__(self, tiler, model, name): - self.tiler = tiler - self.model = model - self.name = name - - def get_data(self, tp): - # Change axes to X,Y,Z rather than Z,Y,X - return ( - self.tiler.get_tp_data(tp, self.bf_channel) - .swapaxes(1, 3) - .swapaxes(1, 2) - ) - - def format_result(self, result, tp): - return {self.name: result, "timepoints": [tp] * len(result)} - - def run_tp(self, tp): - """Simulating processing time with sleep""" - # Access the image - segmentation = self.model.predict(self.get_data(tp)) - return self._format_result(segmentation, tp) - - -class ModelPredictorWriter(DynamicWriter): - def __init__(self, file, name, shape, dtype): - super.__init__(file) - self.datatypes = { - name: (shape, dtype), - "timepoint": ((None,), np.uint16), - } - self.group = f"{self.name}_info" diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index a9135b72007e1a9ad07def4f947b51fd0152bce4..512c8a0c22c62d60e5abff687bafa27c4967178e 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -9,14 +9,14 @@ from importlib.metadata import version from pathlib import Path from pprint import pprint +import baby +import baby.errors import h5py import numpy as np +import tensorflow as tf from pathos.multiprocessing import Pool from tqdm import tqdm -import baby -import baby.errors - try: if baby.__version__ == "v0.30.1": from aliby.baby_sitter import BabyParameters, BabyRunner @@ -25,11 +25,10 @@ except AttributeError: import aliby.global_parameters as global_parameters from agora.abc import ParametersABC, ProcessABC -from agora.io.metadata import MetaData, parse_logfiles +from agora.io.metadata import MetaData from agora.io.reader import StateReader from agora.io.signal import Signal from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter -from aliby.haystack import initialise_tf from aliby.io.dataset import dispatch_dataset from aliby.io.image import dispatch_image from aliby.tile.tiler import Tiler, TilerParameters @@ -43,7 +42,6 @@ from postprocessor.core.postprocessing import ( PostProcessorParameters, ) - # stop warnings from TensorFlow os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" logging.getLogger("tensorflow").setLevel(logging.ERROR) @@ -307,9 +305,9 @@ class Pipeline(ProcessABC): pprint(config[step]) print() try: - print(f"Using Baby {baby.__version__}") + print(f"Using Baby {baby.__version__}.\n") except AttributeError: - print("Using Baby - no version specified.") + print("Using Baby - no version specified.\n") # extract from configuration expt_id = config["general"]["id"] distributed = config["general"]["distributed"] @@ -321,7 +319,7 @@ class Pipeline(ProcessABC): } dispatcher = dispatch_dataset(expt_id, **self.server_info) logging.getLogger("aliby").info( - f"Fetching data using {dispatcher.__class__.__name__}" + f"Fetching data using {dispatcher.__class__.__name__}." ) # get log files, either locally or via OMERO with dispatcher as conn: @@ -517,7 +515,7 @@ class Pipeline(ProcessABC): for k in ["masks", "cell_labels"]: run_kwargs[step][k] = None # check and report clogging - frac_clogged_traps = self.check_earlystop( + frac_clogged_traps = check_earlystop( pipe["filename"], pipe["earlystop"], pipe["steps"]["tiler"].tile_size, @@ -691,50 +689,70 @@ class Pipeline(ProcessABC): pipe["tps"] = min(config["general"]["tps"], image.data.shape[0]) return pipe, session - @staticmethod - def check_earlystop(filename: str, es_parameters: dict, tile_size: int): - """ - Check recent time points for tiles with too many cells. - Returns the fraction of clogged tiles, where clogged tiles have - too many cells or too much of their area covered by cells. - - Parameters - ---------- - filename: str - Name of h5 file. - es_parameters: dict - Parameters defining when early stopping should happen. - For example: - {'min_tp': 100, - 'thresh_pos_clogged': 0.4, - 'thresh_trap_ncells': 8, - 'thresh_trap_area': 0.9, - 'ntps_to_eval': 5} - tile_size: int - Size of tile. - """ - # get the area of the cells organised by trap and cell number - s = Signal(filename) - df = s.get_raw("/extraction/general/None/area") - # check the latest time points only - cells_used = df[ - df.columns[-1 - es_parameters["ntps_to_eval"] : -1] - ].dropna(how="all") - # find tiles with too many cells - traps_above_nthresh = ( - cells_used.groupby("trap").count().apply(np.mean, axis=1) - > es_parameters["thresh_trap_ncells"] - ) - # find tiles with cells covering too great a fraction of the tiles' area - traps_above_athresh = ( - cells_used.groupby("trap").sum().apply(np.mean, axis=1) - / tile_size**2 - > es_parameters["thresh_trap_area"] - ) - return (traps_above_nthresh & traps_above_athresh).mean() +def check_earlystop(filename: str, es_parameters: dict, tile_size: int): + """ + Check recent time points for tiles with too many cells. + + Returns the fraction of clogged tiles, where clogged tiles have + too many cells or too much of their area covered by cells. + + Parameters + ---------- + filename: str + Name of h5 file. + es_parameters: dict + Parameters defining when early stopping should happen. + For example: + {'min_tp': 100, + 'thresh_pos_clogged': 0.4, + 'thresh_trap_ncells': 8, + 'thresh_trap_area': 0.9, + 'ntps_to_eval': 5} + tile_size: int + Size of tile. + """ + # get the area of the cells organised by trap and cell number + s = Signal(filename) + df = s.get_raw("/extraction/general/None/area") + # check the latest time points only + cells_used = df[ + df.columns[-1 - es_parameters["ntps_to_eval"] : -1] + ].dropna(how="all") + # find tiles with too many cells + traps_above_nthresh = ( + cells_used.groupby("trap").count().apply(np.mean, axis=1) + > es_parameters["thresh_trap_ncells"] + ) + # find tiles with cells covering too great a fraction of the tiles' area + traps_above_athresh = ( + cells_used.groupby("trap").sum().apply(np.mean, axis=1) + / tile_size**2 + > es_parameters["thresh_trap_area"] + ) + return (traps_above_nthresh & traps_above_athresh).mean() def close_session(session): + """Close session for multiprocessing.""" if session: session.close() + + +def initialise_tf(version): + """Initialise tensorflow.""" + if version == 1: + core_config = tf.ConfigProto() + core_config.gpu_options.allow_growth = True + session = tf.Session(config=core_config) + return session + if version == 2: + gpus = tf.config.experimental.list_physical_devices("GPU") + if gpus: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + logical_gpus = tf.config.experimental.list_logical_devices("GPU") + print( + len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs" + ) + return None