diff --git a/src/wela/plotting.py b/src/wela/plotting.py
index 56a412e168ebf92e3c4cb2d2b2074e68d3d64e1a..89fb108b006c9c55b8f9ba542d7df492164bb254 100644
--- a/src/wela/plotting.py
+++ b/src/wela/plotting.py
@@ -289,7 +289,6 @@ def plot_lineage(
         raise Exception("idx not part of dataframe")
     if isinstance(signals, str):
         signals = [signals]
-    nosubplots = len(signals)
     # show buddings if possible
     if "buddings" in df.columns:
         buddings = df[df.id == idx]["buddings"].to_numpy()
@@ -301,38 +300,34 @@ def plot_lineage(
     # find time
     t = df[df.id == idx]["time"].to_numpy()
     # generate figure
-    fig = plt.figure(figsize=figsize)
+    no_subplots = len(signals)
+    fig, ax = plt.subplots(no_subplots, 1, figsize=figsize, sharex=True)
     # index for subplot
-    splt = 1
     if title is None:
         plt.suptitle(idx)
     else:
         plt.suptitle(title)
     # plot signals
-    for signal in signals:
+    for i, signal in enumerate(signals):
         if "bud_" + signal in df.columns:
             bud_sig = df[df.id == idx]["bud_" + signal].to_numpy()
             if signal == "growth_rate" and "growth_rate" not in df.columns:
                 signal = "mother_" + signal
             m_sig = df[df.id == idx][signal].to_numpy()
-            plt.subplot(nosubplots, 1, splt)
             for start, end in zip(nb_pts, nb_pts[1:]):
                 # plot bud signal
-                plt.plot(t[start:end], bud_sig[start:end], ".-")
+                ax[i].plot(t[start:end], bud_sig[start:end], ".-")
             # plot mother signal
-            plt.plot(t, m_sig, "k-")
-            plt.ylabel(signal)
+            ax[i].plot(t, m_sig, "k-")
+            ax[i].set_ylabel(signal)
             add_shading(shade_times, shade_colour)
-            plt.grid()
-            splt += 1
+            ax[i].grid(True)
         else:
             m_sig = df[df.id == idx][signal].to_numpy()
-            plt.subplot(nosubplots, 1, splt)
-            plt.plot(t, m_sig)
-            plt.ylabel(signal)
+            ax[i].plot(t, m_sig)
+            ax[i].set_ylabel(signal)
             add_shading(shade_times, shade_colour)
-            plt.grid()
-            splt += 1
+            ax[i].grid(True)
         # plot buddings
         if (
             plot_budding_pts
@@ -341,10 +336,7 @@ def plot_lineage(
         ):
             for bpt in b_pts:
                 plt.plot(t[bpt], m_sig[bpt], "k.")
-    plt.xlabel("time (hours)")
-    # share the x axis
-    ax_list = fig.axes
-    ax_list[0].get_shared_x_axes().join(ax_list[0], *ax_list)
+    ax[-1].set_xlabel("time (hours)")
     if show:
         plt.show(block=False)