Skip to content

Commit

Permalink
[cherry-pick][Dy2Stat]Support non-tensor type in input_spec (#33464) #…
Browse files Browse the repository at this point in the history
…34378

[Dy2Stat]Support non-tensor type in input_spec
  • Loading branch information
2742195759 authored Jul 26, 2021
1 parent dbc54d2 commit 9b48cfd
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 49 deletions.
13 changes: 3 additions & 10 deletions python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,8 @@ def _verify_input_spec(self, input_spec):
raise TypeError(
"The type(input_spec) should be one of (tuple, list), but received {}.".
format(type_name(input_spec)))
input_spec = tuple(input_spec)
for spec in flatten(input_spec):
if not isinstance(spec, paddle.static.InputSpec):
raise ValueError(
"The type(elem) from input_spec should be `InputSpec`, but received {}.".
format(type_name(spec)))

return input_spec
return tuple(input_spec)

def __repr__(self):
return "function: {}({}), input_spec: {}".format(
Expand Down Expand Up @@ -326,9 +320,8 @@ def check_type_and_len(input, spec, check_length=False):
elif isinstance(input_spec, paddle.static.InputSpec):
return input_spec
else:
raise TypeError(
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.".
type_name(input_spec))
# NOTE(Aurelius84): Support non-Tensor type as input spec info
return input_spec


def replace_spec_empty_name(args_name, input_with_spec):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import six
import textwrap
import threading
import warnings
import weakref

from paddle.fluid import framework
Expand Down Expand Up @@ -314,7 +313,7 @@ def __call__(self, *args, **kwargs):
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
# will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to
# display this warning message only once.
warnings.warn(
logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. If you would like to get static graph output, please call API "
"ProgramTranslator.enable(True)")
Expand Down Expand Up @@ -481,6 +480,10 @@ def concrete_program_specify_input_spec(self, input_spec=None):
# NOTE(chenweihang): we should always translated program based on the `input_spec`
# decorated on forward if it is valid
desired_input_spec = self._function_spec.input_spec
if input_spec is not None:
logging_utils.warn(
"\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n".
format(desired_input_spec, input_spec))

has_input_spec = (desired_input_spec is not None)
if has_input_spec:
Expand Down Expand Up @@ -886,7 +889,7 @@ def func(x):
if not self.enable_to_static:
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
# will show up **only once**.
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
Expand Down
86 changes: 63 additions & 23 deletions python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import textwrap
import numpy as np

import paddle
from paddle.fluid import unique_name
from paddle.fluid.data_feeder import convert_dtype

Expand Down Expand Up @@ -141,9 +142,9 @@ def make_hashable(x, error_msg=None):
"""
Makes input `x` hashable.
For some unhashable objects, such as `dict/list/np.ndarray`,applying hash function by using their values.
For some unhashable objects, such as `dict/list/set/np.ndarray`,applying hash function by using their values.
"""
if isinstance(x, (tuple, list)):
if isinstance(x, (tuple, list, set)):
return tuple(map(make_hashable, x))

try:
Expand Down Expand Up @@ -1428,10 +1429,10 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
Returns True if the two input specs are compatible, otherwise False.
args:
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
src_input_spec (list or tuple[InputSpec et.al]): list/tuple of
paddle.static.InputSpec or int/str et.al
desired_input_specs (list or tuple[InputSpec et.al]): list/tuple of
paddle.static.InputSpec or int/str et.al
"""
len_specs = len(src_input_specs)
if len_specs != len(desired_input_specs):
Expand All @@ -1440,30 +1441,69 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
for spec in src_input_specs:
if spec not in desired_input_specs:
return False

else:
for i in range(len_specs):
src_shape = src_input_specs[i].shape
other_shape = desired_input_specs[i].shape
len_shape = len(src_shape)
if len_shape != len(other_shape):
return False
for j in range(len_shape):
if src_shape[j] is None or src_shape[j] < 0:
continue
if other_shape[j] is None or other_shape[j] < 0:
continue
if src_shape[j] != other_shape[j]:
for (src_spec, desired_spec) in zip(src_input_specs,
desired_input_specs):
if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
desired_spec, paddle.static.InputSpec):
if not _compatible_tensor_spec(src_spec, desired_spec):
return False
else:
if not _compatible_non_tensor_spec(src_spec, desired_spec):
return False

src_dtype = convert_dtype(src_input_specs[i].dtype)
other_dtype = convert_dtype(desired_input_specs[i].dtype)
if src_dtype != other_dtype:
return False
return True


def _compatible_tensor_spec(src_spec, desired_spec):
"""
Check whether two tensor type spec is compatible.
"""
for spec in [src_spec, desired_spec]:
if not isinstance(spec, paddle.static.InputSpec):
return False
src_shape = src_spec.shape
other_shape = desired_spec.shape
len_shape = len(src_shape)
if len_shape != len(other_shape):
return False
for j in range(len_shape):
if src_shape[j] is None or src_shape[j] < 0:
continue
if other_shape[j] is None or other_shape[j] < 0:
continue
if src_shape[j] != other_shape[j]:
return False

src_dtype = convert_dtype(src_spec.dtype)
other_dtype = convert_dtype(desired_spec.dtype)
if src_dtype != other_dtype:
return False

return True


def _compatible_non_tensor_spec(src_spec, desired_spec):
"""
Check whether two non-tensor type spec is compatible.
"""

def hash_value(spec):
try:
hash_val = make_hashable(spec)
except:
hash_val = None
return hash_val

src_hash_val = hash_value(src_spec)
desired_hash_val = hash_value(desired_spec)

if src_hash_val != desired_hash_val:
return False
else:
return True


def slice_is_num(slice_node):
# A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2]
Expand Down
23 changes: 15 additions & 8 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,15 @@ def _get_input_var_names(inputs, input_spec):
]
if input_spec is None:
# no prune
result_list = input_var_names
elif input_spec is not None and len(input_spec) == len(input_var_names):
return input_var_names
else:
# fileter out non-tensor type spec infos.
input_spec = [
spec for spec in input_spec
if isinstance(spec, paddle.static.InputSpec)
]

if len(input_spec) == len(input_var_names):
# no prune
result_list = input_var_names
# if input spec name not in input_var_names, only raise warning
Expand Down Expand Up @@ -530,8 +537,9 @@ def save(layer, path, input_spec=None, **configs):
Args:
layer (Layer|function): The Layer or function to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of
input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument,
such as int, float, string, or list/dict of them.If None, all input variables of
the original Layer's forward method would be the inputs of the saved model. Default None.
**configs (dict, optional): Other save configuration options for compatibility. We do not
recommend using these configurations, they may be removed in the future. If not necessary,
Expand Down Expand Up @@ -698,9 +706,8 @@ def fun(inputs):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var))
else:
raise TypeError(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
% type(var))
# NOTE(Aurelius84): Support non-Tensor type in `input_spec`.
inner_input_spec.append(var)

# parse configs
configs = _parse_save_configs(configs)
Expand All @@ -719,7 +726,7 @@ def fun(inputs):
inner_input_spec)
elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same sturcture
# inner_input_spec is list[InputSpec], it should be packed with same structure
# as original input_spec here.
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def test_verify_input_spec(self):
with self.assertRaises(TypeError):
foo_spec = FunctionSpec(foo_func, input_spec=a_spec)

# each element of input_spec should be `InputSpec`
with self.assertRaises(ValueError):
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, 10])

foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
self.assertTrue(len(foo_spec.flat_input_spec) == 2)

Expand Down
Loading

0 comments on commit 9b48cfd

Please sign in to comment.