From 84db01b6557fa73fd40682361f42e0c85dfce7b7 Mon Sep 17 00:00:00 2001 From: Peter Swain <peter.swain@ed.ac.uk> Date: Wed, 7 Aug 2024 10:59:42 +0100 Subject: [PATCH] change(plotting): added plot_binned_mean --- src/wela/plotting.py | 67 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/src/wela/plotting.py b/src/wela/plotting.py index 64d04c8..5c95d8c 100644 --- a/src/wela/plotting.py +++ b/src/wela/plotting.py @@ -1,9 +1,12 @@ +"""Plotting routines to work with dataloader.""" + from copy import copy import matplotlib.cm import matplotlib.pylab as plt import numpy as np import numpy.matlib +from scipy.stats import binned_statistic try: from sklearn.preprocessing import StandardScaler @@ -754,3 +757,67 @@ def get_bud_to_bud_data( local_signals.append(local_data) local_times.append(t[start_tpt_i : end_tpt_i + 1]) return local_signals, local_times + + +def plot_binned_mean(df, x_signal, y_signal, bins=10, groups=None, fmt="o-"): + """ + Plot the mean of y_signal found for bins of x_signal against x_signal. + + Use scipy's binned_statistic. + + Parameters + ---------- + df: pd.DataFrame + Dataframe with the data, typically dl.df. + x_signal: str + Name of the signal to bin and plot on the x-axis. + y_signal: str + Name of the signal to be averaged in bins of x_signal. + bins: int + Number of bins. + groups: list of str (optional) + Specific groups to plot. + fmt: str (optional) + Formatting for points and lines, passed to plt.errorbar. + + Example + ------- + >>> plot_binned_mean(dl.df, "median_GFP", "bud_growth_rate", + bins=10, groups=["2pc_raf", "2pc_glc"]) + """ + stats_dict = {} + if groups is None: + groups = df.group.unique() + for group in groups: + sdf = df[df.group == group][[x_signal, y_signal]].dropna() + stats = ["mean", "median", "std", "count"] + for stat in stats: + stats_dict[f"{stat}_{group}"], bin_edges, _ = binned_statistic( + sdf.median_GFP.values, + values=sdf.bud_growth_rate.values, + statistic=stat, + bins=bins, + ) + stats_dict[f"stderr_{group}"] = stats_dict[f"std_{group}"] / np.sqrt( + stats_dict[f"count_{group}"] + ) + stats_dict[f"bin_midpoints_{group}"] = np.array( + [ + np.mean([bin_edges[i], bin_edges[i + 1]]) + for i in range(len(bin_edges) - 1) + ] + ) + # plot using errorbar + plt.figure() + for group in groups: + plt.errorbar( + stats_dict[f"bin_midpoints_{group}"], + stats_dict[f"mean_{group}"], + yerr=stats_dict[f"stderr_{group}"], + fmt=fmt, + label=group, + ) + plt.xlabel(x_signal.replace("_", " ")) + plt.ylabel(y_signal.replace("_", " ")) + plt.legend() + plt.show(block=False) -- GitLab