Skip to content
Snippets Groups Projects
Commit 65c46d55 authored by pswain's avatar pswain
Browse files

added assert about default channel

parent 816e53d8
No related branches found
No related tags found
No related merge requests found
"""Segment/segmented pipelines. """
Segment/segmented pipelines.
Includes splitting the image into traps/parts, Includes splitting the image into traps/parts,
cell segmentation, nucleus segmentation.""" cell segmentation, nucleus segmentation
Standard order is (T, C, Z, Y, X)
"""
import warnings import warnings
from functools import lru_cache from functools import lru_cache
import h5py import h5py
...@@ -244,6 +248,9 @@ class Tiler(ProcessABC): ...@@ -244,6 +248,9 @@ class Tiler(ProcessABC):
super().__init__(parameters) super().__init__(parameters)
self.image = image self.image = image
self.channels = metadata["channels"] self.channels = metadata["channels"]
assert (
parameters.ref_channel in self.channels
), "Reference channel not in the available channels"
self.ref_channel = self.get_channel_index(parameters.ref_channel) self.ref_channel = self.get_channel_index(parameters.ref_channel)
self.trap_locs = trap_locs self.trap_locs = trap_locs
try: try:
...@@ -340,9 +347,7 @@ class Tiler(ProcessABC): ...@@ -340,9 +347,7 @@ class Tiler(ProcessABC):
except: except:
print( print(
"Warning: Error ocurred when fetching " "Warning: Error ocurred when fetching "
"images. Attempt {}".format( "images. Attempt {}".format(n_attempts + 1)
n_attempts + 1
)
) )
self.image.conn.connect() self.image.conn.connect()
n_attempts += 1 n_attempts += 1
...@@ -559,7 +564,6 @@ class Tiler(ProcessABC): ...@@ -559,7 +564,6 @@ class Tiler(ProcessABC):
for i, ch in enumerate(self.channels): for i, ch in enumerate(self.channels):
if item in ch: if item in ch:
return i return i
raise Exception(item + " not found. Check parameters sent to Tiler.")
### ###
......
...@@ -300,142 +300,6 @@ def get_xy_tile(img, xmin, xmax, ymin, ymax, xidx=2, yidx=3, pad_val=None): ...@@ -300,142 +300,6 @@ def get_xy_tile(img, xmin, xmax, ymin, ymax, xidx=2, yidx=3, pad_val=None):
return tile return tile
def get_trap_timelapse(
raw_expt, trap_locations, trap_id, tile_size=117, channels=None, z=None
):
"""
Get a timelapse for a given trap by specifying the trap_id
:param trap_id: An integer defining which trap to choose. Counted
between 0 and Tiler.n_traps - 1
:param tile_size: The size of the trap tile (centered around the
trap as much as possible, edge cases exist)
:param channels: Which channels to fetch, indexed from 0.
If None, defaults to [0]
:param z: Which z_stacks to fetch, indexed from 0.
If None, defaults to [0].
:return: A numpy array with the timelapse in (C,T,X,Y,Z) order
"""
# Set the defaults (list is mutable)
channels = channels if channels is not None else [0]
z = z if z is not None else [0]
# Get trap location for that id:
trap_centers = [
trap_locations[i][trap_id] for i in range(len(trap_locations))
]
max_shape = (raw_expt.shape[2], raw_expt.shape[3])
tiles_shapes = [
get_tile_shapes((x[0], x[1]), tile_size, max_shape)
for x in trap_centers
]
timelapse = [
get_xy_tile(
raw_expt[channels, i, :, :, z],
xmin,
xmax,
ymin,
ymax,
pad_val=None,
)
for i, (xmin, xmax, ymin, ymax) in enumerate(tiles_shapes)
]
return np.hstack(timelapse)
def get_trap_timelapse_omero(
raw_expt,
trap_locations,
trap_id,
tile_size=117,
channels=None,
z=None,
t=None,
):
"""
Get a timelapse for a given trap by specifying the trap_id
:param raw_expt: A Timelapse object from which data is obtained
:param trap_id: An integer defining which trap to choose. Counted
between 0 and Tiler.n_traps - 1
:param tile_size: The size of the trap tile (centered around the
trap as much as possible, edge cases exist)
:param channels: Which channels to fetch, indexed from 0.
If None, defaults to [0]
:param z: Which z_stacks to fetch, indexed from 0.
If None, defaults to [0].
:return: A numpy array with the timelapse in (C,T,X,Y,Z) order
"""
# Set the defaults (list is mutable)
channels = channels if channels is not None else [0]
z_positions = z if z is not None else [0]
times = (
t if t is not None else np.arange(raw_expt.shape[1])
) # TODO choose sub-set of time points
shape = (len(channels), len(times), tile_size, tile_size, len(z_positions))
# Get trap location for that id:
zct_tiles, slices, trap_ids = all_tiles(
trap_locations,
shape,
raw_expt,
z_positions,
channels,
times,
[trap_id],
)
# TODO Make this an explicit function in TimelapseOMERO
images = raw_expt.pixels.getTiles(zct_tiles)
timelapse = np.full(shape, np.nan)
total = len(zct_tiles)
for (z, c, t, _), (y, x), image in tqdm(
zip(zct_tiles, slices, images), total=total
):
ch = channels.index(c)
tp = times.tolist().index(t)
z_pos = z_positions.index(z)
timelapse[ch, tp, x[0]: x[1], y[0]: y[1], z_pos] = image
# for x in timelapse: # By channel
# np.nan_to_num(x, nan=np.nanmedian(x), copy=False)
return timelapse
def all_tiles(
trap_locations, shape, raw_expt, z_positions, channels, times, traps
):
_, _, x, y, _ = shape
_, _, MAX_X, MAX_Y, _ = raw_expt.shape
trap_ids = []
zct_tiles = []
slices = []
for z in z_positions:
for ch in channels:
for t in times:
for trap_id in traps:
centre = trap_locations[t][trap_id]
(
xmin,
ymin,
xmax,
ymax,
r_xmin,
r_ymin,
r_xmax,
r_ymax,
) = tile_where(centre, x, y, MAX_X, MAX_Y)
slices.append(
(
(r_ymin - ymin, r_ymax - ymin),
(r_xmin - xmin, r_xmax - xmin),
)
)
tile = (r_ymin, r_xmin, r_ymax - r_ymin, r_xmax - r_xmin)
zct_tiles.append((z, ch, t, tile))
trap_ids.append(trap_id) # So we remember the order!
return zct_tiles, slices, trap_ids
def tile_where(centre, x, y, MAX_X, MAX_Y): def tile_where(centre, x, y, MAX_X, MAX_Y):
# Find the position of the tile # Find the position of the tile
xmin = int(centre[1] - x // 2) xmin = int(centre[1] - x // 2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment