Skip to content
Snippets Groups Projects
Commit ed2a126e authored by Arin Wongprommoon's avatar Arin Wongprommoon
Browse files

[routines/heatamp] Incorporates unit scaling into x-axis ticks

Heatmap added x-ticks based on the time points, not absolute time.

When a DataFrame is passed into matplotlib.imshow, the Axes no longer
cares about the column names.  We want the plot to take into account the
sampling period AND the unit scaling if the user specifies one.

I copied over the method I implemented for boxplot.py -- it essentially
'tricks' matplotlib by redefining the labels rather than change the
time-axis values (see
https://stackoverflow.com/questions/10171618/changing-plot-scale-by-a-factor-in-matplotlib).
matplotlib's xscale does not support simple linear re-scaling; in any
case, having time-axis values does not make sense for imshow.

This commit addresses issue #20.
parent bf18b543
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib import cm, colors from matplotlib import cm, colors, ticker
from postprocessor.core.processes.standardscaler import standardscaler from postprocessor.core.processes.standardscaler import standardscaler
from postprocessor.routines.plottingabc import BasePlotter from postprocessor.routines.plottingabc import BasePlotter
...@@ -17,14 +17,14 @@ class _HeatmapPlotter(BasePlotter): ...@@ -17,14 +17,14 @@ class _HeatmapPlotter(BasePlotter):
trace_name, trace_name,
births_df, births_df,
cmap, cmap,
sampling_period, unit_scaling,
xtick_step, xtick_step,
scale, scale,
robust, robust,
xlabel, xlabel,
plot_title, plot_title,
): ):
super().__init__(trace_name, sampling_period, xlabel, plot_title) super().__init__(trace_name, unit_scaling, xlabel, plot_title)
# Define attributes from arguments # Define attributes from arguments
self.trace_df = trace_df self.trace_df = trace_df
self.births_df = births_df self.births_df = births_df
...@@ -59,12 +59,17 @@ class _HeatmapPlotter(BasePlotter): ...@@ -59,12 +59,17 @@ class _HeatmapPlotter(BasePlotter):
self.vmin = None self.vmin = None
self.vmax = None self.vmax = None
# Define horizontal axis ticks and labels
# hacky! -- redefine column names
trace_df.columns = trace_df.columns * self.unit_scaling
self.fmt = ticker.FuncFormatter(
lambda x, pos: "{0:g}".format(x * self.unit_scaling)
)
def plot(self, ax, cax): def plot(self, ax, cax):
"""Draw the heatmap on the provided Axes.""" """Draw the heatmap on the provided Axes."""
super().plot(ax) super().plot(ax)
# Horizontal axis labels as multiples of xtick_step ax.xaxis.set_major_formatter(self.fmt)
ax.xaxis.set_major_locator(plt.MultipleLocator(self.xtick_step))
# Draw trace heatmap # Draw trace heatmap
trace_heatmap = ax.imshow( trace_heatmap = ax.imshow(
self.trace_scaled, self.trace_scaled,
...@@ -73,7 +78,11 @@ class _HeatmapPlotter(BasePlotter): ...@@ -73,7 +78,11 @@ class _HeatmapPlotter(BasePlotter):
vmin=self.vmin, vmin=self.vmin,
vmax=self.vmax, vmax=self.vmax,
) )
# Horizontal axis labels as multiples of xtick_step, taking
# into account unit scaling
ax.xaxis.set_major_locator(
ticker.MultipleLocator(self.xtick_step / self.unit_scaling)
)
# Overlay births, if present # Overlay births, if present
if self.births_df is not None: if self.births_df is not None:
# Must be masked array for transparency # Must be masked array for transparency
...@@ -86,7 +95,6 @@ class _HeatmapPlotter(BasePlotter): ...@@ -86,7 +95,6 @@ class _HeatmapPlotter(BasePlotter):
births_heatmap_mask, births_heatmap_mask,
interpolation="none", interpolation="none",
) )
# Draw colour bar # Draw colour bar
colorbar = ax.figure.colorbar( colorbar = ax.figure.colorbar(
mappable=trace_heatmap, cax=cax, ax=ax, label=self.colorbarlabel mappable=trace_heatmap, cax=cax, ax=ax, label=self.colorbarlabel
...@@ -98,7 +106,7 @@ def heatmap( ...@@ -98,7 +106,7 @@ def heatmap(
trace_name, trace_name,
births_df=None, births_df=None,
cmap=cm.RdBu, cmap=cm.RdBu,
sampling_period=5, unit_scaling=1,
xtick_step=60, xtick_step=60,
scale=True, scale=True,
robust=True, robust=True,
...@@ -120,8 +128,8 @@ def heatmap( ...@@ -120,8 +128,8 @@ def heatmap(
0 or 1. 0 or 1.
cmap : matplotlib ColorMap cmap : matplotlib ColorMap
Colour map for heatmap. Colour map for heatmap.
sampling_period : int or float unit_scaling : int or float
Sampling period, in unit time. Unit scaling factor, e.g. 1/60 to convert minutes to hours.
xtick_step : int or float xtick_step : int or float
Interval length, in unit time, to draw x axis ticks. Interval length, in unit time, to draw x axis ticks.
scale : bool scale : bool
...@@ -154,7 +162,7 @@ def heatmap( ...@@ -154,7 +162,7 @@ def heatmap(
trace_name, trace_name,
births_df, births_df,
cmap, cmap,
sampling_period, unit_scaling,
xtick_step, xtick_step,
scale, scale,
robust, robust,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment