Skip to content

Commit

Permalink
Jit pre save hook (#38186)
Browse files Browse the repository at this point in the history
* Pre-save hooks of jit.save

1. Added pre_save_hooks features to jit.save.
2. Added related unittests

* Added jit pre_save_hooks functions's alias to paddle.jit and copyright.

* Make jit.save_pre_hook style be consisent with Paddle's rule.

* Fixed arguments passing bug in run_save_pre_hooks

* Added API Documents

* Move clear and run_pre_save_hooks as internal methonds only.

* Made register_save_pre_hook as an internal function.
  • Loading branch information
mingxu1067 authored Jan 11, 2022
1 parent d368647 commit e91f7c0
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 1 deletion.
101 changes: 101 additions & 0 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,6 +21,7 @@
import functools
from collections import OrderedDict
import inspect
import threading

import six
import paddle
Expand Down Expand Up @@ -525,6 +527,105 @@ def _build_load_path_and_config(path, config):
return model_path, config


_save_pre_hooks_lock = threading.Lock()
_save_pre_hooks = []


class HookRemoveHelper(object):
""" A HookRemoveHelper that can be used to remove hook. """

def __init__(self, hook):
self._hook = hook

def remove(self):
_remove_save_pre_hook(self._hook)


def _register_save_pre_hook(hook):
"""
Register a save pre-hook for `paddle.jit.save`.
This hook will be executed before `save` function has been invoked.
hook(layer, input_spec, configs) -> None
- layer (Layer|function): This argument is corresponding to `layer` in `paddle.jit.save`.
- input_spec (list or tuple[InputSpec|Tensor|Python built-in variable]): This argument is corresponding to `input_spec` in `paddle.jit.save`.
- configs (dict): This argument is corresponding to `configs` in `paddle.jit.save`.
Args:
hook(function): a function registered as a save pre-hook
Returns:
HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()`.
Examples:
.. code-block:: python
import numpy as np
import paddle
IMAGE_SIZE = 256
CLASS_NUM = 10
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = paddle.nn.Linear(IMAGE_SIZE, CLASS_NUM)
def forward(self, x):
return self._linear(x)
saving_count = 0
def save_pre_hook(layer, input_spec, configs):
global saving_count
saving_count += 1
remove_handler = paddle.jit.register_save_pre_hook(save_pre_hook)
layer = LinearNet()
paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
# saving_count == 1
remove_handler.remove()
paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
# saving_count == 1
"""
global _save_pre_hooks_lock
global _save_pre_hooks
_save_pre_hooks_lock.acquire()
if hook not in _save_pre_hooks:
_save_pre_hooks.append(hook)
_save_pre_hooks_lock.release()
return HookRemoveHelper(hook)


def _clear_save_pre_hooks():
global _save_pre_hooks_lock
global _save_pre_hooks
_save_pre_hooks_lock.acquire()
_save_pre_hooks.clear()
_save_pre_hooks_lock.release()


def _remove_save_pre_hook(hook):
global _save_pre_hooks_lock
global _save_pre_hooks
_save_pre_hooks_lock.acquire()
if hook in _save_pre_hooks:
_save_pre_hooks.remove(hook)
_save_pre_hooks_lock.release()


def _run_save_pre_hooks(func):
def wrapper(layer, path, input_spec=None, **configs):
global _save_pre_hooks
for hook in _save_pre_hooks:
hook(layer, input_spec, configs)
func(layer, path, input_spec, **configs)

return wrapper


@_run_save_pre_hooks
@switch_to_static_graph
def save(layer, path, input_spec=None, **configs):
"""
Expand Down
59 changes: 59 additions & 0 deletions python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import paddle
from paddle.fluid.dygraph.jit import _run_save_pre_hooks, _clear_save_pre_hooks, _register_save_pre_hook

_counter = 0


class TestPreSaveHooks(unittest.TestCase):
def test_pre_save_hook_functions(self):
def fake_func(*args, **kwgs):
global _counter
_counter += 1

remove_handler = _register_save_pre_hook(fake_func)
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1)
self.assertTrue(
paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func)

# Test of avoiding redundancy hanging
remove_handler = _register_save_pre_hook(fake_func)
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1)
self.assertTrue(
paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func)

remove_handler.remove()
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0)

remove_handler = _register_save_pre_hook(fake_func)
_clear_save_pre_hooks()
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0)

global _counter
_counter = 0
remove_handler = _register_save_pre_hook(fake_func)
func_with_hook = _run_save_pre_hooks(fake_func)
func_with_hook(None, None)
self.assertEqual(_counter, 2)


if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion python/paddle/jit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit e91f7c0

Please sign in to comment.