Skip to content

Commit

Permalink
feat: fix io for brainpy.Base (#211)
Browse files Browse the repository at this point in the history
feat: fix `io` for brainpy.Base
  • Loading branch information
chaoming0625 authored May 16, 2022
2 parents c6cfe3f + 6630bd8 commit 3122d03
Show file tree
Hide file tree
Showing 7 changed files with 562 additions and 126 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ publishment.md
.vscode


brainpy/base/tests/io_test_tmp*

development

examples/simulation/data
Expand Down Expand Up @@ -53,7 +55,6 @@ develop/benchmark/CUBA/annarchy*
develop/benchmark/CUBA/brian2*



*~
\#*\#
*.pyc
Expand Down
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.1.11"
__version__ = "2.1.12"


try:
Expand Down
31 changes: 16 additions & 15 deletions brainpy/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,49 +208,50 @@ def unique_name(self, name=None, type_=None):
naming.check_name_uniqueness(name=name, obj=self)
return name

def load_states(self, filename, verbose=False, check_missing=False):
def load_states(self, filename, verbose=False):
"""Load the model states.
Parameters
----------
filename : str
The filename which stores the model states.
verbose: bool
check_missing: bool
Whether report the load progress.
"""
if not os.path.exists(filename):
raise errors.BrainPyError(f'Cannot find the file path: {filename}')
elif filename.endswith('.hdf5') or filename.endswith('.h5'):
io.load_h5(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_h5(filename, target=self, verbose=verbose)
elif filename.endswith('.pkl'):
io.load_pkl(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_pkl(filename, target=self, verbose=verbose)
elif filename.endswith('.npz'):
io.load_npz(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_npz(filename, target=self, verbose=verbose)
elif filename.endswith('.mat'):
io.load_mat(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_mat(filename, target=self, verbose=verbose)
else:
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')

def save_states(self, filename, all_vars=None, **setting):
def save_states(self, filename, variables=None, **setting):
"""Save the model states.
Parameters
----------
filename : str
The file name which to store the model states.
all_vars: optional, dict, TensorCollector
variables: optional, dict, TensorCollector
The variables to save. If not provided, all variables retrieved by ``~.vars()`` will be used.
"""
if all_vars is None:
all_vars = self.vars(method='relative').unique()
if variables is None:
variables = self.vars(method='absolute', level=-1)

if filename.endswith('.hdf5') or filename.endswith('.h5'):
io.save_h5(filename, all_vars=all_vars)
elif filename.endswith('.pkl'):
io.save_pkl(filename, all_vars=all_vars)
io.save_as_h5(filename, variables=variables)
elif filename.endswith('.pkl') or filename.endswith('.pickle'):
io.save_as_pkl(filename, variables=variables)
elif filename.endswith('.npz'):
io.save_npz(filename, all_vars=all_vars, **setting)
io.save_as_npz(filename, variables=variables, **setting)
elif filename.endswith('.mat'):
io.save_mat(filename, all_vars=all_vars)
io.save_as_mat(filename, variables=variables)
else:
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')

Expand Down
24 changes: 24 additions & 0 deletions brainpy/base/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,35 @@ def update(self, other, **kwargs):
self[key] = value

def __add__(self, other):
"""Merging two dicts.
Parameters
----------
other: dict
The other dict instance.
Returns
-------
gather: Collector
The new collector.
"""
gather = type(self)(self)
gather.update(other)
return gather

def __sub__(self, other):
"""Remove other item in the collector.
Parameters
----------
other: dict
The items to remove.
Returns
-------
gather: Collector
The new collector.
"""
if not isinstance(other, dict):
raise ValueError(f'Only support dict, but we got {type(other)}.')
gather = type(self)()
Expand Down
Loading

0 comments on commit 3122d03

Please sign in to comment.