From 8614df0f62a503fea1bc938cf0c66903239d9d16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Mon, 26 Sep 2022 18:57:50 +0100 Subject: [PATCH] refactor(extraction): improve cell functions speed --- src/extraction/core/functions/cell.py | 161 +++++++------------------- 1 file changed, 44 insertions(+), 117 deletions(-) diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py index 8c4ea97e..5b95541a 100644 --- a/src/extraction/core/functions/cell.py +++ b/src/extraction/core/functions/cell.py @@ -3,12 +3,16 @@ Base functions to extract information from a single cell These functions are automatically read by extractor.py, and so can only have the cell_mask and trap_image as inputs and must return only one value. """ +import math +import typing as t + +import bottleneck as bn +import faiss import numpy as np from scipy import ndimage -from sklearn.cluster import KMeans -def area(cell_mask): +def area(cell_mask) -> int: """ Find the area of a cell mask @@ -17,10 +21,10 @@ def area(cell_mask): cell_mask: 2d array Segmentation mask for the cell """ - return np.sum(cell_mask, dtype=int) + return bn.nansum(cell_mask) -def eccentricity(cell_mask): +def eccentricity(cell_mask) -> float: """ Find the eccentricity using the approximate major and minor axes @@ -33,7 +37,7 @@ def eccentricity(cell_mask): return np.sqrt(maj_ax**2 - min_ax**2) / maj_ax -def mean(cell_mask, trap_image): +def mean(cell_mask, trap_image) -> float: """ Finds the mean of the pixels in the cell. @@ -43,10 +47,10 @@ def mean(cell_mask, trap_image): Segmentation mask for the cell trap_image: 2d array """ - return np.mean(trap_image[np.where(cell_mask)], dtype=float) + return bn.nanmean(trap_image[cell_mask]) -def median(cell_mask, trap_image): +def median(cell_mask, trap_image) -> int: """ Finds the median of the pixels in the cell. @@ -56,10 +60,10 @@ def median(cell_mask, trap_image): Segmentation mask for the cell trap_image: 2d array """ - return np.median(trap_image[np.where(cell_mask)]) + return bn.nanmedian(trap_image[cell_mask]) -def max2p5pc(cell_mask, trap_image): +def max2p5pc(cell_mask, trap_image) -> float: """ Finds the mean of the brightest 2.5% of pixels in the cell. @@ -70,18 +74,15 @@ def max2p5pc(cell_mask, trap_image): trap_image: 2d array """ # number of pixels in mask - npixels = cell_mask.sum() + npixels = bn.nansum(cell_mask) top_pixels = int(np.ceil(npixels * 0.025)) - # sort pixels in cell - sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None) - # find highest 2.5% - top_vals = sorted_vals[-top_pixels:] + # sort pixels in cell and find highest 2.5% + top_values = trap_image[bn.rankdata(trap_image[cell_mask])[:top_pixels]] # find mean of these highest pixels - max2p5pc = np.mean(top_vals, dtype=float) - return max2p5pc + return bn.nanmean(top_values) -def max5px(cell_mask, trap_image): +def max5px(cell_mask, trap_image) -> float: """ Finds the mean of the five brightest pixels in the cell. @@ -92,63 +93,13 @@ def max5px(cell_mask, trap_image): trap_image: 2d array """ # sort pixels in cell - sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None) - top_vals = sorted_vals[-5:] + pixels = trap_image[cell_mask] + top_values = pixels[bn.rankdata(pixels)[:5].astype(int) - 1] # find mean of five brightest pixels - max5px = np.mean(top_vals, dtype=float) + max5px = bn.nanmean(top_values) return max5px -def max5px_med(cell_mask, trap_image): - """ - Finds the mean of the five brightest pixels in the cell divided by the median pixel value. - - Parameters - ---------- - cell_mask: 2d array - Segmentation mask for the cell - trap_image: 2d array - """ - # sort pixels in cell - sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None) - top_vals = sorted_vals[-5:] - # find mean of five brightest pixels - max5px = np.mean(top_vals, dtype=float) - # find the median - med = np.median(sorted_vals) - if med == 0: - return np.nan - else: - return max5px / med - - -def max2p5pc_med(cell_mask, trap_image): - """ - Finds the mean of the brightest 2.5% of pixels in the cell - divided by the median pixel value. - - Parameters - ---------- - cell_mask: 2d array - Segmentation mask for the cell - trap_image: 2d array - """ - # number of pixels in mask - npixels = cell_mask.sum() - top_pixels = int(np.ceil(npixels * 0.025)) - # sort pixels in cell - sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None) - # find highest 2.5% - top_vals = sorted_vals[-top_pixels:] - # find mean of these highest pixels - max2p5pc = np.mean(top_vals, dtype=float) - med = np.median(sorted_vals) - if med == 0: - return np.nan - else: - return max2p5pc / med - - def std(cell_mask, trap_image): """ Finds the standard deviation of the values of the pixels in the cell. @@ -159,7 +110,7 @@ def std(cell_mask, trap_image): Segmentation mask for the cell trap_image: 2d array """ - return np.std(trap_image[np.where(cell_mask)], dtype=float) + return bn.std(trap_image[cell_mask]) def k2_major_median(cell_mask, trap_image): @@ -177,47 +128,23 @@ def k2_major_median(cell_mask, trap_image): median: float The median of the major cluster of two clusters """ - if np.any(cell_mask): - X = trap_image[np.where(cell_mask)].reshape(-1, 1) - # cluster pixels in cell into two clusters - kmeans = KMeans(n_clusters=2, random_state=0).fit(X) - high_clust_id = kmeans.cluster_centers_.argmax() - # find the median of pixels in the largest cluster - major_cluster = X[kmeans.predict(X) == high_clust_id] - major_median = np.median(major_cluster, axis=None) - return major_median - else: - return np.nan - -def k2_minor_median(cell_mask, trap_image): - """ - Finds the median of the minor cluster after clustering the pixels in the cell into two clusters. + if bn.anynan(trap_image): + cell_mask[np.isnan(trap_image)] = False + X = trap_image[cell_mask].reshape(-1, 1).astype(np.float32) + # cluster pixels in cell into two clusters + indices = faiss.IndexFlatL2(X.shape[1]) + # (n_clusters=2, random_state=0).fit(X) + _, indices = indices.search(X, k=2) + high_indices = bn.nanargmax(indices, axis=1).astype(bool) + # find the median of pixels in the largest cluster + # high_masks = np.logical_xor( # Use casting to obtain masks + # high_indices.reshape(-1, 1), np.tile((0, 1), X.shape[0]).reshape(-1, 2) + # ) + major_median = bn.nanmedian(X[high_indices]) + return major_median - Parameters - ---------- - cell_mask: 2d array - Segmentation mask for the cell - trap_image: 2d array - Returns - ------- - median: float - The median of the minor cluster. - """ - if np.any(cell_mask): - X = trap_image[np.where(cell_mask)].reshape(-1, 1) - # cluster pixels in cell into two clusters - kmeans = KMeans(n_clusters=2, random_state=0).fit(X) - low_clust_id = kmeans.cluster_centers_.argmin() - # find the median of pixels in the smallest cluster - minor_cluster = X[kmeans.predict(X) == low_clust_id] - minor_median = np.median(minor_cluster, axis=None) - return minor_median - else: - return np.nan - - -def volume(cell_mask): +def volume(cell_mask) -> float: """ Estimates the volume of the cell assuming it is an ellipsoid with the mask providing a cross-section through the median plane of the ellipsoid. @@ -247,20 +174,20 @@ def conical_volume(cell_mask): def spherical_volume(cell_mask): - ''' + """ Estimates the volume of the cell assuming it is a sphere with the mask providing a cross-section through the median plane of the sphere. Parameters ---------- cell_mask: 2d array Segmentation mask for the cell - ''' - area = cell_mask.sum() - r = np.sqrt(area / np.pi) + """ + total_area = area(cell_mask) + r = math.sqrt(total_area / np.pi) return (4 * np.pi * r**3) / 3 -def min_maj_approximation(cell_mask): +def min_maj_approximation(cell_mask) -> t.Tuple[int]: """ Finds the lengths of the minor and major axes of an ellipse from a cell mask. @@ -278,8 +205,8 @@ def min_maj_approximation(cell_mask): # get the size of the top of the cone (points that are equally maximal) cone_top = ndimage.morphology.distance_transform_edt(dn == 0) * padded # minor axis = largest distance from the edge of the ellipse - min_ax = np.round(nn.max()) + min_ax = np.round(bn.nanmax(nn)) # major axis = largest distance from the cone top # + distance from the center of cone top to edge of cone top - maj_ax = np.round(dn.max() + cone_top.sum() / 2) + maj_ax = np.round(bn.nanmax(dn) + bn.nansum(cone_top) / 2) return min_ax, maj_ax -- GitLab