Skip to content
Snippets Groups Projects
Commit 2afa7d57 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

add babyParameters being passed

parent c7e762b6
No related branches found
No related tags found
No related merge requests found
...@@ -76,6 +76,7 @@ class BabyParameters(ParametersABC): ...@@ -76,6 +76,7 @@ class BabyParameters(ParametersABC):
print_info, print_info,
suppress_errors, suppress_errors,
error_dump_dir, error_dump_dir,
tf_version,
): ):
self.model_config = model_config self.model_config = model_config
self.tracker_params = tracker_params self.tracker_params = tracker_params
...@@ -87,6 +88,7 @@ class BabyParameters(ParametersABC): ...@@ -87,6 +88,7 @@ class BabyParameters(ParametersABC):
self.print_info = print_info self.print_info = print_info
self.suppress_errors = suppress_errors self.suppress_errors = suppress_errors
self.error_dump_dir = error_dump_dir self.error_dump_dir = error_dump_dir
self.tf_version = tf_version
@classmethod @classmethod
def default(cls, **kwargs): def default(cls, **kwargs):
...@@ -102,6 +104,7 @@ class BabyParameters(ParametersABC): ...@@ -102,6 +104,7 @@ class BabyParameters(ParametersABC):
print_info=False, print_info=False,
suppress_errors=False, suppress_errors=False,
error_dump_dir=None, error_dump_dir=None,
tf_version=2,
) )
...@@ -113,11 +116,13 @@ class BabyRunner: ...@@ -113,11 +116,13 @@ class BabyRunner:
def __init__(self, tiler, parameters=None, *args, **kwargs): def __init__(self, tiler, parameters=None, *args, **kwargs):
self.tiler = tiler self.tiler = tiler
# self.model_config = modelsets()[choose_model_from_params(**kwargs)] # self.model_config = modelsets()[choose_model_from_params(**kwargs)]
self.model_config = ( self.model_config = modelsets()[
parameters.model_config (
if parameters is not None parameters.model_config
else modelsets()[choose_model_from_params(**kwargs)] if parameters is not None
) else choose_model_from_params(**kwargs)
)
]
self.brain = BabyBrain(**self.model_config) self.brain = BabyBrain(**self.model_config)
self.crawler = BabyCrawler(self.brain) self.crawler = BabyCrawler(self.brain)
self.bf_channel = self.tiler.get_channel_index("Brightfield") self.bf_channel = self.tiler.get_channel_index("Brightfield")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment