diff --git a/neurodsp/plts/time_series.py b/neurodsp/plts/time_series.py index 088b104e..58f9d1ae 100644 --- a/neurodsp/plts/time_series.py +++ b/neurodsp/plts/time_series.py @@ -19,9 +19,9 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs): Parameters ---------- - times : 1d array or list of 1d array + times : 1d or 2d array, or list of 1d array Time definition(s) for the time series to be plotted. - sigs : 1d array or list of 1d array + sigs : 1d or 2d array, or list of 1d array Time series to plot. labels : list of str, optional Labels for each time series. @@ -47,8 +47,8 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs): ax = check_ax(ax, (15, 3)) - times = repeat(times) if isinstance(times, np.ndarray) else times - sigs = [sigs] if isinstance(sigs, np.ndarray) else sigs + times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times + sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs if labels is not None: labels = [labels] if not isinstance(labels, list) else labels @@ -74,9 +74,9 @@ def plot_instantaneous_measure(times, sigs, measure='phase', ax=None, **kwargs): Parameters ---------- - times : 1d array or list of 1d array + times : 1d or 2d array, or list of 1d array Time definition(s) for the time series to be plotted. - sigs : 1d array or list of 1d array + sigs : 1d or 2d array, or list of 1d array Time series to plot. measure : {'phase', 'amplitude', 'frequency'} Which kind of measure is being plotted. diff --git a/neurodsp/tests/plts/test_time_series.py b/neurodsp/tests/plts/test_time_series.py index e1e13762..0e112454 100644 --- a/neurodsp/tests/plts/test_time_series.py +++ b/neurodsp/tests/plts/test_time_series.py @@ -1,6 +1,7 @@ """Test time series plots.""" from pytest import raises +import numpy as np from neurodsp.tests.settings import TEST_PLOTS_PATH from neurodsp.tests.tutils import plot_test @@ -23,6 +24,11 @@ def test_plot_time_series(tsig): colors=['k', 'r'], save_fig=True, file_name='test_plot_time_series.png', file_path=TEST_PLOTS_PATH) + # Test 2D arrays + times_2d = np.vstack((times, times)) + tsig_2d = np.vstack((tsig, tsig)) + plot_time_series(times_2d, tsig_2d) + @plot_test def test_plot_instantaneous_measure(tsig):