Skip to content
Snippets Groups Projects
Commit 8e16964d authored by Arin Wongprommoon's avatar Arin Wongprommoon
Browse files

[routines/heatmap] Refactor: change how x-ticks are defined

When the user defines xtick_step and sampling_period so that
sampling_period is not divisible by xtick_step, an error is shown.
Specifically:

    int(np.where(time_axis == label)[0].item()) for label in xticklabels
ValueError: can only convert an array of size 1 to a Python scalar

The routine assumed that all instances of label (line 66) can be found
in time_axis (defined in line 61).  This is not true when
sampling_period is not divisible by xtick_step -- as an example:

Suppose time_axis is: [0, 1, 2, 3, 4, 5, 6, 7]
And xtick_step is 0.6.
xtick_min is thus defined as 0 and xtick_max is thus defined as 7.2, as
expected.
xticklabels is then defined as [0, 0.6, 1.2, 1.8, ... 7.2]
In the list comprehension that defines self.xticks, label first acquires
the value of 0.  This results in no errors as time_axis contains a 0.
However, when label then acquires the value of 0.6, the ValueError is
returned because time_axis does not contain a 0.6.

Fortunately, Axes.xaxis.set_major_locator() does all this for me, so I
scrapped the original method of defining horizontal axis for this.  I
had originally written the lines to define the horizontal axis when I
didn't know that set_major_locator() existed, and tried to define the
x-ticks manually.

These changes should not affect the behaviour of heatmap apart from
cases that cause the error.

This commit fixes issue #19.
parent 77ab8e65
No related branches found
No related tags found
No related merge requests found
......@@ -34,7 +34,9 @@ class _HeatmapPlotter(BasePlotter):
self.robust = robust
# Define some labels
self.colorbarlabel = "Normalised " + self.trace_name + " fluorescence (AU)"
self.colorbarlabel = (
"Normalised " + self.trace_name + " fluorescence (AU)"
)
self.ylabel = "Cell"
# Scale
......@@ -57,21 +59,11 @@ class _HeatmapPlotter(BasePlotter):
self.vmin = None
self.vmax = None
# Define horizontal axis ticks and labels
time_axis = self.sampling_period * self.trace_df.columns.to_numpy()
xtick_min = self.xtick_step * np.ceil(np.min(time_axis) / self.xtick_step)
xtick_max = self.xtick_step * np.ceil(np.max(time_axis) / self.xtick_step)
xticklabels = np.arange(xtick_min, xtick_max, self.xtick_step)
self.xticks = [
int(np.where(time_axis == label)[0].item()) for label in xticklabels
]
self.xticklabels = list(map(str, xticklabels.tolist()))
def plot(self, ax, cax):
"""Draw the heatmap on the provided Axes."""
super().plot(ax)
ax.set_xticks(self.xticks)
ax.set_xticklabels(self.xticklabels)
# Horizontal axis labels as multiples of xtick_step
ax.xaxis.set_major_locator(plt.MultipleLocator(self.xtick_step))
# Draw trace heatmap
trace_heatmap = ax.imshow(
......@@ -86,7 +78,9 @@ class _HeatmapPlotter(BasePlotter):
if self.births_df is not None:
# Must be masked array for transparency
births_array = self.births_df.to_numpy()
births_heatmap_mask = np.ma.masked_where(births_array == 0, births_array)
births_heatmap_mask = np.ma.masked_where(
births_array == 0, births_array
)
# Overlay
births_heatmap = ax.imshow(
births_heatmap_mask,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment