Skip to content
Snippets Groups Projects
Commit 404ba780 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

workaround for !24

parent 9674bd4f
No related branches found
No related tags found
No related merge requests found
...@@ -14,11 +14,11 @@ from aliby.tile.traps import segment_traps ...@@ -14,11 +14,11 @@ from aliby.tile.traps import segment_traps
class Trap: class Trap:
''' """
Stores a trap's location and size. Stores a trap's location and size.
Allows checks to see if the trap should be padded. Allows checks to see if the trap should be padded.
Can export the trap either in OMERO or numpy formats. Can export the trap either in OMERO or numpy formats.
''' """
def __init__(self, centre, parent, size, max_size): def __init__(self, centre, parent, size, max_size):
self.centre = centre self.centre = centre
...@@ -72,14 +72,15 @@ class Trap: ...@@ -72,14 +72,15 @@ class Trap:
x, y, w, h = self.as_tile(tp) x, y, w, h = self.as_tile(tp)
return slice(x, x + w), slice(y, y + h) return slice(x, x + w), slice(y, y + h)
### ###
class TrapLocations: class TrapLocations:
''' """
Stores each trap as an instance of Trap. Stores each trap as an instance of Trap.
Traps can be iterated. Traps can be iterated.
''' """
def __init__( def __init__(
self, initial_location, tile_size, max_size=1200, drifts=None self, initial_location, tile_size, max_size=1200, drifts=None
...@@ -105,22 +106,22 @@ class TrapLocations: ...@@ -105,22 +106,22 @@ class TrapLocations:
@property @property
def shape(self): def shape(self):
''' """
Returns no of traps and no of drifts Returns no of traps and no of drifts
''' """
return len(self.traps), len(self.drifts) return len(self.traps), len(self.drifts)
def padding_required(self, tp): def padding_required(self, tp):
''' """
Check if any traps need padding Check if any traps need padding
''' """
return any([trap.padding_required(tp) for trap in self.traps]) return any([trap.padding_required(tp) for trap in self.traps])
def to_dict(self, tp): def to_dict(self, tp):
''' """
Export inital locations, tile_size, max_size, and drifts Export inital locations, tile_size, max_size, and drifts
as a dictionary as a dictionary
''' """
res = dict() res = dict()
if tp == 0: if tp == 0:
res["trap_locations"] = self.initial_location res["trap_locations"] = self.initial_location
...@@ -133,16 +134,16 @@ class TrapLocations: ...@@ -133,16 +134,16 @@ class TrapLocations:
@classmethod @classmethod
def from_tiler_init(cls, initial_location, tile_size, max_size=1200): def from_tiler_init(cls, initial_location, tile_size, max_size=1200):
''' """
Instantiate class from an instance of the Tiler class Instantiate class from an instance of the Tiler class
''' """
return cls(initial_location, tile_size, max_size, drifts=[]) return cls(initial_location, tile_size, max_size, drifts=[])
@classmethod @classmethod
def read_hdf5(cls, file): def read_hdf5(cls, file):
''' """
Instantiate class from a hdf5 file Instantiate class from a hdf5 file
''' """
with h5py.File(file, "r") as hfile: with h5py.File(file, "r") as hfile:
trap_info = hfile["trap_info"] trap_info = hfile["trap_info"]
initial_locations = trap_info["trap_locations"][()] initial_locations = trap_info["trap_locations"][()]
...@@ -153,13 +154,12 @@ class TrapLocations: ...@@ -153,13 +154,12 @@ class TrapLocations:
trap_locs.drifts = drifts trap_locs.drifts = drifts
return trap_locs return trap_locs
### ###
class TilerParameters(ParametersABC): class TilerParameters(ParametersABC):
_defaults = {"tile_size": 117, _defaults = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0}
"ref_channel": "Brightfield",
"ref_z": 0}
#### ####
...@@ -227,8 +227,17 @@ class Tiler(ProcessABC): ...@@ -227,8 +227,17 @@ class Tiler(ProcessABC):
@lru_cache(maxsize=2) @lru_cache(maxsize=2)
def get_tc(self, t, c): def get_tc(self, t, c):
# Get image # Get image by forcing loading it into cache. Assumes TCZYX dimensional order.
full = self.image[t, c].compute() # FORCE THE CACHE # WORKADOUND around error (which arose on 2022/06/14) when fetching 3-D data.
full = np.stack(
[
self.image[t, c, z].compute()
for z in range(self.image.shape[2])
],
axis=0,
)
# full = self.image[t, c].compute()
return full return full
### ###
...@@ -280,9 +289,7 @@ class Tiler(ProcessABC): ...@@ -280,9 +289,7 @@ class Tiler(ProcessABC):
# max_size is the minimal no of x or y pixels # max_size is the minimal no of x or y pixels
max_size = min(self.image.shape[-2:]) max_size = min(self.image.shape[-2:])
# first time point, first channel, first z-position # first time point, first channel, first z-position
initial_image = self.image[ initial_image = self.image[0, self.ref_channel, self.ref_z]
0, self.ref_channel, self.ref_z
]
# find the traps # find the traps
trap_locs = segment_traps(initial_image, tile_size) trap_locs = segment_traps(initial_image, tile_size)
# keep only traps that are not near an edge # keep only traps that are not near an edge
...@@ -298,10 +305,10 @@ class Tiler(ProcessABC): ...@@ -298,10 +305,10 @@ class Tiler(ProcessABC):
### ###
def find_drift(self, tp): def find_drift(self, tp):
''' """
Find any translational drifts between two images at consecutive Find any translational drifts between two images at consecutive
time points using cross correlation time points using cross correlation
''' """
# TODO check that the drift doesn't move any tiles out of # TODO check that the drift doesn't move any tiles out of
# the image, remove them from list if so # the image, remove them from list if so
prev_tp = max(0, tp - 1) prev_tp = max(0, tp - 1)
...@@ -339,11 +346,11 @@ class Tiler(ProcessABC): ...@@ -339,11 +346,11 @@ class Tiler(ProcessABC):
### ###
def run_tp(self, tp): def run_tp(self, tp):
''' """
Find traps if they have not yet been found. Find traps if they have not yet been found.
Determine any translational drift of the current image from the Determine any translational drift of the current image from the
previous one. previous one.
''' """
# assert tp >= self.n_processed, "Time point already processed" # assert tp >= self.n_processed, "Time point already processed"
# TODO check contiguity? # TODO check contiguity?
if self.n_processed == 0 or not hasattr(self.trap_locs, "drifts"): if self.n_processed == 0 or not hasattr(self.trap_locs, "drifts"):
...@@ -361,9 +368,9 @@ class Tiler(ProcessABC): ...@@ -361,9 +368,9 @@ class Tiler(ProcessABC):
return self.trap_locs.to_dict(tp) return self.trap_locs.to_dict(tp)
def run(self): def run(self):
''' """
Tile all time points in an experiment at once. Tile all time points in an experiment at once.
''' """
raise NotImplementedError() raise NotImplementedError()
### ###
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment