From fb9f017e78bc77a0a8f7908e7df894a4e4452c52 Mon Sep 17 00:00:00 2001
From: Peter Swain <peter.swain@ed.ac.uk>
Date: Wed, 25 Oct 2023 15:41:07 +0100
Subject: [PATCH] changed channels in tiler to be position dependent

---
 src/agora/abc.py        |  3 +--
 src/aliby/pipeline.py   | 37 +++++++++++++++++--------------------
 src/aliby/tile/tiler.py | 27 ++++++++++++++++++++++++---
 3 files changed, 42 insertions(+), 25 deletions(-)

diff --git a/src/agora/abc.py b/src/agora/abc.py
index 928c170..0aaefdf 100644
--- a/src/agora/abc.py
+++ b/src/agora/abc.py
@@ -35,7 +35,7 @@ class ParametersABC(ABC):
         """
         Return a nested dictionary of the attributes of the class instance.
 
-        Uses recursion.
+        Use recursion.
         """
         if isinstance(iterable, dict):
             if any(
@@ -115,7 +115,6 @@ class ParametersABC(ABC):
 
         If a leaf node that is to be changed is a collection, it adds the new elements.
         """
-
         assert name not in (
             "parameters",
             "params",
diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py
index ede77e7..c71fa23 100644
--- a/src/aliby/pipeline.py
+++ b/src/aliby/pipeline.py
@@ -162,20 +162,6 @@ class PipelineParameters(ParametersABC):
         return parsed_flattened
 
 
-def find_channels_by_group(meta):
-    """Parse meta data to find which imaging channels are used for each group."""
-    channels_dict = {group_no: [] for group_no in meta["positions/group"]}
-    imaging_channels = global_parameters.possible_imaging_channels
-    for i, group_no in enumerate(meta["positions/group"]):
-        for imaging_channel in imaging_channels:
-            if (
-                "positions/" + imaging_channel in meta
-                and meta["positions/" + imaging_channel][i]
-            ):
-                channels_dict[group_no].append(imaging_channel)
-    return channels_dict
-
-
 class Pipeline(ProcessABC):
     """
     Initialise and run tiling, segmentation, extraction and post-processing.
@@ -360,13 +346,16 @@ class Pipeline(ProcessABC):
             with Pool(distributed) as p:
                 results = p.map(
                     lambda x: self.run_one_position(*x),
-                    [(k, i) for i, k in enumerate(position_ids.items())],
+                    [
+                        (position_id, i)
+                        for i, position_id in enumerate(position_ids.items())
+                    ],
                 )
         else:
             # single core
             results = [
-                self.run_one_position((k, v), 1)
-                for k, v in tqdm(position_ids.items())
+                self.run_one_position((position_id, position_id_path), 1)
+                for position_id, position_id_path in tqdm(position_ids.items())
             ]
         return results
 
@@ -406,7 +395,7 @@ class Pipeline(ProcessABC):
         session = None
         run_kwargs = {"extraction": {"cell_labels": None, "masks": None}}
         try:
-            pipe, session = self.setup_pipeline(image_id)
+            pipe, session = self.setup_pipeline(image_id, name)
             loaded_writers = {
                 name: writer(pipe["filename"])
                 for k in self.step_sequence
@@ -426,6 +415,9 @@ class Pipeline(ProcessABC):
             ) as image:
                 # initialise steps
                 if "tiler" not in pipe["steps"]:
+                    pipe["config"]["tiler"]["position_name"] = name.split(".")[
+                        0
+                    ]
                     pipe["steps"]["tiler"] = Tiler.from_image(
                         image,
                         TilerParameters.from_dict(pipe["config"]["tiler"]),
@@ -545,7 +537,9 @@ class Pipeline(ProcessABC):
             close_session(session)
 
     def setup_pipeline(
-        self, image_id: int
+        self,
+        image_id: int,
+        name: str,
     ) -> t.Tuple[
         Path,
         MetaData,
@@ -626,8 +620,11 @@ class Pipeline(ProcessABC):
             if pipe["filename"].exists():
                 self._log("Result file exists.", "info")
                 if not overwrite["tiler"]:
+                    tiler_params_dict = TilerParameters.default().to_dict()
+                    tiler_params_dict["position_name"] = name.split(".")[0]
+                    tiler_params = TilerParameters.from_dict(tiler_params_dict)
                     pipe["steps"]["tiler"] = Tiler.from_h5(
-                        image, pipe["filename"]
+                        image, pipe["filename"], tiler_params
                     )
                     try:
                         (
diff --git a/src/aliby/tile/tiler.py b/src/aliby/tile/tiler.py
index f03d2e2..0d77b84 100644
--- a/src/aliby/tile/tiler.py
+++ b/src/aliby/tile/tiler.py
@@ -37,6 +37,7 @@ import h5py
 import numpy as np
 from skimage.registration import phase_cross_correlation
 
+import aliby.global_parameters as global_parameters
 from agora.abc import ParametersABC, StepABC
 from agora.io.writer import BridgeH5
 from aliby.tile.traps import segment_traps
@@ -214,9 +215,26 @@ class TilerParameters(ParametersABC):
         "ref_channel": "Brightfield",
         "ref_z": 0,
         "backup_ref_channel": None,
+        "position_name": None,
     }
 
 
+def find_channels_by_position(meta):
+    """Parse metadata to find the imaging channels used for each group."""
+    channels_dict = {
+        position_name: [] for position_name in meta["positions/posname"]
+    }
+    imaging_channels = meta["channels"]
+    for i, position_name in enumerate(meta["positions/posname"]):
+        for imaging_channel in imaging_channels:
+            if (
+                "positions/" + imaging_channel in meta
+                and meta["positions/" + imaging_channel][i]
+            ):
+                channels_dict[position_name].append(imaging_channel)
+    return channels_dict
+
+
 class Tiler(StepABC):
     """
     Divide images into smaller tiles for faster processing.
@@ -247,11 +265,14 @@ class Tiler(StepABC):
         """
         super().__init__(parameters)
         self.image = image
-        self._metadata = metadata
-        self.channels = metadata.get(
-            "channels",
+        self.position_name = parameters.to_dict()["position_name"]
+        # get channels for this position
+        channel_dict = find_channels_by_position(metadata)
+        self.channels = channel_dict.get(
+            self.position_name,
             list(range(metadata.get("size_c", 0))),
         )
+        # get reference channel - used for segmentation
         self.ref_channel = self.get_channel_index(parameters.ref_channel)
         if self.ref_channel is None:
             self.ref_channel = self.backup_ref_channel
-- 
GitLab