Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Flip specified drives in input histogram #905

Merged
8 changes: 5 additions & 3 deletions hnn_core/cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ def plot_spikes_raster(self, trial_idx=None, ax=None, show=True):
cell_response=self, trial_idx=trial_idx, ax=ax, show=show)

def plot_spikes_hist(self, trial_idx=None, ax=None, spike_types=None,
color=None, show=True, **kwargs_hist):
invert_spike_types=None, color=None, show=True,
**kwargs_hist):
"""Plot the histogram of spiking activity across trials.

Parameters
Expand Down Expand Up @@ -346,8 +347,9 @@ def plot_spikes_hist(self, trial_idx=None, ax=None, spike_types=None,
The matplotlib figure handle.
"""
return plot_spikes_hist(self, trial_idx=trial_idx, ax=ax,
spike_types=spike_types, color=color,
show=show, **kwargs_hist)
spike_types=spike_types,
invert_spike_types=invert_spike_types,
color=color, show=show, **kwargs_hist)

def to_dict(self):
"""Return cell response as a dict object.
Expand Down
60 changes: 60 additions & 0 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,63 @@ def test_network_plotter_export(tmp_path, setup_net):
assert path_out.is_file()

plt.close('all')


def test_invert_spike_types(setup_net):
"""Test plotting a histogram with an inverted external drive"""
net = setup_net

weights_ampa = {'L2_pyramidal': 0.15, 'L5_pyramidal': 0.15}
syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}

net.add_evoked_drive(
'evdist1', mu=63.53, sigma=3.85, numspikes=1,
weights_ampa=weights_ampa, location='distal',
synaptic_delays=syn_delays, event_seed=274
)

net.add_evoked_drive(
'evprox1', mu=26.61, sigma=2.47, numspikes=1,
weights_ampa=weights_ampa, location='proximal',
synaptic_delays=syn_delays, event_seed=274
)

_ = simulate_dipole(net, dt=0.5, tstop=80., n_trials=1)

# test string input
net.cell_response.plot_spikes_hist(
spike_types=['evprox', 'evdist'],
invert_spike_types='evdist',
show=False,
)

# test case where all inputs are flipped
net.cell_response.plot_spikes_hist(
spike_types=['evprox', 'evdist'],
invert_spike_types=['evprox', 'evdist'],
show=False,
)

# test case where some inputs are flipped
fig = net.cell_response.plot_spikes_hist(
spike_types=['evprox', 'evdist'],
invert_spike_types=['evdist'],
show=False,
)

# check that there are 2 y axes
assert len(fig.axes) == 2

# check for equivalency of both y axes
y1 = fig.axes[0]
y2 = fig.axes[1]

y1_max = max(y1.get_ylim())
y2_max = max(y2.get_ylim())

assert y1_max == y2_max

# check that data are plotted
assert y1_max > 1

plt.close('all')
70 changes: 66 additions & 4 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None,


def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
color=None, show=True, **kwargs_hist):
color=None, invert_spike_types=None, show=True,
**kwargs_hist):
"""Plot the histogram of spiking activity across trials.

Parameters
Expand Down Expand Up @@ -369,6 +370,16 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
Valid strings also include leading characters of spike types

| Ex: ``'ev'`` is equivalent to ``['evdist', 'evprox']``
invert_spike_types: string | list | None
String input of a valid spike type to be mirrored about the y axis

| Ex: ``'evdist'``, ``'evprox'``, ...

List of valid spike types to be mirrored about the y axis

| Ex: ``['evdist', 'evprox']``

If None, all input spike types are plotted on the same y axis
color : str | list of str | dict | None
Input defining colors of plotted histograms. If str, all
histograms plotted with same color. If list of str provided,
Expand Down Expand Up @@ -490,19 +501,70 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,
spike_type_times[spike_label].extend(
spike_times[spike_types_mask[spike_type]])

if invert_spike_types is None:
invert_spike_types = list()
else:
if not isinstance(invert_spike_types, (str, list)):
raise TypeError(
"'invert_spike_types' must be a string or a list of strings")
if isinstance(invert_spike_types, str):
invert_spike_types = [invert_spike_types]

# Check that spike types to invert are correctly specified
unique_inputs = set(spike_labels.values())
unique_invert_inputs = set(invert_spike_types)
check_intersection = unique_invert_inputs.intersection(unique_inputs)
if not check_intersection == unique_invert_inputs:
raise ValueError(
"Elements of 'invert_spike_types' must"
"map to valid input types"
)

# Initialize secondary axis
ax1 = None

# Plot aggregated spike_times
for spike_label, plot_data in spike_type_times.items():
hist_color = spike_color[spike_label]
ax.hist(plot_data, bins,
label=spike_label, color=hist_color, **kwargs_hist)

# Plot on the primary y-axis
if spike_label not in invert_spike_types:
ax.hist(plot_data, bins,
label=spike_label, color=hist_color, **kwargs_hist)
# Plot on secondary y-axis
else:
if ax1 is None:
ax1 = ax.twinx()
ax1.hist(plot_data, bins,
label=spike_label, color=hist_color, **kwargs_hist)

# Set the y-limits based on the maximum across both axes
if ax1 is not None:
ax_ylim = ax.get_ylim()[1]
ax1_ylim = ax1.get_ylim()[1]

y_max = max(ax_ylim, ax1_ylim)
ax.set_ylim(0, y_max)
ax1.set_ylim(0, y_max)
ax1.invert_yaxis()

if len(cell_response.times) > 0:
ax.set_xlim(left=0, right=cell_response.times[-1])
else:
ax.set_xlim(left=0)

ax.set_ylabel("Counts")
ax.legend()

if ax1 is not None:
# Combine legends
handles, labels = ax.get_legend_handles_labels()
handles1, labels1 = ax1.get_legend_handles_labels()
handles.extend(handles1)
labels.extend(labels1)

ax.legend(handles, labels, loc='upper left')
else:
ax.legend()

plt_show(show)
return ax.get_figure()
Expand Down
Loading