Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] pir dy2st unittest verification - Part 1 #58630

Merged
merged 3 commits into from
Nov 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions test/dygraph_to_static/dygraph_to_static_utils_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np

import paddle
from paddle import set_flags, static
from paddle.base import core
from paddle.jit.api import sot_mode_guard
Expand All @@ -29,9 +30,9 @@
# Usage:
class MyTest(Dy2StTestBase):
@set_to_static_mode(
ToStaticMode.LEGACY_AST | ToStaticMode.SOT | ToStaticMode.PIR_AST
ToStaticMode.AST | ToStaticMode.SOT
)
@set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)
@set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE | IrMode.PIR_API)
def test_case1(self):
raise ValueError("MyTest 1")

Expand All @@ -49,8 +50,7 @@ def test_case1(self):


class ToStaticMode(Flag):
LEGACY_AST = auto()
PIR_AST = auto()
AST = auto()
SOT = auto()

def lower_case_name(self):
Expand All @@ -59,13 +59,16 @@ def lower_case_name(self):

class IrMode(Flag):
LEGACY_IR = auto()
PIR = auto()
# pir translator mode, Reference link: https://github.com/PaddlePaddle/community/blob/master/pfcc/paddle-code-reading/IR_Dialect/program_translator.md
PIR_EXE = auto()
# using native pir api mode
PIR_API = auto()

def lower_case_name(self):
return self.name.lower()


DEFAULT_TO_STATIC_MODE = ToStaticMode.LEGACY_AST | ToStaticMode.SOT
DEFAULT_TO_STATIC_MODE = ToStaticMode.AST | ToStaticMode.SOT
DEFAULT_IR_MODE = IrMode.LEGACY_IR


Expand Down Expand Up @@ -98,13 +101,24 @@ def impl(*args, **kwargs):


def to_pir_ast_test(fn):
raise TypeError("Don't enable PIR AST mode now!")
@wraps(fn)
def impl(*args, **kwargs):
logger.info("[PIR][AST] running pir api")
ir_outs = None
try:
with paddle.pir_utils.IrGuard():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以试验尝试一下,先使用Program组网,然后使用pir::Program组网,结果是否正确。

paddle.disable_static()
ir_outs = fn(*args, **kwargs)
finally:
paddle.enable_static()
return ir_outs

return impl


def to_legacy_ir_test(fn):
def impl(*args, **kwargs):
logger.info("[Program] running legacy ir")
# breakpoint()
return fn(*args, **kwargs)

return impl
Expand Down Expand Up @@ -136,13 +150,13 @@ def impl(*args, **kwargs):
class Dy2StTestMeta(type):
TO_STATIC_HANDLER_MAP = {
ToStaticMode.SOT: to_sot_test,
ToStaticMode.LEGACY_AST: to_legacy_ast_test,
ToStaticMode.PIR_AST: to_pir_ast_test,
ToStaticMode.AST: to_legacy_ast_test,
}

IR_HANDLER_MAP = {
IrMode.LEGACY_IR: to_legacy_ir_test,
IrMode.PIR: to_pir_test,
IrMode.PIR_EXE: to_pir_test,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_pir_exe_test

IrMode.PIR_API: to_pir_ast_test,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_pir_api_test

}

def __new__(cls, name, bases, attrs):
Expand Down Expand Up @@ -191,11 +205,11 @@ def __new__(cls, name, bases, attrs):
)
# Generate all test cases
for to_static_mode, ir_mode in to_static_with_ir_modes:
# NOTE(gouzil): Temporarily not supported SOT + PIR, link: https://github.com/PaddlePaddle/Paddle/pull/58630
if (
to_static_mode == ToStaticMode.PIR_AST
and ir_mode == IrMode.LEGACY_IR
to_static_mode == ToStaticMode.SOT
and ir_mode == IrMode.PIR_API
):
# PIR with LEGACY_IR is not a valid combination
continue
new_attrs[
Dy2StTestMeta.test_case_name(
Expand Down Expand Up @@ -250,7 +264,7 @@ def decorator(fn):
# Suger decorators
# These decorators can be simply composed by base decorators
def test_ast_only(fn):
fn = set_to_static_mode(ToStaticMode.LEGACY_AST)(fn)
fn = set_to_static_mode(ToStaticMode.AST)(fn)
return fn


Expand All @@ -260,12 +274,22 @@ def test_sot_only(fn):


def test_pir_only(fn):
fn = set_ir_mode(IrMode.PIR)(fn)
fn = set_ir_mode(IrMode.PIR_EXE)(fn)
return fn


def test_legacy_and_pir(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)(fn)
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE)(fn)
return fn


def test_legacy_and_pir_api(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_API)
return fn


def test_legacy_and_pir_api_and_pir_exe(fn):
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_API | IrMode.PIR_EXE)
return fn


Expand Down