From 6c13cd70ac044da9070af0b75b076126f3dab063 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Fri, 25 Mar 2022 18:48:31 +0000
Subject: [PATCH] revamp template identification algorithm

---
 aliby/tile/traps.py | 121 ++++++++++++++------------------------------
 1 file changed, 38 insertions(+), 83 deletions(-)

diff --git a/aliby/tile/traps.py b/aliby/tile/traps.py
index 4bdb1f43..b152261e 100644
--- a/aliby/tile/traps.py
+++ b/aliby/tile/traps.py
@@ -49,8 +49,14 @@ def segment_traps(
     # TODO Optimise the hyperparameters
 
     disk_radius = int(min([disk_radius_frac * x for x in img.shape]))
-    min_area = min_frac_tilesize * (tile_size ** 2)
-    max_area = max_frac_tilesize * (tile_size ** 2)
+    min_mal = min_frac_tilesize * np.sqrt(2) * tile_size
+    max_mal = max_frac_tilesize * np.sqrt(2) * tile_size
+
+    def half_floor(x):
+        return x - tile_size // 2
+
+    def half_ceil(x):
+        return x + -(tile_size // -2)
 
     if downscale != 1:
         img = transform.rescale(image, downscale)
@@ -69,64 +75,44 @@ def segment_traps(
 
     # label image regions
     label_image = label(cleared)
-    # areas = [
-    #     region.area
-    #     for region in regionprops(label_image)
-    #     if region.area > min_area and region.area < max_area
-    # ]
-    traps = (
-        np.array(
-            [
-                region.centroid
-                for region in regionprops(label_image)
-                if region.area > min_area and region.area < max_area
-            ]
-        )
-        .round()
-        .astype(int)
-    )
-    ma = (
-        np.array(
-            [
-                region.minor_axis_length
-                for region in regionprops(label_image)
-                if region.area > min_area and region.area < max_area
-            ]
-        )
-        .round()
-        .astype(int)
-    )
+    idx_valid_region = [
+        (i, region)
+        for i, region in enumerate(regionprops(label_image))
+        if min_mal < region.major_axis_length < max_mal
+        and tile_size // 2 < region.centroid[0] < half_floor(image.shape[0])
+        and tile_size // 2 < region.centroid[1] < half_floor(image.shape[1])
+    ]
+    idx, valid_region = zip(*idx_valid_region)
 
-    maskx = (tile_size // 2 < traps[:, 0]) & (
-        traps[:, 0] < image.shape[0] - tile_size // 2
-    )
-    masky = (tile_size // 2 < traps[:, 1]) & (
-        traps[:, 1] < image.shape[1] - tile_size // 2
-    )
+    valid_templates = copy(label_image)
+    for i in set(list(range(label_image.max()))).difference(idx):
+        valid_templates[np.where(valid_templates == i + 1)] = -2 * i
 
-    traps = traps[maskx & masky, :]
-    ma = ma[maskx & masky]
+    import matplotlib.colors as colors
 
-    chosen_trap_coords = np.round(traps[ma.argmin()]).astype(int)
+    combined = valid_templates + label_image
+
+    centroids = np.array([x.centroid for x in valid_region]).round().astype(int)
+    minals = [region.minor_axis_length for region in valid_region]
+
+    chosen_trap_coords = np.round(centroids[np.argmin(minals)]).astype(int)
     x, y = chosen_trap_coords
+
     template = image[
-        x - tile_size // 2 : x + tile_size // 2, y - tile_size // 2 : y + tile_size // 2
+        half_floor(x) : half_ceil(x),
+        half_floor(y) : half_ceil(y),
+    ]
+
+    candidate_templates = [
+        image[
+            slice(half_floor(x), half_ceil(x)),
+            slice(half_floor(y), half_ceil(y)),
+        ]
+        for x, y in centroids
     ]
 
     # add template as mean of found traps
-    mean_template = (
-        np.dstack(
-            [
-                image[
-                    x - tile_size // 2 : x + tile_size // 2,
-                    y - tile_size // 2 : y + tile_size // 2,
-                ]
-                for x, y in traps
-            ]
-        )
-        .astype(int)
-        .mean(axis=-1)
-    )
+    mean_template = np.dstack(candidate_templates).astype(int).mean(axis=-1)
 
     traps = identify_trap_locations(image, template, **identify_traps_kwargs)
     mean_traps = identify_trap_locations(image, mean_template, **identify_traps_kwargs)
@@ -141,37 +127,6 @@ def segment_traps(
     return traps if len(traps_retry) < len(traps) else traps_retry
 
 
-# def segment_traps(image, tile_size, downscale=0.4):
-#     # Make image go between 0 and 255
-#     img = image  # Keep a memory of image in case need to re-run
-#     image = stretch_image(image)
-#     # TODO Optimise the hyperparameters
-#     disk_radius = int(min([0.01 * x for x in img.shape]))
-#     min_area = 0.1 * (tile_size ** 2)
-#     if downscale != 1:
-#         img = transform.rescale(image, downscale)
-#     entropy_image = entropy(img, disk(disk_radius))
-#     if downscale != 1:
-#         entropy_image = transform.rescale(entropy_image, 1 / downscale)
-
-#     # apply threshold
-#     thresh = threshold_otsu(entropy_image)
-#     bw = closing(entropy_image > thresh, square(3))
-
-#     # remove artifacts connected to image border
-#     cleared = clear_border(bw)
-
-#     # label image regions
-#     label_image = label(cleared)
-#     traps = [
-#         region.centroid for region in regionprops(label_image) if region.area > min_area
-#     ]
-#     if len(traps) < 10 and downscale != 1:
-#         print("Trying again.")
-#         return segment_traps(image, tile_size, downscale=1)
-#     return traps
-
-
 def identify_trap_locations(
     image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None
 ):
-- 
GitLab