diff --git a/postprocessor/routines/boxplot.py b/postprocessor/routines/boxplot.py index 3a3523b55448e8868bb34def90d182f9f34347e7..ea5e9ba17bbb66bd2c2f403a3a5392a4c8813d6b 100644 --- a/postprocessor/routines/boxplot.py +++ b/postprocessor/routines/boxplot.py @@ -14,13 +14,13 @@ class _BoxplotPlotter(BasePlotter): self, trace_df, trace_name, - sampling_period, + unit_scaling, box_color, xtick_step, xlabel, plot_title, ): - super().__init__(trace_name, sampling_period, xlabel, plot_title) + super().__init__(trace_name, unit_scaling, xlabel, plot_title) # Define attributes from arguments self.trace_df = trace_df self.box_color = box_color @@ -31,9 +31,11 @@ class _BoxplotPlotter(BasePlotter): # Define horizontal axis ticks and labels # hacky! -- redefine column names - trace_df.columns = trace_df.columns * self.sampling_period + trace_df.columns = trace_df.columns * self.unit_scaling self.fmt = ticker.FuncFormatter( - lambda x, pos: "{0:g}".format(x / (self.xtick_step / self.sampling_period)) + lambda x, pos: "{0:g}".format( + x / (self.xtick_step / self.unit_scaling) + ) ) def plot(self, ax): @@ -47,24 +49,59 @@ class _BoxplotPlotter(BasePlotter): ax=ax, ) ax.xaxis.set_major_locator( - ticker.MultipleLocator(self.xtick_step / self.sampling_period) + ticker.MultipleLocator(self.xtick_step / self.unit_scaling) ) def boxplot( trace_df, trace_name, - sampling_period=5, + unit_scaling=1, box_color="b", xtick_step=60, xlabel="Time (min)", plot_title="", ax=None, ): + """Draw series of boxplots from an array of time series of traces + + Draw series of boxplots from an array of time series of traces, showing the + distribution of values at each time point over time. + + Parameters + ---------- + trace_df : pandas.DataFrame + Time series of traces (rows = cells, columns = time points). + trace_name : string + Name of trace being plotted, e.g. 'flavin'. + unit_scaling : int or float + Unit scaling factor, e.g. 1/60 to convert minutes to hours. + box_color : string + matplolib colour string, specifies colour of boxes in boxplot + xtick_step : int or float + Interval length, in unit time, to draw x axis ticks. + xlabel : string + x axis label. + plot_title : string + Plot title. + ax : matplotlib Axes + Axes in which to draw the plot, otherwise use the currently active Axes. + + Returns + ------- + ax : matplotlib Axes + Axes object with the heatmap. + + Examples + -------- + FIXME: Add docs. + + """ + plotter = _BoxplotPlotter( trace_df, trace_name, - sampling_period, + unit_scaling, box_color, xtick_step, xlabel, diff --git a/postprocessor/routines/heatmap.py b/postprocessor/routines/heatmap.py index bacf514998255bb46cba0e8bbeed90aca00e04aa..d2a8927c607150638ca2cfc182bce0a82feeb449 100644 --- a/postprocessor/routines/heatmap.py +++ b/postprocessor/routines/heatmap.py @@ -2,7 +2,7 @@ import numpy as np import matplotlib.pyplot as plt -from matplotlib import cm, colors +from matplotlib import cm, colors, ticker from postprocessor.core.processes.standardscaler import standardscaler from postprocessor.routines.plottingabc import BasePlotter @@ -17,14 +17,14 @@ class _HeatmapPlotter(BasePlotter): trace_name, births_df, cmap, - sampling_period, + unit_scaling, xtick_step, scale, robust, xlabel, plot_title, ): - super().__init__(trace_name, sampling_period, xlabel, plot_title) + super().__init__(trace_name, unit_scaling, xlabel, plot_title) # Define attributes from arguments self.trace_df = trace_df self.births_df = births_df @@ -59,12 +59,17 @@ class _HeatmapPlotter(BasePlotter): self.vmin = None self.vmax = None + # Define horizontal axis ticks and labels + # hacky! -- redefine column names + trace_df.columns = trace_df.columns * self.unit_scaling + self.fmt = ticker.FuncFormatter( + lambda x, pos: "{0:g}".format(x * self.unit_scaling) + ) + def plot(self, ax, cax): """Draw the heatmap on the provided Axes.""" super().plot(ax) - # Horizontal axis labels as multiples of xtick_step - ax.xaxis.set_major_locator(plt.MultipleLocator(self.xtick_step)) - + ax.xaxis.set_major_formatter(self.fmt) # Draw trace heatmap trace_heatmap = ax.imshow( self.trace_scaled, @@ -73,7 +78,11 @@ class _HeatmapPlotter(BasePlotter): vmin=self.vmin, vmax=self.vmax, ) - + # Horizontal axis labels as multiples of xtick_step, taking + # into account unit scaling + ax.xaxis.set_major_locator( + ticker.MultipleLocator(self.xtick_step / self.unit_scaling) + ) # Overlay births, if present if self.births_df is not None: # Must be masked array for transparency @@ -86,7 +95,6 @@ class _HeatmapPlotter(BasePlotter): births_heatmap_mask, interpolation="none", ) - # Draw colour bar colorbar = ax.figure.colorbar( mappable=trace_heatmap, cax=cax, ax=ax, label=self.colorbarlabel @@ -98,7 +106,7 @@ def heatmap( trace_name, births_df=None, cmap=cm.RdBu, - sampling_period=5, + unit_scaling=1, xtick_step=60, scale=True, robust=True, @@ -120,8 +128,8 @@ def heatmap( 0 or 1. cmap : matplotlib ColorMap Colour map for heatmap. - sampling_period : int or float - Sampling period, in unit time. + unit_scaling : int or float + Unit scaling factor, e.g. 1/60 to convert minutes to hours. xtick_step : int or float Interval length, in unit time, to draw x axis ticks. scale : bool @@ -154,7 +162,7 @@ def heatmap( trace_name, births_df, cmap, - sampling_period, + unit_scaling, xtick_step, scale, robust, diff --git a/postprocessor/routines/histogram.py b/postprocessor/routines/histogram.py index 4533d66c1bdcf4ba700dbd5d052abf10590703f2..6d3957463c446e1ebe37a4ff2df50d5ff99688d3 100644 --- a/postprocessor/routines/histogram.py +++ b/postprocessor/routines/histogram.py @@ -12,7 +12,6 @@ class _HistogramPlotter: values, label, color, - sampling_period, binsize, lognormal, lognormal_base, @@ -24,7 +23,6 @@ class _HistogramPlotter: self.values = values self.label = label self.color = color - self.sampling_period = sampling_period self.binsize = binsize self.lognormal = lognormal self.lognormal_base = lognormal_base @@ -39,7 +37,9 @@ class _HistogramPlotter: if self.lognormal: self.bins = np.logspace( 0, - np.ceil(np.log(np.nanmax(values)) / np.log(self.lognormal_base)), + np.ceil( + np.log(np.nanmax(values)) / np.log(self.lognormal_base) + ), base=self.lognormal_base, ) # number of bins will be 50 by default, as it's the default in np.logspace else: @@ -77,7 +77,6 @@ def histogram( values, label, color="b", - sampling_period=5, binsize=5, lognormal=False, lognormal_base=10, @@ -96,8 +95,6 @@ def histogram( Name of value being plotting, e.g. cell division cycle length. color : string Colour of bars. - sampling_period : float - Sampling period, in unit time. binsize : float Bin size. lognormal : bool @@ -127,7 +124,6 @@ def histogram( values, label, color, - sampling_period, binsize, lognormal, lognormal_base, diff --git a/postprocessor/routines/mean_plot.py b/postprocessor/routines/mean_plot.py index a3af66d6ac78729f62692f78c4051eb3e1699ee5..e25944830950affdb64019e554e14322999e4130 100644 --- a/postprocessor/routines/mean_plot.py +++ b/postprocessor/routines/mean_plot.py @@ -13,7 +13,7 @@ class _MeanPlotter(BasePlotter): self, trace_df, trace_name, - sampling_period, + unit_scaling, label, mean_color, error_color, @@ -22,7 +22,7 @@ class _MeanPlotter(BasePlotter): ylabel, plot_title, ): - super().__init__(trace_name, sampling_period, xlabel, plot_title) + super().__init__(trace_name, unit_scaling, xlabel, plot_title) # Define attributes from arguments self.trace_df = trace_df self.label = label @@ -35,7 +35,7 @@ class _MeanPlotter(BasePlotter): self.ylabel = ylabel # Mean and standard error - self.trace_time = np.array(self.trace_df.columns) * self.sampling_period + self.trace_time = np.array(self.trace_df.columns) * self.unit_scaling self.mean_ts = self.trace_df.mean(axis=0) self.stderr = self.trace_df.std(axis=0) / np.sqrt(len(self.trace_df)) @@ -64,7 +64,7 @@ class _MeanPlotter(BasePlotter): def mean_plot( trace_df, trace_name="flavin", - sampling_period=5, + unit_scaling=1, label="wild type", mean_color="b", error_color="lightblue", @@ -82,8 +82,8 @@ def mean_plot( Time series of traces (rows = cells, columns = time points). trace_name : string Name of trace being plotted, e.g. 'flavin'. - sampling_period : int or float - Sampling period, in unit time. + unit_scaling : int or float + Unit scaling factor, e.g. 1/60 to convert minutes to hours. label : string Name of group being plotted, e.g. a strain name. mean_color : string @@ -109,7 +109,7 @@ def mean_plot( plotter = _MeanPlotter( trace_df, trace_name, - sampling_period, + unit_scaling, label, mean_color, error_color, diff --git a/postprocessor/routines/median_plot.py b/postprocessor/routines/median_plot.py index 4b3205c576417f198d62508764dbb433af4919f6..573263b375c92528910a3f30f0bc0ec2f15a86ce 100644 --- a/postprocessor/routines/median_plot.py +++ b/postprocessor/routines/median_plot.py @@ -13,7 +13,7 @@ class _MedianPlotter(BasePlotter): self, trace_df, trace_name, - sampling_period, + unit_scaling, label, median_color, error_color, @@ -22,7 +22,7 @@ class _MedianPlotter(BasePlotter): ylabel, plot_title, ): - super().__init__(trace_name, sampling_period, xlabel, plot_title) + super().__init__(trace_name, unit_scaling, xlabel, plot_title) # Define attributes from arguments self.trace_df = trace_df self.label = label @@ -35,7 +35,7 @@ class _MedianPlotter(BasePlotter): self.ylabel = ylabel # Median and interquartile range - self.trace_time = np.array(self.trace_df.columns) * self.sampling_period + self.trace_time = np.array(self.trace_df.columns) * self.unit_scaling self.median_ts = self.trace_df.median(axis=0) self.quartile1_ts = self.trace_df.quantile(0.25) self.quartile3_ts = self.trace_df.quantile(0.75) @@ -65,7 +65,7 @@ class _MedianPlotter(BasePlotter): def median_plot( trace_df, trace_name="flavin", - sampling_period=5, + unit_scaling=1, label="wild type", median_color="b", error_color="lightblue", @@ -83,8 +83,8 @@ def median_plot( Time series of traces (rows = cells, columns = time points). trace_name : string Name of trace being plotted, e.g. 'flavin'. - sampling_period : int or float - Sampling period, in unit time. + unit_scaling : int or float + Unit scaling factor, e.g. 1/60 to convert minutes to hours. label : string Name of group being plotted, e.g. a strain name. median_color : string @@ -110,7 +110,7 @@ def median_plot( plotter = _MedianPlotter( trace_df, trace_name, - sampling_period, + unit_scaling, label, median_color, error_color, diff --git a/postprocessor/routines/plottingabc.py b/postprocessor/routines/plottingabc.py index 1990df9f8854b7f4880b473b610218ac45cbe046..97b89aa7e95e01e5b937003233f8ba3d37da7f52 100644 --- a/postprocessor/routines/plottingabc.py +++ b/postprocessor/routines/plottingabc.py @@ -6,10 +6,10 @@ from abc import ABC class BasePlotter(ABC): """Base class for plotting handler classes""" - def __init__(self, trace_name, sampling_period, xlabel, plot_title): + def __init__(self, trace_name, unit_scaling, xlabel, plot_title): """Common attributes""" self.trace_name = trace_name - self.sampling_period = sampling_period + self.unit_scaling = unit_scaling self.xlabel = xlabel self.ylabel = None diff --git a/postprocessor/routines/single_birth_plot.py b/postprocessor/routines/single_birth_plot.py index 6d1a405cf3223080db98d670d295962f56ec9ee3..671cfec04d87b5b33b46731de1c86b10b7e96a45 100644 --- a/postprocessor/routines/single_birth_plot.py +++ b/postprocessor/routines/single_birth_plot.py @@ -14,7 +14,7 @@ class _SingleBirthPlotter(_SinglePlotter): trace_values, trace_name, birth_mask, - sampling_period, + unit_scaling, trace_color, birth_color, trace_linestyle, @@ -27,7 +27,7 @@ class _SingleBirthPlotter(_SinglePlotter): trace_timepoints, trace_values, trace_name, - sampling_period, + unit_scaling, trace_color, trace_linestyle, xlabel, @@ -40,7 +40,7 @@ class _SingleBirthPlotter(_SinglePlotter): def plot(self, ax): """Draw the line plots on the provided Axes.""" - trace_time = self.trace_timepoints * self.sampling_period + trace_time = self.trace_timepoints * self.unit_scaling super().plot(ax) birth_mask_bool = self.birth_mask.astype(bool) for occurence, birth_time in enumerate(trace_time[birth_mask_bool]): @@ -62,7 +62,7 @@ def single_birth_plot( trace_values, trace_name="flavin", birth_mask=None, - sampling_period=5, + unit_scaling=1, trace_color="b", birth_color="k", trace_linestyle="-", @@ -84,8 +84,8 @@ def single_birth_plot( birth_mask : array_like Mask to indicate where births are. Expect values of '0' and '1' or 'False' and 'True' in the elements. - sampling_period : int or float - Sampling period, in unit time. + unit_scaling : int or float + Unit scaling factor, e.g. 1/60 to convert minutes to hours. trace_color : string matplotlib colour string for the trace birth_color : string @@ -116,7 +116,7 @@ def single_birth_plot( trace_values, trace_name, birth_mask, - sampling_period, + unit_scaling, trace_color, birth_color, trace_linestyle, diff --git a/postprocessor/routines/single_plot.py b/postprocessor/routines/single_plot.py index 11ee940609e4b6c524ad2f31da494137e628a238..68e7d76019c2bf2f3dd0dc3a6820a96597514494 100644 --- a/postprocessor/routines/single_plot.py +++ b/postprocessor/routines/single_plot.py @@ -13,13 +13,13 @@ class _SinglePlotter(BasePlotter): trace_timepoints, trace_values, trace_name, - sampling_period, + unit_scaling, trace_color, trace_linestyle, xlabel, plot_title, ): - super().__init__(trace_name, sampling_period, xlabel, plot_title) + super().__init__(trace_name, unit_scaling, xlabel, plot_title) # Define attributes from arguments self.trace_timepoints = trace_timepoints self.trace_values = trace_values @@ -33,7 +33,7 @@ class _SinglePlotter(BasePlotter): """Draw the line plot on the provided Axes.""" super().plot(ax) ax.plot( - self.trace_timepoints * self.sampling_period, + self.trace_timepoints * self.unit_scaling, self.trace_values, color=self.trace_color, linestyle=self.trace_linestyle, @@ -45,7 +45,7 @@ def single_plot( trace_timepoints, trace_values, trace_name="flavin", - sampling_period=5, + unit_scaling=1, trace_color="b", trace_linestyle="-", xlabel="Time (min)", @@ -62,8 +62,8 @@ def single_plot( Trace to plot. trace_name : string Name of trace being plotted, e.g. 'flavin'. - sampling_period : int or float - Sampling period, in unit time. + unit_scaling : int or float + Unit scaling factor, e.g. 1/60 to convert minutes to hours. trace_color : string matplotlib colour string, specifies colour of line plot. trace_linestyle : string @@ -89,7 +89,7 @@ def single_plot( trace_timepoints, trace_values, trace_name, - sampling_period, + unit_scaling, trace_color, trace_linestyle, xlabel,