Skip to content

Commit

Permalink
review updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanhammonds committed Feb 6, 2021
1 parent 7e31814 commit 1586d3e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
24 changes: 9 additions & 15 deletions neurodsp/plts/time_series.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Plots for time series."""

from itertools import cycle
from itertools import repeat, cycle

import numpy as np
import numpy.ma as ma
Expand All @@ -19,9 +19,9 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs):
Parameters
----------
times : 1d array, list of 1d array, or 2d array
times : 1d or 2d array, or list of 1d array
Time definition(s) for the time series to be plotted.
sigs : 1d array, list of 1d array, or 2d array
sigs : 1d or 2d array, or list of 1d array
Time series to plot.
labels : list of str, optional
Labels for each time series.
Expand All @@ -47,24 +47,18 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs):

ax = check_ax(ax, (15, 3))

n_repeats = len(sigs) if isinstance(sigs, list) or sigs.ndim == 2 else 1

# Repeat times if needed
if isinstance(times, np.ndarray) and times.ndim != 2:
times = np.tile(times, (n_repeats, 1))

# Make sigs iterable if 1D
sigs = np.reshape(sigs, (1, -1)) if not isinstance(sigs, list) and sigs.ndim == 1 else sigs
times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times
sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = np.repeat(labels, n_repeats)
labels = repeat(labels)

# If not provided, default colors for up to two signals to be black & red
if not colors and len(sigs) <= 2:
colors = ['k', 'r']
colors = np.repeat(colors, n_repeats) if not isinstance(colors, list) else cycle(colors)
colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)

for time, sig, color, label in zip(times, sigs, colors, labels):
ax.plot(time, sig, color=color, label=label)
Expand All @@ -80,9 +74,9 @@ def plot_instantaneous_measure(times, sigs, measure='phase', ax=None, **kwargs):
Parameters
----------
times : 1d array or list of 1d array
times : 1d or 2d array, or list of 1d array
Time definition(s) for the time series to be plotted.
sigs : 1d array or list of 1d array
sigs : 1d or 2d array, or list of 1d array
Time series to plot.
measure : {'phase', 'amplitude', 'frequency'}
Which kind of measure is being plotted.
Expand Down
6 changes: 6 additions & 0 deletions neurodsp/tests/plts/test_time_series.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test time series plots."""

from pytest import raises
import numpy as np

from neurodsp.tests.settings import TEST_PLOTS_PATH
from neurodsp.tests.tutils import plot_test
Expand All @@ -23,6 +24,11 @@ def test_plot_time_series(tsig):
colors=['k', 'r'], save_fig=True, file_name='test_plot_time_series.png',
file_path=TEST_PLOTS_PATH)

# Test 2D arrays
times_2d = np.vstack((times, times))
tsig_2d = np.vstack((tsig, tsig))
plot_time_series(times_2d, tsig_2d)

@plot_test
def test_plot_instantaneous_measure(tsig):

Expand Down

0 comments on commit 1586d3e

Please sign in to comment.