diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 3b40633a73e25..b190c56c053bd 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -47,7 +47,7 @@ calc_stack_effect, get_instructions, ) -from ..instruction_utils.opcode_info import JumpDirection, PopJumpCond +from ..instruction_utils.opcode_info import RETURN, JumpDirection, PopJumpCond from .dispatch_functions import ( operator_BAD, operator_exception_match, @@ -1644,8 +1644,10 @@ def FOR_ITER(self, instr): start = self.indexof(instr) end = self.indexof(instr.jump_to) for i in range(start, end): - if self._instructions[i].opname == "RETURN_VALUE": - raise FallbackError("Found RETURN_VALUE in for loop body.") + if self._instructions[i].opname in RETURN: + raise FallbackError( + f"Found {self._instructions[i].opname} in for loop body." + ) self._graph.add_global_guarded_variable(iterator) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py index 2e8ded5d2ac5e..eb8cb1735bddf 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -19,7 +19,13 @@ from paddle.jit.utils import OrderedSet from .instruction_utils import Instruction -from .opcode_info import ALL_JUMP, HAS_FREE, HAS_LOCAL, UNCONDITIONAL_JUMP +from .opcode_info import ( + ALL_JUMP, + HAS_FREE, + HAS_LOCAL, + RETURN, + UNCONDITIONAL_JUMP, +) @dataclasses.dataclass @@ -122,7 +128,7 @@ def walk(state: State, start: int) -> OrderedSet[str]: else State(OrderedSet(), OrderedSet(), OrderedSet()) ) return jump_branch | not_jump_branch - elif instr.opname == "RETURN_VALUE": + elif instr.opname in RETURN: return state return state diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py index e9b4af9f03fb0..2dc69b7565672 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py @@ -28,6 +28,9 @@ UNCONDITIONAL_JUMP = {"JUMP_ABSOLUTE", "JUMP_FORWARD"} if sys.version_info >= (3, 11): UNCONDITIONAL_JUMP.add("JUMP_BACKWARD") +RETURN = {"RETURN_VALUE"} +if sys.version_info >= (3, 12): + RETURN.add("RETURN_CONST") class JumpDirection(Enum):