Skip to content

Commit

Permalink
Refactored code, fix type validation and parameter grid issues
Browse files Browse the repository at this point in the history
Co-authored-by: Nicholas Tolley <55253912+ntolley@users.noreply.github.com>

Added test for callable params

Signed-off-by: samadpls <abdulsamadsid1@gmail.com>
  • Loading branch information
samadpls authored and ntolley committed Jul 21, 2024
1 parent 2a62446 commit a2e78d7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
8 changes: 3 additions & 5 deletions examples/howto/plot_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def set_params(param_values, net=None):

# Add an evoked drive to the network.
net.add_evoked_drive('evprox',
mu=param_values['mu'],
sigma=param_values['sigma'],
mu=40,
sigma=5,
numspikes=1,
location='proximal',
weights_ampa=weights_ampa,
Expand All @@ -63,9 +63,7 @@ def set_params(param_values, net=None):

param_grid = {
'weight_basket': np.logspace(-4 - 1, 5),
'weight_pyr': np.logspace(-4, -1, 5),
'mu': np.linspace(20, 80, 5),
'sigma': np.linspace(1, 20, 5)
'weight_pyr': np.logspace(-4, -1, 5)
}

###############################################################################
Expand Down
17 changes: 9 additions & 8 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BatchSimulate(object):
def __init__(self, set_params, net_name='jones', tstop=170,
dt=0.025, n_trials=1, record_vsec=False,
record_isec=False, postproc=False, save_outputs=False,
file_path='./sim_results', batch_size=100,
save_folder='./sim_results', batch_size=100,
overwrite=True, summary_func=None):
"""Initialize the BatchSimulate class.
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, set_params, net_name='jones', tstop=170,
Default: False.
save_outputs : bool, optional
Whether to save the simulation outputs to files. Default is False.
file_path : str, optional
save_folder : str, optional
The path to save the simulation outputs.
Default is './sim_results'.
batch_size : int, optional
Expand All @@ -73,7 +73,7 @@ def __init__(self, set_params, net_name='jones', tstop=170,
Notes
-----
When `save_output=True`, the saved files will appear as
`sim_run_{start_idx}-{end_idx}.npy` in the specified `file_path`
`sim_run_{start_idx}-{end_idx}.npy` in the specified `save_folder`
directory. The `start_idx` and `end_idx` indicate the range of
simulation indices contained in each file. Each file will contain
a maximum of `batch_size` simulations, split evenly among the
Expand All @@ -88,7 +88,7 @@ def __init__(self, set_params, net_name='jones', tstop=170,
_validate_type(n_trials, types='int', item_name='n_trials')
_check_option('record_vsec', record_vsec, ['all', 'soma', False])
_check_option('record_isec', record_isec, ['all', 'soma', False])
_validate_type(file_path, types='path-like', item_name='file_path')
_validate_type(save_folder, types='path-like', item_name='save_folder')
_validate_type(batch_size, types='int', item_name='batch_size')

if set_params is not None and not callable(set_params):
Expand All @@ -106,7 +106,7 @@ def __init__(self, set_params, net_name='jones', tstop=170,
self.record_isec = record_isec
self.postproc = postproc
self.save_outputs = save_outputs
self.file_path = file_path
self.save_folder = save_folder
self.batch_size = batch_size
self.overwrite = overwrite
self.summary_func = summary_func
Expand Down Expand Up @@ -147,6 +147,7 @@ def run(self, param_grid, return_output=True,
_check_option('backend', backend, ['loky', 'threading',
'multiprocessing', 'dask'])
_validate_type(verbose, types='int', item_name='verbose')
_validate_type(clear_cache, types=(bool,), item_name='clear_cache')

param_combinations = self._generate_param_combinations(
param_grid, combinations)
Expand Down Expand Up @@ -303,12 +304,12 @@ def _save(self, results, start_idx, end_idx):
_validate_type(start_idx, types='int', item_name='start_idx')
_validate_type(end_idx, types='int', item_name='end_idx')

if not os.path.exists(self.file_path):
os.makedirs(self.file_path)
if not os.path.exists(self.save_folder):
os.makedirs(self.save_folder)

sim_data = np.stack([dpl['dpl'][0].data['agg'] for dpl in results])

file_name = os.path.join(self.file_path,
file_name = os.path.join(self.save_folder,
f'sim_run_{start_idx}-{end_idx}.npy')
if os.path.exists(file_name) and not self.overwrite:
raise FileExistsError(
Expand Down
21 changes: 18 additions & 3 deletions hnn_core/tests/test_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def set_params(param_values, net):

return BatchSimulate(set_params=set_params,
tstop=1.,
file_path=tmp_path,
save_folder=tmp_path,
batch_size=3)


Expand All @@ -45,8 +45,8 @@ def param_grid():
return {
'weight_basket': np.logspace(-4 - 1, 2),
'weight_pyr': np.logspace(-4, -1, 2),
'mu': np.linspace(20, 80, 2),
'sigma': np.linspace(1, 20, 2)
'mu': [40],
'sigma': [5]
}


Expand Down Expand Up @@ -149,6 +149,21 @@ def test_run(batch_simulate_instance, param_grid):
with pytest.raises(TypeError, match='verbose must be'):
batch_simulate_instance.run(param_grid, verbose='invalid')

with pytest.raises(TypeError, match='clear_cache must be'):
batch_simulate_instance.run(param_grid, clear_cache='invalid')

# Callables Test
batch_simulate_instance.summary_func = lambda x: x
assert callable(batch_simulate_instance.summary_func)
assert callable(batch_simulate_instance.set_params)

with pytest.raises(TypeError, match='summary_func must be'):
BatchSimulate(set_params=batch_simulate_instance.set_params,
summary_func='invalid')

with pytest.raises(TypeError, match='set_params must be'):
BatchSimulate(set_params='invalid')


def test_save_load_and_overwrite(batch_simulate_instance,
param_grid, tmp_path):
Expand Down

0 comments on commit a2e78d7

Please sign in to comment.