From 55b2787d4ba8c01d74a9683d7292490f4d459e84 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Mon, 27 Feb 2023 10:41:01 +0000
Subject: [PATCH] change(cells): stack cellmasks in dim 0

---
 src/agora/io/cells.py            | 2 +-
 src/aliby/tile/traps.py          | 4 ++--
 src/aliby/utils/imageViewer.py   | 4 ++--
 src/extraction/core/extractor.py | 4 ++--
 4 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py
index e72d32f4..182590b7 100644
--- a/src/agora/io/cells.py
+++ b/src/agora/io/cells.py
@@ -165,7 +165,7 @@ class Cells:
     def at_times(self, timepoints: t.Iterable[int], kind="mask"):
         return [
             [
-                np.dstack(tile_masks) if len(tile_masks) else []
+                np.stack(tile_masks) if len(tile_masks) else []
                 for tile_masks in self.at_time(tp, kind=kind).values()
             ]
             for tp in timepoints
diff --git a/src/aliby/tile/traps.py b/src/aliby/tile/traps.py
index 4eddeb7e..17772892 100644
--- a/src/aliby/tile/traps.py
+++ b/src/aliby/tile/traps.py
@@ -120,7 +120,7 @@ def segment_traps(
         for x, y in centroids
     ]
     # make a mean template from all the found traps
-    mean_template = np.dstack(candidate_templates).astype(int).mean(axis=-1)
+    mean_template = np.stack(candidate_templates).astype(int).mean(axis=0)
 
     # find traps using the mean trap template
     traps = identify_trap_locations(
@@ -243,4 +243,4 @@ def stretch_image(image):
     maxval = np.percentile(image, 98)
     image = np.clip(image, minval, maxval)
     image = (image - minval) / (maxval - minval)
-    return image
\ No newline at end of file
+    return image
diff --git a/src/aliby/utils/imageViewer.py b/src/aliby/utils/imageViewer.py
index 54392a09..b08ebb33 100644
--- a/src/aliby/utils/imageViewer.py
+++ b/src/aliby/utils/imageViewer.py
@@ -264,8 +264,8 @@ class remoteImageViewer(BaseImageViewer):
         ]
         lbls = [self.cells.labels_at_time(tp).get(trap_id, []) for tp in tps]
         lbld_outlines = [
-            np.dstack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
-                axis=2
+            np.stack([mask * lbl for mask, lbl in zip(maskset, lblset)]).max(
+                axis=0
             )
             if len(lblset)
             else np.zeros_like(imgs_list[0]).astype(bool)
diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index eee3185a..61f21e5d 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -479,7 +479,7 @@ class Extractor(StepABC):
             masks = {trap_id: [] for trap_id in range(cells.ntraps)}
             for trap_id, cells in raw_masks.items():
                 if len(cells):
-                    masks[trap_id] = np.dstack(np.array(cells)).astype(bool)
+                    masks[trap_id] = np.stack(np.array(cells)).astype(bool)
         # convert to a list of masks
         masks = [np.array(v) for v in masks.values()]
 
@@ -491,7 +491,7 @@ class Extractor(StepABC):
         bgs = []
         if self.params.sub_bg:
             bgs = [
-                ~np.sum(m, axis=2).astype(bool)
+                ~np.sum(m, axis=0).astype(bool)
                 if np.any(m)
                 else np.zeros((tile_size, tile_size))
                 for m in masks
-- 
GitLab