Skip to content
Snippets Groups Projects
Commit abc8df96 authored by pswain's avatar pswain
Browse files

added baby_sitter to run new version of BABY

parent 4fc31051
No related branches found
No related tags found
No related merge requests found
......@@ -22,18 +22,16 @@ from requests.exceptions import HTTPError, Timeout
################### Dask Methods ################################
def format_segmentation(segmentation, tp):
"""Format a single timepoint into a dictionary.
"""
Format BABY's results from a single time point into a dictionary.
Parameters
------------
segmentation: list
A list of results, each result is the output of the crawler, which is JSON-encoded
A list of results, each result is the output of BABY
crawler, which is JSON-encoded.
tp: int
the time point considered
Returns
--------
A dictionary containing the formatted results of BABY
The time point.
"""
# Segmentation is a list of dictionaries, ordered by trap
# Add trap information
......@@ -179,7 +177,7 @@ class BabyRunner(StepABC):
def choose_model_from_params(
modelset_filter=None,
camera="prime95b",
camera="sCMOS",
channel="brightfield",
zoom="60x",
n_stacks="5z",
......@@ -204,7 +202,9 @@ def choose_model_from_params(
-------
model_name : str
"""
# cameras prime95 has become sCMOS and evolve has EMCCD
valid_models = list(modelsets().keys())
breakpoint()
# Apply modelset filter if specified
if modelset_filter is not None:
......@@ -218,6 +218,7 @@ def choose_model_from_params(
]
params_re = re.compile("^" + "_".join(params) + "$")
valid_models = list(filter(params_re.search, valid_models))
breakpoint()
# Check that there are valid models
if len(valid_models) == 0:
raise KeyError(
......
import itertools
import re
import typing as t
from pathlib import Path
import numpy as np
from baby import BabyCrawler, modelsets
from agora.abc import ParametersABC, StepABC
class BabyParameters(ParametersABC):
"""Parameters used for analysing the results from BABY."""
def __init__(
self,
modelset_name,
tracker_params,
clogging_thresh,
min_bud_tps,
isbud_thresh,
session,
graph,
print_info,
suppress_errors,
error_dump_dir,
tf_version: int,
):
"""Initialise parameters for BABY."""
self.modelset_name = modelset_name
self.tracker_params = tracker_params
self.clogging_thresh = clogging_thresh
self.min_bud_tps = min_bud_tps
self.isbud_thresh = isbud_thresh
self.session = session
self.graph = graph
self.print_info = print_info
self.suppress_errors = suppress_errors
self.error_dump_dir = error_dump_dir
self.tf_version = tf_version
@classmethod
def default(cls, **kwargs):
"""Define default parameters; kwargs choose BABY model set."""
return cls(
modelset_name=get_modelset_name_from_params(**kwargs),
tracker_params=dict(ctrack_params=dict(), budtrack_params=dict()),
clogging_thresh=1,
min_bud_tps=3,
isbud_thresh=0.5,
session=None,
graph=None,
print_info=False,
suppress_errors=False,
error_dump_dir=None,
tf_version=2,
)
def update_baby_modelset(self, path: t.Union[str, Path, t.Dict[str, str]]):
"""
Replace default BABY model and flattener.
Both are saved in a folder by our retraining script.
"""
if isinstance(path, dict):
weights_flattener = {k: Path(v) for k, v in path.items()}
else:
weights_dir = Path(path)
weights_flattener = {
"flattener_file": weights_dir.parent / "flattener.json",
"morph_model_file": weights_dir / "weights.h5",
}
self.update("modelset_name", weights_flattener)
class BabyRunner(StepABC):
"""
A BabyRunner object for cell segmentation.
Segments one time point at a time.
"""
def __init__(self, tiler, parameters=None, **kwargs):
"""Instantiate from a Tiler object."""
self.tiler = tiler
modelset_name = (
get_modelset_name_from_params(**kwargs)
if parameters is None
else parameters.modelset_name
)
tiler_z = self.tiler.shape[-3]
if f"{tiler_z}z" not in modelset_name:
raise KeyError(
f"Tiler z-stack ({tiler_z}) and model"
f" ({modelset_name}) do not match."
)
self.brain = modelsets.get(modelset_name)
self.crawler = BabyCrawler(self.brain)
self.brightfield_channel = self.tiler.ref_channel_index
@classmethod
def from_tiler(cls, parameters: BabyParameters, tiler):
"""Explicitly instantiate from a Tiler object."""
return cls(tiler, parameters)
def get_data(self, tp):
"""Get image and re-arrange axes."""
img_from_tiler = self.tiler.get_tp_data(tp, self.brightfield_channel)
# move z axis to the last axis
img_z_at_end = np.moveaxis(img_from_tiler, 1, destination=-1)
# move y axis before the x axis
img = np.moveaxis(img_z_at_end, 2, destination=1)
return img
def _run_tp(
self,
tp,
refine_outlines=True,
assign_mothers=True,
with_edgemasks=True,
**kwargs,
):
"""Segment data from one time point."""
img = self.get_data(tp)
segmentation = self.crawler.step(
img,
refine_outlines=refine_outlines,
assign_mothers=assign_mothers,
with_edgemasks=with_edgemasks,
**kwargs,
)
return format_segmentation(segmentation, tp)
def get_modelset_name_from_params(
imaging_device="alcatras",
channel="brightfield",
camera="sCMOS",
zoom="60x",
n_stacks="5z",
):
"""Get the appropriate model set from BABY's trained models."""
# list of models - microscopy setups - for which BABY has been trained
# cameras prime95 and evolve have become sCMOS and EMCCD
possible_models = list(modelsets.remote_modelsets()["models"].keys())
# filter possible_models
params = [
str(x) if x is not None else ".+"
for x in [imaging_device, channel.lower(), camera, zoom, n_stacks]
]
params_regex = re.compile("-".join(params) + "$")
valid_models = [
res for res in filter(params_regex.search, possible_models)
]
# check that there are valid models
if len(valid_models) == 1:
return valid_models[0]
else:
raise KeyError(
"Error in finding BABY model sets matching {}".format(
", ".join(params)
)
)
def format_segmentation(segmentation, tp):
"""
Format BABY's results for a single time point into a dict.
The dict has BABY's outputs as keys and lists of the results
for each segmented cell as values.
Parameters
------------
segmentation: list
A list of BABY's results as dicts for each tile.
tp: int
The time point.
"""
# segmentation is a list of dictionaries for each tile
for i, tile_dict in enumerate(segmentation):
# assign the trap ID to each cell identified
tile_dict["trap"] = [i] * len(tile_dict["cell_label"])
# record mothers for each labelled cell
tile_dict["mother_assign_dynamic"] = np.array(
tile_dict["mother_assign"]
)[np.array(tile_dict["cell_label"], dtype=int) - 1]
# merge into a dict with BABY's outputs as keys and
# lists of results for all cells as values
merged = {
output: list(
itertools.chain.from_iterable(
tile_dict[output] for tile_dict in segmentation
)
)
for output in segmentation[0].keys()
}
# remove mother_assign
merged.pop("mother_assign", None)
# ensure that each value is a list of the same length
no_cells = min([len(v) for v in merged.values()])
merged = {k: v[:no_cells] for k, v in merged.items()}
# define time point key
merged["timepoint"] = [tp] * no_cells
return merged
......@@ -14,13 +14,20 @@ import numpy as np
from pathos.multiprocessing import Pool
from tqdm import tqdm
import baby
try:
if baby.__version__ == "v0.30.1":
from aliby.baby_sitter import BabyParameters, BabyRunner
except AttributeError:
from aliby.baby_client import BabyParameters, BabyRunner
import aliby.global_parameters as global_parameters
from agora.abc import ParametersABC, ProcessABC
from agora.io.metadata import MetaData, parse_logfiles
from agora.io.reader import StateReader
from agora.io.signal import Signal
from agora.io.writer import LinearBabyWriter, StateWriter, TilerWriter
from aliby.baby_client import BabyParameters, BabyRunner
from aliby.haystack import initialise_tf
from aliby.io.dataset import dispatch_dataset
from aliby.io.image import dispatch_image
......@@ -35,6 +42,7 @@ from postprocessor.core.postprocessing import (
PostProcessorParameters,
)
# stop warnings from TensorFlow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger("tensorflow").setLevel(logging.ERROR)
......
......@@ -92,6 +92,8 @@ def max2p5pc(cell_mask, trap_image) -> float:
def max5px_median(cell_mask, trap_image) -> float:
"""
Estimate the degree of localisation.
Find the mean of the five brightest pixels in the cell divided by the
median of all pixels.
......@@ -103,14 +105,17 @@ def max5px_median(cell_mask, trap_image) -> float:
"""
# sort pixels in cell
pixels = trap_image[cell_mask]
top_values = bn.partition(pixels, len(pixels) - 5)[-5:]
# find mean of five brightest pixels
max5px = np.mean(top_values)
med = np.median(pixels)
if med == 0:
return np.nan
if len(pixels) > 5:
top_values = bn.partition(pixels, len(pixels) - 5)[-5:]
# find mean of five brightest pixels
max5px = np.mean(top_values)
med = np.median(pixels)
if med == 0:
return np.nan
else:
return max5px / np.median(pixels)
else:
return max5px / np.median(pixels)
return np.nan
def std(cell_mask, trap_image):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment