diff --git a/aliby/baby_client.py b/aliby/baby_client.py
index 195eca5c43c1711c589c5315ff31901d0fe8c570..5731967b3a5dd34da2457818de033ce6ceb6d68e 100644
--- a/aliby/baby_client.py
+++ b/aliby/baby_client.py
@@ -5,6 +5,8 @@ import itertools
 import logging
 import re
 import time
+import typing as t
+from pathlib import Path, PosixPath
 from time import perf_counter
 
 import baby.errors
@@ -63,7 +65,9 @@ def format_segmentation(segmentation, tp):
 class BabyParameters(ParametersABC):
     def __init__(
         self,
-        model_config,
+        model_config: t.Dict[
+            str, t.Union[str, int, float, bool, t.List[t.Union[int, float]]]
+        ],
         tracker_params,
         clogging_thresh,
         min_bud_tps,
@@ -73,7 +77,7 @@ class BabyParameters(ParametersABC):
         print_info,
         suppress_errors,
         error_dump_dir,
-        tf_version,
+        tf_version: int,
     ):
         self.model_config = model_config
         self.tracker_params = tracker_params
@@ -104,6 +108,25 @@ class BabyParameters(ParametersABC):
             tf_version=2,
         )
 
+    def update_baby_modelset(
+        self, path: t.Union[str, PosixPath, t.Dict[str, str]]
+    ):
+        """
+        Replace default BABY model and flattener with another one from a folder outputted
+        by our standard retraining script.
+        """
+
+        if isinstance(path, dict):
+            weights_flattener = {k: Path(v) for k, v in path.items()}
+        else:
+            weights_dir = Path(path)
+            weights_flattener = {
+                "flattener_file": weights_dir.parent / "flattener.json",
+                "morph_model_file": weights_dir / "weights.h5",
+            }
+
+        self.update("model_config", weights_flattener)
+
 
 class BabyRunner:
     """A BabyRunner object for cell segmentation.
@@ -113,13 +136,11 @@ class BabyRunner:
     def __init__(self, tiler, parameters=None, **kwargs):
         self.tiler = tiler
         # self.model_config = modelsets()[choose_model_from_params(**kwargs)]
-        self.model_config = modelsets()[
-            (
-                parameters.model_config
-                if parameters is not None
-                else choose_model_from_params(**kwargs)
-            )
-        ]
+        self.model_config = (
+            choose_model_from_params(**kwargs)
+            if parameters is None
+            else parameters.model_config
+        )
         self.brain = BabyBrain(**self.model_config)
         self.crawler = BabyCrawler(self.brain)
         self.bf_channel = self.tiler.get_channel_index("Brightfield")
@@ -269,4 +290,4 @@ def choose_model_from_params(
             "No model sets found matching {}".format(", ".join(params))
         )
     # Pick the first model
-    return valid_models[0]
+    return modelsets()[valid_models[0]]