diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 193025b1864ab..1957de4e9e368 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1314,7 +1314,8 @@ def _state_dict_impl(self, destination=None, include_sublayers=True, structured_name_prefix="", - include_non_persistable_buffer=False): + include_non_persistable_buffer=False, + use_hook=True): """ Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict @@ -1322,6 +1323,7 @@ def _state_dict_impl(self, destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False + use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True """ if destination is None: @@ -1345,25 +1347,28 @@ def _state_dict_impl(self, layer_item._state_dict_impl( destination_temp, include_sublayers, structured_name_prefix + layer_name + ".", - include_non_persistable_buffer)) + include_non_persistable_buffer, use_hook)) destination = destination_temp - for state_dict_hook in self._state_dict_hooks.values(): - hook_result = state_dict_hook(destination) - if hook_result is not None: - destination = hook_result + if use_hook: + for state_dict_hook in self._state_dict_hooks.values(): + hook_result = state_dict_hook(destination) + if hook_result is not None: + destination = hook_result return destination def to_static_state_dict(self, destination=None, include_sublayers=True, - structured_name_prefix=""): + structured_name_prefix="", + use_hook=True): ''' Get all parameters and buffers of current layer and its sub-layers. And set them into a dict Parameters: destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True + use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True Retruns: dict: a dict contains all the parameters and persistable buffers. @@ -1383,18 +1388,21 @@ def to_static_state_dict(self, destination=destination, include_sublayers=include_sublayers, structured_name_prefix=structured_name_prefix, - include_non_persistable_buffer=True) + include_non_persistable_buffer=True, + use_hook=use_hook) def state_dict(self, destination=None, include_sublayers=True, - structured_name_prefix=""): + structured_name_prefix="", + use_hook=True): ''' Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict Parameters: destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True + use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True Retruns: dict: a dict contains all the parameters and persistable buffers. @@ -1414,7 +1422,8 @@ def state_dict(self, destination=destination, include_sublayers=include_sublayers, structured_name_prefix=structured_name_prefix, - include_non_persistable_buffer=False) + include_non_persistable_buffer=False, + use_hook=use_hook) @framework.deprecate_stat_dict def set_state_dict(self, state_dict, use_structured_name=True): @@ -1465,7 +1474,7 @@ def _check_match(key, param): return param, state matched_param_state = [] - for key, param in self.state_dict().items(): + for key, param in self.state_dict(use_hook=False).items(): key_name = key if use_structured_name else param.name try: match_res = _check_match(key_name, param) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 18620f55367f6..fcbd7b108e9ed 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -693,6 +693,37 @@ def test_skip_BatchNorm_Layer_norm(self): self.assertEqual((param.dtype == paddle.float32), True) +class TestStateDictHookForAMP(unittest.TestCase): + + def test_state_dict_hook(self): + + def func_isinstance(): + paddle.seed(100) + model = paddle.nn.Linear(2, 4) + model = paddle.amp.decorate(models=model, + level='O2', + save_dtype='float32') + param_value_ori = {} + for param in model.parameters(): + param_value_ori[param.name] = param.numpy() + + state_dict = model.state_dict() + for key, value in state_dict.items(): + state_dict[key] = value.cast("float16") + model.set_state_dict(state_dict) + + param_value_now = {} + for param in model.parameters(): + param_value_now[param.name] = param.numpy() + + for key in param_value_ori.keys(): + print(np.equal(param_value_ori[key], param_value_now[key])) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() + + class TestPureFp16SaveLoad(unittest.TestCase): def test_save_dtype_exception(self): def func():