diff --git a/src/wela/dataloader.py b/src/wela/dataloader.py index 9ef2ff06a799be1d845d1f05f86e4490b2dd8dc3..4a66f512e6f2a42cac7e23ccb36e1e3a9f5ec6d8 100644 --- a/src/wela/dataloader.py +++ b/src/wela/dataloader.py @@ -600,6 +600,7 @@ class dataloader: duration_threshold=None, tmin=None, tmax=None, + group=None, ): """ Find a sub data frame of dataloader's main data frame. @@ -617,8 +618,13 @@ class dataloader: Only include data for times greater than tmin tmax: float (optional) Only include data for times less than tmax + group: str (optional) + Group to specialise to. """ - sdf = self.df + if group is None: + sdf = self.df + else: + sdf = self.df[self.df.group == group] selected_ids = [] # drop signals that are all NaN if dropna and signal: diff --git a/src/wela/plotting.py b/src/wela/plotting.py index d57ec6b20991f83d72ec0cfb37ed7f11319ecd3a..5c95d8ca4ec8a8e6d70b2309fdb3630c7cfdaa5a 100644 --- a/src/wela/plotting.py +++ b/src/wela/plotting.py @@ -1,9 +1,12 @@ +"""Plotting routines to work with dataloader.""" + from copy import copy import matplotlib.cm import matplotlib.pylab as plt import numpy as np import numpy.matlib +from scipy.stats import binned_statistic try: from sklearn.preprocessing import StandardScaler @@ -289,7 +292,6 @@ def plot_lineage( raise Exception("idx not part of dataframe") if isinstance(signals, str): signals = [signals] - nosubplots = len(signals) # show buddings if possible if "buddings" in df.columns: buddings = df[df.id == idx]["buddings"].to_numpy() @@ -301,38 +303,34 @@ def plot_lineage( # find time t = df[df.id == idx]["time"].to_numpy() # generate figure - fig = plt.figure(figsize=figsize) + no_subplots = len(signals) + fig, ax = plt.subplots(no_subplots, 1, figsize=figsize, sharex=True) # index for subplot - splt = 1 if title is None: plt.suptitle(idx) else: plt.suptitle(title) # plot signals - for signal in signals: + for i, signal in enumerate(signals): if "bud_" + signal in df.columns: bud_sig = df[df.id == idx]["bud_" + signal].to_numpy() if signal == "growth_rate" and "growth_rate" not in df.columns: signal = "mother_" + signal m_sig = df[df.id == idx][signal].to_numpy() - plt.subplot(nosubplots, 1, splt) for start, end in zip(nb_pts, nb_pts[1:]): # plot bud signal - plt.plot(t[start:end], bud_sig[start:end], ".-") + ax[i].plot(t[start:end], bud_sig[start:end], ".-") # plot mother signal - plt.plot(t, m_sig, "k-") - plt.ylabel(signal) + ax[i].plot(t, m_sig, "k-") + ax[i].set_ylabel(signal) add_shading(shade_times, shade_colour) - plt.grid() - splt += 1 + ax[i].grid(True) else: m_sig = df[df.id == idx][signal].to_numpy() - plt.subplot(nosubplots, 1, splt) - plt.plot(t, m_sig) - plt.ylabel(signal) + ax[i].plot(t, m_sig) + ax[i].set_ylabel(signal) add_shading(shade_times, shade_colour) - plt.grid() - splt += 1 + ax[i].grid(True) # plot buddings if ( plot_budding_pts @@ -341,10 +339,7 @@ def plot_lineage( ): for bpt in b_pts: plt.plot(t[bpt], m_sig[bpt], "k.") - plt.xlabel("time (hours)") - # share the x axis - ax_list = fig.axes - ax_list[0].get_shared_x_axes().join(ax_list[0], *ax_list) + ax[-1].set_xlabel("time (hours)") if show: plt.show(block=False) @@ -762,3 +757,67 @@ def get_bud_to_bud_data( local_signals.append(local_data) local_times.append(t[start_tpt_i : end_tpt_i + 1]) return local_signals, local_times + + +def plot_binned_mean(df, x_signal, y_signal, bins=10, groups=None, fmt="o-"): + """ + Plot the mean of y_signal found for bins of x_signal against x_signal. + + Use scipy's binned_statistic. + + Parameters + ---------- + df: pd.DataFrame + Dataframe with the data, typically dl.df. + x_signal: str + Name of the signal to bin and plot on the x-axis. + y_signal: str + Name of the signal to be averaged in bins of x_signal. + bins: int + Number of bins. + groups: list of str (optional) + Specific groups to plot. + fmt: str (optional) + Formatting for points and lines, passed to plt.errorbar. + + Example + ------- + >>> plot_binned_mean(dl.df, "median_GFP", "bud_growth_rate", + bins=10, groups=["2pc_raf", "2pc_glc"]) + """ + stats_dict = {} + if groups is None: + groups = df.group.unique() + for group in groups: + sdf = df[df.group == group][[x_signal, y_signal]].dropna() + stats = ["mean", "median", "std", "count"] + for stat in stats: + stats_dict[f"{stat}_{group}"], bin_edges, _ = binned_statistic( + sdf.median_GFP.values, + values=sdf.bud_growth_rate.values, + statistic=stat, + bins=bins, + ) + stats_dict[f"stderr_{group}"] = stats_dict[f"std_{group}"] / np.sqrt( + stats_dict[f"count_{group}"] + ) + stats_dict[f"bin_midpoints_{group}"] = np.array( + [ + np.mean([bin_edges[i], bin_edges[i + 1]]) + for i in range(len(bin_edges) - 1) + ] + ) + # plot using errorbar + plt.figure() + for group in groups: + plt.errorbar( + stats_dict[f"bin_midpoints_{group}"], + stats_dict[f"mean_{group}"], + yerr=stats_dict[f"stderr_{group}"], + fmt=fmt, + label=group, + ) + plt.xlabel(x_signal.replace("_", " ")) + plt.ylabel(y_signal.replace("_", " ")) + plt.legend() + plt.show(block=False)