diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index 324bac922a052214229f88273d44f707f1dda2ad..3924715b6cb5fe97abc29ad76a3b2fb58672813e 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -76,9 +76,7 @@ class PipelineParameters(ParametersABC): directory = Path(general.get("directory", "../data")) - with dispatch_dataset( - expt_id, **general.get("server_info", {}) - ) as conn: + with dispatch_dataset(expt_id, **get_server_info(general)) as conn: directory = directory / conn.unique_name if not directory.exists(): directory.mkdir(parents=True) @@ -288,10 +286,11 @@ class Pipeline(ProcessABC): """ config = self.parameters.to_dict() - expt_id = config["general"]["id"] - distributed = config["general"]["distributed"] - pos_filter = config["general"]["filter"] - root_dir = Path(config["general"]["directory"]) + general = config["general"] + expt_id = general["id"] + distributed = general["distributed"] + pos_filter = general["filter"] + root_dir = Path(general["directory"]) dispatcher = dispatch_dataset( expt_id, **self.general.get("server_info", {}) @@ -314,7 +313,7 @@ class Pipeline(ProcessABC): # Modify to the configuration self.parameters.general["directory"] = str(directory) - config["general"]["directory"] = directory + general["directory"] = directory self.setLogger(directory) @@ -400,7 +399,7 @@ class Pipeline(ProcessABC): min_process_from = min(process_from.values()) with get_image_class(image_id)( - image_id, **self.general.get("server_info", {}) + image_id, **get_server_info(self.general) ) as image: # Initialise Steps @@ -654,7 +653,7 @@ class Pipeline(ProcessABC): trackers_state: t.List[np.ndarray] = [] with get_image_class(image_id)( - image_id, **self.general.get("server_info", {}) + image_id, **get_server_info(general_config) ) as image: filename = Path(f"{directory}/{image.name}.h5") meta = MetaData(directory, filename) @@ -731,3 +730,11 @@ class Pipeline(ProcessABC): def _close_session(session): if session: session.close() + + +def get_server_info(general: dict) -> t.Dict[str, int or str]: + return { + k: general[k] + for k in ("host", "username", "password") + if general.get(k) + }