From 1d82fe1188aacd53e234a5ff373e7997400bde6e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Thu, 23 Jun 2022 13:57:51 +0100
Subject: [PATCH] add static typing and clean separators

---
 aliby/tile/tiler.py | 61 ++++++---------------------------------------
 1 file changed, 7 insertions(+), 54 deletions(-)

diff --git a/aliby/tile/tiler.py b/aliby/tile/tiler.py
index d22c7f77..de383b5a 100644
--- a/aliby/tile/tiler.py
+++ b/aliby/tile/tiler.py
@@ -30,8 +30,6 @@ class Trap:
         self.half_size = size // 2
         self.max_size = max_size
 
-    ###
-
     def padding_required(self, tp):
         """
         Check if we need to pad the trap image for this time point.
@@ -51,8 +49,6 @@ class Trap:
         drifts = self.parent.drifts
         return self.centre - np.sum(drifts[: tp + 1], axis=0)
 
-    ###
-
     def as_tile(self, tp):
         """
         Return trap in the OMERO tile format of x, y, w, h
@@ -65,8 +61,6 @@ class Trap:
         y = int(y - self.half_size)
         return x, y, self.size, self.size
 
-    ###
-
     def as_range(self, tp):
         """
         Return trap in a range format: two slice objects that can
@@ -76,9 +70,6 @@ class Trap:
         return slice(x, x + w), slice(y, y + h)
 
 
-###
-
-
 class TrapLocations:
     """
     Stores each trap as an instance of Trap.
@@ -105,8 +96,6 @@ class TrapLocations:
     def __iter__(self):
         yield from self.traps
 
-    ###
-
     @property
     def shape(self):
         """
@@ -133,8 +122,6 @@ class TrapLocations:
         res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
         return res
 
-    ###
-
     @classmethod
     def from_tiler_init(cls, initial_location, tile_size, max_size=1200):
         """
@@ -158,16 +145,10 @@ class TrapLocations:
         return trap_locs
 
 
-###
-
-
 class TilerParameters(ParametersABC):
     _defaults = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0}
 
 
-####
-
-
 class Tiler(ProcessABC):
     """
     Remote Timelapse Tiler.
@@ -200,14 +181,17 @@ class Tiler(ProcessABC):
         except Exception as e:
             print(f"Warning:Tiler: No z_perchannel data: {e}")
 
-    ###
-
     @classmethod
     def from_image(cls, image: Image, parameters: TilerParameters):
         return cls(image.data, image.metadata, parameters)
 
     @classmethod
-    def from_hdf5(cls, image: Image, filepath, parameters=None):
+    def from_hdf5(
+        cls,
+        image: Union[Image, ImageLocal],
+        filepath: Union[str, PosixPath],
+        parameters: TilerParameters = None,
+    ):
         trap_locs = TrapLocations.read_hdf5(filepath)
         metadata = load_attributes(filepath)
         metadata["channels"] = image.metadata["channels"]
@@ -227,16 +211,12 @@ class Tiler(ProcessABC):
             tiler.n_processed = len(trap_locs.drifts)
         return tiler
 
-    ###
-
     @lru_cache(maxsize=2)
     def get_tc(self, t, c):
         full = self.image[t, c].compute(scheduler="synchronous")
 
         return full
 
-    ###
-
     @property
     def shape(self):
         """
@@ -264,16 +244,7 @@ class Tiler(ProcessABC):
     def n_traps(self):
         return len(self.trap_locs)
 
-    @property
-    def finished(self):
-        """
-        Returns True if all channels have been processed
-        """
-        return self.n_processed == self.image.shape[0]
-
-    ###
-
-    def _initialise_traps(self, tile_size):
+    def _initialise_traps(self, tile_size: int):
         """
         Find initial trap positions.
 
@@ -297,15 +268,11 @@ class Tiler(ProcessABC):
         # store traps in an instance of TrapLocations
         self.trap_locs = TrapLocations.from_tiler_init(trap_locs, tile_size)
 
-    ###
-
     def find_drift(self, tp):
         """
         Find any translational drifts between two images at consecutive
         time points using cross correlation
         """
-        # TODO check that the drift doesn't move any tiles out of
-        # the image, remove them from list if so
         prev_tp = max(0, tp - 1)
         # cross-correlate
         drift, error, _ = phase_cross_correlation(
@@ -318,8 +285,6 @@ class Tiler(ProcessABC):
         else:
             self.trap_locs.drifts.append(drift.tolist())
 
-    ###
-
     def get_tp_data(self, tp, c):
         traps = []
         full = self.get_tc(tp, c)
@@ -330,16 +295,12 @@ class Tiler(ProcessABC):
             traps.append(ndtrap)
         return np.stack(traps)
 
-    ###
-
     def get_trap_data(self, trap_id, tp, c):
         full = self.get_tc(tp, c)
         trap = self.trap_locs.traps[trap_id]
         ndtrap = self.ifoob_pad(full, trap.as_range(tp))
         return ndtrap
 
-    ###
-
     def run_tp(self, tp):
         """
         Find traps if they have not yet been found.
@@ -373,8 +334,6 @@ class Tiler(ProcessABC):
 
         return None
 
-    ###
-
     # The next set of functions are necessary for the extraction object
     def get_traps_timepoint(self, tp, tile_size=None, channels=None, z=None):
         # FIXME we currently ignore the tile size
@@ -390,21 +349,15 @@ class Tiler(ProcessABC):
             res.append(val)
         return np.stack(res, axis=1)
 
-    ###
-
     def get_channel_index(self, item):
         for i, ch in enumerate(self.channels):
             if item in ch:
                 return i
 
-    ###
-
     def get_position_annotation(self):
         # TODO required for matlab support
         return None
 
-    ###
-
     @staticmethod
     def ifoob_pad(full, slices):  # TODO Remove when inheriting TilerABC
         """
-- 
GitLab