From 1ac94283733f07cf05895d17c52571c588cb8547 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Mon, 23 Jan 2023 18:22:09 +0000
Subject: [PATCH] fix(pipe): update calls to server_info data

---
 src/aliby/pipeline.py | 40 +++++++++++++++++-----------------------
 1 file changed, 17 insertions(+), 23 deletions(-)

diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py
index 3924715b..2977cada 100644
--- a/src/aliby/pipeline.py
+++ b/src/aliby/pipeline.py
@@ -76,7 +76,10 @@ class PipelineParameters(ParametersABC):
 
         directory = Path(general.get("directory", "../data"))
 
-        with dispatch_dataset(expt_id, **get_server_info(general)) as conn:
+        with dispatch_dataset(
+            expt_id,
+            **{k: general.get(k) for k in ("host", "username", "password")},
+        ) as conn:
             directory = directory / conn.unique_name
             if not directory.exists():
                 directory.mkdir(parents=True)
@@ -286,15 +289,16 @@ class Pipeline(ProcessABC):
         """
 
         config = self.parameters.to_dict()
-        general = config["general"]
-        expt_id = general["id"]
-        distributed = general["distributed"]
-        pos_filter = general["filter"]
-        root_dir = Path(general["directory"])
-
-        dispatcher = dispatch_dataset(
-            expt_id, **self.general.get("server_info", {})
-        )
+        expt_id = config["general"]["id"]
+        distributed = config["general"]["distributed"]
+        pos_filter = config["general"]["filter"]
+        root_dir = Path(config["general"]["directory"])
+        self.server_info = {
+            k: config["general"].get(k)
+            for k in ("host", "username", "password")
+        }
+
+        dispatcher = dispatch_dataset(expt_id, **self.server_info)
         logging.getLogger("aliby").info(
             f"Fetching data using {dispatcher.__class__.__name__}"
         )
@@ -313,7 +317,7 @@ class Pipeline(ProcessABC):
 
         # Modify to the configuration
         self.parameters.general["directory"] = str(directory)
-        general["directory"] = directory
+        config["general"]["directory"] = directory
 
         self.setLogger(directory)
 
@@ -399,7 +403,7 @@ class Pipeline(ProcessABC):
             min_process_from = min(process_from.values())
 
             with get_image_class(image_id)(
-                image_id, **get_server_info(self.general)
+                image_id, **self.server_info
             ) as image:
 
                 # Initialise Steps
@@ -652,9 +656,7 @@ class Pipeline(ProcessABC):
         directory = general_config["directory"]
 
         trackers_state: t.List[np.ndarray] = []
-        with get_image_class(image_id)(
-            image_id, **get_server_info(general_config)
-        ) as image:
+        with get_image_class(image_id)(image_id, **self.server_info) as image:
             filename = Path(f"{directory}/{image.name}.h5")
             meta = MetaData(directory, filename)
 
@@ -730,11 +732,3 @@ class Pipeline(ProcessABC):
 def _close_session(session):
     if session:
         session.close()
-
-
-def get_server_info(general: dict) -> t.Dict[str, int or str]:
-    return {
-        k: general[k]
-        for k in ("host", "username", "password")
-        if general.get(k)
-    }
-- 
GitLab