From a6f5b3bdb814181ba5ba0a00fe247fa0346c1b38 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Wed, 29 Jun 2022 18:31:50 +0100
Subject: [PATCH] feat(imageviewer): add visualisation options

---
 aliby/utils/imageViewer.py | 59 ++++++++++++++++++++++++++++----------
 1 file changed, 44 insertions(+), 15 deletions(-)

diff --git a/aliby/utils/imageViewer.py b/aliby/utils/imageViewer.py
index 974180c1..5d7a5d44 100644
--- a/aliby/utils/imageViewer.py
+++ b/aliby/utils/imageViewer.py
@@ -357,23 +357,52 @@ class remoteImageViewer:
         )
         custom_imshow(
             tiled_imgs["cell_labels"],
-            cmap="Set1",
+            cmap=sns.color_palette("Paired", as_cmap=True),
             **lbl_plot_kwargs,
         )
-        plt.yticks(
-            ticks=[
-                (i * self.tiler.tile_size * nrows)
-                + self.tiler.tile_size * nrows / 2
-                for i in range(len(channels))
-            ],
-            labels=channel_labels,
-        )
-        plt.xticks(
-            ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)],
-            labels=["+ {} ".format(i) for i in range(ncols)],
-        )
-        plt.xlabel("Additional time-points")
-        plt.show()
+
+        if remove_axis == True:
+            plt.axis("off")
+        elif remove_axis == "x":
+            plt.tick_params(
+                axis="x",
+                which="both",
+                bottom=False,
+                top=False,
+                labelbottom=False,
+            )
+
+        if remove_axis != "True":
+            plt.yticks(
+                ticks=[
+                    (i * self.tiler.tile_size * nrows)
+                    + self.tiler.tile_size * nrows / 2
+                    for i in range(len(channels))
+                ],
+                labels=channel_labels,
+            )
+
+        if not remove_axis:
+            xlabels = (
+                ["+ {} ".format(i) for i in range(ncols)]
+                if nrows > 1
+                else list(trange)
+            )
+            plt.xlabel("Time-point")
+
+            plt.xticks(
+                ticks=[self.tiler.tile_size * (i + 0.5) for i in range(ncols)],
+                labels=xlabels,
+            )
+
+        if not np.any(out):
+            print("ImageViewer:Warning:No cell outlines found")
+
+        if savefile:
+            plt.savefig(savefile, bbox_inches="tight", dpi=300)
+            plt.close()
+        else:
+            plt.show()
 
     # def plot_labelled_trap(
     #     self,
-- 
GitLab