Skip to content
Snippets Groups Projects
Commit 4a055044 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

rename DummyRunner BabyRunner and clean old methods

parent 9496ba04
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,7 @@ import requests ...@@ -13,6 +13,7 @@ import requests
import tensorflow as tf import tensorflow as tf
from tqdm import tqdm from tqdm import tqdm
from agora.base import ParametersABC, ProcessABC
import baby.errors import baby.errors
from baby import modelsets from baby import modelsets
from baby.brain import BabyBrain from baby.brain import BabyBrain
...@@ -22,6 +23,7 @@ from requests_toolbelt.multipart.encoder import MultipartEncoder ...@@ -22,6 +23,7 @@ from requests_toolbelt.multipart.encoder import MultipartEncoder
from pcore.utils import Cache, accumulate, get_store_path from pcore.utils import Cache, accumulate, get_store_path
################### Dask Methods ################################ ################### Dask Methods ################################
def format_segmentation(segmentation, tp): def format_segmentation(segmentation, tp):
"""Format a single timepoint into a dictionary. """Format a single timepoint into a dictionary.
...@@ -61,69 +63,65 @@ def format_segmentation(segmentation, tp): ...@@ -61,69 +63,65 @@ def format_segmentation(segmentation, tp):
return merged return merged
def choose_model_from_params( class BabyParameters(ParametersABC):
modelset_filter=None, def __init__(
camera="prime95b", self,
channel="brightfield", model_config,
zoom="60x", tracker_params,
n_stacks="5z", clogging_thresh,
**kwargs, min_bud_tps,
): isbud_thresh,
""" session,
Define which model to query from the server based on a set of parameters. graph,
print_info,
Parameters suppress_errors,
---------- error_dump_dir,
valid_models: List[str] ):
The names of the models that are available. self.model_config = model_config
modelset_filter: str self.tracker_params = tracker_params
A regex filter to apply on the models to start. self.clogging_thresh = clogging_thresh
camera: str self.min_bud_tps = min_bud_tps
The camera used in the experiment (case insensitive). self.isbud_thresh = isbud_thresh
channel:str self.session = session
The channel used for segmentation (case insensitive). self.graph = graph
zoom: str self.print_info = print_info
The zoom on the channel. self.suppress_errors = suppress_errors
n_stacks: str self.error_dump_dir = error_dump_dir
The number of z_stacks to use in segmentation
@classmethod
Returns def default(cls, **kwargs):
------- """kwargs passes values to the model chooser"""
model_name : str return cls(
""" model_config=choose_model_from_params(**kwargs),
valid_models = list(modelsets().keys()) tracker_params=dict(ctrack_params=dict(), budtrack_params=dict()),
clogging_thresh=1,
# Apply modelset filter if specified min_bud_tps=3,
if modelset_filter is not None: isbud_thresh=0.5,
msf_regex = re.compile(modelset_filter) session=None,
valid_models = filter(msf_regex.search, valid_models) graph=None,
print_info=False,
# Apply parameter filters if specified suppress_errors=False,
params = [ error_dump_dir=None,
str(x) if x is not None else ".+" )
for x in [camera.lower(), channel.lower(), zoom, n_stacks]
]
params_re = re.compile("^" + "_".join(params) + "$")
valid_models = list(filter(params_re.search, valid_models))
# Check that there are valid models
if len(valid_models) == 0:
raise KeyError("No model sets found matching {}".format(", ".join(params)))
# Pick the first model
return valid_models[0]
class DummyRunner: class BabyRunner:
"""A BabyRunner object for cell segmentation. """A BabyRunner object for cell segmentation.
Does segmentation one time point at a time.""" Does segmentation one time point at a time."""
def __init__(self, tiler, *args, **kwargs): def __init__(self, tiler, parameters=None, *args, **kwargs):
self.tiler = tiler self.tiler = tiler
self.model_config = modelsets()[choose_model_from_params(**kwargs)] # self.model_config = modelsets()[choose_model_from_params(**kwargs)]
self.model_config = parameters.model_config
self.brain = BabyBrain(**self.model_config) self.brain = BabyBrain(**self.model_config)
self.crawler = BabyCrawler(self.brain) self.crawler = BabyCrawler(self.brain)
self.bf_channel = self.tiler.get_channel_index("Brightfield") self.bf_channel = self.tiler.get_channel_index("Brightfield")
@classmethod
def from_tiler(cls, parameters: BabyParameters, tiler):
return cls(tiler, parameters)
def get_data(self, tp): def get_data(self, tp):
# Swap axes x and z, probably shouldn't swap, just move z # Swap axes x and z, probably shouldn't swap, just move z
return self.tiler.get_tp_data(tp, self.bf_channel).swapaxes(1, 3).swapaxes(1, 2) return self.tiler.get_tp_data(tp, self.bf_channel).swapaxes(1, 3).swapaxes(1, 2)
...@@ -138,7 +136,7 @@ class DummyRunner: ...@@ -138,7 +136,7 @@ class DummyRunner:
return format_segmentation(segmentation, tp) return format_segmentation(segmentation, tp)
class DummyClient: class BabyClient:
"""A dummy BabyClient object for Dask Demo. """A dummy BabyClient object for Dask Demo.
...@@ -207,451 +205,52 @@ class DummyClient: ...@@ -207,451 +205,52 @@ class DummyClient:
return format_segmentation(seg, tp) return format_segmentation(seg, tp)
################### Old Methods ################################# def choose_model_from_params(
# class BabyNoMatches(Exception): modelset_filter=None,
# pass camera="prime95b",
# channel="brightfield",
# zoom="60x",
# class BabyNoSilent(Exception): n_stacks="5z",
# pass **kwargs,
# ):
# """
# # Todo: add defaults! Define which model to query from the server based on a set of parameters.
# Parameters
# ----------
# def create_request(dims, bit_depth, img, **kwargs): valid_models: List[str]
# """ The names of the models that are available.
# Construct a multipart/form-data request with the following modelset_filter: str
# information in the given order: A regex filter to apply on the models to start.
# :param session_id: the session ID (for tracking) camera: str
# :param dims: the dimensions of the images The camera used in the experiment (case insensitive).
# :param bit_depth: the bit-depth of the images, must be "8" or "16" channel:str
# :param img: the image to segment, flattened in order 'F' The channel used for segmentation (case insensitive).
# :return: a MultipartEncoder to use as data for the request. zoom: str
# """ The zoom on the channel.
# fields = collections.OrderedDict([ n_stacks: str
# ("dims", json.dumps(dims)), The number of z_stacks to use in segmentation
# ("bitdepth", json.dumps(bit_depth)),
# ("img", img.tostring(order='F'))]) Returns
# # Add optional arguments -------
# fields.update({kw: json.dumps(v) for kw, v in kwargs.items()}) model_name : str
# m = MultipartEncoder( """
# fields=fields, valid_models = list(modelsets().keys())
# boundary="----BabyFormBoundary"
# ) # Apply modelset filter if specified
# return m if modelset_filter is not None:
# msf_regex = re.compile(modelset_filter)
# valid_models = filter(msf_regex.search, valid_models)
# class BabyClient:
# def __init__(self, tiler, url='http://localhost:5101', **kwargs): # Apply parameter filters if specified
# self.tiler = tiler params = [
# self.url = url str(x) if x is not None else ".+"
# self._config = kwargs for x in [camera.lower(), channel.lower(), zoom, n_stacks]
# r_model_sets = requests.get(self.url + '/models') ]
# self.valid_models = r_model_sets.json() params_re = re.compile("^" + "_".join(params) + "$")
# self._model_set = choose_model_from_params(self.valid_models, valid_models = list(filter(params_re.search, valid_models))
# **self.config) # Check that there are valid models
# self.sessions = Cache(load_fn=lambda _: self.get_new_session()) if len(valid_models) == 0:
# self.processing = [] raise KeyError("No model sets found matching {}".format(", ".join(params)))
# self.z = None # Pick the first model
# self.channel = None return valid_models[0]
# self.__init_properties(self.config)
#
# def __init_properties(self, config):
# n_stacks = int(config.get('n_stacks', '5z').replace('z', ''))
# self.z = list(range(n_stacks))
# self.channel = config.get('channel', 'Brightfield')
#
# @property
# def model_set(self):
# return self._model_set
#
# @model_set.setter
# def model_set(self, model_set):
# if self._model_set != model_set:
# # Need a new session if the model_set has changed
# self.session_id = ""
# self._model_set = model_set
# else:
# pass
#
# @property
# def config(self):
# return self._config
#
# @config.setter
# def config(self, config):
# if self._config is not None and config is not None:
# raise BabyNoSilent("Can only silently set a configuration "
# "to/from None")
# else:
# self._config = config
# self.model_set = choose_model_from_params(**self.config)
#
# def get_new_session(self):
# try:
# r_session = requests.get(self.url +
# '/session/{}'.format(self.model_set))
# r_session.raise_for_status()
# return r_session.json()["sessionid"]
# except KeyError as e:
# raise e
# except HTTPError as e:
# raise e
#
# def queue_image(self, img, session_id, **kwargs):
# # TODO validate image type?
# # TODO character encoding options?
# bit_depth = img.dtype.itemsize * 8 # bit depth = byte_size * 8
# data = create_request(img.shape, bit_depth, img, **kwargs)
# status = requests.post(self.url +
# '/segment?sessionid={}'.format(session_id),
# data=data,
# headers={'Content-Type': data.content_type})
# status.raise_for_status()
# return status
#
# def get_segmentation(self, session_id):
# try:
# seg_response = requests.get(
# self.url + '/segment?sessionid={}'.format(session_id),
# timeout=120)
# seg_response.raise_for_status()
# result = seg_response.json()
# except Timeout as e:
# raise e
# except HTTPError as e:
# raise e
# return result
#
# def process_position(self, position: str, tps: Iterable[int], store,
# save_dir, tile_size=96, **kwargs):
# # Open the store
# store_file = get_store_path(save_dir, store, position)
# with h5py.File(store_file, 'a') as file:
# hfile = file.require_group('cell_info')
# processed = hfile.require_dataset('processed_timepoints',
# maxshape=(None,),
# dtype=np.uint16)
# # reset the time points to avoid double-processing
# tps = [t for t in tps if t not in processed]
# position_results = []
# skipped = []
# for timepoint_id in tps:
# try:
# # Finish processing previously queued images
# self.flush_processing(position_results)
# self.process_timepoint(position, timepoint_id,
# tile_size=tile_size)
# except KeyError as e:
# # TODO log that this will not be processed
# skipped.append(timepoint_id)
# # Flush all processing before moving to the next position
# mother_assign = None
# while len(self.processing) > 0:
# mother_assign = self.flush_processing(position_results)
# store_position(position_results, self.tiler.positions.index(position),
# store, save_dir, position, mother_assign=mother_assign,
# tile_size=tile_size)
# processed_tps = [t for t in tps if t not in skipped]
# with h5py.File(store_file, 'a') as file:
# hfile = file.require_group('cell_info')
# processed = hfile['processed_timepoints']
# if processed.shape[0] < max(processed_tps):
# processed.resize(max(processed_tps), axis=0)
# processed[processed_tps] = processed_tps
# return processed_tps
#
# def process_timepoint(self, pos, timepoint, tile_size=96):
# channel_idx = [self.tiler.get_channel_index(self.channel)]
# traps = self.tiler[pos].get_traps_timepoint(timepoint,
# channels=channel_idx,
# tile_size=tile_size,
# z=self.z)
# traps = np.squeeze(traps)
# timepoint_key = (pos, timepoint)
# session_id = self.sessions[pos]
# self.queue_image(traps, session_id)
# self.processing.append(timepoint_key)
#
# def flush_processing(self, position_results):
# """ Get the results of previously queued images.
#
# :return:
# """
# for pos, tp in self.processing:
# try:
# result = self.get_segmentation(self.sessions[pos])
# tp_dataframe, mother_assign = format_segmentation(result, tp)
# position_results.append(tp_dataframe)
# self.processing.remove((pos, tp))
# except Timeout:
# continue
# except HTTPError:
# continue
# except TypeError as e:
# raise e
# return mother_assign
#
# def run(self, keys, store='store.h5', **kwargs):
# # key are (pos, timepoint) tuples
# run_tps = dict()
# for pos, tps in accumulate(keys):
# run_tps[pos] = self.process_position(pos, tps, store, **kwargs)
# keys = [(pos, tp) for pos in run_tps for tp in run_tps[pos]]
# return keys
#
#
# def format_segmentation(segmentation, tp):
# """ Format a single timepoint into a dataframe and append to the
# position results.
# :param segmentation: A list of results, each result is the output of the
# crawler, which is JSON-encoded
# :param tp: The time point considered
# :return: A pandas dataframe containing the formatted results of BABY
# """
# # Segmentation is a list of dictionaries, ordered by trap
# # Add trap information
# mother_assign = None
# for i, x in enumerate(segmentation):
# x['trap'] = [i] * len(x['cell_label'])
# # Merge into a dictionary of lists, by column
# merged = {k: list(itertools.chain.from_iterable(
# res[k] for res in segmentation))
# for k in segmentation[0].keys()}
# # Special case for mother_assign
# if 'mother_assign' in merged:
# del merged['mother_assign']
# mother_assign = [x['mother_assign'] for x in segmentation]
# # Check that the lists are all of the same length (in case of errors in
# # BABY)
# n_cells = min([len(v) for v in merged.values()])
# merged = {k: v[:n_cells] for k,v in merged.items()}
# tp_dataframe = pd.DataFrame(merged)
# # Set time point value for all traps
# tp_dataframe['timepoint'] = tp
# return tp_dataframe, mother_assign
#
#
# def store_position(position_results, position_index, store, save_dir,
# position_name, mother_assign=None, tile_size=96):
# """Store the results from a set of timepoints for a given position to
# and HDF5 store
# :param position_results: List of timepoint dataframes as returned by
# `format_segmentation`
# :param position_index: The index of the position considered
# :param store: The name of the HDF5 store to use.
# :return:
# """
# # Combine all of the results into one data frame
# position_results = pd.concat(position_results)
# store_file = get_store_path(save_dir, store, position_name)
# df_to_hdf(position_results, store_file, mother_assign=mother_assign,
# tile_size=tile_size)
# return
#
#
# def sparsity(arr):
# """Defines a sparsity score for a matrix based on the percentage of
# zeros."""
# try:
# return 1.0 - np.count_nonzero(arr) / arr.size
# except:
# return 1
#
#
# def df_to_hdf(df, filename, mother_assign=None, tile_size=96):
# """Convert the dataframe of segmentation results into an HDF5 file.
# :param df: The dataframe.
# :param filename: The Name of the HDF5 file to use.
# :return:
# """
# # TODO: Use numpy min_scalar_type!
# datatypes = {
# 'centres': ((None, 2), np.uint16),
# 'position': ((None,), np.uint16),
# 'angles': ((None,), h5py.vlen_dtype(np.float32)),
# 'radii': ((None,), h5py.vlen_dtype(np.float32)),
# 'edgemasks': ((None, tile_size, tile_size), np.bool),
# 'ellipse_dims': ((None, 2), np.float32),
# 'cell_label': ((None,), np.uint16),
# 'trap': ((None,), np.uint16),
# 'timepoint': ((None,), np.uint16),
# 'mother_assign': ((None,), np.uint16)
# }
#
# file = h5py.File(filename, 'a')
# hfile = file.require_group('cell_info')
#
# n = len(df)
# for key in df.columns:
# # We're only saving data that has a pre-defined data-type
# if key not in datatypes:
# raise KeyError(f"No defined data type for key {key}")
# if key not in hfile:
# # TODO Include sparsity check
# max_shape, dtype = datatypes[key]
# shape = (n,) + max_shape[1:]
# data = df[key].to_list()
# hfile.create_dataset(key, shape=shape, maxshape=max_shape,
# dtype=dtype, compression='gzip')
# hfile[key][()] = data
# else:
# # The dataset already exists, expand it
# dset = hfile[key]
# dset.resize(dset.shape[0] + n, axis=0)
# dset[-n:] = df[key].tolist()
# if mother_assign:
# # We do not append to mother_assign; raise error if already saved
# n = len(mother_assign)
# hfile.require_dataset('mother_assign', shape=(n,),
# dtype=h5py.vlen_dtype(np.uint16),
# compression='gzip')
# hfile['mother_assign'][()] = mother_assign
# file.close()
# return
#
#
# class BabyRunner:
# valid_models = modelsets()
# ERROR_DUMP_DIR = 'baby-errors'
#
# def __init__(self, tiler, error_dump_dir=None, **kwargs):
# self.tiler = tiler
# if error_dump_dir is None:
# self.error_dump_dir = self.ERROR_DUMP_DIR
# self._config = kwargs
# model_name = choose_model_from_params(self.valid_models, **self.config)
# self.sessions = Cache(load_fn=lambda _: self.session())
# self.z = None
# self.channel = None
# self.default_image_size = None
# self.__init_properties(self.config)
# # Create tensorflow objects
# self.tf_session = None
# self.tf_graph = None
# # TODO: put the tensorflow initilization in a separate function
# tf_version = tuple(int(v) for v in tf.version.VERSION.split('.'))
# if tf_version[0] == 1:
# config = tf.ConfigProto()
# config.gpu_options.allow_growth = True
# self.tf_session = tf.Session(config=config)
# self.tf_graph = tf.get_default_graph()
# elif tf_version[0] == 2:
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
# for gpu in gpus:
# tf.config.experimental.set_memory_growth(gpu, True)
# logical_gpus = tf.config.experimental.list_logical_devices(
# 'GPU')
# print(len(gpus), "Physical GPUs,", len(logical_gpus),
# "Logical GPUs")
# # Overriding some of the default model values in baby to avoid errors
# model_config = self.valid_models[model_name]
# default_image_size = self.config.get("default_image_size", None)
# if default_image_size:
# model_config["default_image_size"] = default_image_size
# self.default_image_size = default_image_size
# # Getting the runner
# self.brain = BabyBrain(**model_config,
# session=self.tf_session, graph=self.tf_graph,
# suppress_errors=True,
# error_dump_dir=self.error_dump_dir,
# )
#
# @property
# def config(self):
# return self._config
#
# def __init_properties(self, config):
# n_stacks = int(config.get('n_stacks', '5z').replace('z', ''))
# self.z = slice(0, n_stacks)
# self.channel = self.tiler.get_channel_index(
# config.get('channel', 'Brightfield'))
#
# def session(self):
# return BabyCrawler(self.brain)
#
# def segment(self, img, sessionid, **kwargs):
# # Getting the result for a given image
# crawler = self.sessions[sessionid]
# pred = crawler.step(img, **kwargs)
# return pred
#
# def process_position(self, position, tps, store, save_dir, verbose,
# **kwargs):
# """ Segment the position for the given number of time points and save.
#
# :param save_dir: Directory in which to save results
# :param position: The name of the position to segment
# :param tps: A list of time points on which to run the segmentation
# :param store: The file in which to save the results, as csv. Results
# are appended to this file so make sure not to use a previously used
# file name or you will have hard-to-find duplicates!
# :param verbose: Set to show progression of the time points
# :param kwargs: Additional segmentation parameters, to be given to
# the BABY crawler
# :return: None
# """
# self.tiler.current_position = position
# position_results = []
# skipped = []
# mother_assign = None
# # Open the store
# store_file = get_store_path(save_dir, store, position)
# with h5py.File(store_file, 'a') as file:
# hfile = file.require_group('cell_info')
# if 'processed_timepoints' in hfile:
# processed = hfile['processed_timepoints']
# # reset the time points to avoid double-processing
# tps = [t for t in tps if t not in processed]
# else:
# processed = hfile.create_dataset('processed_timepoints',
# shape=(len(tps),),
# maxshape=(None, ),
# dtype=np.uint16)
# for tp in tqdm(tps, desc=position, disable=not verbose):
# try:
# t = time.perf_counter()
# traps = np.squeeze(
# self.tiler.get_traps_timepoint(tp, channels=[self.channel],
# z=self.z,
# tile_size=self.default_image_size))
# t2 = time.perf_counter()
# print(f"Loading image {position}, {tp} in {t2 - t}s")
# segmentation = self.segment(traps, position, **kwargs)
# print(f"Segmenting image {position}, {tp} in "
# f"{time.perf_counter() - t2}s")
# # Segmentation is a list of dictionaries, ordered by trap
# # Add trap information
# tp_dataframe, mother_assign = format_segmentation(segmentation, tp)
# position_results.append(tp_dataframe)
# except baby.errors.BadOutput as e:
# skipped.append(tp)
# continue
# #raise (e)
# store_position(position_results,
# self.tiler.positions.index(position),
# store, save_dir, position_name=position,
# mother_assign=mother_assign,
# tile_size=self.default_image_size)
# processed_tps = [t for t in tps if t not in skipped]
# with h5py.File(store_file, 'a') as hfile:
# processed = hfile['/cell_info/processed_timepoints']
# if processed.shape[0] < max(processed_tps):
# processed.resize(max(processed_tps), axis=0)
# processed[processed_tps] = processed_tps
# return processed_tps
#
# def run(self, keys, store, clear_cache=False, **kwargs):
# save_dir = self.tiler.expt.root_dir
# # key are (pos, timepoint) tuples
# run_tps = dict()
# for pos, tps in accumulate(keys):
# run_tps[pos] = self.process_position(pos, tps, store, save_dir,
# **kwargs)
# if clear_cache:
# self.tiler[pos].clear_cache()
# keys = [(pos, tp) for pos in run_tps for tp in run_tps[pos]]
# return keys
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment