From 8614df0f62a503fea1bc938cf0c66903239d9d16 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Mon, 26 Sep 2022 18:57:50 +0100
Subject: [PATCH] refactor(extraction): improve cell functions speed

---
 src/extraction/core/functions/cell.py | 161 +++++++-------------------
 1 file changed, 44 insertions(+), 117 deletions(-)

diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py
index 8c4ea97e..5b95541a 100644
--- a/src/extraction/core/functions/cell.py
+++ b/src/extraction/core/functions/cell.py
@@ -3,12 +3,16 @@ Base functions to extract information from a single cell
 
 These functions are automatically read by extractor.py, and so can only have the cell_mask and trap_image as inputs and must return only one value.
 """
+import math
+import typing as t
+
+import bottleneck as bn
+import faiss
 import numpy as np
 from scipy import ndimage
-from sklearn.cluster import KMeans
 
 
-def area(cell_mask):
+def area(cell_mask) -> int:
     """
     Find the area of a cell mask
 
@@ -17,10 +21,10 @@ def area(cell_mask):
     cell_mask: 2d array
         Segmentation mask for the cell
     """
-    return np.sum(cell_mask, dtype=int)
+    return bn.nansum(cell_mask)
 
 
-def eccentricity(cell_mask):
+def eccentricity(cell_mask) -> float:
     """
     Find the eccentricity using the approximate major and minor axes
 
@@ -33,7 +37,7 @@ def eccentricity(cell_mask):
     return np.sqrt(maj_ax**2 - min_ax**2) / maj_ax
 
 
-def mean(cell_mask, trap_image):
+def mean(cell_mask, trap_image) -> float:
     """
     Finds the mean of the pixels in the cell.
 
@@ -43,10 +47,10 @@ def mean(cell_mask, trap_image):
         Segmentation mask for the cell
     trap_image: 2d array
     """
-    return np.mean(trap_image[np.where(cell_mask)], dtype=float)
+    return bn.nanmean(trap_image[cell_mask])
 
 
-def median(cell_mask, trap_image):
+def median(cell_mask, trap_image) -> int:
     """
     Finds the median of the pixels in the cell.
 
@@ -56,10 +60,10 @@ def median(cell_mask, trap_image):
         Segmentation mask for the cell
     trap_image: 2d array
     """
-    return np.median(trap_image[np.where(cell_mask)])
+    return bn.nanmedian(trap_image[cell_mask])
 
 
-def max2p5pc(cell_mask, trap_image):
+def max2p5pc(cell_mask, trap_image) -> float:
     """
     Finds the mean of the brightest 2.5% of pixels in the cell.
 
@@ -70,18 +74,15 @@ def max2p5pc(cell_mask, trap_image):
     trap_image: 2d array
     """
     # number of pixels in mask
-    npixels = cell_mask.sum()
+    npixels = bn.nansum(cell_mask)
     top_pixels = int(np.ceil(npixels * 0.025))
-    # sort pixels in cell
-    sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None)
-    # find highest 2.5%
-    top_vals = sorted_vals[-top_pixels:]
+    # sort pixels in cell and find highest 2.5%
+    top_values = trap_image[bn.rankdata(trap_image[cell_mask])[:top_pixels]]
     # find mean of these highest pixels
-    max2p5pc = np.mean(top_vals, dtype=float)
-    return max2p5pc
+    return bn.nanmean(top_values)
 
 
-def max5px(cell_mask, trap_image):
+def max5px(cell_mask, trap_image) -> float:
     """
     Finds the mean of the five brightest pixels in the cell.
 
@@ -92,63 +93,13 @@ def max5px(cell_mask, trap_image):
     trap_image: 2d array
     """
     # sort pixels in cell
-    sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None)
-    top_vals = sorted_vals[-5:]
+    pixels = trap_image[cell_mask]
+    top_values = pixels[bn.rankdata(pixels)[:5].astype(int) - 1]
     # find mean of five brightest pixels
-    max5px = np.mean(top_vals, dtype=float)
+    max5px = bn.nanmean(top_values)
     return max5px
 
 
-def max5px_med(cell_mask, trap_image):
-    """
-    Finds the mean of the five brightest pixels in the cell divided by the median pixel value.
-
-    Parameters
-    ----------
-    cell_mask: 2d array
-        Segmentation mask for the cell
-    trap_image: 2d array
-    """
-    # sort pixels in cell
-    sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None)
-    top_vals = sorted_vals[-5:]
-    # find mean of five brightest pixels
-    max5px = np.mean(top_vals, dtype=float)
-    # find the median
-    med = np.median(sorted_vals)
-    if med == 0:
-        return np.nan
-    else:
-        return max5px / med
-
-
-def max2p5pc_med(cell_mask, trap_image):
-    """
-    Finds the mean of the brightest 2.5% of pixels in the cell
-    divided by the median pixel value.
-
-    Parameters
-    ----------
-    cell_mask: 2d array
-        Segmentation mask for the cell
-    trap_image: 2d array
-    """
-    # number of pixels in mask
-    npixels = cell_mask.sum()
-    top_pixels = int(np.ceil(npixels * 0.025))
-    # sort pixels in cell
-    sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None)
-    # find highest 2.5%
-    top_vals = sorted_vals[-top_pixels:]
-    # find mean of these highest pixels
-    max2p5pc = np.mean(top_vals, dtype=float)
-    med = np.median(sorted_vals)
-    if med == 0:
-        return np.nan
-    else:
-        return max2p5pc / med
-
-
 def std(cell_mask, trap_image):
     """
     Finds the standard deviation of the values of the pixels in the cell.
@@ -159,7 +110,7 @@ def std(cell_mask, trap_image):
         Segmentation mask for the cell
     trap_image: 2d array
     """
-    return np.std(trap_image[np.where(cell_mask)], dtype=float)
+    return bn.std(trap_image[cell_mask])
 
 
 def k2_major_median(cell_mask, trap_image):
@@ -177,47 +128,23 @@ def k2_major_median(cell_mask, trap_image):
     median: float
         The median of the major cluster of two clusters
     """
-    if np.any(cell_mask):
-        X = trap_image[np.where(cell_mask)].reshape(-1, 1)
-        # cluster pixels in cell into two clusters
-        kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
-        high_clust_id = kmeans.cluster_centers_.argmax()
-        # find the median of pixels in the largest cluster
-        major_cluster = X[kmeans.predict(X) == high_clust_id]
-        major_median = np.median(major_cluster, axis=None)
-        return major_median
-    else:
-        return np.nan
-
-def k2_minor_median(cell_mask, trap_image):
-    """
-    Finds the median of the minor cluster after clustering the pixels in the cell into two clusters.
+    if bn.anynan(trap_image):
+        cell_mask[np.isnan(trap_image)] = False
+    X = trap_image[cell_mask].reshape(-1, 1).astype(np.float32)
+    # cluster pixels in cell into two clusters
+    indices = faiss.IndexFlatL2(X.shape[1])
+    # (n_clusters=2, random_state=0).fit(X)
+    _, indices = indices.search(X, k=2)
+    high_indices = bn.nanargmax(indices, axis=1).astype(bool)
+    # find the median of pixels in the largest cluster
+    # high_masks = np.logical_xor(  # Use casting to obtain masks
+    #     high_indices.reshape(-1, 1), np.tile((0, 1), X.shape[0]).reshape(-1, 2)
+    # )
+    major_median = bn.nanmedian(X[high_indices])
+    return major_median
 
-    Parameters
-    ----------
-    cell_mask: 2d array
-        Segmentation mask for the cell
-    trap_image: 2d array
 
-    Returns
-    -------
-    median: float
-        The median of the minor cluster.
-    """
-    if np.any(cell_mask):
-        X = trap_image[np.where(cell_mask)].reshape(-1, 1)
-        # cluster pixels in cell into two clusters
-        kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
-        low_clust_id = kmeans.cluster_centers_.argmin()
-        # find the median of pixels in the smallest cluster
-        minor_cluster = X[kmeans.predict(X) == low_clust_id]
-        minor_median = np.median(minor_cluster, axis=None)
-        return minor_median
-    else:
-        return np.nan
-
-
-def volume(cell_mask):
+def volume(cell_mask) -> float:
     """
     Estimates the volume of the cell assuming it is an ellipsoid with the mask providing a cross-section through the median plane of the ellipsoid.
 
@@ -247,20 +174,20 @@ def conical_volume(cell_mask):
 
 
 def spherical_volume(cell_mask):
-    '''
+    """
     Estimates the volume of the cell assuming it is a sphere with the mask providing a cross-section through the median plane of the sphere.
 
     Parameters
     ----------
     cell_mask: 2d array
         Segmentation mask for the cell
-    '''
-    area = cell_mask.sum()
-    r = np.sqrt(area / np.pi)
+    """
+    total_area = area(cell_mask)
+    r = math.sqrt(total_area / np.pi)
     return (4 * np.pi * r**3) / 3
 
 
-def min_maj_approximation(cell_mask):
+def min_maj_approximation(cell_mask) -> t.Tuple[int]:
     """
     Finds the lengths of the minor and major axes of an ellipse from a cell mask.
 
@@ -278,8 +205,8 @@ def min_maj_approximation(cell_mask):
     # get the size of the top of the cone (points that are equally maximal)
     cone_top = ndimage.morphology.distance_transform_edt(dn == 0) * padded
     # minor axis = largest distance from the edge of the ellipse
-    min_ax = np.round(nn.max())
+    min_ax = np.round(bn.nanmax(nn))
     # major axis = largest distance from the cone top
     # + distance from the center of cone top to edge of cone top
-    maj_ax = np.round(dn.max() + cone_top.sum() / 2)
+    maj_ax = np.round(bn.nanmax(dn) + bn.nansum(cone_top) / 2)
     return min_ax, maj_ax
-- 
GitLab