From d65cf0c64b5656f20017a5f74aad81cc028a4ea6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Fri, 7 Jan 2022 18:15:24 +0000
Subject: [PATCH] move tiler to agora

---
 aliby/baby_client.py   |   2 +-
 aliby/pipeline.py      |   4 +-
 aliby/segment.py       | 333 -----------------------------------------
 tests/test_pipeline.py |  92 +++++++-----
 tests/test_segment.py  |   7 +-
 tests/test_traps.py    |  14 +-
 6 files changed, 70 insertions(+), 382 deletions(-)
 delete mode 100644 aliby/segment.py

diff --git a/aliby/baby_client.py b/aliby/baby_client.py
index 4dd61079..b3a6ca3d 100644
--- a/aliby/baby_client.py
+++ b/aliby/baby_client.py
@@ -13,7 +13,7 @@ import requests
 import tensorflow as tf
 from tqdm import tqdm
 
-from agora.base import ParametersABC, ProcessABC
+from agora.abc import ParametersABC, ProcessABC
 import baby.errors
 from baby import modelsets
 from baby.brain import BabyBrain
diff --git a/aliby/pipeline.py b/aliby/pipeline.py
index 8a4f94c2..591dbd95 100644
--- a/aliby/pipeline.py
+++ b/aliby/pipeline.py
@@ -20,11 +20,11 @@ from pathos.multiprocessing import Pool
 from aliby.experiment import MetaData
 from aliby.haystack import initialise_tf
 from aliby.baby_client import BabyRunner, BabyParameters
-from aliby.segment import Tiler, TilerParameters
-from argo.io.omero import Dataset, Image
+from agora.tile.tiler import Tiler, TilerParameters
 from agora.abc import ParametersABC, ProcessABC
 from agora.io.writer import TilerWriter, BabyWriter
 from agora.io.signal import Signal
+from argo.io.omero import Dataset, Image
 from extraction.core.extractor import Extractor, ExtractorParameters
 from extraction.core.functions.defaults import exparams_from_meta
 from postprocessor.core.processor import PostProcessor, PostProcessorParameters
diff --git a/aliby/segment.py b/aliby/segment.py
deleted file mode 100644
index 5c7612ca..00000000
--- a/aliby/segment.py
+++ /dev/null
@@ -1,333 +0,0 @@
-"""Segment/segmented pipelines.
-Includes splitting the image into traps/parts,
-cell segmentation, nucleus segmentation."""
-import warnings
-from functools import lru_cache
-
-import h5py
-import numpy as np
-
-from pathlib import Path, PosixPath
-
-from skimage.registration import phase_cross_correlation
-
-from agora.base import ParametersABC, ProcessABC
-from aliby.traps import segment_traps
-
-from agora.io.writer import load_attributes
-
-trap_template_directory = Path(__file__).parent / "trap_templates"
-# TODO do we need multiple templates, one for each setup?
-trap_template = np.array([])  # np.load(trap_template_directory / "trap_prime.npy")
-
-
-def get_tile_shapes(x, tile_size, max_shape):
-    half_size = tile_size // 2
-    xmin = int(x[0] - half_size)
-    ymin = max(0, int(x[1] - half_size))
-    if xmin + tile_size > max_shape[0]:
-        xmin = max_shape[0] - tile_size
-    if ymin + tile_size > max_shape[1]:
-        ymin = max_shape[1] - tile_size
-    return xmin, xmin + tile_size, ymin, ymin + tile_size
-
-
-###################### Dask versions ########################
-class Trap:
-    def __init__(self, centre, parent, size, max_size):
-        self.centre = centre
-        self.parent = parent  # Used to access drifts
-        self.size = size
-        self.half_size = size // 2
-        self.max_size = max_size
-
-    def padding_required(self, tp):
-        """Check if we need to pad the trap image for this time point."""
-        try:
-            assert all(self.at_time(tp) - self.half_size >= 0)
-            assert all(self.at_time(tp) + self.half_size <= self.max_size)
-        except AssertionError:
-            return True
-        return False
-
-    def at_time(self, tp):
-        """Return trap centre at time tp"""
-        drifts = self.parent.drifts
-        return self.centre - np.sum(drifts[:tp], axis=0)
-
-    def as_tile(self, tp):
-        """Return trap in the OMERO tile format of x, y, w, h
-
-        Also returns the padding necessary for this tile.
-        """
-        x, y = self.at_time(tp)
-        # tile bottom corner
-        x = int(x - self.half_size)
-        y = int(y - self.half_size)
-        return x, y, self.size, self.size
-
-    def as_range(self, tp):
-        """Return trap in a range format, two slice objects that can be used in Arrays"""
-        x, y, w, h = self.as_tile(tp)
-        return slice(x, x + w), slice(y, y + h)
-
-
-class TrapLocations:
-    def __init__(self, initial_location, tile_size, max_size=1200, drifts=[]):
-        self.tile_size = tile_size
-        self.max_size = max_size
-        self.initial_location = initial_location
-        self.traps = [
-            Trap(centre, self, tile_size, max_size) for centre in initial_location
-        ]
-        self.drifts = drifts
-
-    @classmethod
-    def from_source(cls, fpath: str):
-        with h5py.File(fpath, "r") as f:
-            # TODO read tile size from file metadata
-            drifts = f["trap_info/drifts"][()]
-            tlocs = cls(f["trap_info/trap_locations"][()], tile_size=96, drifts=drifts)
-
-        return tlocs
-
-    @property
-    def shape(self):
-        return len(self.traps), len(self.drifts)
-
-    def __len__(self):
-        return len(self.traps)
-
-    def __iter__(self):
-        yield from self.traps
-
-    def padding_required(self, tp):
-        return any([trap.padding_required(tp) for trap in self.traps])
-
-    def to_dict(self, tp):
-        res = dict()
-        if tp == 0:
-            res["trap_locations"] = self.initial_location
-            res["attrs/tile_size"] = self.tile_size
-            res["attrs/max_size"] = self.max_size
-        res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
-        # res['processed_timepoints'] = tp
-        return res
-
-    @classmethod
-    def read_hdf5(cls, file):
-        with h5py.File(file, "r") as hfile:
-            trap_info = hfile["trap_info"]
-            initial_locations = trap_info["trap_locations"][()]
-            drifts = trap_info["drifts"][()]
-            max_size = trap_info.attrs["max_size"]
-            tile_size = trap_info.attrs["tile_size"]
-        trap_locs = cls(initial_locations, tile_size, max_size=max_size)
-        trap_locs.drifts = drifts
-        return trap_locs
-
-
-class TilerParameters(ParametersABC):
-    def __init__(
-        self, tile_size: int, ref_channel: str, ref_z: int, template_name: str = None
-    ):
-        self.tile_size = tile_size
-        self.ref_channel = ref_channel
-        self.ref_z = ref_z
-        self.template_name = template_name
-
-    @classmethod
-    def from_template(cls, template_name: str, ref_channel: str, ref_z: int):
-        return cls(template.shape[0], ref_channel, ref_z, template_path=template_name)
-
-    @classmethod
-    def default(cls):
-        return cls(96, "Brightfield", 0)
-
-
-class Tiler(ProcessABC):
-    """A dummy TimelapseTiler object fora Dask Demo.
-
-    Does trap finding and image registration."""
-
-    def __init__(
-        self,
-        image,
-        metadata,
-        parameters: TilerParameters,
-    ):
-        super().__init__(parameters)
-        self.image = image
-        self.channels = metadata["channels"]
-        self.ref_channel = self.get_channel_index(parameters.ref_channel)
-
-    @classmethod
-    def from_image(cls, image, parameters: TilerParameters):
-        return cls(image.data, image.metadata, parameters)
-
-    @classmethod
-    def from_hdf5(cls, image, filepath, tile_size=None):
-        trap_locs = TrapLocations.read_hdf5(filepath)
-        metadata = load_attributes(filepath)
-        metadata["channels"] = metadata["channels/channel"].tolist()
-        if tile_size is None:
-            tile_size = trap_locs.tile_size
-        return Tiler(
-            image=image,
-            metadata=metadata,
-            template=None,
-            tile_size=tile_size,
-            trap_locs=trap_locs,
-        )
-
-    @lru_cache(maxsize=2)
-    def get_tc(self, t, c):
-        # Get image
-        full = self.image[t, c].compute()  # FORCE THE CACHE
-        return full
-
-    @property
-    def shape(self):
-        c, t, z, y, x = self.image.shape
-        return (c, t, x, y, z)
-
-    @property
-    def n_processed(self):
-        if not hasattr(self, "_n_processed"):
-            self._n_processed = 0
-        return self._n_processed
-
-    @n_processed.setter
-    def n_processed(self, value):
-        self._n_processed = value
-
-    @property
-    def n_traps(self):
-        return len(self.trap_locs)
-
-    @property
-    def finished(self):
-        return self.n_processed == self.image.shape[0]
-
-    def _initialise_traps(self, tile_size):
-        """Find initial trap positions.
-
-        Removes all those that are too close to the edge so no padding is necessary.
-        """
-        half_tile = tile_size // 2
-        max_size = min(self.image.shape[-2:])
-        initial_image = self.image[
-            0, self.ref_channel, self.ref_z
-        ]  # First time point, first channel, first z-position
-        trap_locs = segment_traps(initial_image, tile_size)
-        trap_locs = [
-            [x, y]
-            for x, y in trap_locs
-            if half_tile < x < max_size - half_tile
-            and half_tile < y < max_size - half_tile
-        ]
-        self.trap_locs = TrapLocations(trap_locs, tile_size)
-
-    def find_drift(self, tp):
-        # TODO check that the drift doesn't move any tiles out of the image, remove them from list if so
-        prev_tp = max(0, tp - 1)
-        drift, error, _ = phase_cross_correlation(
-            self.image[prev_tp, self.ref_channel, self.ref_z],
-            self.image[tp, self.ref_channel, self.ref_z],
-        )
-        self.trap_locs.drifts.append(drift)
-
-    def get_tp_data(self, tp, c):
-        traps = []
-        full = self.get_tc(tp, c)
-        # if self.trap_locs.padding_required(tp):
-        for trap in self.trap_locs:
-            ndtrap = self.ifoob_pad(full, trap.as_range(tp))
-
-            traps.append(ndtrap)
-        return np.stack(traps)
-
-    def get_trap_data(self, trap_id, tp, c):
-        full = self.get_tc(tp, c)
-        trap = self.trap_locs.traps[trap_id]
-        ndtrap = self.ifoob_pad(full, trap.as_range(tp))
-
-        return ndtrap
-
-    @staticmethod
-    def ifoob_pad(full, slices):
-        """
-        Returns the slices padded if it is out of bounds
-
-        Parameters:
-        ----------
-        full: (zstacks, max_size, max_size) ndarray
-        Entire position with zstacks as first axis
-        slices: tuple of two slices
-        Each slice indicates an axis to index
-
-
-        Returns
-        Trap for given slices, padded with median if needed, or np.nan if the padding is too much
-        """
-        max_size = full.shape[-1]
-
-        y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
-        trap = full[:, y, x]
-
-        padding = np.array(
-            [(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
-        )
-        if padding.any():
-            tile_size = slices[0].stop - slices[0].start
-            if (padding > tile_size / 4).any():
-                trap = np.full((full.shape[0], tile_size, tile_size), np.nan)
-            else:
-
-                trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median")
-
-        return trap
-
-    def run_tp(self, tp):
-        assert tp >= self.n_processed, "Time point already processed"
-        # TODO check contiguity?
-        if self.n_processed == 0:
-            self._initialise_traps(self.tile_size)
-        self.find_drift(tp)  # Get drift
-        # update n_processed
-        self.n_processed += 1
-        # Return result for writer
-        return self.trap_locs.to_dict(tp)
-
-    def run(self, tp):
-        if self.n_processed == 0:
-            self._initialise_traps(self.tile_size)
-        self.find_drift(tp)  # Get drift
-        # update n_processed
-        self.n_processed += 1
-        # Return result for writer
-        return self.trap_locs.to_dict(tp)
-
-    # The next set of functions are necessary for the extraction object
-    def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None):
-        # FIXME we currently ignore the tile size
-        # FIXME can we ignore z(always  give)
-        res = []
-        for c in channels:
-            val = self.get_tp_data(tp, c)[:, z]  # Only return requested z
-            # positions
-            # Starts at traps, z, y, x
-            # Turn to Trap, C, T, X, Y, Z order
-            val = val.swapaxes(1, 3).swapaxes(1, 2)
-            val = np.expand_dims(val, axis=1)
-            res.append(val)
-        return np.stack(res, axis=1)
-
-    def get_channel_index(self, item):
-        for i, ch in enumerate(self.channels):
-            if item in ch:
-                return i
-
-    def get_position_annotation(self):
-        # TODO required for matlab support
-        return None
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
index b37c48ad..a89d4bd7 100644
--- a/tests/test_pipeline.py
+++ b/tests/test_pipeline.py
@@ -2,35 +2,44 @@ import os
 import unittest
 from pathlib import Path
 
-from core.baby_client import BabyRunner
-from core.experiment import ExperimentOMERO
-from core.segment import Tiler
-from core.pipeline import ExperimentLocal
+from aliby.baby_client import BabyRunner
+from aliby.experiment import ExperimentOMERO
+from aliby.pipeline import ExperimentLocal
+from agora.tile.tiler import Tiler
 
 
 class TestLocal(unittest.TestCase):
     def setUp(self) -> None:
-        self.root_dir = '/Users/s1893247/PhD/pipeline-core/data/glclvl_0' \
-                   '.1_mig1_msn2_maf1_sfp1_dot6_03'
+        self.root_dir = (
+            "/Users/s1893247/PhD/pipeline-core/data/glclvl_0"
+            ".1_mig1_msn2_maf1_sfp1_dot6_03"
+        )
         self.raw_expt = ExperimentLocal(self.root_dir, finished=True)
         self.tiler = Tiler(self.raw_expt, finished=False)
 
-        config = {"camera": "evolve",
-                  "channel": "Brightfield",
-                  "zoom": "60x",
-                  "n_stacks": "5z",
-                  "default_image_size": 80}
+        config = {
+            "camera": "evolve",
+            "channel": "Brightfield",
+            "zoom": "60x",
+            "n_stacks": "5z",
+            "default_image_size": 80,
+        }
 
         self.store = "test.hdf5"
         self.baby_runner = BabyRunner(self.tiler, **config)
 
     def test_local(self):
-        steps = [('pos001', 0), ('pos001', 1), ('pos001', 2),
-                           ('pos001', 3), ('pos001', 4)]
-
-        trap_store = self.root_dir + '/traps.csv'
-        drift_store = self.root_dir + '/drifts.csv'
-        baby_store = self.root_dir + '/baby.csv'
+        steps = [
+            ("pos001", 0),
+            ("pos001", 1),
+            ("pos001", 2),
+            ("pos001", 3),
+            ("pos001", 4),
+        ]
+
+        trap_store = self.root_dir + "/traps.csv"
+        drift_store = self.root_dir + "/drifts.csv"
+        baby_store = self.root_dir + "/baby.csv"
 
         self.raw_expt.run(steps)
         self.tiler.run(steps, trap_store=trap_store, drift_store=drift_store)
@@ -43,32 +52,41 @@ class TestLocal(unittest.TestCase):
 
 class TestRemote(unittest.TestCase):
     def setUp(self) -> None:
-        self.root_dir = '/Users/s1893247/PhD/pipeline-core/data/ome_test'
-        self.raw_expt = ExperimentOMERO(51, username='root',
-                                        password='omero-root-password',
-                                        host='localhost',
-                                        save_dir=self.root_dir)
+        self.root_dir = "/Users/s1893247/PhD/pipeline-core/data/ome_test"
+        self.raw_expt = ExperimentOMERO(
+            51,
+            username="root",
+            password="omero-root-password",
+            host="localhost",
+            save_dir=self.root_dir,
+        )
         self.tiler = Tiler(self.raw_expt, finished=False)
 
-        config = {"camera": "evolve",
-                  "channel": "Brightfield",
-                  "zoom": "60x",
-                  "n_stacks": "5z",
-                  "default_image_size": 80}
-
+        config = {
+            "camera": "evolve",
+            "channel": "Brightfield",
+            "zoom": "60x",
+            "n_stacks": "5z",
+            "default_image_size": 80,
+        }
 
         self.baby_runner = BabyRunner(self.tiler, **config)
 
     def test_remote(self):
-        steps = [('pos001', 0), ('pos001', 1), ('pos001', 2),
-                           ('pos002', 0), ('pos002', 1)]
+        steps = [
+            ("pos001", 0),
+            ("pos001", 1),
+            ("pos001", 2),
+            ("pos002", 0),
+            ("pos002", 1),
+        ]
 
         run_config = {"with_edgemasks": True}
 
-        pos_store = self.root_dir + '/positions.csv'
-        trap_store = self.root_dir + '/traps.csv'
-        drift_store = self.root_dir + '/drifts.csv'
-        baby_store = self.root_dir + '/baby.h5'
+        pos_store = self.root_dir + "/positions.csv"
+        trap_store = self.root_dir + "/traps.csv"
+        drift_store = self.root_dir + "/drifts.csv"
+        baby_store = self.root_dir + "/baby.h5"
 
         self.raw_expt.run(steps, pos_store)
         self.tiler.run(steps, trap_store=trap_store, drift_store=drift_store)
@@ -78,14 +96,16 @@ class TestRemote(unittest.TestCase):
         # rm_tree(self.root_dir)
         pass
 
+
 def rm_tree(path):
     path = Path(path)
-    for child in path.glob('*'):
+    for child in path.glob("*"):
         if child.is_file():
             child.unlink()
         else:
             rm_tree(child)
     path.rmdir()
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     unittest.main()
diff --git a/tests/test_segment.py b/tests/test_segment.py
index 2d067ae6..f3045d19 100644
--- a/tests/test_segment.py
+++ b/tests/test_segment.py
@@ -1,7 +1,7 @@
 import unittest
 import numpy as np
 
-from core.segment import align_timelapse_images
+from agora.tile.tiler import align_timelapse_images
 
 
 class TestCase(unittest.TestCase):
@@ -11,9 +11,8 @@ class TestCase(unittest.TestCase):
     def test_align_timelapse_images(self):
         drift, references = align_timelapse_images(self.data)
         self.assertEqual(references, [0])
-        self.assertItemsEqual(drift.flatten(),
-                              np.zeros_like(drift.flatten()))
+        self.assertItemsEqual(drift.flatten(), np.zeros_like(drift.flatten()))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tests/test_traps.py b/tests/test_traps.py
index 9cf519b7..ce60e162 100644
--- a/tests/test_traps.py
+++ b/tests/test_traps.py
@@ -1,19 +1,21 @@
 import unittest
 import numpy as np
 
-from core.traps import identify_trap_locations
+from agora.tile.traps import identify_trap_locations
+
 
 class TestCase(unittest.TestCase):
     def setUp(self):
-        self.data = np.pad(np.ones((5,5)), 10, mode='constant')
-        self.template = np.pad(np.ones((5,5)), 2, mode='constant')
+        self.data = np.pad(np.ones((5, 5)), 10, mode="constant")
+        self.template = np.pad(np.ones((5, 5)), 2, mode="constant")
 
     def test_identify_trap_locations(self):
-        coords = identify_trap_locations(self.data, self.template,
-                                optimize_scale=False, downscale=1)
+        coords = identify_trap_locations(
+            self.data, self.template, optimize_scale=False, downscale=1
+        )
         self.assertEqual(len(coords), 1)
         self.assertItemsEqual(coords[0], [12, 12])
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
-- 
GitLab