diff --git a/src/wela/plotting.py b/src/wela/plotting.py
index 4e47b063517ecd067e5c66451e1085de55a3ee4b..f7fa0c246be62c41bf62f906ebc36c4ae437a194 100644
--- a/src/wela/plotting.py
+++ b/src/wela/plotting.py
@@ -105,7 +105,7 @@ def kymograph(
     dt = np.min(np.diff(np.sort(df.time.unique())))
     data = wdf.to_numpy()
     if sort_order is not None:
-        data = data[np.argsort(sort_order), :]
+        data = data[sort_order, :]
     if filterfunc is not None:
         data = filterfunc(data)
     if standardscale:
@@ -146,7 +146,7 @@ def kymograph(
     if buddings:
         buddings = df.pivot(index=y, columns=x, values="buddings").to_numpy()
         if sort_order is not None:
-            buddings = buddings[np.argsort(sort_order), :]
+            buddings = buddings[sort_order, :]
         bud_mask = np.ma.masked_where(buddings == 0, buddings)
         ax.imshow(bud_mask, interpolation="none")
     ax.figure.colorbar(
diff --git a/src/wela/sorting.py b/src/wela/sorting.py
index 31616c1dcee42f32e932a7fe7dca7a0a729d9a7c..11694aacf67d2c927fc38528c6c68e0bfd1ad275 100644
--- a/src/wela/sorting.py
+++ b/src/wela/sorting.py
@@ -11,20 +11,21 @@ def sort_by_budding(buddings, bud_number=0):
             sort_order.append(cell_buds[bud_number][1])
         else:
             sort_order.append(np.nan)
-    sort_order = np.array(sort_order)
+    sort_order = np.argsort(sort_order)
     return sort_order
 
 
-def sort_by_maximum(data, byfunction=None):
-    """Return indices of cells sorted by largest value."""
-    sort_order = np.nan * np.ones(data.shape[0])
-    sdata = data[not_all_nan(data), :]
-    if byfunction is not None:
-        sorted_indices = np.argsort(byfunction(sdata, axis=1))
+def sort_by_maximum(data, byfunction, reverse=False):
+    """
+    Return indices of cells sorted by largest value.
+
+    Any cells that have all NaN values are placed last.
+    """
+    sort_order = np.argsort(byfunction(data, axis=1))
+    if reverse:
+        return sort_order[::-1]
     else:
-        sorted_indices = np.nanargmax(sdata, axis=1)
-    sort_order[not_all_nan(data)] = sorted_indices
-    return sort_order
+        return sort_order
 
 
 def not_all_nan(data):