diff --git a/.gitignore b/.gitignore index a5212d0bf..1d40ddd9e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ publishment.md .vscode +brainpy/base/tests/io_test_tmp* + development examples/simulation/data @@ -53,7 +55,6 @@ develop/benchmark/CUBA/annarchy* develop/benchmark/CUBA/brian2* - *~ \#*\# *.pyc diff --git a/brainpy/__init__.py b/brainpy/__init__.py index d9b214c0c..e2ac8336b 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.1.11" +__version__ = "2.1.12" try: diff --git a/brainpy/base/base.py b/brainpy/base/base.py index d4c8c9401..70996bf3d 100644 --- a/brainpy/base/base.py +++ b/brainpy/base/base.py @@ -208,7 +208,7 @@ def unique_name(self, name=None, type_=None): naming.check_name_uniqueness(name=name, obj=self) return name - def load_states(self, filename, verbose=False, check_missing=False): + def load_states(self, filename, verbose=False): """Load the model states. Parameters @@ -216,41 +216,42 @@ def load_states(self, filename, verbose=False, check_missing=False): filename : str The filename which stores the model states. verbose: bool - check_missing: bool + Whether report the load progress. """ if not os.path.exists(filename): raise errors.BrainPyError(f'Cannot find the file path: {filename}') elif filename.endswith('.hdf5') or filename.endswith('.h5'): - io.load_h5(filename, target=self, verbose=verbose, check=check_missing) + io.load_by_h5(filename, target=self, verbose=verbose) elif filename.endswith('.pkl'): - io.load_pkl(filename, target=self, verbose=verbose, check=check_missing) + io.load_by_pkl(filename, target=self, verbose=verbose) elif filename.endswith('.npz'): - io.load_npz(filename, target=self, verbose=verbose, check=check_missing) + io.load_by_npz(filename, target=self, verbose=verbose) elif filename.endswith('.mat'): - io.load_mat(filename, target=self, verbose=verbose, check=check_missing) + io.load_by_mat(filename, target=self, verbose=verbose) else: raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}') - def save_states(self, filename, all_vars=None, **setting): + def save_states(self, filename, variables=None, **setting): """Save the model states. Parameters ---------- filename : str The file name which to store the model states. - all_vars: optional, dict, TensorCollector + variables: optional, dict, TensorCollector + The variables to save. If not provided, all variables retrieved by ``~.vars()`` will be used. """ - if all_vars is None: - all_vars = self.vars(method='relative').unique() + if variables is None: + variables = self.vars(method='absolute', level=-1) if filename.endswith('.hdf5') or filename.endswith('.h5'): - io.save_h5(filename, all_vars=all_vars) - elif filename.endswith('.pkl'): - io.save_pkl(filename, all_vars=all_vars) + io.save_as_h5(filename, variables=variables) + elif filename.endswith('.pkl') or filename.endswith('.pickle'): + io.save_as_pkl(filename, variables=variables) elif filename.endswith('.npz'): - io.save_npz(filename, all_vars=all_vars, **setting) + io.save_as_npz(filename, variables=variables, **setting) elif filename.endswith('.mat'): - io.save_mat(filename, all_vars=all_vars) + io.save_as_mat(filename, variables=variables) else: raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}') diff --git a/brainpy/base/collector.py b/brainpy/base/collector.py index 1b0178bf9..f86ba372a 100644 --- a/brainpy/base/collector.py +++ b/brainpy/base/collector.py @@ -39,11 +39,35 @@ def update(self, other, **kwargs): self[key] = value def __add__(self, other): + """Merging two dicts. + + Parameters + ---------- + other: dict + The other dict instance. + + Returns + ------- + gather: Collector + The new collector. + """ gather = type(self)(self) gather.update(other) return gather def __sub__(self, other): + """Remove other item in the collector. + + Parameters + ---------- + other: dict + The items to remove. + + Returns + ------- + gather: Collector + The new collector. + """ if not isinstance(other, dict): raise ValueError(f'Only support dict, but we got {type(other)}.') gather = type(self)() diff --git a/brainpy/base/io.py b/brainpy/base/io.py index 7e1fcbe8a..97cf03f87 100644 --- a/brainpy/base/io.py +++ b/brainpy/base/io.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- +from typing import Dict, Type, Union, Tuple, List import logging -import os import pickle import numpy as np @@ -9,35 +9,47 @@ from brainpy import errors from brainpy.base.collector import TensorCollector -Base = math = None logger = logging.getLogger('brainpy.base.io') -try: - import h5py -except (ModuleNotFoundError, ImportError): - h5py = None - -try: - import scipy.io as sio -except (ModuleNotFoundError, ImportError): - sio = None - __all__ = [ 'SUPPORTED_FORMATS', - 'save_h5', - 'save_npz', - 'save_pkl', - 'save_mat', - 'load_h5', - 'load_npz', - 'load_pkl', - 'load_mat', + 'save_as_h5', + 'save_as_npz', + 'save_as_pkl', + 'save_as_mat', + 'load_by_h5', + 'load_by_npz', + 'load_by_pkl', + 'load_by_mat', ] SUPPORTED_FORMATS = ['.h5', '.hdf5', '.npz', '.pkl', '.mat'] -def _check(module, module_name, ext): +def check_dict_data( + a_dict: Dict, + key_type: Union[Type, Tuple[Type, ...]] = None, + val_type: Union[Type, Tuple[Type, ...]] = None, + name: str = None +): + """Check the dict data.""" + name = '' if (name is None) else f'"{name}"' + if not isinstance(a_dict, dict): + raise ValueError(f'{name} must be a dict, while we got {type(a_dict)}') + if key_type is not None: + for key, value in a_dict.items(): + if not isinstance(key, key_type): + raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), ' + f'while we got ({type(key)}, {type(value)})') + if val_type is not None: + for key, value in a_dict.items(): + if not isinstance(value, val_type): + raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), ' + f'while we got ({type(key)}, {type(value)})') + + +def _check_module(module, module_name, ext): + """Check whether the required module is installed.""" if module is None: raise errors.PackageMissingError( '"{package}" must be installed when you want to save/load data with {ext} ' @@ -52,104 +64,329 @@ def _check_missing(variables, filename): f'The missed variables are: {list(variables.keys())}.') -def save_h5(filename, all_vars): - _check(h5py, module_name='h5py', ext=os.path.splitext(filename)) - assert isinstance(all_vars, dict) - all_vars = TensorCollector(all_vars).unique() +def _check_target(target): + from .base import Base + if not isinstance(target, Base): + raise TypeError(f'"target" must be instance of "{Base.__name__}", but we got {type(target)}') + + +not_found_msg = ('"{key}" is stored in {filename}. But we does ' + 'not find it is defined as variable in {target}.') +id_dismatch_msg = ('{key1} and {key2} is the same data in {filename}. ' + 'But we found they are different in {target}.') + +DUPLICATE_KEY = 'duplicate_keys' +DUPLICATE_TARGET = 'duplicate_targets' + + +def _load( + target, + verbose: bool, + filename: str, + load_vars: dict, + duplicates: Tuple[List[str], List[str]], + remove_first_axis: bool = False +): + from brainpy import math as bm + + # get variables + _check_target(target) + variables = target.vars(method='absolute', level=-1) + all_names = list(variables.keys()) + + # read data from file + for key in load_vars.keys(): + if verbose: + print(f'Loading {key} ...') + if key not in variables: + raise KeyError(not_found_msg.format(key=key, target=target.name, filename=filename)) + if remove_first_axis: + value = load_vars[key][0] + else: + value = load_vars[key] + variables[key].value = bm.asarray(value) + all_names.remove(key) + + # check duplicate names + duplicate_keys = duplicates[0] + duplicate_targets = duplicates[1] + for key1, key2 in zip(duplicate_keys, duplicate_targets): + if key1 not in all_names: + raise KeyError(not_found_msg.format(key=key1, target=target.name, filename=filename)) + if id(variables[key1]) != id(variables[key2]): + raise ValueError(id_dismatch_msg.format(key1=key1, key2=target, filename=filename, target=target.name)) + all_names.remove(key1) + + # check missing names + if len(all_names): + logger.warning(f'There are variable states missed in {filename}. ' + f'The missed variables are: {all_names}.') + + +def _unique_and_duplicate(collector: dict): + gather = TensorCollector() + id2name = dict() + duplicates = ([], []) + for k, v in collector.items(): + id_ = id(v) + if id_ not in id2name: + gather[k] = v + id2name[id_] = k + else: + k2 = id2name[id_] + duplicates[0].append(k) + duplicates[1].append(k2) + duplicates = (duplicates[0], duplicates[1]) + return gather, duplicates + + +def save_as_h5(filename: str, variables: dict): + """Save variables into a HDF5 file. + + Parameters + ---------- + filename: str + The filename to save. + variables: dict + All variables to save. + """ + if not (filename.endswith('.hdf5') or filename.endswith('.h5')): + raise ValueError(f'Cannot save variables as a HDF5 file. We only support file with ' + f'postfix of ".hdf5" and ".h5". But we got {filename}') + + from brainpy import math as bm + import h5py + + # check variables + check_dict_data(variables, name='variables') + variables, duplicates = _unique_and_duplicate(variables) # save f = h5py.File(filename, "w") - for key, data in all_vars.items(): - f[key] = np.asarray(data.value) + for key, data in variables.items(): + f[key] = bm.as_numpy(data) + if len(duplicates[0]): + f.create_dataset(DUPLICATE_TARGET, data='+'.join(duplicates[1])) + f.create_dataset(DUPLICATE_KEY, data='+'.join(duplicates[0])) f.close() -def load_h5(filename, target, verbose=False, check=False): - global math, Base - if Base is None: from brainpy.base.base import Base - if math is None: from brainpy import math - assert isinstance(target, Base) - _check(h5py, module_name='h5py', ext=os.path.splitext(filename)) +def load_by_h5(filename: str, target, verbose: bool = False): + """Load variables in a HDF5 file. - all_vars = target.vars(method='absolute') - f = h5py.File(filename, "r") - for key in f.keys(): - if verbose: print(f'Loading {key} ...') - var = all_vars.pop(key) - var[:] = math.asarray(f[key][:]) - f.close() - if check: _check_missing(all_vars, filename=filename) + Parameters + ---------- + filename: str + The filename to load variables. + target: Base + The instance of :py:class:`~.brainpy.Base`. + verbose: bool + Whether report the load progress. + """ + if not (filename.endswith('.hdf5') or filename.endswith('.h5')): + raise ValueError(f'Cannot load variables from a HDF5 file. We only support file with ' + f'postfix of ".hdf5" and ".h5". But we got {filename}') + # read data + import h5py + load_vars = dict() + with h5py.File(filename, "r") as f: + for key in f.keys(): + if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue + load_vars[key] = np.asarray(f[key]) + if DUPLICATE_KEY in f: + duplicate_keys = np.asarray(f[DUPLICATE_KEY]).item().decode("utf-8").split('+') + duplicate_targets = np.asarray(f[DUPLICATE_TARGET]).item().decode("utf-8").split('+') + duplicates = (duplicate_keys, duplicate_targets) + else: + duplicates = ([], []) + + # assign values + _load(target, verbose, filename, load_vars, duplicates) + + +def save_as_npz(filename, variables, compressed=False): + """Save variables into a numpy file. + + Parameters + ---------- + filename: str + The filename to store. + variables: dict + Variables to save. + compressed: bool + Whether we use the compressed mode. + """ + if not filename.endswith('.npz'): + raise ValueError(f'Cannot save variables as a .npz file. We only support file with ' + f'postfix of ".npz". But we got {filename}') + + from brainpy import math as bm + check_dict_data(variables, name='variables') + variables, duplicates = _unique_and_duplicate(variables) -def save_npz(filename, all_vars, compressed=False): - assert isinstance(all_vars, dict) - all_vars = TensorCollector(all_vars).unique() - all_vars = {k.replace('.', '--'): np.asarray(v.value) for k, v in all_vars.items()} + # save + variables = {k: bm.as_numpy(v) for k, v in variables.items()} + if len(duplicates[0]): + variables[DUPLICATE_KEY] = np.asarray(duplicates[0]) + variables[DUPLICATE_TARGET] = np.asarray(duplicates[1]) if compressed: - np.savez_compressed(filename, **all_vars) + np.savez_compressed(filename, **variables) else: - np.savez(filename, **all_vars) - - -def load_npz(filename, target, verbose=False, check=False): - global math, Base - if Base is None: from brainpy.base.base import Base - if math is None: from brainpy import math - assert isinstance(target, Base) - - all_vars = target.vars(method='absolute') + np.savez(filename, **variables) + + +def load_by_npz(filename, target, verbose=False): + """Load variables from a numpy file. + + Parameters + ---------- + filename: str + The filename to load variables. + target: Base + The instance of :py:class:`~.brainpy.Base`. + verbose: bool + Whether report the load progress. + """ + if not filename.endswith('.npz'): + raise ValueError(f'Cannot load variables from a .npz file. We only support file with ' + f'postfix of ".npz". But we got {filename}') + + # load data + load_vars = dict() all_data = np.load(filename) for key in all_data.files: - if verbose: print(f'Loading {key} ...') - var = all_vars.pop(key) - var[:] = math.asarray(all_data[key]) - if check: _check_missing(all_vars, filename=filename) - - -def save_pkl(filename, all_vars): - assert isinstance(all_vars, dict) - all_vars = TensorCollector(all_vars).unique() - targets = {k: np.asarray(v) for k, v in all_vars.items()} - f = open(filename, 'wb') - pickle.dump(targets, f, protocol=pickle.HIGHEST_PROTOCOL) - f.close() - - -def load_pkl(filename, target, verbose=False, check=False): - global math, Base - if Base is None: from brainpy.base.base import Base - if math is None: from brainpy import math - assert isinstance(target, Base) - f = open(filename, 'rb') - all_data = pickle.load(f) - f.close() - - all_vars = target.vars(method='absolute') - for key, data in all_data.items(): - if verbose: print(f'Loading {key} ...') - var = all_vars.pop(key) - var[:] = math.asarray(data) - if check: _check_missing(all_vars, filename=filename) - - -def save_mat(filename, all_vars): - assert isinstance(all_vars, dict) - all_vars = TensorCollector(all_vars).unique() - _check(sio, module_name='scipy', ext=os.path.splitext(filename)) - all_vars = {k.replace('.', '--'): np.asarray(v.value) for k, v in all_vars.items()} - sio.savemat(filename, all_vars) + if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue + load_vars[key] = all_data[key] + if DUPLICATE_KEY in all_data: + duplicate_keys = all_data[DUPLICATE_KEY].tolist() + duplicate_targets = all_data[DUPLICATE_TARGET].tolist() + duplicates = (duplicate_keys, duplicate_targets) + else: + duplicates = ([], []) + + # assign values + _load(target, verbose, filename, load_vars, duplicates) + + +def save_as_pkl(filename, variables): + """Save variables into a pickle file. + + Parameters + ---------- + filename: str + The filename to save. + variables: dict + All variables to save. + """ + if not (filename.endswith('.pkl') or filename.endswith('.pickle')): + raise ValueError(f'Cannot save variables into a pickle file. We only support file with ' + f'postfix of ".pkl" and ".pickle". But we got {filename}') + + check_dict_data(variables, name='variables') + variables, duplicates = _unique_and_duplicate(variables) + import brainpy.math as bm + targets = {k: bm.as_numpy(v) for k, v in variables.items()} + if len(duplicates[0]) > 0: + targets[DUPLICATE_KEY] = np.asarray(duplicates[0]) + targets[DUPLICATE_TARGET] = np.asarray(duplicates[1]) + with open(filename, 'wb') as f: + pickle.dump(targets, f, protocol=pickle.HIGHEST_PROTOCOL) + + +def load_by_pkl(filename, target, verbose=False): + """Load variables from a pickle file. + + Parameters + ---------- + filename: str + The filename to load variables. + target: Base + The instance of :py:class:`~.brainpy.Base`. + verbose: bool + Whether report the load progress. + """ + if not (filename.endswith('.pkl') or filename.endswith('.pickle')): + raise ValueError(f'Cannot load variables from a pickle file. We only support file with ' + f'postfix of ".pkl" and ".pickle". But we got {filename}') + + # load variables + load_vars = dict() + with open(filename, 'rb') as f: + all_data = pickle.load(f) + for key, data in all_data.items(): + if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue + load_vars[key] = data + if DUPLICATE_KEY in all_data: + duplicate_keys = all_data[DUPLICATE_KEY].tolist() + duplicate_targets = all_data[DUPLICATE_TARGET].tolist() + duplicates = (duplicate_keys, duplicate_targets) + else: + duplicates = ([], []) + + # assign data + _load(target, verbose, filename, load_vars, duplicates) + + +def save_as_mat(filename, variables): + """Save variables into a HDF5 file. + + Parameters + ---------- + filename: str + The filename to save. + variables: dict + All variables to save. + """ + if not filename.endswith('.mat'): + raise ValueError(f'Cannot save variables into a .mat file. We only support file with ' + f'postfix of ".mat". But we got {filename}') + + from brainpy import math as bm + import scipy.io as sio + check_dict_data(variables, name='variables') + variables, duplicates = _unique_and_duplicate(variables) + variables = {k: np.expand_dims(bm.as_numpy(v), axis=0) for k, v in variables.items()} + if len(duplicates[0]): + variables[DUPLICATE_KEY] = np.expand_dims(np.asarray(duplicates[0]), axis=0) + variables[DUPLICATE_TARGET] = np.expand_dims(np.asarray(duplicates[1]), axis=0) + sio.savemat(filename, variables) + + +def load_by_mat(filename, target, verbose=False): + """Load variables from a numpy file. + + Parameters + ---------- + filename: str + The filename to load variables. + target: Base + The instance of :py:class:`~.brainpy.Base`. + verbose: bool + Whether report the load progress. + """ + if not filename.endswith('.mat'): + raise ValueError(f'Cannot load variables from a .mat file. We only support file with ' + f'postfix of ".mat". But we got {filename}') -def load_mat(filename, target, verbose=False, check=False): - global math, Base - if Base is None: from brainpy.base.base import Base - if math is None: from brainpy import math - assert isinstance(target, Base) + import scipy.io as sio + # load data + load_vars = dict() all_data = sio.loadmat(filename) - all_vars = target.vars(method='absolute') for key, data in all_data.items(): - if verbose: print(f'Loading {key} ...') - var = all_vars.pop(key) - var[:] = math.asarray(data) - if check: _check_missing(all_vars, filename=filename) + if key.startswith('__'): + continue + if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: + continue + load_vars[key] = data[0] + if DUPLICATE_KEY in all_data: + duplicate_keys = [a.strip() for a in all_data[DUPLICATE_KEY].tolist()[0]] + duplicate_targets = [a.strip() for a in all_data[DUPLICATE_TARGET].tolist()[0]] + duplicates = (duplicate_keys, duplicate_targets) + else: + duplicates = ([], []) + + # assign values + _load(target, verbose, filename, load_vars, duplicates) diff --git a/brainpy/base/tests/test_io.py b/brainpy/base/tests/test_io.py new file mode 100644 index 000000000..666482b07 --- /dev/null +++ b/brainpy/base/tests/test_io.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- + + +import brainpy as bp +import brainpy.math as bm +import unittest + + +class TestIO1(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestIO1, self).__init__(*args, **kwargs) + + rng = bm.random.RandomState() + + class IO1(bp.dyn.DynamicalSystem): + def __init__(self): + super(IO1, self).__init__() + + self.a = bm.Variable(bm.zeros(1)) + self.b = bm.Variable(bm.ones(3)) + self.c = bm.Variable(bm.ones((3, 4))) + self.d = bm.Variable(bm.ones((2, 3, 4))) + + class IO2(bp.dyn.DynamicalSystem): + def __init__(self): + super(IO2, self).__init__() + + self.a = bm.Variable(rng.rand(3)) + self.b = bm.Variable(rng.randn(10)) + + io1 = IO1() + io2 = IO2() + io1.a2 = io2.a + io1.b2 = io2.b + io2.a2 = io1.a + io2.b2 = io2.b + + self.net = bp.dyn.Container(io1, io2) + + print(self.net.vars().keys()) + print(self.net.vars().unique().keys()) + + def test_h5(self): + bp.base.save_as_h5('io_test_tmp.h5', self.net.vars()) + bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True) + + bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars()) + bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) + + def test_h5_postfix(self): + with self.assertRaises(ValueError): + bp.base.save_as_h5('io_test_tmp.h52', self.net.vars()) + with self.assertRaises(ValueError): + bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True) + + def test_npz(self): + bp.base.save_as_npz('io_test_tmp.npz', self.net.vars()) + bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True) + + bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True) + bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True) + + def test_npz_postfix(self): + with self.assertRaises(ValueError): + bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars()) + with self.assertRaises(ValueError): + bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True) + + def test_pkl(self): + bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars()) + bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True) + + bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars()) + bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True) + + def test_pkl_postfix(self): + with self.assertRaises(ValueError): + bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars()) + with self.assertRaises(ValueError): + bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True) + + def test_mat(self): + bp.base.save_as_mat('io_test_tmp.mat', self.net.vars()) + bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True) + + def test_mat_postfix(self): + with self.assertRaises(ValueError): + bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars()) + with self.assertRaises(ValueError): + bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True) + + +class TestIO2(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestIO2, self).__init__(*args, **kwargs) + + rng = bm.random.RandomState() + + class IO1(bp.dyn.DynamicalSystem): + def __init__(self): + super(IO1, self).__init__() + + self.a = bm.Variable(bm.zeros(1)) + self.b = bm.Variable(bm.ones(3)) + self.c = bm.Variable(bm.ones((3, 4))) + self.d = bm.Variable(bm.ones((2, 3, 4))) + + class IO2(bp.dyn.DynamicalSystem): + def __init__(self): + super(IO2, self).__init__() + + self.a = bm.Variable(rng.rand(3)) + self.b = bm.Variable(rng.randn(10)) + + io1 = IO1() + io2 = IO2() + + self.net = bp.dyn.Container(io1, io2) + + print(self.net.vars().keys()) + print(self.net.vars().unique().keys()) + + def test_h5(self): + bp.base.save_as_h5('io_test_tmp.h5', self.net.vars()) + bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True) + + bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars()) + bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) + + def test_h5_postfix(self): + with self.assertRaises(ValueError): + bp.base.save_as_h5('io_test_tmp.h52', self.net.vars()) + with self.assertRaises(ValueError): + bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True) + + def test_npz(self): + bp.base.save_as_npz('io_test_tmp.npz', self.net.vars()) + bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True) + + bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True) + bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True) + + def test_npz_postfix(self): + with self.assertRaises(ValueError): + bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars()) + with self.assertRaises(ValueError): + bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True) + + def test_pkl(self): + bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars()) + bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True) + + bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars()) + bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True) + + def test_pkl_postfix(self): + with self.assertRaises(ValueError): + bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars()) + with self.assertRaises(ValueError): + bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True) + + def test_mat(self): + bp.base.save_as_mat('io_test_tmp.mat', self.net.vars()) + bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True) + + def test_mat_postfix(self): + with self.assertRaises(ValueError): + bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars()) + with self.assertRaises(ValueError): + bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True) diff --git a/brainpy/tools/checking.py b/brainpy/tools/checking.py index 39cd5eea1..0d47feb1a 100644 --- a/brainpy/tools/checking.py +++ b/brainpy/tools/checking.py @@ -155,12 +155,15 @@ def check_dict_data(a_dict: Dict, """Check the dictionary data. """ name = '' if (name is None) else f'"{name}"' - assert isinstance(a_dict, dict), f'{name} must be a dict, while we got {type(a_dict)}' + if not isinstance(a_dict, dict): + raise ValueError(f'{name} must be a dict, while we got {type(a_dict)}') for key, value in a_dict.items(): - assert isinstance(key, key_type), (f'{name} must be a dict of ({key_type}, {val_type}), ' - f'while we got ({type(key)}, {type(value)})') - assert isinstance(value, val_type), (f'{name} must be a dict of ({key_type}, {val_type}), ' - f'while we got ({type(key)}, {type(value)})') + if not isinstance(key, key_type): + raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), ' + f'while we got ({type(key)}, {type(value)})') + if not isinstance(value, val_type): + raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), ' + f'while we got ({type(key)}, {type(value)})') def check_initializer(initializer: Union[Callable, init.Initializer, Tensor],