From 4dd9423f628b7d7e0fa39626951bd3e1be3052fb Mon Sep 17 00:00:00 2001
From: pswain <peter.swain@ed.ac.uk>
Date: Thu, 23 May 2024 15:26:11 +0100
Subject: [PATCH] fix(imageviewer): checks correctly for empty traps

---
 src/wela/imageviewer.py | 37 +++++++++++++++++++------------------
 1 file changed, 19 insertions(+), 18 deletions(-)

diff --git a/src/wela/imageviewer.py b/src/wela/imageviewer.py
index 5a94cb7..84be7cf 100644
--- a/src/wela/imageviewer.py
+++ b/src/wela/imageviewer.py
@@ -24,7 +24,7 @@ def colormap(channel):
 
 
 class BaseImageViewer(ABC):
-    """Base class with routines common to all ImageViewers."""
+    """Base class for all ImageViewers."""
 
     def __init__(self, h5file_path):
         """Initialise from a Path to a h5 file."""
@@ -32,13 +32,18 @@ class BaseImageViewer(ABC):
         print(f"Viewing {str(h5file_path)}")
         self.full = {}
 
-    def print_trap_info(self, cells):
-        """List available traps - those with identified cells."""
-        traps_with_labels = [
-            i for i, labels in enumerate(cells.labels) if labels
-        ]
-        print(f"Traps with labelled cells {traps_with_labels}.")
-        print(f"Maximum number of time points {cells.ntimepoints}.")
+    def find_traps_with_cells(self, tpt_end, tpt_start=0):
+        """List traps with cells."""
+        cells = self.cells
+        tpts = range(tpt_start, tpt_end)
+        traps_with_cells = []
+        for trap_id in cells.traps:
+            for tpt in tpts:
+                if tpt in cells.nonempty_tp_in_trap(trap_id):
+                    traps_with_cells.append(trap_id)
+                    break
+        print(f"Traps with cells {traps_with_cells}")
+        return traps_with_cells
 
     def get_tiles(self, trap_id, tps, channels_to_skip=None, cell_only=True):
         """Get dict of tiles with channel indices as keys."""
@@ -148,7 +153,7 @@ class BaseImageViewer(ABC):
             ts_labels[tp_index, 0, ...] = outlines[tp_index]
         return ts_images, ts_labels, channels
 
-    def view(self, trap_id, tps=10, channels_to_skip=None):
+    def view(self, trap_id, tpt_end=10, tpt_start=0, channels_to_skip=None):
         """
         Use Napari to view all channels and outlines for a particular trap.
 
@@ -164,10 +169,7 @@ class BaseImageViewer(ABC):
             If None, all time points will be viewed, but gathering the images
             will be slow.
         """
-        if tps is None:
-            tps = np.arange(self.cells.ntimepoints)
-        elif type(tps) is int:
-            tps = np.arange(tps)
+        tps = np.arange(tpt_start, tpt_end + 1)
         ts_images, ts_labels, channels = self.get_data_for_viewing(
             trap_id, tps, channels_to_skip
         )
@@ -194,7 +196,7 @@ class LocalImageViewer(BaseImageViewer):
     """
     View images from local files.
 
-    File are either zarr or organised in directories.
+    Files are either zarr or organised in directories.
     """
 
     def __init__(self, h5file: str, image_file: str):
@@ -206,7 +208,6 @@ class LocalImageViewer(BaseImageViewer):
             with dispatch_image(image_file_path)(image_file_path) as image:
                 self.tiler = Tiler.from_h5(image, h5file_path)
             self.cells = Cells.from_source(h5file_path)
-            self.print_trap_info(self.cells)
         else:
             if not h5file_path.exists():
                 print(f" Trouble loading {h5file}.")
@@ -215,7 +216,7 @@ class LocalImageViewer(BaseImageViewer):
 
 
 class RemoteImageViewer(BaseImageViewer):
-    """Fetching remote images with tiling and outline display."""
+    """View images from OMERO."""
 
     def __init__(
         self, h5file: str, server_info: t.Dict[str, str], omero_id: int
@@ -234,12 +235,12 @@ class RemoteImageViewer(BaseImageViewer):
         if image_id is None:
             print("Can't find an image.")
         else:
+            print(f"Using image ID {image_id}.")
             self.image_id = image_id
             image = OImage(image_id, **server_info)
             print("Connected to OMERO.")
             self.tiler = Tiler.from_h5(image, h5file_path)
             self.cells = Cells.from_source(h5file_path)
-            self.print_trap_info(self.cells)
 
 
 def get_files(
@@ -253,4 +254,4 @@ def get_files(
     h5file = [f for f in h5files if position in f][0]
     image_file_name = h5file.split("/")[-1].split(".")[0] + ".zarr"
     image_file = str(Path(aliby_input) / omero_name / image_file_name)
-    return [h5file, image_file]
+    return {"h5file": h5file, "image_file": image_file}
-- 
GitLab