Skip to content
Snippets Groups Projects
Commit 5c2923a9 authored by pswain's avatar pswain
Browse files

tidy(pipeline): removed haystack; added initialise_tf to pipeline

parent 5ca40821
No related branches found
No related tags found
No related merge requests found
"""
Neural network initialisation.
"""
from pathlib import Path
from time import perf_counter
import numpy as np
import tensorflow as tf
from agora.io.writer import DynamicWriter
def initialise_tf(version):
# Initialise tensorflow
if version == 1:
core_config = tf.ConfigProto()
core_config.gpu_options.allow_growth = True
session = tf.Session(config=core_config)
return session
# TODO this only works for TF2
if version == 2:
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(
len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
)
return None
def timer(func, *args, **kwargs):
start = perf_counter()
result = func(*args, **kwargs)
print(f"Function {func.__name__}: {perf_counter() - start}s")
return result
################## CUSTOM OBJECTS ##################################
class ModelPredictor:
"""Generic object that takes a NN and returns the prediction.
Use for predicting fluorescence/other from bright field.
This does not do instance segmentations of anything.
"""
def __init__(self, tiler, model, name):
self.tiler = tiler
self.model = model
self.name = name
def get_data(self, tp):
# Change axes to X,Y,Z rather than Z,Y,X
return (
self.tiler.get_tp_data(tp, self.bf_channel)
.swapaxes(1, 3)
.swapaxes(1, 2)
)
def format_result(self, result, tp):
return {self.name: result, "timepoints": [tp] * len(result)}
def run_tp(self, tp):
"""Simulating processing time with sleep"""
# Access the image
segmentation = self.model.predict(self.get_data(tp))
return self._format_result(segmentation, tp)
class ModelPredictorWriter(DynamicWriter):
def __init__(self, file, name, shape, dtype):
super.__init__(file)
self.datatypes = {
name: (shape, dtype),
"timepoint": ((None,), np.uint16),
}
self.group = f"{self.name}_info"
......@@ -9,14 +9,14 @@ from importlib.metadata import version
from pathlib import Path
from pprint import pprint
import baby
import baby.errors
import h5py
import numpy as np
import tensorflow as tf
from pathos.multiprocessing import Pool
from tqdm import tqdm
import baby
import baby.errors
try:
if baby.__version__ == "v0.30.1":
from aliby.baby_sitter import BabyParameters, BabyRunner
......@@ -25,11 +25,10 @@ except AttributeError:
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, ProcessABC
from agora.io.metadata import MetaData, parse_logfiles
from agora.io.metadata import MetaData
from agora.io.reader import StateReader
from agora.io.signal import Signal
from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter
from aliby.haystack import initialise_tf
from aliby.io.dataset import dispatch_dataset
from aliby.io.image import dispatch_image
from aliby.tile.tiler import Tiler, TilerParameters
......@@ -43,7 +42,6 @@ from postprocessor.core.postprocessing import (
PostProcessorParameters,
)
# stop warnings from TensorFlow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger("tensorflow").setLevel(logging.ERROR)
......@@ -307,9 +305,9 @@ class Pipeline(ProcessABC):
pprint(config[step])
print()
try:
print(f"Using Baby {baby.__version__}")
print(f"Using Baby {baby.__version__}.\n")
except AttributeError:
print("Using Baby - no version specified.")
print("Using Baby - no version specified.\n")
# extract from configuration
expt_id = config["general"]["id"]
distributed = config["general"]["distributed"]
......@@ -321,7 +319,7 @@ class Pipeline(ProcessABC):
}
dispatcher = dispatch_dataset(expt_id, **self.server_info)
logging.getLogger("aliby").info(
f"Fetching data using {dispatcher.__class__.__name__}"
f"Fetching data using {dispatcher.__class__.__name__}."
)
# get log files, either locally or via OMERO
with dispatcher as conn:
......@@ -517,7 +515,7 @@ class Pipeline(ProcessABC):
for k in ["masks", "cell_labels"]:
run_kwargs[step][k] = None
# check and report clogging
frac_clogged_traps = self.check_earlystop(
frac_clogged_traps = check_earlystop(
pipe["filename"],
pipe["earlystop"],
pipe["steps"]["tiler"].tile_size,
......@@ -691,50 +689,70 @@ class Pipeline(ProcessABC):
pipe["tps"] = min(config["general"]["tps"], image.data.shape[0])
return pipe, session
@staticmethod
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
"""
Check recent time points for tiles with too many cells.
Returns the fraction of clogged tiles, where clogged tiles have
too many cells or too much of their area covered by cells.
Parameters
----------
filename: str
Name of h5 file.
es_parameters: dict
Parameters defining when early stopping should happen.
For example:
{'min_tp': 100,
'thresh_pos_clogged': 0.4,
'thresh_trap_ncells': 8,
'thresh_trap_area': 0.9,
'ntps_to_eval': 5}
tile_size: int
Size of tile.
"""
# get the area of the cells organised by trap and cell number
s = Signal(filename)
df = s.get_raw("/extraction/general/None/area")
# check the latest time points only
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
# find tiles with too many cells
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
# find tiles with cells covering too great a fraction of the tiles' area
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
def close_session(session):
"""Close session for multiprocessing."""
if session:
session.close()
def initialise_tf(version):
"""Initialise tensorflow."""
if version == 1:
core_config = tf.ConfigProto()
core_config.gpu_options.allow_growth = True
session = tf.Session(config=core_config)
return session
if version == 2:
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(
len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
)
return None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment