From 977ffc4c119c08bf296f000f10e5b76e62ef6d97 Mon Sep 17 00:00:00 2001 From: pswain <peter.swain@ed.ac.uk> Date: Fri, 17 May 2024 14:11:05 +0100 Subject: [PATCH] fix(kymograph): error in using sort_order --- src/wela/plotting.py | 4 ++-- src/wela/sorting.py | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/wela/plotting.py b/src/wela/plotting.py index 4e47b06..f7fa0c2 100644 --- a/src/wela/plotting.py +++ b/src/wela/plotting.py @@ -105,7 +105,7 @@ def kymograph( dt = np.min(np.diff(np.sort(df.time.unique()))) data = wdf.to_numpy() if sort_order is not None: - data = data[np.argsort(sort_order), :] + data = data[sort_order, :] if filterfunc is not None: data = filterfunc(data) if standardscale: @@ -146,7 +146,7 @@ def kymograph( if buddings: buddings = df.pivot(index=y, columns=x, values="buddings").to_numpy() if sort_order is not None: - buddings = buddings[np.argsort(sort_order), :] + buddings = buddings[sort_order, :] bud_mask = np.ma.masked_where(buddings == 0, buddings) ax.imshow(bud_mask, interpolation="none") ax.figure.colorbar( diff --git a/src/wela/sorting.py b/src/wela/sorting.py index 31616c1..11694aa 100644 --- a/src/wela/sorting.py +++ b/src/wela/sorting.py @@ -11,20 +11,21 @@ def sort_by_budding(buddings, bud_number=0): sort_order.append(cell_buds[bud_number][1]) else: sort_order.append(np.nan) - sort_order = np.array(sort_order) + sort_order = np.argsort(sort_order) return sort_order -def sort_by_maximum(data, byfunction=None): - """Return indices of cells sorted by largest value.""" - sort_order = np.nan * np.ones(data.shape[0]) - sdata = data[not_all_nan(data), :] - if byfunction is not None: - sorted_indices = np.argsort(byfunction(sdata, axis=1)) +def sort_by_maximum(data, byfunction, reverse=False): + """ + Return indices of cells sorted by largest value. + + Any cells that have all NaN values are placed last. + """ + sort_order = np.argsort(byfunction(data, axis=1)) + if reverse: + return sort_order[::-1] else: - sorted_indices = np.nanargmax(sdata, axis=1) - sort_order[not_all_nan(data)] = sorted_indices - return sort_order + return sort_order def not_all_nan(data): -- GitLab