Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT][dynamic shape] add symbolic fallback when SymbolicVariable's handler not exists. #67786

Merged
merged 14 commits into from
Aug 30, 2024
7 changes: 3 additions & 4 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,9 @@ def infer_meta(self, func, *args, **kwargs):
if isinstance(func, str):
# TODO(Aurelius84): Is length of args always greater than 0?
# Do we need add condition check here?
out = getattr(args[0], func)(*args[1:], **kwargs)
else:
out = func(*args, **kwargs)

func = getattr(args[0], func)
args = args[1:]
out = func(*args, **kwargs)
return convert_variable_to_meta_info(out)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(self):
self.translate_count = 0
self.code_symbolic_inputs = {}

def get_symbolic_inputs(self, code: types.CodeType):
def get_symbolic_inputs(
self, code: types.CodeType
) -> dict[str, dict[int, int] | None]:
self.code_symbolic_inputs.setdefault(code, {})
return self.code_symbolic_inputs[code]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,13 +640,14 @@ def dispatch_reversed(var: ContainerVariable):
# bool
Dispatcher.register(
bool,
("ContainerVariable | SymbolicVariable",),
("ContainerVariable",),
lambda var: var.bool(),
)

Dispatcher.register(
operator.truth,
("ConstantVariable | SymbolicVariable",),
lambda var: var.bool(),
("ConstantVariable",),
lambda var: Dispatcher.call(bool, var),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以将 VariableBaseoperator.truth 都转发到 bool 上嘛?会有什么问题吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@SigureMo SigureMo Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这样

Dispatcher.register(
    operator.truth,
    ("VariableBase",),
    lambda var: Dispatcher.call(bool, var)
)

重新走 bool 这个方法的 dispatch,而不是 VariableBase 的 bool

这样这个文件里 operator.truth 只需要出现一次了,其他的都可以删掉了


# str
Expand Down Expand Up @@ -936,7 +937,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
binary_fn,
),
)
# Tensor
# Tensor and Symbolic
fallback_tensor_unary_method = {
int,
bool,
Expand Down Expand Up @@ -1034,20 +1035,6 @@ def tensor_mod_dispatcher(
magic_method.name,
),
)
# Symbolic
for unary_fn in fallback_tensor_unary_method:
Dispatcher.register(
unary_fn,
("SymbolicVariable",),
partial(
lambda fn, var: VariableFactory.from_value(
fn(var.get_py_value()),
var.graph,
tracker=DummyTracker([var]),
),
unary_fn,
),
)

for binary_fn in BINARY_OPS:
for magic_method in magic_method_builtin_dispatch(binary_fn):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def analyse_dynamic_axes(self, tracker: Tracker):
1,
f"Start analyse dynamic axes for {tracker.trace_value_from_frame().inlined_expr} in {self.graph.pycode_gen._origin_code}\n",
)
for key in symbolic_inputs:
for key, symbolic_input in symbolic_inputs.items():
if key.startswith(tracker_expr):
log(1, f" {key}: {symbolic_inputs[key]}\n")
log(1, f" {key}: {symbolic_input}\n")
log(
1,
f" -> Tensor {tracker_expr} with dynamic axes {dynamic_axes}\n",
Expand Down Expand Up @@ -666,6 +666,7 @@ class SymbolicVariable(VariableBase):

var_name_generator = NameGenerator("symint_")
value: int | SymbolicValue
mutable_attrs = ["need_guard_value"]
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand All @@ -685,6 +686,12 @@ def __init__(
[], paddle.int64, True, self.var_name, False, None, None
)
self.need_guard_value = False
self.graph.side_effects.record_mutable_variable(self)

def to_constant(self):
return ConstantVariable(
self.get_py_value(), self.graph, DummyTracker([self])
)

def get_py_value(self, allow_tensor: bool = False) -> bool | int | float:
self.need_guard_value = True
Expand Down Expand Up @@ -748,11 +755,6 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:

assert frame_value_tracer.inlined_expr in symbolic_inputs

# TODO(zrr1999): Once dynamic shape is used, there will be no new guards
if isinstance(self.value, int):
symbolic_input = symbolic_inputs[frame_value_tracer.inlined_expr]
symbolic_input.setdefault(self.value, 0)
symbolic_input[self.value] += 1
if self.need_guard_value:
return super().make_stringified_guard()
return [
Expand All @@ -765,12 +767,16 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:

@staticmethod
def should_create_symbolic_variable(
value: Any, tracker: Tracker, symbolic_inputs: dict[str, dict[int, int]]
value: Any,
tracker: Tracker,
symbolic_inputs: dict[str, dict[int, int] | None],
):
tracker_expr = tracker.trace_value_from_frame().inlined_expr
symbolic_inputs.setdefault(tracker_expr, {})
if tracker_expr in symbolic_inputs:
symbolic_input = symbolic_inputs[tracker_expr]
if symbolic_input is None:
return False
symbolic_input.setdefault(value, 0)
symbolic_input[value] += 1
if symbolic_input[value] >= STATIC_DIM_FREQ_THRESHOLD:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
from __future__ import annotations

import inspect
import itertools
import operator
import types
from functools import reduce
from typing import TYPE_CHECKING, Any, Callable
from typing import (
TYPE_CHECKING,
Any,
Callable,
)

import paddle

from .... import psdb
from ....profiler import EventGuard
from ....utils import (
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_EXPORT,
get_static_function,
is_break_graph_api,
Expand All @@ -33,8 +39,13 @@
is_not_supported_paddle_layer,
is_paddle_api,
magic_method_builtin_dispatch,
map_if,
)
from ....utils.exceptions import (
BreakGraphError,
FallbackError,
SotErrorBase,
)
from ....utils.exceptions import BreakGraphError, FallbackError, SotErrorBase
from ..dispatcher import Dispatcher
from ..guard import (
StringifiedExpression,
Expand Down Expand Up @@ -626,11 +637,58 @@ def __init__(
self.value = fn

def call_function(self, /, *args, **kwargs):
from .basic import SymbolicVariable

# Lookup the handler from dispatcher
handler = Dispatcher.dispatch(self.value, *args, **kwargs)

if handler is not None:
return handler(*args, **kwargs)

if ENV_SOT_ALLOW_DYNAMIC_SHAPE.get() and any(
isinstance(var, SymbolicVariable)
for var in itertools.chain(args, kwargs.values())
):
fake_args, fake_kwargs = map_if(
(args, kwargs),
pred=lambda x: isinstance(x, SymbolicVariable),
# this is a fake args, we don't need to care about the value of the args
true_fn=lambda x: ConstantVariable.wrap_literal(
None, graph=self.graph
),
false_fn=lambda x: x,
)
handler = Dispatcher.dispatch(self.value, *fake_args, **fake_kwargs)
if handler is not None:
from ..executor_cache import (
OpcodeExecutorCache,
)

symbolic_inputs = OpcodeExecutorCache().get_symbolic_inputs(
self.graph.pycode_gen._origin_code
)

for var in itertools.chain(args, kwargs.values()):
if isinstance(var, SymbolicVariable):
if var.tracker.is_traceable():
tracker_expr = (
var.tracker.trace_value_from_frame().inlined_expr
)
symbolic_inputs[tracker_expr] = None
else:
for traceable_var in var.get_traceable_inputs():
tracker_expr = (
traceable_var.tracker.trace_value_from_frame().inlined_expr
)
symbolic_inputs[tracker_expr] = None
args, kwargs = map_if(
(args, kwargs),
pred=lambda x: isinstance(x, SymbolicVariable),
true_fn=lambda x: x.to_constant(),
false_fn=lambda x: x,
)
return handler(*args, **kwargs)

# Try to inline call the magic function
magic_methods = magic_method_builtin_dispatch(self.value)
for magic_method in magic_methods:
Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .exceptions import ( # noqa: F401
BreakGraphError,
DynamicShapeFallbackError,
ExportError,
FallbackError,
InnerError,
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/jit/sot/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import traceback

Expand Down Expand Up @@ -41,6 +42,10 @@ def __init__(self, msg, disable_eval_frame=False):
self.disable_eval_frame = disable_eval_frame


class DynamicShapeFallbackError(SotErrorBase):
pass


# raise in inline function call strategy.
class BreakGraphError(SotErrorBase):
pass
Expand All @@ -52,6 +57,8 @@ def inner_error_default_handler(func, message_fn):
def impl(*args, **kwargs):
try:
return func(*args, **kwargs)
except SotErrorBase as e:
raise e
except Exception as e:
message = message_fn(*args, **kwargs)
origin_exception_message = "\n".join(
Expand Down
5 changes: 2 additions & 3 deletions test/sot/test_trace_list_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@ def test_bar_static_shape(self):

@with_allow_dynamic_shape_guard(True)
def test_bar_dynamic_shape(self):
# TODO: Fix this after implement symbolic fallback mechanism
a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)]
b = [paddle.to_tensor([2, 3]), paddle.to_tensor(4), paddle.to_tensor(5)]
with test_instruction_translator_cache_context() as cache:
self.assert_results(bar, a, 1, 1)
self.assertEqual(cache.translate_count, 1)
self.assert_results(bar, a, 2, 0) # Cache hit, but break graph
self.assertEqual(cache.translate_count, 3)
self.assertEqual(cache.translate_count, 2)
self.assert_results(bar, b, 1, 1) # Cache hit
self.assertEqual(cache.translate_count, 3)
self.assertEqual(cache.translate_count, 2)


if __name__ == "__main__":
Expand Down