Skip to content
7 changes: 3 additions & 4 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,9 @@ def infer_meta(self, func, *args, **kwargs):
if isinstance(func, str):
# TODO(Aurelius84): Is length of args always greater than 0?
# Do we need add condition check here?
out = getattr(args[0], func)(*args[1:], **kwargs)
else:
out = func(*args, **kwargs)

func = getattr(args[0], func)
args = args[1:]
out = func(*args, **kwargs)
return convert_variable_to_meta_info(out)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(self):
self.translate_count = 0
self.code_symbolic_inputs = {}

def get_symbolic_inputs(self, code: types.CodeType):
def get_symbolic_inputs(
self, code: types.CodeType
) -> dict[str, dict[int, int] | None]:
self.code_symbolic_inputs.setdefault(code, {})
return self.code_symbolic_inputs[code]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,13 +640,14 @@ def dispatch_reversed(var: ContainerVariable):
# bool
Dispatcher.register(
bool,
("ContainerVariable | SymbolicVariable",),
("ContainerVariable",),
lambda var: var.bool(),
)

Dispatcher.register(
operator.truth,
("ConstantVariable | SymbolicVariable",),
lambda var: var.bool(),
("ConstantVariable",),
lambda var: Dispatcher.call(bool, var),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以将 VariableBaseoperator.truth 都转发到 bool 上嘛?会有什么问题吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@SigureMo SigureMo Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这样

Dispatcher.register(
    operator.truth,
    ("VariableBase",),
    lambda var: Dispatcher.call(bool, var)
)

重新走 bool 这个方法的 dispatch,而不是 VariableBase 的 bool

这样这个文件里 operator.truth 只需要出现一次了,其他的都可以删掉了


# str
Expand Down Expand Up @@ -936,7 +937,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
binary_fn,
),
)
# Tensor
# Tensor and Symbolic
fallback_tensor_unary_method = {
int,
bool,
Expand Down Expand Up @@ -1034,20 +1035,6 @@ def tensor_mod_dispatcher(
magic_method.name,
),
)
# Symbolic
for unary_fn in fallback_tensor_unary_method:
Dispatcher.register(
unary_fn,
("SymbolicVariable",),
partial(
lambda fn, var: VariableFactory.from_value(
fn(var.get_py_value()),
var.graph,
tracker=DummyTracker([var]),
),
unary_fn,
),
)

for binary_fn in BINARY_OPS:
for magic_method in magic_method_builtin_dispatch(binary_fn):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def analyse_dynamic_axes(self, tracker: Tracker):
1,
f"Start analyse dynamic axes for {tracker.trace_value_from_frame().inlined_expr} in {self.graph.pycode_gen._origin_code}\n",
)
for key in symbolic_inputs:
for key, symbolic_input in symbolic_inputs.items():
if key.startswith(tracker_expr):
log(1, f" {key}: {symbolic_inputs[key]}\n")
log(1, f" {key}: {symbolic_input}\n")
log(
1,
f" -> Tensor {tracker_expr} with dynamic axes {dynamic_axes}\n",
Expand Down Expand Up @@ -666,6 +666,7 @@ class SymbolicVariable(VariableBase):

var_name_generator = NameGenerator("symint_")
value: int | SymbolicValue
mutable_attrs = ["need_guard_value"]

def __init__(
self,
Expand All @@ -685,6 +686,12 @@ def __init__(
[], paddle.int64, True, self.var_name, False, None, None
)
self.need_guard_value = False
self.graph.side_effects.record_mutable_variable(self)

def to_constant(self):
return ConstantVariable(
self.get_py_value(), self.graph, DummyTracker([self])
)

def get_py_value(self, allow_tensor: bool = False) -> bool | int | float:
self.need_guard_value = True
Expand Down Expand Up @@ -748,11 +755,6 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:

assert frame_value_tracer.inlined_expr in symbolic_inputs

# TODO(zrr1999): Once dynamic shape is used, there will be no new guards
if isinstance(self.value, int):
symbolic_input = symbolic_inputs[frame_value_tracer.inlined_expr]
symbolic_input.setdefault(self.value, 0)
symbolic_input[self.value] += 1
if self.need_guard_value:
return super().make_stringified_guard()
return [
Expand All @@ -765,12 +767,16 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:

@staticmethod
def should_create_symbolic_variable(
value: Any, tracker: Tracker, symbolic_inputs: dict[str, dict[int, int]]
value: Any,
tracker: Tracker,
symbolic_inputs: dict[str, dict[int, int] | None],
):
tracker_expr = tracker.trace_value_from_frame().inlined_expr
symbolic_inputs.setdefault(tracker_expr, {})
if tracker_expr in symbolic_inputs:
symbolic_input = symbolic_inputs[tracker_expr]
if symbolic_input is None:
return False
symbolic_input.setdefault(value, 0)
symbolic_input[value] += 1
if symbolic_input[value] >= STATIC_DIM_FREQ_THRESHOLD:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
from __future__ import annotations

import inspect
import itertools
import operator
import types
from functools import reduce
from typing import TYPE_CHECKING, Any, Callable
from typing import (
TYPE_CHECKING,
Any,
Callable,
)

import paddle

from .... import psdb
from ....profiler import EventGuard
from ....utils import (
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_EXPORT,
get_static_function,
is_break_graph_api,
Expand All @@ -33,8 +39,13 @@
is_not_supported_paddle_layer,
is_paddle_api,
magic_method_builtin_dispatch,
map_if,
)
from ....utils.exceptions import (
BreakGraphError,
FallbackError,
SotErrorBase,
)
from ....utils.exceptions import BreakGraphError, FallbackError, SotErrorBase
from ..dispatcher import Dispatcher
from ..guard import (
StringifiedExpression,
Expand Down Expand Up @@ -626,11 +637,58 @@ def __init__(
self.value = fn

def call_function(self, /, *args, **kwargs):
from .basic import SymbolicVariable

# Lookup the handler from dispatcher
handler = Dispatcher.dispatch(self.value, *args, **kwargs)

if handler is not None:
return handler(*args, **kwargs)

if ENV_SOT_ALLOW_DYNAMIC_SHAPE.get() and any(
isinstance(var, SymbolicVariable)
for var in itertools.chain(args, kwargs.values())
):
fake_args, fake_kwargs = map_if(
(args, kwargs),
pred=lambda x: isinstance(x, SymbolicVariable),
# this is a fake args, we don't need to care about the value of the args
true_fn=lambda x: ConstantVariable.wrap_literal(
None, graph=self.graph
),
false_fn=lambda x: x,
)
handler = Dispatcher.dispatch(self.value, *fake_args, **fake_kwargs)
if handler is not None:
from ..executor_cache import (
OpcodeExecutorCache,
)

symbolic_inputs = OpcodeExecutorCache().get_symbolic_inputs(
self.graph.pycode_gen._origin_code
)

for var in itertools.chain(args, kwargs.values()):
if isinstance(var, SymbolicVariable):
if var.tracker.is_traceable():
tracker_expr = (
var.tracker.trace_value_from_frame().inlined_expr
)
symbolic_inputs[tracker_expr] = None
else:
for traceable_var in var.get_traceable_inputs():
tracker_expr = (
traceable_var.tracker.trace_value_from_frame().inlined_expr
)
symbolic_inputs[tracker_expr] = None
args, kwargs = map_if(
(args, kwargs),
pred=lambda x: isinstance(x, SymbolicVariable),
true_fn=lambda x: x.to_constant(),
false_fn=lambda x: x,
)
return handler(*args, **kwargs)

# Try to inline call the magic function
magic_methods = magic_method_builtin_dispatch(self.value)
for magic_method in magic_methods:
Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .exceptions import ( # noqa: F401
BreakGraphError,
DynamicShapeFallbackError,
ExportError,
FallbackError,
InnerError,
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/jit/sot/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import traceback

Expand Down Expand Up @@ -41,6 +42,10 @@ def __init__(self, msg, disable_eval_frame=False):
self.disable_eval_frame = disable_eval_frame


class DynamicShapeFallbackError(SotErrorBase):
pass


# raise in inline function call strategy.
class BreakGraphError(SotErrorBase):
pass
Expand All @@ -52,6 +57,8 @@ def inner_error_default_handler(func, message_fn):
def impl(*args, **kwargs):
try:
return func(*args, **kwargs)
except SotErrorBase as e:
raise e
except Exception as e:
message = message_fn(*args, **kwargs)
origin_exception_message = "\n".join(
Expand Down
5 changes: 2 additions & 3 deletions test/sot/test_trace_list_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@ def test_bar_static_shape(self):

@with_allow_dynamic_shape_guard(True)
def test_bar_dynamic_shape(self):
# TODO: Fix this after implement symbolic fallback mechanism
a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)]
b = [paddle.to_tensor([2, 3]), paddle.to_tensor(4), paddle.to_tensor(5)]
with test_instruction_translator_cache_context() as cache:
self.assert_results(bar, a, 1, 1)
self.assertEqual(cache.translate_count, 1)
self.assert_results(bar, a, 2, 0) # Cache hit, but break graph
self.assertEqual(cache.translate_count, 3)
self.assertEqual(cache.translate_count, 2)
self.assert_results(bar, b, 1, 1) # Cache hit
self.assertEqual(cache.translate_count, 3)
self.assertEqual(cache.translate_count, 2)


if __name__ == "__main__":
Expand Down