Skip to content
Snippets Groups Projects
Commit b5483679 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

update imageviewer, add example and plot method

parent 902a4664
No related branches found
No related tags found
No related merge requests found
"""This is an example module to show the structure.""" """
from typing import Union ImageViewer class, used to look at individual or multiple traps over time.
Example of usage:
fpath = "/home/alan/Documents/dev/skeletons/scripts/data/16543_2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/URA8_young018.h5"
trap_id = 9
trange = list(range(0, 30))
ncols = 8
riv = remoteImageViewer(fpath)
riv.plot_labeled_traps(trap_id, trange, ncols)
"""
import yaml
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from skimage.morphology import dilation
from agora.io.cells import CellsLinear as Cells
from agora.io.writer import load_attributes
from aliby.io.omero import Image as OImage
from aliby.tile.tiler import Tiler
import matplotlib.pyplot as plt
class localImageViewer: class localImageViewer:
...@@ -33,50 +56,27 @@ class localImageViewer: ...@@ -33,50 +56,27 @@ class localImageViewer:
Image.fromarray(pixvals.astype(np.uint8)) Image.fromarray(pixvals.astype(np.uint8))
from aliby.tile.tiler import Tiler, TilerParameters, TrapLocations class remoteImageViewer:
from agora.io.writer import load_attributes def __init__(self, fpath, server_info=None):
attrs = load_attributes(fpath)
import json
with open("/home/alan/Documents/dev/skeletons/server_info.json", "r") as f:
# json.dump(
# {
# "host": "islay.bio.ed.ac.uk",
# "username": "upload",
# "password": "",
# },
# f,
# )
server_info = json.load(f)
import h5py
from aliby.io.omero import Image
from agora.io.cells import CellsLinear as Cells self.image_id = attrs.get("image_id")
assert self.image_id is not None, "No valid image_id found in metadata"
if server_info is None:
server_info = yaml.safe_load(attrs["parameters"])["general"]["server_info"]
self.server_info = server_info
class remoteImageViewer: with OImage(self.image_id, **self.server_info) as image:
def __init__(self, fpath):
with h5py.File(fpath, "r") as f:
self.image_id = f.attrs.get("image_id", None) or 105146
# trap_locs = TrapLocations.from_source(fpath)
with Image(self.image_id, **server_info) as image:
self.tiler = Tiler.from_hdf5(image, fpath) self.tiler = Tiler.from_hdf5(image, fpath)
self.cells = Cells.from_source(fpath)
# if parameters is None:
# parameters = TilerParameters.default()
# with h5py.File(hdf, "r") as f: self.cells = Cells.from_source(fpath)
# # image_id = f.attrs["omero_id"]
# image_id = 16543
def get_position(self): def get_position(self):
pass raise (NotImplementedError)
def get_position_timelapse(self): def get_position_timelapse(self):
pass raise (NotImplementedError)
@property @property
def full(self): def full(self):
...@@ -84,13 +84,17 @@ class remoteImageViewer: ...@@ -84,13 +84,17 @@ class remoteImageViewer:
self._full = {} self._full = {}
return self._full return self._full
def get_tc(self, tp): def get_tc(self, tp, server_info=None):
with Image(self.image_id, **server_info) as image: server_info = server_info or self.server_info
with OImage(self.image_id, **server_info) as image:
self.tiler.image = image.data self.tiler.image = image.data
return self.tiler.get_tc(tp, riv.tiler.ref_channel) return self.tiler.get_tc(tp, riv.tiler.ref_channel)
def get_trap_timepoints(self, trap_id, tps): def get_trap_timepoints(self, trap_id, tps, server_info=None):
with Image(self.image_id, **server_info) as image: server_info = server_info or self.server_info
with OImage(self.image_id, **server_info) as image:
self.tiler.image = image.data self.tiler.image = image.data
if set(tps).difference(self.full.keys()): if set(tps).difference(self.full.keys()):
tps = set(tps).difference(self.full.keys()) tps = set(tps).difference(self.full.keys())
...@@ -119,30 +123,45 @@ class remoteImageViewer: ...@@ -119,30 +123,45 @@ class remoteImageViewer:
img_concat = np.concatenate(imgs_list, axis=1) img_concat = np.concatenate(imgs_list, axis=1)
return outline_concat, img_concat return outline_concat, img_concat
def plot_labeled_traps(self, trap_id, trange, ncols, **kwargs):
"""
Wrapper to plot a single trap over time
import matplotlib.pyplot as plt Parameters
---------
# fpath = "/home/alan/Documents/dev/skeletons/data/2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01/URA8_old007.h5" :trap_id: int trap identification
fpath = "/home/alan/Documents/dev/skeletons/data/2021_11_01_01_Raf_00/2021_11_01_01_Raf_00/d1134002.h5" :trange: list list of time points to fetch
riv = remoteImageViewer(fpath) """
# pos = riv.get_tc(0) nrows = len(trange) // ncols
out, img = riv.get_labeled_trap(9, list(range(0, 30))) width = riv.tiler.tile_size * ncols
out_bak = out out, img = self.get_labeled_trap(trap_id, trange)
out = dilation(out).astype(float)
out[out == 0] = np.nan # dilation makes outlines easier to see
plt.imshow( out = dilation(out).astype(float)
np.concatenate(np.array_split(img, 6, axis=1)), out[out == 0] = np.nan
interpolation=None,
cmap="Greys_r", def concat_pad(array):
) return np.concatenate(
plt.imshow( np.array_split(
np.concatenate(np.array_split(out, 6, axis=1)), np.pad(
cmap="Set1", array,
interpolation=None, ((0, 0), (0, array.shape[1] % width)),
) constant_values=np.nan,
plt.show() ),
nrows,
concat = lambda a: np.concatenate([x for x in a]) axis=1,
add = lambda a: np.sum(a, axis=0) )
# plt.imshow(add(roll(tmp[0], 10), np.roll(roll(tmp[1], 11), 6, axis=0))) )
# plt.show()
plt.imshow(
concat_pad(img),
interpolation=None,
cmap="Greys_r",
)
plt.imshow(
concat_pad(out),
# concat_pad(mask),
cmap="Set1",
interpolation=None,
)
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment