Skip to content

Commit

Permalink
Added support for external drives
Browse files Browse the repository at this point in the history
  • Loading branch information
raj1701 committed May 30, 2023
1 parent 490a187 commit 0f4513e
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 8 deletions.
100 changes: 92 additions & 8 deletions hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def __eq__(self, other):
return False

# Check gid_ranges
if not (self.gid_ranges.keys() == other.cell_types.keys()):
if not (self.gid_ranges.keys() == other.gid_ranges.keys()):
return False
for key in self.gid_ranges.keys():
if not (self.gid_ranges[key] == other.gid_ranges[key]):
Expand All @@ -443,6 +443,13 @@ def __eq__(self, other):
# Check cell_response
self.cell_response == other.cell_response

# Check external drives
if not (self.external_drives.keys() == other.external_drives.keys()):
return False
for key in self.external_drives.keys():
if not (self.external_drives[key] == other.external_drives[key]):
return False

return True

def set_cell_positions(self, *, inplane_distance=None,
Expand Down Expand Up @@ -610,6 +617,11 @@ def add_evoked_drive(self, name, *, mu, sigma, numspikes, location,
drive['conn_seed'] = conn_seed
drive['dynamics'] = dict(mu=mu, sigma=sigma, numspikes=numspikes)
drive['events'] = list()
# Need to save this information
drive['weights_ampa'] = weights_ampa
drive['weights_nmda'] = weights_nmda
drive['synaptic_delays'] = synaptic_delays
drive['probability'] = probability

self._attach_drive(name, drive, weights_ampa, weights_nmda, location,
space_constant, synaptic_delays,
Expand Down Expand Up @@ -719,6 +731,11 @@ def add_poisson_drive(self, name, *, tstart=0, tstop=None, rate_constant,
drive['dynamics'] = dict(tstart=tstart, tstop=tstop,
rate_constant=rate_constant)
drive['events'] = list()
# Need to save this information
drive['weights_ampa'] = weights_ampa
drive['weights_nmda'] = weights_nmda
drive['synaptic_delays'] = synaptic_delays
drive['probability'] = probability

self._attach_drive(name, drive, weights_ampa, weights_nmda, location,
space_constant, synaptic_delays,
Expand Down Expand Up @@ -825,6 +842,11 @@ def add_bursty_drive(self, name, *, tstart=0, tstart_std=0, tstop=None,
burst_rate=burst_rate, burst_std=burst_std,
numspikes=numspikes, spike_isi=spike_isi)
drive['events'] = list()
# Need to save this information
drive['weights_ampa'] = weights_ampa
drive['weights_nmda'] = weights_nmda
drive['synaptic_delays'] = synaptic_delays
drive['probability'] = probability

self._attach_drive(name, drive, weights_ampa, weights_nmda, location,
space_constant, synaptic_delays,
Expand Down Expand Up @@ -1371,7 +1393,6 @@ def plot_cells(self, ax=None, show=True):
def write(self, fname):
net_data = dict()
cell_types_data = dict()
# print(self.cell_types)
for key in self.cell_types:
cell_types_data[key] = _get_cell_as_dict(self.cell_types[key])
net_data['cell_types'] = cell_types_data
Expand All @@ -1391,6 +1412,11 @@ def write(self, fname):
net_data['cell_response'] = (_get_cell_response_as_dict
(self.cell_response))
# Write External drives
external_drives_data = dict()
for key in self.external_drives.keys():
external_drives_data[key] = (_get_external_drive_as_dict
(self.external_drives[key]))
net_data['external_drives'] = external_drives_data
# Write External biases
# Write connectivity
# Write rec arrays
Expand Down Expand Up @@ -1447,18 +1473,25 @@ def _get_cell_response_as_dict(cell_response):
return cell_response_data


def _get_external_drive_as_dict(drive):
drive_data = dict()
for key in drive.keys():
# Cannot store sets with hdf5
if isinstance(drive[key], set):
drive_data[key] = list(drive[key])
else:
drive_data[key] = drive[key]
return drive_data


def _read_cell_types(cell_types_data):
cell_types = dict()
# print(cell_types_data)
for cell_name in cell_types_data:
# print(cell_name)
cell_data = cell_types_data[cell_name]
sections = dict()
sections_data = cell_data['sections']
for section_name in sections_data:
section_data = sections_data[section_name]
# print(section_name)
# print(section_data)
sections[section_name] = Section(L=section_data['L'],
diam=section_data['diam'],
cm=section_data['cm'],
Expand All @@ -1480,7 +1513,7 @@ def _read_cell_types(cell_types_data):
cell_types[cell_name].vsec = cell_data['vsec']
cell_types[cell_name].isec = cell_data['isec']
cell_types[cell_name].tonic_biases = cell_data['tonic_biases']
# print(cell_types)

return cell_types


Expand All @@ -1494,9 +1527,57 @@ def _read_cell_response(cell_response_data):
return cell_response


def _read_external_drive(net, drive_data):
if drive_data['type'] == 'evoked':
# Skipped n_drive_cells here
net.add_evoked_drive(name=drive_data['name'],
mu=drive_data['dynamics']['mu'],
sigma=drive_data['dynamics']['sigma'],
numspikes=drive_data['dynamics']['numspikes'],
location=drive_data['location'],
cell_specific=drive_data['cell_specific'],
weights_ampa=drive_data['weights_ampa'],
weights_nmda=drive_data['weights_nmda'],
synaptic_delays=drive_data['synaptic_delays'],
event_seed=drive_data['event_seed'],
conn_seed=drive_data['conn_seed'])
elif drive_data['type'] == 'poisson':
net.add_poisson_drive(name=drive_data['name'],
tstart=drive_data['dynamics']['tstart'],
tstop=drive_data['dynamics']['tstop'],
rate_constant=(drive_data['dynamics']
['rate_constant']),
location=drive_data['location'],
n_drive_cells=drive_data['n_drive_cells'],
cell_specific=drive_data['cell_specific'],
weights_ampa=drive_data['weights_ampa'],
weights_nmda=drive_data['weights_nmda'],
synaptic_delays=drive_data['synaptic_delays'],
event_seed=drive_data['event_seed'],
conn_seed=drive_data['conn_seed'])
elif drive_data['type'] == 'bursty':
net.add_bursty_drive(name=drive_data['name'],
tstart=drive_data['dynamics']['tstart'],
tstart_std=drive_data['dynamics']['tstart_std'],
tstop=drive_data['dynamics']['tstop'],
burst_rate=drive_data['dynamics']['burst_rate'],
burst_std=drive_data['dynamics']['burst_std'],
num_spikes=drive_data['dynamics']['num_spikes'],
spike_isi=drive_data['dynamics']['spike_isi'],
location=drive_data['location'],
n_drive_cells=drive_data['n_drive_cells'],
cell_specific=drive_data['cell_specific'],
weights_ampa=drive_data['weights_ampa'],
weights_nmda=drive_data['weights_nmda'],
synaptic_delays=drive_data['synaptic_delays'],
event_seed=drive_data['event_seed'],
conn_seed=drive_data['conn_seed'])

net.external_drives[drive_data['name']]['events'] = drive_data['events']


def read_network(fname):
net_data = read_hdf5(fname)
# print(net_data)
params = dict()
params['N_pyr_x'] = 10
params['N_pyr_y'] = 10
Expand All @@ -1516,6 +1597,9 @@ def read_network(fname):
net.pos_dict = net_data['pos_dict']
# Set cell_response
net.cell_response = _read_cell_response(net_data['cell_response'])
# Set external drives
for key in net_data['external_drives'].keys():
_read_external_drive(net, net_data['external_drives'][key])
return net


Expand Down
1 change: 1 addition & 0 deletions hnn_core/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

def test_network_io(tmpdir):
net_jones = jones_2009_model()
add_erp_drives_to_jones_model(net_jones)
# Writing network
net_jones.write(tmpdir.join('net_jones.hdf5'))
# Reading network
Expand Down

0 comments on commit 0f4513e

Please sign in to comment.