From 8bbeb63250a1faecfd0bfd8d3f5e33f30db83451 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Tue, 29 Mar 2022 11:40:08 +0100
Subject: [PATCH] bugfixes general parameters and post_pro legacy

---
 aliby/pipeline.py | 26 ++++++++++++++++++++++++--
 1 file changed, 24 insertions(+), 2 deletions(-)

diff --git a/aliby/pipeline.py b/aliby/pipeline.py
index 1c35910c..a3bdec83 100644
--- a/aliby/pipeline.py
+++ b/aliby/pipeline.py
@@ -103,6 +103,16 @@ class PipelineParameters(ParametersABC):
                 ),
             )
         }
+
+        for k, v in general.items():  # Overwrite general parameters
+            if k not in defaults["general"]:
+                defaults["general"][k] = v
+            elif isinstance(v, dict):
+                for k2, v2 in v.items():
+                    defaults["general"][k][k2] = v2
+            else:
+                defaults["general"][k] = v
+
         defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
         defaults["baby"] = BabyParameters.default(**extraction).to_dict()
         defaults["extraction"] = exparams_from_meta(meta)
@@ -111,7 +121,7 @@ class PipelineParameters(ParametersABC):
         ).to_dict()
         # for k in defaults.keys():
         #     exec("defaults[k].update(" + k + ")")
-        # return cls(**{k: v for k, v in defaults.items()})
+        return cls(**{k: v for k, v in defaults.items()})
 
     def load_logs(self):
         parsed_flattened = parse_logfiles(self.log_dir)
@@ -180,6 +190,12 @@ class Pipeline(ProcessABC):
         pipeline_parameters.general["directory"] = dir_path.parent
         pipeline_parameters.general["filter"] = [fpath.stem for fpath in files]
 
+        # Fix legacy postprocessing parameters
+        post_process_params = pipeline_parameters.postprocessing.get("parameters", None)
+        if post_process_params:
+            pipeline_parameters.postprocessing["param_sets"] = copy(post_process_params)
+            del pipeline_parameters.postprocessing["parameters"]
+
         return cls(pipeline_parameters)
 
     @classmethod
@@ -196,6 +212,12 @@ class Pipeline(ProcessABC):
         pipeline_parameters.general["directory"] = directory
         pipeline_parameters.general["filter"] = Path(fpath).stem
 
+        # Fix legacy postprocessing parameters
+        post_process_params = pipeline_parameters.postprocessing.get("parameters", None)
+        if post_process_params:
+            pipeline_parameters.postprocessing["param_sets"] = copy(post_process_params)
+            del pipeline_parameters.postprocessing["parameters"]
+
         return cls(pipeline_parameters, store=directory)
 
     def run(self):
@@ -429,7 +451,7 @@ class Pipeline(ProcessABC):
                     )
                     tmp = copy(config["extraction"]["multichannel_ops"])
                     for op, (input_ch, op_id, red_ext) in tmp.items():
-                        if set(input_ch).difference(av_channels_wsub):
+                        if not set(input_ch).issubset(av_channels_wsub):
                             del config["extraction"]["multichannel_ops"][op]
 
                     exparams = ExtractorParameters.from_dict(config["extraction"])
-- 
GitLab