diff --git a/src/wela/plotting.py b/src/wela/plotting.py index 56a412e168ebf92e3c4cb2d2b2074e68d3d64e1a..89fb108b006c9c55b8f9ba542d7df492164bb254 100644 --- a/src/wela/plotting.py +++ b/src/wela/plotting.py @@ -289,7 +289,6 @@ def plot_lineage( raise Exception("idx not part of dataframe") if isinstance(signals, str): signals = [signals] - nosubplots = len(signals) # show buddings if possible if "buddings" in df.columns: buddings = df[df.id == idx]["buddings"].to_numpy() @@ -301,38 +300,34 @@ def plot_lineage( # find time t = df[df.id == idx]["time"].to_numpy() # generate figure - fig = plt.figure(figsize=figsize) + no_subplots = len(signals) + fig, ax = plt.subplots(no_subplots, 1, figsize=figsize, sharex=True) # index for subplot - splt = 1 if title is None: plt.suptitle(idx) else: plt.suptitle(title) # plot signals - for signal in signals: + for i, signal in enumerate(signals): if "bud_" + signal in df.columns: bud_sig = df[df.id == idx]["bud_" + signal].to_numpy() if signal == "growth_rate" and "growth_rate" not in df.columns: signal = "mother_" + signal m_sig = df[df.id == idx][signal].to_numpy() - plt.subplot(nosubplots, 1, splt) for start, end in zip(nb_pts, nb_pts[1:]): # plot bud signal - plt.plot(t[start:end], bud_sig[start:end], ".-") + ax[i].plot(t[start:end], bud_sig[start:end], ".-") # plot mother signal - plt.plot(t, m_sig, "k-") - plt.ylabel(signal) + ax[i].plot(t, m_sig, "k-") + ax[i].set_ylabel(signal) add_shading(shade_times, shade_colour) - plt.grid() - splt += 1 + ax[i].grid(True) else: m_sig = df[df.id == idx][signal].to_numpy() - plt.subplot(nosubplots, 1, splt) - plt.plot(t, m_sig) - plt.ylabel(signal) + ax[i].plot(t, m_sig) + ax[i].set_ylabel(signal) add_shading(shade_times, shade_colour) - plt.grid() - splt += 1 + ax[i].grid(True) # plot buddings if ( plot_budding_pts @@ -341,10 +336,7 @@ def plot_lineage( ): for bpt in b_pts: plt.plot(t[bpt], m_sig[bpt], "k.") - plt.xlabel("time (hours)") - # share the x axis - ax_list = fig.axes - ax_list[0].get_shared_x_axes().join(ax_list[0], *ax_list) + ax[-1].set_xlabel("time (hours)") if show: plt.show(block=False)