Skip to content
Snippets Groups Projects
Commit cf8d2661 authored by pswain's avatar pswain
Browse files

tidied up autocrosscorr; removed plotting

parent d8c5e2db
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
import matplotlib.pylab as plt
def autocrosscorr( def autocrosscorr(
yA, yA,
yB=None, yB=None,
connected=True,
normalised=True, normalised=True,
): ):
""" """
...@@ -22,14 +20,9 @@ def autocrosscorr( ...@@ -22,14 +20,9 @@ def autocrosscorr(
yB: array (required for cross-correlation only) yB: array (required for cross-correlation only)
An array of signal values, with each row a replicate measurement An array of signal values, with each row a replicate measurement
and each column a time point. and each column a time point.
connected: boolean
If True, find the connected correlation function, which measures the
correlation once the population mean has been substracted.
normalised: boolean normalised: boolean
If True and connected is True, normalise each time point by the If True normalise each time point by the standard deviation across
standard deviation across the replicates. the replicates.
If True and connected is False, normalise each time by the root mean
square across replicates.
Returns Returns
------- -------
...@@ -40,21 +33,16 @@ def autocrosscorr( ...@@ -40,21 +33,16 @@ def autocrosscorr(
# number of replicates # number of replicates
nr = yA.shape[0] nr = yA.shape[0]
# number of time points # number of time points
n = yA.shape[1] nt = yA.shape[1]
# deviation from mean at each time point # deviation from the mean, where the mean is calculated over replicates
if connected: # at each time point, which allows for non-stationary behaviour
dyA = yA - np.nanmean(yA, axis=0).reshape((1, n)) dyA = yA - np.nanmean(yA, axis=0).reshape((1, nt))
else:
dyA = yA
# standard deviation over time for each replicate # standard deviation over time for each replicate
stdA = np.sqrt(np.nanmean(dyA**2, axis=1).reshape((nr, 1))) stdA = np.sqrt(np.nanmean(dyA**2, axis=1).reshape((nr, 1)))
if np.any(yB): if np.any(yB):
# cross correlation # cross correlation
if connected: dyB = yB - np.nanmean(yB, axis=1).reshape((1, nt))
dyB = yB - np.nanmean(yB, axis=0).reshape((1, n)) stdB = np.sqrt(np.nanmean(dyB**2, axis=1).reshape((nr, 1)))
else:
dyB = yB
stdB = np.sqrt(np.nanmean(dyB**2, axis=0).reshape((1, n)))
else: else:
# auto correlation # auto correlation
yB = yA yB = yA
...@@ -63,89 +51,13 @@ def autocrosscorr( ...@@ -63,89 +51,13 @@ def autocrosscorr(
# calculate correlation # calculate correlation
corr = np.nan * np.ones(yA.shape) corr = np.nan * np.ones(yA.shape)
# lag r runs over time points # lag r runs over time points
for r in np.arange(0, n): for r in np.arange(0, nt):
prods = [dyA[:, t] * dyB[:, t + r] for t in range(n - r)] prods = [dyA[:, t] * dyB[:, t + r] for t in range(nt - r)]
corr[:, r] = np.nansum(prods, axis=0) / (n - r) corr[:, r] = np.nansum(prods, axis=0) / (nt - r)
if normalised: if normalised:
if connected: # normalise by standard deviation
# normalise by standard deviation corr = corr / stdA / stdB
corr = corr / stdA / stdB
else:
# normalise by root mean square
corr = corr / np.sqrt(np.nanmean(yA**2, axis=1).reshape((nr, 1))
* np.nanmean(yB**2, axis=1).reshape((nr, 1)))
return corr return corr
### ###
def plot_replicatearray(
data,
t=None,
plotmean=True,
xlabel=None,
ylabel=None,
title=None,
grid=True,
):
"""
Plots summary statistics versus axis 1 for an array of replicates.
Parameters
----------
data: array
An array of signal values, with each row a replicate measurement
and each column a time point.
t : array (optional)
An array of time points.
plotmean: boolean
If True, plot the mean correlation over replicates versus the lag
with the standard error.
If False, plot the median correlation and the interquartile range.
xlabel: string
Label for x-axis.
ylabel: string
Label for y-axis.
title: string
Title for plot.
grid: boolean
If True, draw grid on plot.
"""
# number of time points
n = data.shape[1]
# number of replicates
nr = data.shape[0]
if not np.any(t):
t = np.arange(n)
plt.figure()
if plotmean:
# mean and standard error
plt.plot(t, np.nanmean(data, axis=0), "b-")
stderr = np.nanstd(data, axis=0) / np.sqrt(nr)
plt.fill_between(
t,
np.nanmean(data, axis=0) + stderr,
np.nanmean(data, axis=0) - stderr,
color="b",
alpha=0.2,
)
else:
# median and interquartile range
plt.plot(t, np.nanmedian(data, axis=0), "b-")
plt.fill_between(
t,
np.nanquantile(data, 0.25, axis=0),
np.nanquantile(data, 0.75, axis=0),
color="b",
alpha=0.2,
)
if xlabel:
plt.xlabel(xlabel)
if ylabel:
plt.ylabel(ylabel)
if title:
plt.title(title)
if grid:
plt.grid()
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment