diff --git a/src/wela/dataloader.py b/src/wela/dataloader.py
index 9ef2ff06a799be1d845d1f05f86e4490b2dd8dc3..4a66f512e6f2a42cac7e23ccb36e1e3a9f5ec6d8 100644
--- a/src/wela/dataloader.py
+++ b/src/wela/dataloader.py
@@ -600,6 +600,7 @@ class dataloader:
         duration_threshold=None,
         tmin=None,
         tmax=None,
+        group=None,
     ):
         """
         Find a sub data frame of dataloader's main data frame.
@@ -617,8 +618,13 @@ class dataloader:
             Only include data for times greater than tmin
         tmax: float (optional)
             Only include data for times less than tmax
+        group: str (optional)
+            Group to specialise to.
         """
-        sdf = self.df
+        if group is None:
+            sdf = self.df
+        else:
+            sdf = self.df[self.df.group == group]
         selected_ids = []
         # drop signals that are all NaN
         if dropna and signal:
diff --git a/src/wela/plotting.py b/src/wela/plotting.py
index d57ec6b20991f83d72ec0cfb37ed7f11319ecd3a..5c95d8ca4ec8a8e6d70b2309fdb3630c7cfdaa5a 100644
--- a/src/wela/plotting.py
+++ b/src/wela/plotting.py
@@ -1,9 +1,12 @@
+"""Plotting routines to work with dataloader."""
+
 from copy import copy
 
 import matplotlib.cm
 import matplotlib.pylab as plt
 import numpy as np
 import numpy.matlib
+from scipy.stats import binned_statistic
 
 try:
     from sklearn.preprocessing import StandardScaler
@@ -289,7 +292,6 @@ def plot_lineage(
         raise Exception("idx not part of dataframe")
     if isinstance(signals, str):
         signals = [signals]
-    nosubplots = len(signals)
     # show buddings if possible
     if "buddings" in df.columns:
         buddings = df[df.id == idx]["buddings"].to_numpy()
@@ -301,38 +303,34 @@ def plot_lineage(
     # find time
     t = df[df.id == idx]["time"].to_numpy()
     # generate figure
-    fig = plt.figure(figsize=figsize)
+    no_subplots = len(signals)
+    fig, ax = plt.subplots(no_subplots, 1, figsize=figsize, sharex=True)
     # index for subplot
-    splt = 1
     if title is None:
         plt.suptitle(idx)
     else:
         plt.suptitle(title)
     # plot signals
-    for signal in signals:
+    for i, signal in enumerate(signals):
         if "bud_" + signal in df.columns:
             bud_sig = df[df.id == idx]["bud_" + signal].to_numpy()
             if signal == "growth_rate" and "growth_rate" not in df.columns:
                 signal = "mother_" + signal
             m_sig = df[df.id == idx][signal].to_numpy()
-            plt.subplot(nosubplots, 1, splt)
             for start, end in zip(nb_pts, nb_pts[1:]):
                 # plot bud signal
-                plt.plot(t[start:end], bud_sig[start:end], ".-")
+                ax[i].plot(t[start:end], bud_sig[start:end], ".-")
             # plot mother signal
-            plt.plot(t, m_sig, "k-")
-            plt.ylabel(signal)
+            ax[i].plot(t, m_sig, "k-")
+            ax[i].set_ylabel(signal)
             add_shading(shade_times, shade_colour)
-            plt.grid()
-            splt += 1
+            ax[i].grid(True)
         else:
             m_sig = df[df.id == idx][signal].to_numpy()
-            plt.subplot(nosubplots, 1, splt)
-            plt.plot(t, m_sig)
-            plt.ylabel(signal)
+            ax[i].plot(t, m_sig)
+            ax[i].set_ylabel(signal)
             add_shading(shade_times, shade_colour)
-            plt.grid()
-            splt += 1
+            ax[i].grid(True)
         # plot buddings
         if (
             plot_budding_pts
@@ -341,10 +339,7 @@ def plot_lineage(
         ):
             for bpt in b_pts:
                 plt.plot(t[bpt], m_sig[bpt], "k.")
-    plt.xlabel("time (hours)")
-    # share the x axis
-    ax_list = fig.axes
-    ax_list[0].get_shared_x_axes().join(ax_list[0], *ax_list)
+    ax[-1].set_xlabel("time (hours)")
     if show:
         plt.show(block=False)
 
@@ -762,3 +757,67 @@ def get_bud_to_bud_data(
                     local_signals.append(local_data)
                     local_times.append(t[start_tpt_i : end_tpt_i + 1])
     return local_signals, local_times
+
+
+def plot_binned_mean(df, x_signal, y_signal, bins=10, groups=None, fmt="o-"):
+    """
+    Plot the mean of y_signal found for bins of x_signal against x_signal.
+
+    Use scipy's binned_statistic.
+
+    Parameters
+    ----------
+    df: pd.DataFrame
+        Dataframe with the data, typically dl.df.
+    x_signal: str
+        Name of the signal to bin and plot on the x-axis.
+    y_signal: str
+        Name of the signal to be averaged in bins of x_signal.
+    bins: int
+        Number of bins.
+    groups: list of str (optional)
+        Specific groups to plot.
+    fmt: str (optional)
+        Formatting for points and lines, passed to plt.errorbar.
+
+    Example
+    -------
+    >>> plot_binned_mean(dl.df, "median_GFP", "bud_growth_rate",
+            bins=10, groups=["2pc_raf", "2pc_glc"])
+    """
+    stats_dict = {}
+    if groups is None:
+        groups = df.group.unique()
+    for group in groups:
+        sdf = df[df.group == group][[x_signal, y_signal]].dropna()
+        stats = ["mean", "median", "std", "count"]
+        for stat in stats:
+            stats_dict[f"{stat}_{group}"], bin_edges, _ = binned_statistic(
+                sdf.median_GFP.values,
+                values=sdf.bud_growth_rate.values,
+                statistic=stat,
+                bins=bins,
+            )
+        stats_dict[f"stderr_{group}"] = stats_dict[f"std_{group}"] / np.sqrt(
+            stats_dict[f"count_{group}"]
+        )
+        stats_dict[f"bin_midpoints_{group}"] = np.array(
+            [
+                np.mean([bin_edges[i], bin_edges[i + 1]])
+                for i in range(len(bin_edges) - 1)
+            ]
+        )
+    # plot using errorbar
+    plt.figure()
+    for group in groups:
+        plt.errorbar(
+            stats_dict[f"bin_midpoints_{group}"],
+            stats_dict[f"mean_{group}"],
+            yerr=stats_dict[f"stderr_{group}"],
+            fmt=fmt,
+            label=group,
+        )
+    plt.xlabel(x_signal.replace("_", " "))
+    plt.ylabel(y_signal.replace("_", " "))
+    plt.legend()
+    plt.show(block=False)