From 1ac94283733f07cf05895d17c52571c588cb8547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Mon, 23 Jan 2023 18:22:09 +0000 Subject: [PATCH] fix(pipe): update calls to server_info data --- src/aliby/pipeline.py | 40 +++++++++++++++++----------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index 3924715b..2977cada 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -76,7 +76,10 @@ class PipelineParameters(ParametersABC): directory = Path(general.get("directory", "../data")) - with dispatch_dataset(expt_id, **get_server_info(general)) as conn: + with dispatch_dataset( + expt_id, + **{k: general.get(k) for k in ("host", "username", "password")}, + ) as conn: directory = directory / conn.unique_name if not directory.exists(): directory.mkdir(parents=True) @@ -286,15 +289,16 @@ class Pipeline(ProcessABC): """ config = self.parameters.to_dict() - general = config["general"] - expt_id = general["id"] - distributed = general["distributed"] - pos_filter = general["filter"] - root_dir = Path(general["directory"]) - - dispatcher = dispatch_dataset( - expt_id, **self.general.get("server_info", {}) - ) + expt_id = config["general"]["id"] + distributed = config["general"]["distributed"] + pos_filter = config["general"]["filter"] + root_dir = Path(config["general"]["directory"]) + self.server_info = { + k: config["general"].get(k) + for k in ("host", "username", "password") + } + + dispatcher = dispatch_dataset(expt_id, **self.server_info) logging.getLogger("aliby").info( f"Fetching data using {dispatcher.__class__.__name__}" ) @@ -313,7 +317,7 @@ class Pipeline(ProcessABC): # Modify to the configuration self.parameters.general["directory"] = str(directory) - general["directory"] = directory + config["general"]["directory"] = directory self.setLogger(directory) @@ -399,7 +403,7 @@ class Pipeline(ProcessABC): min_process_from = min(process_from.values()) with get_image_class(image_id)( - image_id, **get_server_info(self.general) + image_id, **self.server_info ) as image: # Initialise Steps @@ -652,9 +656,7 @@ class Pipeline(ProcessABC): directory = general_config["directory"] trackers_state: t.List[np.ndarray] = [] - with get_image_class(image_id)( - image_id, **get_server_info(general_config) - ) as image: + with get_image_class(image_id)(image_id, **self.server_info) as image: filename = Path(f"{directory}/{image.name}.h5") meta = MetaData(directory, filename) @@ -730,11 +732,3 @@ class Pipeline(ProcessABC): def _close_session(session): if session: session.close() - - -def get_server_info(general: dict) -> t.Dict[str, int or str]: - return { - k: general[k] - for k in ("host", "username", "password") - if general.get(k) - } -- GitLab