From d65cf0c64b5656f20017a5f74aad81cc028a4ea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Fri, 7 Jan 2022 18:15:24 +0000 Subject: [PATCH] move tiler to agora --- aliby/baby_client.py | 2 +- aliby/pipeline.py | 4 +- aliby/segment.py | 333 ----------------------------------------- tests/test_pipeline.py | 92 +++++++----- tests/test_segment.py | 7 +- tests/test_traps.py | 14 +- 6 files changed, 70 insertions(+), 382 deletions(-) delete mode 100644 aliby/segment.py diff --git a/aliby/baby_client.py b/aliby/baby_client.py index 4dd61079..b3a6ca3d 100644 --- a/aliby/baby_client.py +++ b/aliby/baby_client.py @@ -13,7 +13,7 @@ import requests import tensorflow as tf from tqdm import tqdm -from agora.base import ParametersABC, ProcessABC +from agora.abc import ParametersABC, ProcessABC import baby.errors from baby import modelsets from baby.brain import BabyBrain diff --git a/aliby/pipeline.py b/aliby/pipeline.py index 8a4f94c2..591dbd95 100644 --- a/aliby/pipeline.py +++ b/aliby/pipeline.py @@ -20,11 +20,11 @@ from pathos.multiprocessing import Pool from aliby.experiment import MetaData from aliby.haystack import initialise_tf from aliby.baby_client import BabyRunner, BabyParameters -from aliby.segment import Tiler, TilerParameters -from argo.io.omero import Dataset, Image +from agora.tile.tiler import Tiler, TilerParameters from agora.abc import ParametersABC, ProcessABC from agora.io.writer import TilerWriter, BabyWriter from agora.io.signal import Signal +from argo.io.omero import Dataset, Image from extraction.core.extractor import Extractor, ExtractorParameters from extraction.core.functions.defaults import exparams_from_meta from postprocessor.core.processor import PostProcessor, PostProcessorParameters diff --git a/aliby/segment.py b/aliby/segment.py deleted file mode 100644 index 5c7612ca..00000000 --- a/aliby/segment.py +++ /dev/null @@ -1,333 +0,0 @@ -"""Segment/segmented pipelines. -Includes splitting the image into traps/parts, -cell segmentation, nucleus segmentation.""" -import warnings -from functools import lru_cache - -import h5py -import numpy as np - -from pathlib import Path, PosixPath - -from skimage.registration import phase_cross_correlation - -from agora.base import ParametersABC, ProcessABC -from aliby.traps import segment_traps - -from agora.io.writer import load_attributes - -trap_template_directory = Path(__file__).parent / "trap_templates" -# TODO do we need multiple templates, one for each setup? -trap_template = np.array([]) # np.load(trap_template_directory / "trap_prime.npy") - - -def get_tile_shapes(x, tile_size, max_shape): - half_size = tile_size // 2 - xmin = int(x[0] - half_size) - ymin = max(0, int(x[1] - half_size)) - if xmin + tile_size > max_shape[0]: - xmin = max_shape[0] - tile_size - if ymin + tile_size > max_shape[1]: - ymin = max_shape[1] - tile_size - return xmin, xmin + tile_size, ymin, ymin + tile_size - - -###################### Dask versions ######################## -class Trap: - def __init__(self, centre, parent, size, max_size): - self.centre = centre - self.parent = parent # Used to access drifts - self.size = size - self.half_size = size // 2 - self.max_size = max_size - - def padding_required(self, tp): - """Check if we need to pad the trap image for this time point.""" - try: - assert all(self.at_time(tp) - self.half_size >= 0) - assert all(self.at_time(tp) + self.half_size <= self.max_size) - except AssertionError: - return True - return False - - def at_time(self, tp): - """Return trap centre at time tp""" - drifts = self.parent.drifts - return self.centre - np.sum(drifts[:tp], axis=0) - - def as_tile(self, tp): - """Return trap in the OMERO tile format of x, y, w, h - - Also returns the padding necessary for this tile. - """ - x, y = self.at_time(tp) - # tile bottom corner - x = int(x - self.half_size) - y = int(y - self.half_size) - return x, y, self.size, self.size - - def as_range(self, tp): - """Return trap in a range format, two slice objects that can be used in Arrays""" - x, y, w, h = self.as_tile(tp) - return slice(x, x + w), slice(y, y + h) - - -class TrapLocations: - def __init__(self, initial_location, tile_size, max_size=1200, drifts=[]): - self.tile_size = tile_size - self.max_size = max_size - self.initial_location = initial_location - self.traps = [ - Trap(centre, self, tile_size, max_size) for centre in initial_location - ] - self.drifts = drifts - - @classmethod - def from_source(cls, fpath: str): - with h5py.File(fpath, "r") as f: - # TODO read tile size from file metadata - drifts = f["trap_info/drifts"][()] - tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts) - - return tlocs - - @property - def shape(self): - return len(self.traps), len(self.drifts) - - def __len__(self): - return len(self.traps) - - def __iter__(self): - yield from self.traps - - def padding_required(self, tp): - return any([trap.padding_required(tp) for trap in self.traps]) - - def to_dict(self, tp): - res = dict() - if tp == 0: - res["trap_locations"] = self.initial_location - res["attrs/tile_size"] = self.tile_size - res["attrs/max_size"] = self.max_size - res["drifts"] = np.expand_dims(self.drifts[tp], axis=0) - # res['processed_timepoints'] = tp - return res - - @classmethod - def read_hdf5(cls, file): - with h5py.File(file, "r") as hfile: - trap_info = hfile["trap_info"] - initial_locations = trap_info["trap_locations"][()] - drifts = trap_info["drifts"][()] - max_size = trap_info.attrs["max_size"] - tile_size = trap_info.attrs["tile_size"] - trap_locs = cls(initial_locations, tile_size, max_size=max_size) - trap_locs.drifts = drifts - return trap_locs - - -class TilerParameters(ParametersABC): - def __init__( - self, tile_size: int, ref_channel: str, ref_z: int, template_name: str = None - ): - self.tile_size = tile_size - self.ref_channel = ref_channel - self.ref_z = ref_z - self.template_name = template_name - - @classmethod - def from_template(cls, template_name: str, ref_channel: str, ref_z: int): - return cls(template.shape[0], ref_channel, ref_z, template_path=template_name) - - @classmethod - def default(cls): - return cls(96, "Brightfield", 0) - - -class Tiler(ProcessABC): - """A dummy TimelapseTiler object fora Dask Demo. - - Does trap finding and image registration.""" - - def __init__( - self, - image, - metadata, - parameters: TilerParameters, - ): - super().__init__(parameters) - self.image = image - self.channels = metadata["channels"] - self.ref_channel = self.get_channel_index(parameters.ref_channel) - - @classmethod - def from_image(cls, image, parameters: TilerParameters): - return cls(image.data, image.metadata, parameters) - - @classmethod - def from_hdf5(cls, image, filepath, tile_size=None): - trap_locs = TrapLocations.read_hdf5(filepath) - metadata = load_attributes(filepath) - metadata["channels"] = metadata["channels/channel"].tolist() - if tile_size is None: - tile_size = trap_locs.tile_size - return Tiler( - image=image, - metadata=metadata, - template=None, - tile_size=tile_size, - trap_locs=trap_locs, - ) - - @lru_cache(maxsize=2) - def get_tc(self, t, c): - # Get image - full = self.image[t, c].compute() # FORCE THE CACHE - return full - - @property - def shape(self): - c, t, z, y, x = self.image.shape - return (c, t, x, y, z) - - @property - def n_processed(self): - if not hasattr(self, "_n_processed"): - self._n_processed = 0 - return self._n_processed - - @n_processed.setter - def n_processed(self, value): - self._n_processed = value - - @property - def n_traps(self): - return len(self.trap_locs) - - @property - def finished(self): - return self.n_processed == self.image.shape[0] - - def _initialise_traps(self, tile_size): - """Find initial trap positions. - - Removes all those that are too close to the edge so no padding is necessary. - """ - half_tile = tile_size // 2 - max_size = min(self.image.shape[-2:]) - initial_image = self.image[ - 0, self.ref_channel, self.ref_z - ] # First time point, first channel, first z-position - trap_locs = segment_traps(initial_image, tile_size) - trap_locs = [ - [x, y] - for x, y in trap_locs - if half_tile < x < max_size - half_tile - and half_tile < y < max_size - half_tile - ] - self.trap_locs = TrapLocations(trap_locs, tile_size) - - def find_drift(self, tp): - # TODO check that the drift doesn't move any tiles out of the image, remove them from list if so - prev_tp = max(0, tp - 1) - drift, error, _ = phase_cross_correlation( - self.image[prev_tp, self.ref_channel, self.ref_z], - self.image[tp, self.ref_channel, self.ref_z], - ) - self.trap_locs.drifts.append(drift) - - def get_tp_data(self, tp, c): - traps = [] - full = self.get_tc(tp, c) - # if self.trap_locs.padding_required(tp): - for trap in self.trap_locs: - ndtrap = self.ifoob_pad(full, trap.as_range(tp)) - - traps.append(ndtrap) - return np.stack(traps) - - def get_trap_data(self, trap_id, tp, c): - full = self.get_tc(tp, c) - trap = self.trap_locs.traps[trap_id] - ndtrap = self.ifoob_pad(full, trap.as_range(tp)) - - return ndtrap - - @staticmethod - def ifoob_pad(full, slices): - """ - Returns the slices padded if it is out of bounds - - Parameters: - ---------- - full: (zstacks, max_size, max_size) ndarray - Entire position with zstacks as first axis - slices: tuple of two slices - Each slice indicates an axis to index - - - Returns - Trap for given slices, padded with median if needed, or np.nan if the padding is too much - """ - max_size = full.shape[-1] - - y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices] - trap = full[:, y, x] - - padding = np.array( - [(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices] - ) - if padding.any(): - tile_size = slices[0].stop - slices[0].start - if (padding > tile_size / 4).any(): - trap = np.full((full.shape[0], tile_size, tile_size), np.nan) - else: - - trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median") - - return trap - - def run_tp(self, tp): - assert tp >= self.n_processed, "Time point already processed" - # TODO check contiguity? - if self.n_processed == 0: - self._initialise_traps(self.tile_size) - self.find_drift(tp) # Get drift - # update n_processed - self.n_processed += 1 - # Return result for writer - return self.trap_locs.to_dict(tp) - - def run(self, tp): - if self.n_processed == 0: - self._initialise_traps(self.tile_size) - self.find_drift(tp) # Get drift - # update n_processed - self.n_processed += 1 - # Return result for writer - return self.trap_locs.to_dict(tp) - - # The next set of functions are necessary for the extraction object - def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None): - # FIXME we currently ignore the tile size - # FIXME can we ignore z(always give) - res = [] - for c in channels: - val = self.get_tp_data(tp, c)[:, z] # Only return requested z - # positions - # Starts at traps, z, y, x - # Turn to Trap, C, T, X, Y, Z order - val = val.swapaxes(1, 3).swapaxes(1, 2) - val = np.expand_dims(val, axis=1) - res.append(val) - return np.stack(res, axis=1) - - def get_channel_index(self, item): - for i, ch in enumerate(self.channels): - if item in ch: - return i - - def get_position_annotation(self): - # TODO required for matlab support - return None diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index b37c48ad..a89d4bd7 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -2,35 +2,44 @@ import os import unittest from pathlib import Path -from core.baby_client import BabyRunner -from core.experiment import ExperimentOMERO -from core.segment import Tiler -from core.pipeline import ExperimentLocal +from aliby.baby_client import BabyRunner +from aliby.experiment import ExperimentOMERO +from aliby.pipeline import ExperimentLocal +from agora.tile.tiler import Tiler class TestLocal(unittest.TestCase): def setUp(self) -> None: - self.root_dir = '/Users/s1893247/PhD/pipeline-core/data/glclvl_0' \ - '.1_mig1_msn2_maf1_sfp1_dot6_03' + self.root_dir = ( + "/Users/s1893247/PhD/pipeline-core/data/glclvl_0" + ".1_mig1_msn2_maf1_sfp1_dot6_03" + ) self.raw_expt = ExperimentLocal(self.root_dir, finished=True) self.tiler = Tiler(self.raw_expt, finished=False) - config = {"camera": "evolve", - "channel": "Brightfield", - "zoom": "60x", - "n_stacks": "5z", - "default_image_size": 80} + config = { + "camera": "evolve", + "channel": "Brightfield", + "zoom": "60x", + "n_stacks": "5z", + "default_image_size": 80, + } self.store = "test.hdf5" self.baby_runner = BabyRunner(self.tiler, **config) def test_local(self): - steps = [('pos001', 0), ('pos001', 1), ('pos001', 2), - ('pos001', 3), ('pos001', 4)] - - trap_store = self.root_dir + '/traps.csv' - drift_store = self.root_dir + '/drifts.csv' - baby_store = self.root_dir + '/baby.csv' + steps = [ + ("pos001", 0), + ("pos001", 1), + ("pos001", 2), + ("pos001", 3), + ("pos001", 4), + ] + + trap_store = self.root_dir + "/traps.csv" + drift_store = self.root_dir + "/drifts.csv" + baby_store = self.root_dir + "/baby.csv" self.raw_expt.run(steps) self.tiler.run(steps, trap_store=trap_store, drift_store=drift_store) @@ -43,32 +52,41 @@ class TestLocal(unittest.TestCase): class TestRemote(unittest.TestCase): def setUp(self) -> None: - self.root_dir = '/Users/s1893247/PhD/pipeline-core/data/ome_test' - self.raw_expt = ExperimentOMERO(51, username='root', - password='omero-root-password', - host='localhost', - save_dir=self.root_dir) + self.root_dir = "/Users/s1893247/PhD/pipeline-core/data/ome_test" + self.raw_expt = ExperimentOMERO( + 51, + username="root", + password="omero-root-password", + host="localhost", + save_dir=self.root_dir, + ) self.tiler = Tiler(self.raw_expt, finished=False) - config = {"camera": "evolve", - "channel": "Brightfield", - "zoom": "60x", - "n_stacks": "5z", - "default_image_size": 80} - + config = { + "camera": "evolve", + "channel": "Brightfield", + "zoom": "60x", + "n_stacks": "5z", + "default_image_size": 80, + } self.baby_runner = BabyRunner(self.tiler, **config) def test_remote(self): - steps = [('pos001', 0), ('pos001', 1), ('pos001', 2), - ('pos002', 0), ('pos002', 1)] + steps = [ + ("pos001", 0), + ("pos001", 1), + ("pos001", 2), + ("pos002", 0), + ("pos002", 1), + ] run_config = {"with_edgemasks": True} - pos_store = self.root_dir + '/positions.csv' - trap_store = self.root_dir + '/traps.csv' - drift_store = self.root_dir + '/drifts.csv' - baby_store = self.root_dir + '/baby.h5' + pos_store = self.root_dir + "/positions.csv" + trap_store = self.root_dir + "/traps.csv" + drift_store = self.root_dir + "/drifts.csv" + baby_store = self.root_dir + "/baby.h5" self.raw_expt.run(steps, pos_store) self.tiler.run(steps, trap_store=trap_store, drift_store=drift_store) @@ -78,14 +96,16 @@ class TestRemote(unittest.TestCase): # rm_tree(self.root_dir) pass + def rm_tree(path): path = Path(path) - for child in path.glob('*'): + for child in path.glob("*"): if child.is_file(): child.unlink() else: rm_tree(child) path.rmdir() -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_segment.py b/tests/test_segment.py index 2d067ae6..f3045d19 100644 --- a/tests/test_segment.py +++ b/tests/test_segment.py @@ -1,7 +1,7 @@ import unittest import numpy as np -from core.segment import align_timelapse_images +from agora.tile.tiler import align_timelapse_images class TestCase(unittest.TestCase): @@ -11,9 +11,8 @@ class TestCase(unittest.TestCase): def test_align_timelapse_images(self): drift, references = align_timelapse_images(self.data) self.assertEqual(references, [0]) - self.assertItemsEqual(drift.flatten(), - np.zeros_like(drift.flatten())) + self.assertItemsEqual(drift.flatten(), np.zeros_like(drift.flatten())) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_traps.py b/tests/test_traps.py index 9cf519b7..ce60e162 100644 --- a/tests/test_traps.py +++ b/tests/test_traps.py @@ -1,19 +1,21 @@ import unittest import numpy as np -from core.traps import identify_trap_locations +from agora.tile.traps import identify_trap_locations + class TestCase(unittest.TestCase): def setUp(self): - self.data = np.pad(np.ones((5,5)), 10, mode='constant') - self.template = np.pad(np.ones((5,5)), 2, mode='constant') + self.data = np.pad(np.ones((5, 5)), 10, mode="constant") + self.template = np.pad(np.ones((5, 5)), 2, mode="constant") def test_identify_trap_locations(self): - coords = identify_trap_locations(self.data, self.template, - optimize_scale=False, downscale=1) + coords = identify_trap_locations( + self.data, self.template, optimize_scale=False, downscale=1 + ) self.assertEqual(len(coords), 1) self.assertItemsEqual(coords[0], [12, 12]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() -- GitLab