Skip to content
Snippets Groups Projects
Commit 1d82fe11 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

add static typing and clean separators

parent b0906919
No related branches found
No related tags found
No related merge requests found
......@@ -30,8 +30,6 @@ class Trap:
self.half_size = size // 2
self.max_size = max_size
###
def padding_required(self, tp):
"""
Check if we need to pad the trap image for this time point.
......@@ -51,8 +49,6 @@ class Trap:
drifts = self.parent.drifts
return self.centre - np.sum(drifts[: tp + 1], axis=0)
###
def as_tile(self, tp):
"""
Return trap in the OMERO tile format of x, y, w, h
......@@ -65,8 +61,6 @@ class Trap:
y = int(y - self.half_size)
return x, y, self.size, self.size
###
def as_range(self, tp):
"""
Return trap in a range format: two slice objects that can
......@@ -76,9 +70,6 @@ class Trap:
return slice(x, x + w), slice(y, y + h)
###
class TrapLocations:
"""
Stores each trap as an instance of Trap.
......@@ -105,8 +96,6 @@ class TrapLocations:
def __iter__(self):
yield from self.traps
###
@property
def shape(self):
"""
......@@ -133,8 +122,6 @@ class TrapLocations:
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
return res
###
@classmethod
def from_tiler_init(cls, initial_location, tile_size, max_size=1200):
"""
......@@ -158,16 +145,10 @@ class TrapLocations:
return trap_locs
###
class TilerParameters(ParametersABC):
_defaults = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0}
####
class Tiler(ProcessABC):
"""
Remote Timelapse Tiler.
......@@ -200,14 +181,17 @@ class Tiler(ProcessABC):
except Exception as e:
print(f"Warning:Tiler: No z_perchannel data: {e}")
###
@classmethod
def from_image(cls, image: Image, parameters: TilerParameters):
return cls(image.data, image.metadata, parameters)
@classmethod
def from_hdf5(cls, image: Image, filepath, parameters=None):
def from_hdf5(
cls,
image: Union[Image, ImageLocal],
filepath: Union[str, PosixPath],
parameters: TilerParameters = None,
):
trap_locs = TrapLocations.read_hdf5(filepath)
metadata = load_attributes(filepath)
metadata["channels"] = image.metadata["channels"]
......@@ -227,16 +211,12 @@ class Tiler(ProcessABC):
tiler.n_processed = len(trap_locs.drifts)
return tiler
###
@lru_cache(maxsize=2)
def get_tc(self, t, c):
full = self.image[t, c].compute(scheduler="synchronous")
return full
###
@property
def shape(self):
"""
......@@ -264,16 +244,7 @@ class Tiler(ProcessABC):
def n_traps(self):
return len(self.trap_locs)
@property
def finished(self):
"""
Returns True if all channels have been processed
"""
return self.n_processed == self.image.shape[0]
###
def _initialise_traps(self, tile_size):
def _initialise_traps(self, tile_size: int):
"""
Find initial trap positions.
......@@ -297,15 +268,11 @@ class Tiler(ProcessABC):
# store traps in an instance of TrapLocations
self.trap_locs = TrapLocations.from_tiler_init(trap_locs, tile_size)
###
def find_drift(self, tp):
"""
Find any translational drifts between two images at consecutive
time points using cross correlation
"""
# TODO check that the drift doesn't move any tiles out of
# the image, remove them from list if so
prev_tp = max(0, tp - 1)
# cross-correlate
drift, error, _ = phase_cross_correlation(
......@@ -318,8 +285,6 @@ class Tiler(ProcessABC):
else:
self.trap_locs.drifts.append(drift.tolist())
###
def get_tp_data(self, tp, c):
traps = []
full = self.get_tc(tp, c)
......@@ -330,16 +295,12 @@ class Tiler(ProcessABC):
traps.append(ndtrap)
return np.stack(traps)
###
def get_trap_data(self, trap_id, tp, c):
full = self.get_tc(tp, c)
trap = self.trap_locs.traps[trap_id]
ndtrap = self.ifoob_pad(full, trap.as_range(tp))
return ndtrap
###
def run_tp(self, tp):
"""
Find traps if they have not yet been found.
......@@ -373,8 +334,6 @@ class Tiler(ProcessABC):
return None
###
# The next set of functions are necessary for the extraction object
def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None):
# FIXME we currently ignore the tile size
......@@ -390,21 +349,15 @@ class Tiler(ProcessABC):
res.append(val)
return np.stack(res, axis=1)
###
def get_channel_index(self, item):
for i, ch in enumerate(self.channels):
if item in ch:
return i
###
def get_position_annotation(self):
# TODO required for matlab support
return None
###
@staticmethod
def ifoob_pad(full, slices): # TODO Remove when inheriting TilerABC
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment