From 8dcbf9059055a312ef2d279363a55e02e4cbef3c Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sat, 7 Sep 2024 15:51:58 +0000 Subject: [PATCH 1/5] support fallback all arguments --- .../executor/function_graph.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 9403175d8cbbe..e29e767607140 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -21,7 +21,7 @@ import inspect from collections import namedtuple from copy import deepcopy -from functools import cached_property +from functools import cached_property, reduce from typing import Any, Callable, Tuple, Union from typing_extensions import TypeAlias, TypeGuard @@ -42,7 +42,6 @@ from ...symbolic.symbolic_context import SymbolicTraceContext from ...utils import ( ENV_SOT_ALLOW_DYNAMIC_SHAPE, - BreakGraphError, NameGenerator, SotUndefinedVar, inner_error_default_handler, @@ -640,13 +639,15 @@ def try_infer_meta_fn(args, kwargs) -> Any: except NotSupportedTensorArgumentError as e: bound_arguments = inspect.signature(func).bind(*args, **kwargs) bound_arguments.apply_defaults() - if e.name not in bound_arguments.arguments: - # TODO(zrr1999): fallback static shape for all symbolic variables - raise BreakGraphError( - f"Can't find {e.name} in bound arguments." + if e.name in bound_arguments.arguments: + original_var = bound_arguments.arguments[e.name] + flatten_vars = original_var.flatten_items() + else: + flatten_vars = reduce( + lambda x, y: x + y.flatten_items(), + bound_arguments.arguments.values(), + [], ) - original_var = bound_arguments.arguments[e.name] - flatten_vars = original_var.flatten_items() if not any( isinstance(arg, SymbolicVariable) for arg in flatten_vars From 6a7878b868312cbc4007ced1de70d7e228d3e8b3 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 9 Sep 2024 16:52:35 +0000 Subject: [PATCH 2/5] fix --- .../executor/function_graph.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index e29e767607140..651cc65640579 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -28,7 +28,7 @@ import paddle from paddle.jit.utils import OrderedSet -from paddle.utils import flatten +from paddle.utils import flatten, map_structure from .....utils.layers_utils import NotSupportedTensorArgumentError from ...infer_meta import ( @@ -642,6 +642,20 @@ def try_infer_meta_fn(args, kwargs) -> Any: if e.name in bound_arguments.arguments: original_var = bound_arguments.arguments[e.name] flatten_vars = original_var.flatten_items() + if not any( + isinstance(arg, SymbolicVariable) + for arg in flatten_vars + ): + raise e + + args, kwargs = map_if( + (args, kwargs), + pred=lambda x: x is original_var, + true_fn=lambda x: replace_symbolic_var_with_constant_var( + x + ), + false_fn=lambda x: x, + ) else: flatten_vars = reduce( lambda x, y: x + y.flatten_items(), @@ -649,17 +663,15 @@ def try_infer_meta_fn(args, kwargs) -> Any: [], ) - if not any( - isinstance(arg, SymbolicVariable) for arg in flatten_vars - ): - raise e + if not any( + isinstance(arg, SymbolicVariable) + for arg in flatten_vars + ): + raise e - args, kwargs = map_if( - (args, kwargs), - pred=lambda x: x is original_var, - true_fn=lambda x: replace_symbolic_var_with_constant_var(x), - false_fn=lambda x: x, - ) + args, kwargs = map_structure( + replace_symbolic_var_with_constant_var, (args, kwargs) + ) metas = convert_to_meta(args) kwmetas = convert_to_meta(kwargs) From c807af63819e7fa18687f2adf4bec9c39c7dc1c1 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 9 Sep 2024 17:00:04 +0000 Subject: [PATCH 3/5] add TODO --- .../paddle/jit/sot/opcode_translator/executor/function_graph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 651cc65640579..5554b6e1a5f11 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -646,6 +646,7 @@ def try_infer_meta_fn(args, kwargs) -> Any: isinstance(arg, SymbolicVariable) for arg in flatten_vars ): + # TODO(zrr1999): maybe we can continue to fallback to all args are constant. raise e args, kwargs = map_if( From 7a0053bf2cbc7f050d3785e2eaef35f1ea4d096c Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Wed, 18 Sep 2024 10:45:51 +0000 Subject: [PATCH 4/5] =?UTF-8?q?=E7=A9=BA=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From ab3386ee7d2209e444c9969a48b23560618567c4 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 19 Sep 2024 03:30:35 +0000 Subject: [PATCH 5/5] =?UTF-8?q?=E7=A9=BA=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit