Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT][3.11] Skip frame when code has listcomp or genexpr #59816

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import gc
import sys
import traceback
import types
from typing import List, Tuple
Expand Down Expand Up @@ -189,6 +190,13 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction:
Returns:
GuardedFunction | None: The translated code object and its guard function, or None if translation fails.
"""
if sys.version_info >= (3, 11):
for const in frame.f_code.co_consts:
if isinstance(const, types.CodeType) and const.co_name.startswith(
"<"
):
log(2, f"Found code object {const.co_name}, skip it\n")
return CustomCode(None, False), dummy_guard
simulator = OpcodeExecutor(frame, **kwargs)
try:
simulator.check_code_simulatable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,11 @@ def code_exist(opname, argval, instrs):
modified = True
else:
idx += 1


def remove_duplicate_resume(instrs, code_options):
resumes = list(filter(lambda instr: instr.opname == "RESUME", instrs))
if not resumes:
return
for resume in resumes[1:]:
instrs.remove(resume)
7 changes: 6 additions & 1 deletion test/sot/test_05_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
# BUILD_MAP (new)
# BUILD_CONST_KEY_MAP (new)

import sys
import unittest

from test_case_base import TestCaseBase

import paddle
from paddle.jit.sot.psdb import check_no_breakgraph
from paddle.jit.sot.utils.envs import strict_mode_guard


@check_no_breakgraph
Expand Down Expand Up @@ -242,7 +244,10 @@ def test_construct(self):
self.assert_results(dict_construct_from_dict)
self.assert_results(dict_construct_from_list)
self.assert_results(dict_construct_from_tuple)
self.assert_results(dict_construct_from_comprehension)
# Temporarily fallback for comprehension in python3.11
use_strict_mode = sys.version_info < (3, 11)
with strict_mode_guard(use_strict_mode):
self.assert_results(dict_construct_from_comprehension)

def test_dict_noargs(self):
self.assert_results(dict_no_arguments)
Expand Down
6 changes: 5 additions & 1 deletion test/sot/test_12_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import sys
import unittest

from test_case_base import TestCaseBase
Expand Down Expand Up @@ -240,7 +241,10 @@ def run_list_comp(x):
class TestListComp(TestCaseBase):
def test_list_comp(self):
x = [paddle.randn([1, 4]), paddle.randn([1, 4])]
self.assert_results(run_list_comp, x)
# Temporarily fallback for comprehension in python3.11
use_strict_mode = sys.version_info < (3, 11)
with strict_mode_guard(use_strict_mode):
self.assert_results(run_list_comp, x)


def for_enumerate_cache(func_list, x):
Expand Down
16 changes: 10 additions & 6 deletions test/sot/test_builtin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import sys
import unittest
from typing import Iterable

Expand Down Expand Up @@ -103,12 +104,15 @@ def test_map(self):
self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3})

def test_map_comprehension(self):
self.assert_results(test_map_list_comprehension, [1, 2, 3, 4])
self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4))
self.assert_results(test_map_range_comprehension, range(5))
self.assert_results(
test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3}
)
# Temporarily fallback for comprehension in python3.11
use_strict_mode = sys.version_info < (3, 11)
with strict_mode_guard(use_strict_mode):
self.assert_results(test_map_list_comprehension, [1, 2, 3, 4])
self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4))
self.assert_results(test_map_range_comprehension, range(5))
self.assert_results(
test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3}
)

def test_map_with_breakgraph(self):
with strict_mode_guard(False):
Expand Down
56 changes: 56 additions & 0 deletions test/sot/test_listcomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from test_case_base import TestCaseBase

import paddle
from paddle.jit.sot.utils.envs import min_graph_size_guard, strict_mode_guard

# 8 will trigger the warmup in RESUME instruction and cause a segmentation fault
# RUN_N_TIMES should be larger than 8
RUN_N_TIMES = 20


def listcomp_fn():
print(1)
x = [i for i in range(10)] # noqa: C416
return x


def genexpr_fn():
print(1)
x = (i for i in range(10))
return x


class TestListComp(TestCaseBase):
@strict_mode_guard(False)
@min_graph_size_guard(10)
def test_listcomp(self):
for _ in range(RUN_N_TIMES):
paddle.jit.to_static(listcomp_fn)()


class TestGenExpr(TestCaseBase):
@strict_mode_guard(False)
@min_graph_size_guard(10)
def test_genexpr(self):
for _ in range(RUN_N_TIMES):
paddle.jit.to_static(genexpr_fn)()


if __name__ == "__main__":
unittest.main()