Waveform Database Software Package (WFDB) for Python 4.1.0
(6,402 bytes)
import unittest
import matplotlib.pyplot as plt
import numpy as np
import wfdb
from wfdb.plot import plot
class TestPlotWfdb(unittest.TestCase):
"""
Tests for the wfdb.plot_wfdb function
"""
def assertAxesMatchSignal(self, axes, signal, t_divisor=1):
"""
Check that axis limits are reasonable for plotting a signal array.
Parameters
----------
axes : matplotlib.axes.Axes
An Axes object.
signal : numpy.ndarray
A one-dimensional array of sample values.
t_divisor : float, optional
The intended plotting resolution (number of samples of `signal`
per unit of the X axis.)
"""
xmin, xmax = axes.get_xlim()
tmin = 0
tmax = (len(signal) - 1) / t_divisor
# The range from tmin to tmax should fit within the plot.
self.assertLessEqual(
xmin,
tmin,
msg=f"X range is [{xmin}, {xmax}]; expected [{tmin}, {tmax}]",
)
self.assertGreaterEqual(
xmax,
tmax,
msg=f"X range is [{xmin}, {xmax}]; expected [{tmin}, {tmax}]",
)
# The padding on left and right sides should be approximately equal.
self.assertAlmostEqual(
xmin - tmin,
tmax - xmax,
delta=(tmax - tmin) / 10 + 1 / t_divisor,
msg=f"X range is [{xmin}, {xmax}]; expected [{tmin}, {tmax}]",
)
ymin, ymax = axes.get_ylim()
vmin = np.nanmin(signal)
vmax = np.nanmax(signal)
# The range from vmin to vmax should fit within the plot.
self.assertLessEqual(
ymin,
vmin,
msg=f"Y range is [{ymin}, {ymax}]; expected [{vmin}, {vmax}]",
)
self.assertGreaterEqual(
ymax,
vmax,
msg=f"Y range is [{ymin}, {ymax}]; expected [{vmin}, {vmax}]",
)
# The padding on top and bottom should be approximately equal.
self.assertAlmostEqual(
ymin - vmin,
vmax - ymax,
delta=(vmax - vmin) / 10,
msg=f"Y range is [{ymin}, {ymax}]; expected [{vmin}, {vmax}]",
)
def test_physical_smooth(self):
"""
Plot a record with physical, single-frequency data
"""
record = wfdb.rdrecord(
"sample-data/100",
sampto=1000,
physical=True,
smooth_frames=True,
)
self.assertIsNotNone(record.p_signal)
annotation = wfdb.rdann("sample-data/100", "atr", sampto=1000)
fig = wfdb.plot_wfdb(
record,
annotation,
time_units="samples",
ecg_grids="all",
return_fig=True,
)
plt.close(fig)
self.assertEqual(len(fig.axes), record.n_sig)
for ch in range(record.n_sig):
self.assertAxesMatchSignal(fig.axes[ch], record.p_signal[:, ch])
def test_digital_smooth(self):
"""
Plot a record with digital, single-frequency data
"""
record = wfdb.rdrecord(
"sample-data/drive02",
sampto=1000,
physical=False,
smooth_frames=True,
)
self.assertIsNotNone(record.d_signal)
fig = wfdb.plot_wfdb(record, time_units="seconds", return_fig=True)
plt.close(fig)
self.assertEqual(len(fig.axes), record.n_sig)
for ch in range(record.n_sig):
self.assertAxesMatchSignal(
fig.axes[ch], record.d_signal[:, ch], record.fs
)
def test_physical_multifrequency(self):
"""
Plot a record with physical, multi-frequency data
"""
record = wfdb.rdrecord(
"sample-data/wave_4",
sampto=10,
physical=True,
smooth_frames=False,
)
self.assertIsNotNone(record.e_p_signal)
fig = wfdb.plot_wfdb(record, time_units="seconds", return_fig=True)
plt.close(fig)
self.assertEqual(len(fig.axes), record.n_sig)
for ch in range(record.n_sig):
self.assertAxesMatchSignal(
fig.axes[ch],
record.e_p_signal[ch],
record.fs * record.samps_per_frame[ch],
)
def test_digital_multifrequency(self):
"""
Plot a record with digital, multi-frequency data
"""
record = wfdb.rdrecord(
"sample-data/multi-segment/041s/041s",
sampto=1000,
physical=False,
smooth_frames=False,
)
self.assertIsNotNone(record.e_d_signal)
fig = wfdb.plot_wfdb(record, time_units="seconds", return_fig=True)
plt.close(fig)
self.assertEqual(len(fig.axes), record.n_sig)
for ch in range(record.n_sig):
self.assertAxesMatchSignal(
fig.axes[ch],
record.e_d_signal[ch],
record.fs * record.samps_per_frame[ch],
)
class TestPlotInternal(unittest.TestCase):
"""
Unit tests for internal wfdb.plot.plot functions
"""
def test_get_plot_dims(self):
sampfrom = 0
sampto = 3000
record = wfdb.rdrecord(
"sample-data/100", physical=True, sampfrom=sampfrom, sampto=sampto
)
ann = wfdb.rdann(
"sample-data/100", "atr", sampfrom=sampfrom, sampto=sampto
)
sig_len, n_sig, n_annot, n_subplots = plot._get_plot_dims(
signal=record.p_signal, ann_samp=[ann.sample]
)
assert sig_len == sampto - sampfrom
assert n_sig == record.n_sig
assert n_annot == 1
assert n_subplots == record.n_sig
def test_create_figure_single_subplots(self):
n_subplots = 1
fig, axes = plot._create_figure(
n_subplots, sharex=True, sharey=True, figsize=None
)
assert fig is not None
assert axes is not None
assert len(axes) == n_subplots
def test_create_figure_multiple_subplots(self):
n_subplots = 5
fig, axes = plot._create_figure(
n_subplots, sharex=True, sharey=True, figsize=None
)
assert fig is not None
assert axes is not None
assert len(axes) == n_subplots
if __name__ == "__main__":
unittest.main()