-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jit pre save hook #38186
Jit pre save hook #38186
Changes from 5 commits
9330b1e
0107076
e6ee75e
538c956
3e7c315
fa820bc
9b9b1bb
4677f6e
3400716
dc6f624
52b53cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
|
@@ -20,6 +21,7 @@ | |
import functools | ||
from collections import OrderedDict | ||
import inspect | ||
import threading | ||
|
||
import six | ||
import paddle | ||
|
@@ -42,7 +44,8 @@ | |
|
||
__all__ = [ | ||
'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level', | ||
'set_verbosity', 'save', 'load', 'not_to_static' | ||
'set_verbosity', 'save', 'load', 'not_to_static', 'register_save_pre_hook', | ||
'clear_save_pre_hooks' | ||
] | ||
|
||
|
||
|
@@ -525,6 +528,58 @@ def _build_load_path_and_config(path, config): | |
return model_path, config | ||
|
||
|
||
_save_pre_hooks_lock = threading.Lock() | ||
_save_pre_hooks = [] | ||
|
||
|
||
class HookRemoveHelper(object): | ||
""" A HookRemoveHelper that can be used to remove hook. """ | ||
|
||
def __init__(self, hook): | ||
self._hook = hook | ||
|
||
def remove(self): | ||
remove_save_pre_hook(self._hook) | ||
|
||
|
||
def register_save_pre_hook(hook): | ||
global _save_pre_hooks_lock | ||
global _save_pre_hooks | ||
_save_pre_hooks_lock.acquire() | ||
if hook not in _save_pre_hooks: | ||
_save_pre_hooks.append(hook) | ||
_save_pre_hooks_lock.release() | ||
return HookRemoveHelper(hook) | ||
|
||
|
||
def remove_save_pre_hook(hook): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, renamed |
||
global _save_pre_hooks_lock | ||
global _save_pre_hooks | ||
_save_pre_hooks_lock.acquire() | ||
if hook in _save_pre_hooks: | ||
_save_pre_hooks.remove(hook) | ||
_save_pre_hooks_lock.release() | ||
|
||
|
||
def clear_save_pre_hooks(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently not, but this is a convenient API to let user clear all hooks in one call. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the specific usage scenario is not clear, we recommended to use it as an internal method first, and then upgrade it to a public API if necessary in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, currently rename |
||
global _save_pre_hooks_lock | ||
global _save_pre_hooks | ||
_save_pre_hooks_lock.acquire() | ||
_save_pre_hooks.clear() | ||
_save_pre_hooks_lock.release() | ||
|
||
|
||
def run_save_pre_hooks(func): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
def wrapper(layer, path, input_spec=None, **configs): | ||
global _save_pre_hooks | ||
for hook in _save_pre_hooks: | ||
hook(layer, input_spec, configs) | ||
func(layer, path, input_spec, **configs) | ||
|
||
return wrapper | ||
|
||
|
||
@run_save_pre_hooks | ||
@switch_to_static_graph | ||
def save(layer, path, input_spec=None, **configs): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import unittest | ||
|
||
import paddle | ||
from paddle.jit import register_save_pre_hook | ||
from paddle.jit import clear_save_pre_hooks | ||
from paddle.fluid.dygraph.jit import run_save_pre_hooks | ||
|
||
_counter = 0 | ||
|
||
|
||
class TestPreSaveHooks(unittest.TestCase): | ||
def test_pre_save_hook_functions(self): | ||
def fake_func(*args, **kwgs): | ||
global _counter | ||
_counter += 1 | ||
|
||
remove_handler = register_save_pre_hook(fake_func) | ||
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1) | ||
self.assertTrue( | ||
paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func) | ||
|
||
# Test of avoiding redundancy hanging | ||
remove_handler = register_save_pre_hook(fake_func) | ||
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1) | ||
self.assertTrue( | ||
paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func) | ||
|
||
remove_handler.remove() | ||
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0) | ||
|
||
remove_handler = register_save_pre_hook(fake_func) | ||
clear_save_pre_hooks() | ||
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0) | ||
|
||
global _counter | ||
_counter = 0 | ||
remove_handler = register_save_pre_hook(fake_func) | ||
func_with_hook = run_save_pre_hooks(fake_func) | ||
func_with_hook(None, None) | ||
self.assertEqual(_counter, 2) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should append api doc like other api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, added api doc to public functions.