From 72fd4c4c0123bee123b9a1104c5fda4ebe142fcf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Fri, 28 Jan 2022 14:12:05 +0000
Subject: [PATCH] bugfix filter

---
 aliby/pipeline.py | 30 +++++++++++++++++++++++-------
 1 file changed, 23 insertions(+), 7 deletions(-)

diff --git a/aliby/pipeline.py b/aliby/pipeline.py
index f124edc4..1da0f055 100644
--- a/aliby/pipeline.py
+++ b/aliby/pipeline.py
@@ -144,7 +144,7 @@ class Pipeline(ProcessABC):
         config = self.parameters.to_dict()
         expt_id = config["general"]["id"]
         distributed = config["general"]["distributed"]
-        strain_filter = config["general"]["filter"]
+        pos_filter = config["general"]["filter"]
         root_dir = config["general"]["directory"]
         root_dir = Path(root_dir)
 
@@ -163,16 +163,32 @@ class Pipeline(ProcessABC):
         config["general"]["directory"] = directory
 
         # Filter TODO integrate filter onto class and add regex
-        if isinstance(strain_filter, str):
-            image_ids = {
-                k: v for k, v in image_ids.items() if re.search(strain_filter, k)
-            }
-        elif isinstance(strain_filter, int):
+        filt_int = lambda d, filt: {
+            k: v for i, (k, v) in enumerate(d.items()) if i == filt
+        }
+
+        filt_str = lambda d, filt: {
+            k: v for k, v in image_ids.items() if re.search(filt, k)
+        }
+
+        def pick_filter(image_ids, filt):
+            if isinstance(filt, str):
+                image_ids = filt_str(image_ids, filt)
+            elif isinstance(filt, int):
+                image_ids = filt_int(image_ids, filt)
+            return image_ids
+
+        if isinstance(pos_filter, list):
             image_ids = {
-                k: v for i, (k, v) in enumerate(image_ids.items()) if i == strain_filter
+                k: v
+                for filt in pos_filter
+                for k, v in pick_filter(image_ids, filt).items()
             }
+        else:
+            image_ids = pick_filter(image_ids, pos_filter)
 
         assert len(image_ids), "No images to segment"
+
         if distributed != 0:  # Gives the number of simultaneous processes
             with Pool(distributed) as p:
                 results = p.map(lambda x: self.create_pipeline(x), image_ids.items())
-- 
GitLab