Skip to content
Snippets Groups Projects
Commit 8614df0f authored by Alán Muñoz's avatar Alán Muñoz
Browse files

refactor(extraction): improve cell functions speed

parent 436918c9
No related branches found
No related tags found
No related merge requests found
...@@ -3,12 +3,16 @@ Base functions to extract information from a single cell ...@@ -3,12 +3,16 @@ Base functions to extract information from a single cell
These functions are automatically read by extractor.py, and so can only have the cell_mask and trap_image as inputs and must return only one value. These functions are automatically read by extractor.py, and so can only have the cell_mask and trap_image as inputs and must return only one value.
""" """
import math
import typing as t
import bottleneck as bn
import faiss
import numpy as np import numpy as np
from scipy import ndimage from scipy import ndimage
from sklearn.cluster import KMeans
def area(cell_mask): def area(cell_mask) -> int:
""" """
Find the area of a cell mask Find the area of a cell mask
...@@ -17,10 +21,10 @@ def area(cell_mask): ...@@ -17,10 +21,10 @@ def area(cell_mask):
cell_mask: 2d array cell_mask: 2d array
Segmentation mask for the cell Segmentation mask for the cell
""" """
return np.sum(cell_mask, dtype=int) return bn.nansum(cell_mask)
def eccentricity(cell_mask): def eccentricity(cell_mask) -> float:
""" """
Find the eccentricity using the approximate major and minor axes Find the eccentricity using the approximate major and minor axes
...@@ -33,7 +37,7 @@ def eccentricity(cell_mask): ...@@ -33,7 +37,7 @@ def eccentricity(cell_mask):
return np.sqrt(maj_ax**2 - min_ax**2) / maj_ax return np.sqrt(maj_ax**2 - min_ax**2) / maj_ax
def mean(cell_mask, trap_image): def mean(cell_mask, trap_image) -> float:
""" """
Finds the mean of the pixels in the cell. Finds the mean of the pixels in the cell.
...@@ -43,10 +47,10 @@ def mean(cell_mask, trap_image): ...@@ -43,10 +47,10 @@ def mean(cell_mask, trap_image):
Segmentation mask for the cell Segmentation mask for the cell
trap_image: 2d array trap_image: 2d array
""" """
return np.mean(trap_image[np.where(cell_mask)], dtype=float) return bn.nanmean(trap_image[cell_mask])
def median(cell_mask, trap_image): def median(cell_mask, trap_image) -> int:
""" """
Finds the median of the pixels in the cell. Finds the median of the pixels in the cell.
...@@ -56,10 +60,10 @@ def median(cell_mask, trap_image): ...@@ -56,10 +60,10 @@ def median(cell_mask, trap_image):
Segmentation mask for the cell Segmentation mask for the cell
trap_image: 2d array trap_image: 2d array
""" """
return np.median(trap_image[np.where(cell_mask)]) return bn.nanmedian(trap_image[cell_mask])
def max2p5pc(cell_mask, trap_image): def max2p5pc(cell_mask, trap_image) -> float:
""" """
Finds the mean of the brightest 2.5% of pixels in the cell. Finds the mean of the brightest 2.5% of pixels in the cell.
...@@ -70,18 +74,15 @@ def max2p5pc(cell_mask, trap_image): ...@@ -70,18 +74,15 @@ def max2p5pc(cell_mask, trap_image):
trap_image: 2d array trap_image: 2d array
""" """
# number of pixels in mask # number of pixels in mask
npixels = cell_mask.sum() npixels = bn.nansum(cell_mask)
top_pixels = int(np.ceil(npixels * 0.025)) top_pixels = int(np.ceil(npixels * 0.025))
# sort pixels in cell # sort pixels in cell and find highest 2.5%
sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None) top_values = trap_image[bn.rankdata(trap_image[cell_mask])[:top_pixels]]
# find highest 2.5%
top_vals = sorted_vals[-top_pixels:]
# find mean of these highest pixels # find mean of these highest pixels
max2p5pc = np.mean(top_vals, dtype=float) return bn.nanmean(top_values)
return max2p5pc
def max5px(cell_mask, trap_image): def max5px(cell_mask, trap_image) -> float:
""" """
Finds the mean of the five brightest pixels in the cell. Finds the mean of the five brightest pixels in the cell.
...@@ -92,63 +93,13 @@ def max5px(cell_mask, trap_image): ...@@ -92,63 +93,13 @@ def max5px(cell_mask, trap_image):
trap_image: 2d array trap_image: 2d array
""" """
# sort pixels in cell # sort pixels in cell
sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None) pixels = trap_image[cell_mask]
top_vals = sorted_vals[-5:] top_values = pixels[bn.rankdata(pixels)[:5].astype(int) - 1]
# find mean of five brightest pixels # find mean of five brightest pixels
max5px = np.mean(top_vals, dtype=float) max5px = bn.nanmean(top_values)
return max5px return max5px
def max5px_med(cell_mask, trap_image):
"""
Finds the mean of the five brightest pixels in the cell divided by the median pixel value.
Parameters
----------
cell_mask: 2d array
Segmentation mask for the cell
trap_image: 2d array
"""
# sort pixels in cell
sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None)
top_vals = sorted_vals[-5:]
# find mean of five brightest pixels
max5px = np.mean(top_vals, dtype=float)
# find the median
med = np.median(sorted_vals)
if med == 0:
return np.nan
else:
return max5px / med
def max2p5pc_med(cell_mask, trap_image):
"""
Finds the mean of the brightest 2.5% of pixels in the cell
divided by the median pixel value.
Parameters
----------
cell_mask: 2d array
Segmentation mask for the cell
trap_image: 2d array
"""
# number of pixels in mask
npixels = cell_mask.sum()
top_pixels = int(np.ceil(npixels * 0.025))
# sort pixels in cell
sorted_vals = np.sort(trap_image[np.where(cell_mask)], axis=None)
# find highest 2.5%
top_vals = sorted_vals[-top_pixels:]
# find mean of these highest pixels
max2p5pc = np.mean(top_vals, dtype=float)
med = np.median(sorted_vals)
if med == 0:
return np.nan
else:
return max2p5pc / med
def std(cell_mask, trap_image): def std(cell_mask, trap_image):
""" """
Finds the standard deviation of the values of the pixels in the cell. Finds the standard deviation of the values of the pixels in the cell.
...@@ -159,7 +110,7 @@ def std(cell_mask, trap_image): ...@@ -159,7 +110,7 @@ def std(cell_mask, trap_image):
Segmentation mask for the cell Segmentation mask for the cell
trap_image: 2d array trap_image: 2d array
""" """
return np.std(trap_image[np.where(cell_mask)], dtype=float) return bn.std(trap_image[cell_mask])
def k2_major_median(cell_mask, trap_image): def k2_major_median(cell_mask, trap_image):
...@@ -177,47 +128,23 @@ def k2_major_median(cell_mask, trap_image): ...@@ -177,47 +128,23 @@ def k2_major_median(cell_mask, trap_image):
median: float median: float
The median of the major cluster of two clusters The median of the major cluster of two clusters
""" """
if np.any(cell_mask): if bn.anynan(trap_image):
X = trap_image[np.where(cell_mask)].reshape(-1, 1) cell_mask[np.isnan(trap_image)] = False
# cluster pixels in cell into two clusters X = trap_image[cell_mask].reshape(-1, 1).astype(np.float32)
kmeans = KMeans(n_clusters=2, random_state=0).fit(X) # cluster pixels in cell into two clusters
high_clust_id = kmeans.cluster_centers_.argmax() indices = faiss.IndexFlatL2(X.shape[1])
# find the median of pixels in the largest cluster # (n_clusters=2, random_state=0).fit(X)
major_cluster = X[kmeans.predict(X) == high_clust_id] _, indices = indices.search(X, k=2)
major_median = np.median(major_cluster, axis=None) high_indices = bn.nanargmax(indices, axis=1).astype(bool)
return major_median # find the median of pixels in the largest cluster
else: # high_masks = np.logical_xor( # Use casting to obtain masks
return np.nan # high_indices.reshape(-1, 1), np.tile((0, 1), X.shape[0]).reshape(-1, 2)
# )
def k2_minor_median(cell_mask, trap_image): major_median = bn.nanmedian(X[high_indices])
""" return major_median
Finds the median of the minor cluster after clustering the pixels in the cell into two clusters.
Parameters
----------
cell_mask: 2d array
Segmentation mask for the cell
trap_image: 2d array
Returns def volume(cell_mask) -> float:
-------
median: float
The median of the minor cluster.
"""
if np.any(cell_mask):
X = trap_image[np.where(cell_mask)].reshape(-1, 1)
# cluster pixels in cell into two clusters
kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
low_clust_id = kmeans.cluster_centers_.argmin()
# find the median of pixels in the smallest cluster
minor_cluster = X[kmeans.predict(X) == low_clust_id]
minor_median = np.median(minor_cluster, axis=None)
return minor_median
else:
return np.nan
def volume(cell_mask):
""" """
Estimates the volume of the cell assuming it is an ellipsoid with the mask providing a cross-section through the median plane of the ellipsoid. Estimates the volume of the cell assuming it is an ellipsoid with the mask providing a cross-section through the median plane of the ellipsoid.
...@@ -247,20 +174,20 @@ def conical_volume(cell_mask): ...@@ -247,20 +174,20 @@ def conical_volume(cell_mask):
def spherical_volume(cell_mask): def spherical_volume(cell_mask):
''' """
Estimates the volume of the cell assuming it is a sphere with the mask providing a cross-section through the median plane of the sphere. Estimates the volume of the cell assuming it is a sphere with the mask providing a cross-section through the median plane of the sphere.
Parameters Parameters
---------- ----------
cell_mask: 2d array cell_mask: 2d array
Segmentation mask for the cell Segmentation mask for the cell
''' """
area = cell_mask.sum() total_area = area(cell_mask)
r = np.sqrt(area / np.pi) r = math.sqrt(total_area / np.pi)
return (4 * np.pi * r**3) / 3 return (4 * np.pi * r**3) / 3
def min_maj_approximation(cell_mask): def min_maj_approximation(cell_mask) -> t.Tuple[int]:
""" """
Finds the lengths of the minor and major axes of an ellipse from a cell mask. Finds the lengths of the minor and major axes of an ellipse from a cell mask.
...@@ -278,8 +205,8 @@ def min_maj_approximation(cell_mask): ...@@ -278,8 +205,8 @@ def min_maj_approximation(cell_mask):
# get the size of the top of the cone (points that are equally maximal) # get the size of the top of the cone (points that are equally maximal)
cone_top = ndimage.morphology.distance_transform_edt(dn == 0) * padded cone_top = ndimage.morphology.distance_transform_edt(dn == 0) * padded
# minor axis = largest distance from the edge of the ellipse # minor axis = largest distance from the edge of the ellipse
min_ax = np.round(nn.max()) min_ax = np.round(bn.nanmax(nn))
# major axis = largest distance from the cone top # major axis = largest distance from the cone top
# + distance from the center of cone top to edge of cone top # + distance from the center of cone top to edge of cone top
maj_ax = np.round(dn.max() + cone_top.sum() / 2) maj_ax = np.round(bn.nanmax(dn) + bn.nansum(cone_top) / 2)
return min_ax, maj_ax return min_ax, maj_ax
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment