From 6c13cd70ac044da9070af0b75b076126f3dab063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Fri, 25 Mar 2022 18:48:31 +0000 Subject: [PATCH] revamp template identification algorithm --- aliby/tile/traps.py | 121 ++++++++++++++------------------------------ 1 file changed, 38 insertions(+), 83 deletions(-) diff --git a/aliby/tile/traps.py b/aliby/tile/traps.py index 4bdb1f43..b152261e 100644 --- a/aliby/tile/traps.py +++ b/aliby/tile/traps.py @@ -49,8 +49,14 @@ def segment_traps( # TODO Optimise the hyperparameters disk_radius = int(min([disk_radius_frac * x for x in img.shape])) - min_area = min_frac_tilesize * (tile_size ** 2) - max_area = max_frac_tilesize * (tile_size ** 2) + min_mal = min_frac_tilesize * np.sqrt(2) * tile_size + max_mal = max_frac_tilesize * np.sqrt(2) * tile_size + + def half_floor(x): + return x - tile_size // 2 + + def half_ceil(x): + return x + -(tile_size // -2) if downscale != 1: img = transform.rescale(image, downscale) @@ -69,64 +75,44 @@ def segment_traps( # label image regions label_image = label(cleared) - # areas = [ - # region.area - # for region in regionprops(label_image) - # if region.area > min_area and region.area < max_area - # ] - traps = ( - np.array( - [ - region.centroid - for region in regionprops(label_image) - if region.area > min_area and region.area < max_area - ] - ) - .round() - .astype(int) - ) - ma = ( - np.array( - [ - region.minor_axis_length - for region in regionprops(label_image) - if region.area > min_area and region.area < max_area - ] - ) - .round() - .astype(int) - ) + idx_valid_region = [ + (i, region) + for i, region in enumerate(regionprops(label_image)) + if min_mal < region.major_axis_length < max_mal + and tile_size // 2 < region.centroid[0] < half_floor(image.shape[0]) + and tile_size // 2 < region.centroid[1] < half_floor(image.shape[1]) + ] + idx, valid_region = zip(*idx_valid_region) - maskx = (tile_size // 2 < traps[:, 0]) & ( - traps[:, 0] < image.shape[0] - tile_size // 2 - ) - masky = (tile_size // 2 < traps[:, 1]) & ( - traps[:, 1] < image.shape[1] - tile_size // 2 - ) + valid_templates = copy(label_image) + for i in set(list(range(label_image.max()))).difference(idx): + valid_templates[np.where(valid_templates == i + 1)] = -2 * i - traps = traps[maskx & masky, :] - ma = ma[maskx & masky] + import matplotlib.colors as colors - chosen_trap_coords = np.round(traps[ma.argmin()]).astype(int) + combined = valid_templates + label_image + + centroids = np.array([x.centroid for x in valid_region]).round().astype(int) + minals = [region.minor_axis_length for region in valid_region] + + chosen_trap_coords = np.round(centroids[np.argmin(minals)]).astype(int) x, y = chosen_trap_coords + template = image[ - x - tile_size // 2 : x + tile_size // 2, y - tile_size // 2 : y + tile_size // 2 + half_floor(x) : half_ceil(x), + half_floor(y) : half_ceil(y), + ] + + candidate_templates = [ + image[ + slice(half_floor(x), half_ceil(x)), + slice(half_floor(y), half_ceil(y)), + ] + for x, y in centroids ] # add template as mean of found traps - mean_template = ( - np.dstack( - [ - image[ - x - tile_size // 2 : x + tile_size // 2, - y - tile_size // 2 : y + tile_size // 2, - ] - for x, y in traps - ] - ) - .astype(int) - .mean(axis=-1) - ) + mean_template = np.dstack(candidate_templates).astype(int).mean(axis=-1) traps = identify_trap_locations(image, template, **identify_traps_kwargs) mean_traps = identify_trap_locations(image, mean_template, **identify_traps_kwargs) @@ -141,37 +127,6 @@ def segment_traps( return traps if len(traps_retry) < len(traps) else traps_retry -# def segment_traps(image, tile_size, downscale=0.4): -# # Make image go between 0 and 255 -# img = image # Keep a memory of image in case need to re-run -# image = stretch_image(image) -# # TODO Optimise the hyperparameters -# disk_radius = int(min([0.01 * x for x in img.shape])) -# min_area = 0.1 * (tile_size ** 2) -# if downscale != 1: -# img = transform.rescale(image, downscale) -# entropy_image = entropy(img, disk(disk_radius)) -# if downscale != 1: -# entropy_image = transform.rescale(entropy_image, 1 / downscale) - -# # apply threshold -# thresh = threshold_otsu(entropy_image) -# bw = closing(entropy_image > thresh, square(3)) - -# # remove artifacts connected to image border -# cleared = clear_border(bw) - -# # label image regions -# label_image = label(cleared) -# traps = [ -# region.centroid for region in regionprops(label_image) if region.area > min_area -# ] -# if len(traps) < 10 and downscale != 1: -# print("Trying again.") -# return segment_traps(image, tile_size, downscale=1) -# return traps - - def identify_trap_locations( image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None ): -- GitLab