From 84db01b6557fa73fd40682361f42e0c85dfce7b7 Mon Sep 17 00:00:00 2001
From: Peter Swain <peter.swain@ed.ac.uk>
Date: Wed, 7 Aug 2024 10:59:42 +0100
Subject: [PATCH] change(plotting): added plot_binned_mean

---
 src/wela/plotting.py | 67 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 67 insertions(+)

diff --git a/src/wela/plotting.py b/src/wela/plotting.py
index 64d04c8..5c95d8c 100644
--- a/src/wela/plotting.py
+++ b/src/wela/plotting.py
@@ -1,9 +1,12 @@
+"""Plotting routines to work with dataloader."""
+
 from copy import copy
 
 import matplotlib.cm
 import matplotlib.pylab as plt
 import numpy as np
 import numpy.matlib
+from scipy.stats import binned_statistic
 
 try:
     from sklearn.preprocessing import StandardScaler
@@ -754,3 +757,67 @@ def get_bud_to_bud_data(
                     local_signals.append(local_data)
                     local_times.append(t[start_tpt_i : end_tpt_i + 1])
     return local_signals, local_times
+
+
+def plot_binned_mean(df, x_signal, y_signal, bins=10, groups=None, fmt="o-"):
+    """
+    Plot the mean of y_signal found for bins of x_signal against x_signal.
+
+    Use scipy's binned_statistic.
+
+    Parameters
+    ----------
+    df: pd.DataFrame
+        Dataframe with the data, typically dl.df.
+    x_signal: str
+        Name of the signal to bin and plot on the x-axis.
+    y_signal: str
+        Name of the signal to be averaged in bins of x_signal.
+    bins: int
+        Number of bins.
+    groups: list of str (optional)
+        Specific groups to plot.
+    fmt: str (optional)
+        Formatting for points and lines, passed to plt.errorbar.
+
+    Example
+    -------
+    >>> plot_binned_mean(dl.df, "median_GFP", "bud_growth_rate",
+            bins=10, groups=["2pc_raf", "2pc_glc"])
+    """
+    stats_dict = {}
+    if groups is None:
+        groups = df.group.unique()
+    for group in groups:
+        sdf = df[df.group == group][[x_signal, y_signal]].dropna()
+        stats = ["mean", "median", "std", "count"]
+        for stat in stats:
+            stats_dict[f"{stat}_{group}"], bin_edges, _ = binned_statistic(
+                sdf.median_GFP.values,
+                values=sdf.bud_growth_rate.values,
+                statistic=stat,
+                bins=bins,
+            )
+        stats_dict[f"stderr_{group}"] = stats_dict[f"std_{group}"] / np.sqrt(
+            stats_dict[f"count_{group}"]
+        )
+        stats_dict[f"bin_midpoints_{group}"] = np.array(
+            [
+                np.mean([bin_edges[i], bin_edges[i + 1]])
+                for i in range(len(bin_edges) - 1)
+            ]
+        )
+    # plot using errorbar
+    plt.figure()
+    for group in groups:
+        plt.errorbar(
+            stats_dict[f"bin_midpoints_{group}"],
+            stats_dict[f"mean_{group}"],
+            yerr=stats_dict[f"stderr_{group}"],
+            fmt=fmt,
+            label=group,
+        )
+    plt.xlabel(x_signal.replace("_", " "))
+    plt.ylabel(y_signal.replace("_", " "))
+    plt.legend()
+    plt.show(block=False)
-- 
GitLab