Skip to content

Commit

Permalink
[Dy2stat]Allow users to switch eval/train mode when using @to_static …
Browse files Browse the repository at this point in the history
…to decorate a function (#37383)

* Allow users to switch eval/train mode when using @to_static to decorate a function

* refine code for train() and eval()
  • Loading branch information
0x45f authored and 0x45f committed Nov 22, 2021
1 parent 9ffb43b commit e14788e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,25 @@ def __init__(self, function, input_spec=None, **kwargs):
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator()
self._kwargs = kwargs
self._training = True

def train(self):
if isinstance(self._class_instance,
layers.Layer) and self._class_instance.training == False:
raise RuntimeError(
"Failed to switch train mode. {} is a Layer's method, "
"please use Layer.train() to switch train mode.".format(
self.dygraph_function))
self._training = True

def eval(self):
if isinstance(self._class_instance,
layers.Layer) and self._class_instance.training == True:
raise RuntimeError(
"Failed to switch eval mode. {} is a Layer's method, "
"please use Layer.eval() to switch eval mode.".format(
self.dygraph_function))
self._training = False

def __get__(self, instance, owner):
"""
Expand Down Expand Up @@ -340,6 +359,8 @@ def __call__(self, *args, **kwargs):
# 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training
else:
partial_program_layer.training = self._training

# 4. return outputs.
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,5 +297,52 @@ def test_raise_error(self):
self.program_translator.get_program(net.forward, self.x)


class SwitchModeNet(paddle.nn.Layer):
def __init__(self):
super(SwitchModeNet, self).__init__()

@paddle.jit.to_static
def forward(self, x):
return x + 1

@paddle.jit.to_static
def foo(self):
return True


@paddle.jit.to_static
def switch_mode_funciton():
return True


class TestFunctionTrainEvalMode(unittest.TestCase):
def test_switch_mode(self):
paddle.disable_static()
switch_mode_funciton.eval()
switch_mode_funciton()
self.assertEqual(switch_mode_funciton._training, False)
_, partial_layer = switch_mode_funciton.program_cache.last()[-1]
self.assertEqual(partial_layer.training, False)

switch_mode_funciton.train()
switch_mode_funciton()
self.assertEqual(switch_mode_funciton._training, True)
_, partial_layer = switch_mode_funciton.program_cache.last()[-1]
self.assertEqual(partial_layer.training, True)

def test_raise_error(self):
paddle.disable_static()
net = SwitchModeNet()

self.assertEqual(net.training, True)
with self.assertRaises(RuntimeError):
net.forward.eval()

net.eval()
self.assertEqual(net.training, False)
with self.assertRaises(RuntimeError):
net.foo.train()


if __name__ == '__main__':
unittest.main()

1 comment on commit e14788e

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.