From 551acc7431866d260aa901be1490df9a1cc79b4d Mon Sep 17 00:00:00 2001
From: Peter Swain <pswain@Home-iMac.local>
Date: Tue, 21 May 2024 10:25:31 +0100
Subject: [PATCH] feature(bud_to_bud_plot): more than one bud event

---
 src/wela/plotting.py | 61 ++++++++++++++++++++++++++++++++------------
 1 file changed, 44 insertions(+), 17 deletions(-)

diff --git a/src/wela/plotting.py b/src/wela/plotting.py
index 39ce895..8595309 100644
--- a/src/wela/plotting.py
+++ b/src/wela/plotting.py
@@ -610,6 +610,7 @@ def bud_to_bud_plot(
     colour="b",
     group=None,
     nbins=None,
+    no_future_buddings=2,
     return_signal=False,
     df=None,
     title=None,
@@ -636,9 +637,11 @@ def bud_to_bud_plot(
         Colour of lines.
     group: str, optional
         The name of the group to plot.
-    nbins: integer, optional
+    nbins: int, optional
         The number of time bins to partition the interval between the
         first and the second budding event.
+    no_future_buddings: int, optional
+        The number of future budding events to include. Default is 1.
     return_signal: boolean, optional
         If True, return the signal for each cell interpolated to the time
         bins.
@@ -673,29 +676,24 @@ def bud_to_bud_plot(
         t = t / 60
     tpt_i = np.argmin((t - tpt) ** 2)
     # get data for bud-to-bud around tpt for each cell
-    local_signals, local_times = [], []
-    for i in range(signal_data.shape[0]):
-        future_buddings = np.nonzero(buddings[i, :][tpt_i:])[0]
-        if np.any(future_buddings):
-            end_tpt_i = tpt_i + future_buddings[0]
-            past_buddings = np.nonzero(buddings[i, :][:tpt_i])[0]
-            if np.any(past_buddings):
-                start_tpt_i = past_buddings[-1]
-                local_data = signal_data[i, start_tpt_i : end_tpt_i + 1]
-                if ~np.all(np.isnan(local_data)):
-                    local_signals.append(local_data)
-                    local_times.append(t[start_tpt_i : end_tpt_i + 1])
+    local_signals, local_times = get_bud_to_bud_data(
+        tpt_i,
+        t,
+        signal_data,
+        buddings,
+        no_future_buddings_index=no_future_buddings - 1,
+    )
     if local_times:
-        # find bins for normalised time, between 0 and 1
+        # find bins for normalised time, between 0 and no_future_buddings
         nbins = int(np.median([len(local_time) for local_time in local_times]))
-        ntbins = np.linspace(0, 1, nbins)
+        ntbins = np.linspace(0, no_future_buddings, nbins)
         # interpolate each local signal to make a new signal
         new_signal = np.nan * np.ones((len(local_signals), nbins))
         for i in range(len(local_signals)):
             s = local_signals[i]
-            # normalise time between 0 and 1
+            # normalise time between 0 and no_future_buddings
             nt = local_times[i] - local_times[i][0]
-            nt /= nt[-1]
+            nt = nt / nt[-1] * no_future_buddings
             # interpolate into the bins
             new_signal[i, :] = np.interp(
                 ntbins,
@@ -727,3 +725,32 @@ def bud_to_bud_plot(
             plt.show(block=False)
         if return_signal:
             return new_signal
+
+
+def get_bud_to_bud_data(
+    tpt_i, t, signal_data, buddings, no_future_buddings_index=0
+):
+    """
+    Get data for bud-to-bud around tpt for each cell.
+
+    Parameters
+    ----------
+    no_future_buddings_index: int
+        The index to select future buddings, with 0 representing the next
+        budding event.
+    """
+    local_signals, local_times = [], []
+    for i in range(signal_data.shape[0]):
+        future_buddings = np.nonzero(buddings[i, :][tpt_i:])[0]
+        if np.any(future_buddings) and (
+            future_buddings.size > no_future_buddings_index
+        ):
+            end_tpt_i = tpt_i + future_buddings[no_future_buddings_index]
+            past_buddings = np.nonzero(buddings[i, :][:tpt_i])[0]
+            if np.any(past_buddings):
+                start_tpt_i = past_buddings[-1]
+                local_data = signal_data[i, start_tpt_i : end_tpt_i + 1]
+                if ~np.all(np.isnan(local_data)):
+                    local_signals.append(local_data)
+                    local_times.append(t[start_tpt_i : end_tpt_i + 1])
+    return local_signals, local_times
-- 
GitLab