From 977ffc4c119c08bf296f000f10e5b76e62ef6d97 Mon Sep 17 00:00:00 2001
From: pswain <peter.swain@ed.ac.uk>
Date: Fri, 17 May 2024 14:11:05 +0100
Subject: [PATCH] fix(kymograph): error in using sort_order

---
 src/wela/plotting.py |  4 ++--
 src/wela/sorting.py  | 21 +++++++++++----------
 2 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/src/wela/plotting.py b/src/wela/plotting.py
index 4e47b06..f7fa0c2 100644
--- a/src/wela/plotting.py
+++ b/src/wela/plotting.py
@@ -105,7 +105,7 @@ def kymograph(
     dt = np.min(np.diff(np.sort(df.time.unique())))
     data = wdf.to_numpy()
     if sort_order is not None:
-        data = data[np.argsort(sort_order), :]
+        data = data[sort_order, :]
     if filterfunc is not None:
         data = filterfunc(data)
     if standardscale:
@@ -146,7 +146,7 @@ def kymograph(
     if buddings:
         buddings = df.pivot(index=y, columns=x, values="buddings").to_numpy()
         if sort_order is not None:
-            buddings = buddings[np.argsort(sort_order), :]
+            buddings = buddings[sort_order, :]
         bud_mask = np.ma.masked_where(buddings == 0, buddings)
         ax.imshow(bud_mask, interpolation="none")
     ax.figure.colorbar(
diff --git a/src/wela/sorting.py b/src/wela/sorting.py
index 31616c1..11694aa 100644
--- a/src/wela/sorting.py
+++ b/src/wela/sorting.py
@@ -11,20 +11,21 @@ def sort_by_budding(buddings, bud_number=0):
             sort_order.append(cell_buds[bud_number][1])
         else:
             sort_order.append(np.nan)
-    sort_order = np.array(sort_order)
+    sort_order = np.argsort(sort_order)
     return sort_order
 
 
-def sort_by_maximum(data, byfunction=None):
-    """Return indices of cells sorted by largest value."""
-    sort_order = np.nan * np.ones(data.shape[0])
-    sdata = data[not_all_nan(data), :]
-    if byfunction is not None:
-        sorted_indices = np.argsort(byfunction(sdata, axis=1))
+def sort_by_maximum(data, byfunction, reverse=False):
+    """
+    Return indices of cells sorted by largest value.
+
+    Any cells that have all NaN values are placed last.
+    """
+    sort_order = np.argsort(byfunction(data, axis=1))
+    if reverse:
+        return sort_order[::-1]
     else:
-        sorted_indices = np.nanargmax(sdata, axis=1)
-    sort_order[not_all_nan(data)] = sorted_indices
-    return sort_order
+        return sort_order
 
 
 def not_all_nan(data):
-- 
GitLab