Skip to content

Commit

Permalink
Add paddle.set_grad_enabled (#31794)
Browse files Browse the repository at this point in the history
  • Loading branch information
willthefrog authored Apr 22, 2021
1 parent c332828 commit f8ca5a9
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@

from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS
from .framework import set_grad_enabled #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/fluid/tests/unittests/test_imperative_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,28 @@ def test_paddle_imperative_no_grad_guard(self):
self.assertTrue(tmp._grad_ivar() is None)
self.assertTrue(l0.weight._grad_ivar() is not None)

def test_paddle_imperative_set_grad_enabled(self):
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2)
self.assertTrue(l0.weight._grad_ivar() is None)
l1 = fluid.Linear(2, 2)
with paddle.set_grad_enabled(False):
self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2
with paddle.set_grad_enabled(True):
tmp2 = l1.weight * 2
self.assertTrue(tmp.stop_gradient)
self.assertTrue(tmp2.stop_gradient is False)
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp2
o = l1(y)
o.backward()

self.assertTrue(tmp._grad_ivar() is None)
self.assertTrue(tmp2._grad_ivar() is not None)
self.assertTrue(l0.weight._grad_ivar() is not None)

def test_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
'NPUPlace', 'get_default_dtype', 'set_default_dtype'
]

__all__ += ['grad', 'LayerList', 'load', 'save', 'no_grad', 'DataParallel']
__all__ += [
'grad', 'set_grad_enabled', 'LayerList', 'load', 'save', 'no_grad',
'DataParallel'
]

from . import random
from .random import seed
from .framework import get_default_dtype
from .framework import set_default_dtype
from .framework import set_grad_enabled

from ..fluid.param_attr import ParamAttr #DEFINE_ALIAS
# from ..fluid.layers.tensor import create_global_var #DEFINE_ALIAS
Expand Down
36 changes: 36 additions & 0 deletions python/paddle/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# TODO: define framework api
from paddle.fluid.layer_helper_base import LayerHelperBase
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.framework import _dygraph_tracer
import numpy as np
from contextlib import contextmanager

__all__ = ['set_default_dtype', 'get_default_dtype']

Expand Down Expand Up @@ -80,3 +82,37 @@ def get_default_dtype():
paddle.get_default_dtype()
"""
return LayerHelperBase.get_default_dtype()


@contextmanager
def set_grad_enabled(mode):
"""
:api_attr: imperative
Create a context which enables or disables dygraph gradient calculation.
Args:
mode(bool): whether to enable (`True`), or disable (`False`) grad.
Examples:
.. code-block:: python
x = paddle.ones([3, 2])
x.stop_gradient = False
with torch.set_grad_enabled(False):
y = x * 2
with torch.set_grad_enabled(True):
z = x * 2
print(y.stop_gradient) # True
print(z.stop_gradient) # False
"""

tracer = _dygraph_tracer()
if tracer:
prev_mode = tracer._has_grad
tracer._has_grad = mode
try:
yield
finally:
tracer._has_grad = prev_mode
else:
yield

0 comments on commit f8ca5a9

Please sign in to comment.