Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] avoid trace create layer tracker #61858

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,12 @@ def LOAD_ATTR(self, instr: Instruction):
)(obj, attr_name_var)
)

@call_break_graph_decorator(push_n=1)
def LOAD_SUPER_ATTR(self, instr: Instruction):
# This bytecode is for Python 3.12+, and it will break graph in Python 3.11-.
# We align it's behavior with Python 3.11-.
raise BreakGraphError("call super is not supported")

def LOAD_CONST(self, instr: Instruction):
var = self._co_consts[instr.arg]
self.stack.push(var)
Expand Down
32 changes: 32 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import builtins
import dis
import sys
from itertools import chain
from typing import TYPE_CHECKING

from ...utils import InnerError, NameGenerator
Expand Down Expand Up @@ -432,5 +433,36 @@ def gen_instructions(self, codegen: PyCodeGen):
codegen.gen_build_map(len(self.kwargs))
codegen.gen_call_function_ex(has_kwargs=True)

def trace_value_from_frame(self):
class_tracer = self.layer_class.tracker.trace_value_from_frame()
arg_tracers = [
arg.tracker.trace_value_from_frame() for arg in self.args
]
kwarg_tracers_dict = {
k: v.tracker.trace_value_from_frame()
for k, v in self.kwargs.items()
}
kwarg_tracers = list(kwarg_tracers_dict.values())

expr = "{}("
expr += ", ".join(["{}"] * len(arg_tracers))
if len(arg_tracers) and len(kwarg_tracers) > 0:
expr += ", "
expr += ", ".join(f"{k}={{}}" for k in kwarg_tracers_dict.keys())
expr += ")"

return StringifyExpression(
expr,
[class_tracer] + arg_tracers + kwarg_tracers,
union_free_vars(
*(
tracer.free_vars
for tracer in chain(
[class_tracer], arg_tracers, kwarg_tracers
)
)
),
)

def __repr__(self) -> str:
return f"CreateLayerTracker(Layer={self.layer_class}, args={self.args}, kwargs={self.kwargs})"
19 changes: 19 additions & 0 deletions test/sot/test_simulate_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import paddle
from paddle import nn
from paddle.jit.sot import symbolic_translate
from paddle.jit.sot.utils import strict_mode_guard


class A:
Expand All @@ -43,6 +44,20 @@ def error_foo(x):
return t(x)


class NopLayer(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.weight = None


def created_layer_reconstruct():
x = paddle.to_tensor([1, 2], dtype="float32")
weight = NopLayer().weight
if weight is not None:
x += 1
return x


def bar(x):
a = A(x)
t = paddle.to_tensor(x)
Expand All @@ -66,6 +81,10 @@ def run():

self.assertRaises(paddle.jit.sot.utils.exceptions.InnerError, run)

@strict_mode_guard(False)
def test_created_layer_reconstruct(self):
self.assert_results(created_layer_reconstruct)


if __name__ == "__main__":
unittest.main()