diff --git a/aliby/pipeline.py b/aliby/pipeline.py index f124edc4ad78d9b6d0d7f2276c784b1da6d45ed5..1da0f055957a9f84d6a91127027f68421e0c2069 100644 --- a/aliby/pipeline.py +++ b/aliby/pipeline.py @@ -144,7 +144,7 @@ class Pipeline(ProcessABC): config = self.parameters.to_dict() expt_id = config["general"]["id"] distributed = config["general"]["distributed"] - strain_filter = config["general"]["filter"] + pos_filter = config["general"]["filter"] root_dir = config["general"]["directory"] root_dir = Path(root_dir) @@ -163,16 +163,32 @@ class Pipeline(ProcessABC): config["general"]["directory"] = directory # Filter TODO integrate filter onto class and add regex - if isinstance(strain_filter, str): - image_ids = { - k: v for k, v in image_ids.items() if re.search(strain_filter, k) - } - elif isinstance(strain_filter, int): + filt_int = lambda d, filt: { + k: v for i, (k, v) in enumerate(d.items()) if i == filt + } + + filt_str = lambda d, filt: { + k: v for k, v in image_ids.items() if re.search(filt, k) + } + + def pick_filter(image_ids, filt): + if isinstance(filt, str): + image_ids = filt_str(image_ids, filt) + elif isinstance(filt, int): + image_ids = filt_int(image_ids, filt) + return image_ids + + if isinstance(pos_filter, list): image_ids = { - k: v for i, (k, v) in enumerate(image_ids.items()) if i == strain_filter + k: v + for filt in pos_filter + for k, v in pick_filter(image_ids, filt).items() } + else: + image_ids = pick_filter(image_ids, pos_filter) assert len(image_ids), "No images to segment" + if distributed != 0: # Gives the number of simultaneous processes with Pool(distributed) as p: results = p.map(lambda x: self.create_pipeline(x), image_ids.items())