Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT][dynamic shape] Fallback static shape all symvar #68113

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
import inspect
from collections import namedtuple
from copy import deepcopy
from functools import cached_property
from functools import cached_property, reduce
from typing import Any, Callable, Tuple, Union

from typing_extensions import TypeAlias, TypeGuard

import paddle
from paddle.jit.utils import OrderedSet
from paddle.utils import flatten
from paddle.utils import flatten, map_structure

from .....utils.layers_utils import NotSupportedTensorArgumentError
from ...infer_meta import (
Expand All @@ -42,7 +42,6 @@
from ...symbolic.symbolic_context import SymbolicTraceContext
from ...utils import (
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
BreakGraphError,
NameGenerator,
SotUndefinedVar,
inner_error_default_handler,
Expand Down Expand Up @@ -640,25 +639,40 @@ def try_infer_meta_fn(args, kwargs) -> Any:
except NotSupportedTensorArgumentError as e:
bound_arguments = inspect.signature(func).bind(*args, **kwargs)
bound_arguments.apply_defaults()
if e.name not in bound_arguments.arguments:
# TODO(zrr1999): fallback static shape for all symbolic variables
raise BreakGraphError(
f"Can't find {e.name} in bound arguments."
if e.name in bound_arguments.arguments:
original_var = bound_arguments.arguments[e.name]
flatten_vars = original_var.flatten_items()
if not any(
isinstance(arg, SymbolicVariable)
for arg in flatten_vars
):
# TODO(zrr1999): maybe we can continue to fallback to all args are constant.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 TODO 是不是可以删了?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块应该是指 e.name in bound_arguments.arguments 但是依然报错了还是可以继续fallback,比如name有人手抖写成另一个参数了

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块应该是指 e.name in bound_arguments.arguments 但是依然报错了还是可以继续fallback,比如name有人手抖写成另一个参数了。不过感觉加上会增加复杂性而且感觉好像也遇不到

raise e

args, kwargs = map_if(
(args, kwargs),
pred=lambda x: x is original_var,
true_fn=lambda x: replace_symbolic_var_with_constant_var(
x
),
false_fn=lambda x: x,
)
else:
flatten_vars = reduce(
lambda x, y: x + y.flatten_items(),
bound_arguments.arguments.values(),
[],
)
original_var = bound_arguments.arguments[e.name]
flatten_vars = original_var.flatten_items()

if not any(
isinstance(arg, SymbolicVariable) for arg in flatten_vars
):
raise e
if not any(
isinstance(arg, SymbolicVariable)
for arg in flatten_vars
):
raise e

args, kwargs = map_if(
(args, kwargs),
pred=lambda x: x is original_var,
true_fn=lambda x: replace_symbolic_var_with_constant_var(x),
false_fn=lambda x: x,
)
args, kwargs = map_structure(
replace_symbolic_var_with_constant_var, (args, kwargs)
)

metas = convert_to_meta(args)
kwmetas = convert_to_meta(kwargs)
Expand Down