diff --git a/aliby/baby_client.py b/aliby/baby_client.py index 195eca5c43c1711c589c5315ff31901d0fe8c570..5731967b3a5dd34da2457818de033ce6ceb6d68e 100644 --- a/aliby/baby_client.py +++ b/aliby/baby_client.py @@ -5,6 +5,8 @@ import itertools import logging import re import time +import typing as t +from pathlib import Path, PosixPath from time import perf_counter import baby.errors @@ -63,7 +65,9 @@ def format_segmentation(segmentation, tp): class BabyParameters(ParametersABC): def __init__( self, - model_config, + model_config: t.Dict[ + str, t.Union[str, int, float, bool, t.List[t.Union[int, float]]] + ], tracker_params, clogging_thresh, min_bud_tps, @@ -73,7 +77,7 @@ class BabyParameters(ParametersABC): print_info, suppress_errors, error_dump_dir, - tf_version, + tf_version: int, ): self.model_config = model_config self.tracker_params = tracker_params @@ -104,6 +108,25 @@ class BabyParameters(ParametersABC): tf_version=2, ) + def update_baby_modelset( + self, path: t.Union[str, PosixPath, t.Dict[str, str]] + ): + """ + Replace default BABY model and flattener with another one from a folder outputted + by our standard retraining script. + """ + + if isinstance(path, dict): + weights_flattener = {k: Path(v) for k, v in path.items()} + else: + weights_dir = Path(path) + weights_flattener = { + "flattener_file": weights_dir.parent / "flattener.json", + "morph_model_file": weights_dir / "weights.h5", + } + + self.update("model_config", weights_flattener) + class BabyRunner: """A BabyRunner object for cell segmentation. @@ -113,13 +136,11 @@ class BabyRunner: def __init__(self, tiler, parameters=None, **kwargs): self.tiler = tiler # self.model_config = modelsets()[choose_model_from_params(**kwargs)] - self.model_config = modelsets()[ - ( - parameters.model_config - if parameters is not None - else choose_model_from_params(**kwargs) - ) - ] + self.model_config = ( + choose_model_from_params(**kwargs) + if parameters is None + else parameters.model_config + ) self.brain = BabyBrain(**self.model_config) self.crawler = BabyCrawler(self.brain) self.bf_channel = self.tiler.get_channel_index("Brightfield") @@ -269,4 +290,4 @@ def choose_model_from_params( "No model sets found matching {}".format(", ".join(params)) ) # Pick the first model - return valid_models[0] + return modelsets()[valid_models[0]]