From cf8d266122c1cd463b42c37ef6cd336a7bdfcacf Mon Sep 17 00:00:00 2001
From: pswain <peter.swain@ed.ac.uk>
Date: Tue, 7 Jun 2022 17:15:01 +0100
Subject: [PATCH] tidied up autocrosscorr; removed plotting

---
 time_series_analysis.py | 114 +++++-----------------------------------
 1 file changed, 13 insertions(+), 101 deletions(-)

diff --git a/time_series_analysis.py b/time_series_analysis.py
index 8ecdea0..5bf9da6 100644
--- a/time_series_analysis.py
+++ b/time_series_analysis.py
@@ -1,11 +1,9 @@
 import numpy as np
-import matplotlib.pylab as plt
 
 
 def autocrosscorr(
     yA,
     yB=None,
-    connected=True,
     normalised=True,
 ):
     """
@@ -22,14 +20,9 @@ def autocrosscorr(
     yB: array (required for cross-correlation only)
         An array of signal values, with each row a replicate measurement
         and each column a time point.
-    connected: boolean
-        If True, find the connected correlation function, which measures the
-        correlation once the population mean has been substracted.
     normalised: boolean
-        If True and connected is True, normalise each time point by the
-        standard deviation across the replicates.
-        If True and connected is False, normalise each time by the root mean
-        square across replicates.
+        If True normalise each time point by the standard deviation across
+        the replicates.
 
     Returns
     -------
@@ -40,21 +33,16 @@ def autocrosscorr(
     # number of replicates
     nr = yA.shape[0]
     # number of time points
-    n = yA.shape[1]
-    # deviation from mean at each time point
-    if connected:
-        dyA = yA - np.nanmean(yA, axis=0).reshape((1, n))
-    else:
-        dyA = yA
+    nt = yA.shape[1]
+    # deviation from the mean, where the mean is calculated over replicates
+    # at each time point, which allows for non-stationary behaviour
+    dyA = yA - np.nanmean(yA, axis=0).reshape((1, nt))
     # standard deviation over time for each replicate
     stdA = np.sqrt(np.nanmean(dyA**2, axis=1).reshape((nr, 1)))
     if np.any(yB):
         # cross correlation
-        if connected:
-            dyB = yB - np.nanmean(yB, axis=0).reshape((1, n))
-        else:
-            dyB = yB
-        stdB = np.sqrt(np.nanmean(dyB**2, axis=0).reshape((1, n)))
+        dyB = yB - np.nanmean(yB, axis=1).reshape((1, nt))
+        stdB = np.sqrt(np.nanmean(dyB**2, axis=1).reshape((nr, 1)))
     else:
         # auto correlation
         yB = yA
@@ -63,89 +51,13 @@ def autocrosscorr(
     # calculate correlation
     corr = np.nan * np.ones(yA.shape)
     # lag r runs over time points
-    for r in np.arange(0, n):
-        prods = [dyA[:, t] * dyB[:, t + r] for t in range(n - r)]
-        corr[:, r] = np.nansum(prods, axis=0) / (n - r)
+    for r in np.arange(0, nt):
+        prods = [dyA[:, t] * dyB[:, t + r] for t in range(nt - r)]
+        corr[:, r] = np.nansum(prods, axis=0) / (nt - r)
     if normalised:
-        if connected:
-            # normalise by standard deviation
-            corr = corr / stdA / stdB
-        else:
-            # normalise by root mean square
-            corr = corr / np.sqrt(np.nanmean(yA**2, axis=1).reshape((nr, 1))
-                                  * np.nanmean(yB**2, axis=1).reshape((nr, 1)))
+        # normalise by standard deviation
+        corr = corr / stdA / stdB
     return corr
 
 
 ###
-
-
-def plot_replicatearray(
-    data,
-    t=None,
-    plotmean=True,
-    xlabel=None,
-    ylabel=None,
-    title=None,
-    grid=True,
-):
-    """
-    Plots summary statistics versus axis 1 for an array of replicates.
-
-    Parameters
-    ----------
-    data: array
-        An array of signal values, with each row a replicate measurement
-        and each column a time point.
-    t : array (optional)
-        An array of time points.
-    plotmean: boolean
-        If True, plot the mean correlation over replicates versus the lag
-        with the standard error.
-        If False, plot the median correlation and the interquartile range.
-    xlabel: string
-        Label for x-axis.
-    ylabel: string
-        Label for y-axis.
-    title: string
-        Title for plot.
-    grid: boolean
-        If True, draw grid on plot.
-    """
-    # number of time points
-    n = data.shape[1]
-    # number of replicates
-    nr = data.shape[0]
-    if not np.any(t):
-        t = np.arange(n)
-    plt.figure()
-    if plotmean:
-        # mean and standard error
-        plt.plot(t, np.nanmean(data, axis=0), "b-")
-        stderr = np.nanstd(data, axis=0) / np.sqrt(nr)
-        plt.fill_between(
-            t,
-            np.nanmean(data, axis=0) + stderr,
-            np.nanmean(data, axis=0) - stderr,
-            color="b",
-            alpha=0.2,
-        )
-    else:
-        # median and interquartile range
-        plt.plot(t, np.nanmedian(data, axis=0), "b-")
-        plt.fill_between(
-            t,
-            np.nanquantile(data, 0.25, axis=0),
-            np.nanquantile(data, 0.75, axis=0),
-            color="b",
-            alpha=0.2,
-        )
-    if xlabel:
-        plt.xlabel(xlabel)
-    if ylabel:
-        plt.ylabel(ylabel)
-    if title:
-        plt.title(title)
-    if grid:
-        plt.grid()
-    plt.show()
-- 
GitLab