Skip to content

Commit

Permalink
Modify based on reviewer's comment, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
zhhsplendid committed Sep 14, 2020
1 parent 896ef8a commit 702d8e8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
22 changes: 11 additions & 11 deletions python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __init__(self, function, input_spec=None):
self._function_spec = FunctionSpec(function, input_spec)
self._program_cache = ProgramCache()
self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_static`.
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator()

def __get__(self, instance, owner):
Expand Down Expand Up @@ -300,11 +300,11 @@ def __call__(self, *args, **kwargs):
"""

# 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_static:
if not self._program_trans.enable_to_static:
logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. If you would like to get static graph output, please call API "
"ProgramTranslator.enable_static(True)")
"ProgramTranslator.enable(True)")
return self._call_dygraph_function(*args, **kwargs)

if not in_dygraph_mode():
Expand Down Expand Up @@ -729,15 +729,15 @@ def __init__(self):
return
self._initialized = True
self._program_cache = ProgramCache()
self.enable_static = True
self.enable_to_static = True

def enable(self, enable_static):
def enable(self, enable_to_static):
"""
Enable or disable the converting from imperative to declarative by
ProgramTranslator globally.
Args:
enable_static (bool): True or False to enable or disable declarative.
enable_to_static (bool): True or False to enable or disable declarative.
Returns:
None.
Expand Down Expand Up @@ -766,9 +766,9 @@ def func(x):
print(func(x).numpy()) # [[2. 2.]]
"""
check_type(enable_static, "enable_static", bool,
check_type(enable_to_static, "enable_to_static", bool,
"ProgramTranslator.enable")
self.enable_static = enable_static
self.enable_to_static = enable_to_static

def get_output(self, dygraph_func, *args, **kwargs):
"""
Expand Down Expand Up @@ -809,7 +809,7 @@ def func(x):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
if not self.enable_static:
if not self.enable_to_static:
warnings.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. "
Expand Down Expand Up @@ -876,7 +876,7 @@ def func(x):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
if not self.enable_static:
if not self.enable_to_static:
warnings.warn(
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
"just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
Expand Down Expand Up @@ -930,7 +930,7 @@ def func(x):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
if not self.enable_static:
if not self.enable_to_static:
warnings.warn(
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
"We will just return dygraph output. "
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def func(x):
# TODO: remove this decorator after we finalize training API
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_static:
if in_dygraph_mode() or not program_translator.enable_to_static:
warnings.warn(
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set ProgramTranslator.enable to False. "
Expand Down Expand Up @@ -775,7 +775,7 @@ def get_inout_spec(all_vars, target_vars, return_name=False):

# 1. input check
prog_translator = ProgramTranslator()
if not prog_translator.enable_static:
if not prog_translator.enable_to_static:
raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,7 @@ def get_inout_spec(all_vars, return_name=False):

# 1. input check
prog_translator = ProgramTranslator()
if not prog_translator.enable_static:
if not prog_translator.enable_to_static:
raise RuntimeError(
"save_inference_model doesn't work when setting ProgramTranslator.enable to False."
)
Expand Down

0 comments on commit 702d8e8

Please sign in to comment.