diff --git a/aliby/utils/imageViewer.py b/aliby/utils/imageViewer.py index cfdb25421683318859e496b31f6fa8f90934a942..f50f0185768da50eda435f61d2b30b7335011706 100644 --- a/aliby/utils/imageViewer.py +++ b/aliby/utils/imageViewer.py @@ -101,9 +101,12 @@ class remoteImageViewer: return channels - def get_trap_timepoints(self, trap_id, tps, channels=None, server_info=None): + def get_trap_timepoints( + self, trap_id, tps, channels=None, z=None, server_info=None + ): server_info = server_info or self.server_info channels = self.find_channels(channels) + z = z or self.tiler.ref_z ch_tps = set([(channels[0], tp) for tp in tps]) with OImage(self.image_id, **server_info) as image: @@ -111,9 +114,14 @@ class remoteImageViewer: if ch_tps.difference(self.full.keys()): tps = set(tps).difference(self.full.keys()) for ch, tp in ch_tps: - self.full[(ch, tp)] = self.tiler.get_traps_timepoint( - tp, channels=[ch], z=[0] - )[:, 0, 0, ..., 0] + if z is 0: + self.full[(ch, tp)] = self.tiler.get_traps_timepoint( + tp, channels=[ch], z=[z] + )[:, 0, 0, ..., 0] + else: + self.full[(ch, tp)] = self.tiler.get_traps_timepoint( + tp, channels=[ch], z=[z] + )[:, 0, 0, ..., z] requested_trap = {tp: self.full[(ch, tp)] for ch, tp in ch_tps} return requested_trap @@ -135,12 +143,25 @@ class remoteImageViewer: img_concat = np.concatenate(imgs_list, axis=1) return outline_concat, img_concat - def plot_labeled_channelrows(self, trap_id, channels, trange, **kwargs): + def get_images(self, trap_id, trange, channels, **kwargs): + """ + Wrapper to fetch images + """ imgs = {} + for ch in self.find_channels(channels): out, imgs[ch] = self.get_labeled_trap( trap_id, trange, channels=[ch], **kwargs ) + return out, imgs + + def plot_labeled_zstacks(self, trap_id, channels, trange, z=None, **kwargs): + # if z is None: + # z = + out, images = self.get_imgs(trap_id, trange, channels, z=z, **kwargs) + + def plot_labeled_channelrows(self, trap_id, channels, trange, **kwargs): + out, images = self.get_imgs(trap_id, trange, channels, **kwargs) # dilation makes outlines easier to see out = dilation(out).astype(float) @@ -171,7 +192,16 @@ class remoteImageViewer: ) plt.show() - def plot_labeled_traps(self, trap_id, trange, ncols, **kwargs): + def plot_labeled_traps( + self, + trap_id, + trange, + ncols, + remove_axis=False, + savefile=False, + skip_outlines=False, + **kwargs + ): """ Wrapper to plot a single trap over time @@ -202,18 +232,31 @@ class remoteImageViewer: ) plt.imshow( - self.concat_pad(img), + concat_pad(img), interpolation=None, cmap="Greys_r", ) - plt.imshow( - self.concat_pad(out), - # concat_pad(mask), - cmap="Set1", - interpolation=None, - ) - plt.yticks( - ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)], - labels=[trange[0] + ncols * i for i in range(nrows)], - ) - plt.show() + if not skip_outlines: + plt.imshow( + concat_pad(out), + # concat_pad(mask), + cmap="Set1", + interpolation=None, + ) + + bbox_inches = None + if remove_axis: + plt.axis("off") + bbox_inches = "tight" + + else: + plt.yticks( + ticks=[self.tiler.tile_size * (i + 0.5) for i in range(nrows)], + labels=[trange[0] + ncols * i for i in range(nrows)], + ) + + if not savefile: + plt.show() + else: + if np.any(out): + plt.savefig(savefile, bbox_inches=bbox_inches)