Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] Convert dtype to DataType in PIR mode #60627

Merged
merged 5 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,9 @@ def program_id(self):
Return current train or eval program hash id.
"""
if _in_amp_guard() or _in_pure_fp16_guard():
raise NotImplementedError("not implement error.")
raise NotImplementedError(
"Currently, AMP is not supported in PIR mode"
)
if self.training:
return self._train_program_id
else:
Expand All @@ -632,13 +634,17 @@ def program_id(self):
@cached_property
def train_program(self):
if _in_amp_guard() or _in_pure_fp16_guard():
raise NotImplementedError("not implement error.")
raise NotImplementedError(
"Currently, AMP is not supported in PIR mode"
)
return self._create_program()

@cached_property
def infer_program(self):
if _in_amp_guard() or _in_pure_fp16_guard():
raise NotImplementedError("not implement error.")
raise NotImplementedError(
"Currently, AMP is not supported in PIR mode"
)
return self._create_program(is_infer_mode=True)

def _verify_program(self, main_program):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import numpy as np

import paddle
from paddle.framework import use_pir_api
from paddle.pir.core import vartype_to_datatype

from ....infer_meta import MetaInfo
from ....symbolic.statement_ir import Symbol
Expand Down Expand Up @@ -267,6 +269,20 @@ def make_stringify_guard(self) -> list[StringifyExpression]:
else:
return object_equal_stringify_guard(self)

def get_py_value(self, allow_tensor=False):
if use_pir_api() and isinstance(
self.value, paddle.base.core.VarDesc.VarType
):
return vartype_to_datatype[self.value]
return super().get_py_value(allow_tensor)

def get_py_type(self):
if use_pir_api() and isinstance(
self.value, paddle.base.core.VarDesc.VarType
):
return paddle.pir.core.DataType
return super().get_py_type()

@property
def main_info(self) -> dict[str, Any]:
return {
Expand Down
75 changes: 41 additions & 34 deletions test/sot/test_case_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

import contextlib
import copy
import inspect
import os
import types
import unittest
from functools import wraps

import numpy as np

Expand All @@ -38,39 +37,7 @@ def test_instruction_translator_cache_context():
cache.clear()


def github_action_error_msg(msg: str):
if 'GITHUB_ACTIONS' in os.environ:
frame = inspect.currentframe()
if frame is not None:
# find the first frame that is in the test folder
while frame.f_back is not None:
filename = frame.f_code.co_filename
if filename.startswith("./"):
filename = f"tests/{filename[2:]}"
lineno = frame.f_lineno
output = f"\n::error file={filename},line={lineno}::{msg}"
return output
frame = frame.f_back
return None


class TestCaseBase(unittest.TestCase):
def assertIs(self, x, y, msg=None):
super().assertIs(x, y, msg=msg)
if msg is None:
msg = f"Assert Is, x is {x}, y is {y}"
msg = github_action_error_msg(msg)
if msg is not None:
print(msg)

def assertEqual(self, x, y, msg=None):
super().assertEqual(x, y, msg=msg)
if msg is None:
msg = f"Assert Equal, x is {x}, y is {y}"
msg = github_action_error_msg(msg)
if msg is not None:
print(msg)

def assert_nest_match(self, x, y):
cls_x = type(x)
cls_y = type(y)
Expand Down Expand Up @@ -136,3 +103,43 @@ def copy_fn(fn):
sym_copied_fn.__globals__[key], paddle_fn.__globals__[key]
)
self.assert_nest_match(sym_output, paddle_output)


# Some decorators for PIR test
def to_pir_test(fn):
# NOTE(SigureMo): This function should sync with test/dygraph_to_static/dygraph_to_static_utils.py
@wraps(fn)
def impl(*args, **kwargs):
in_dygraph_mode = paddle.in_dynamic_mode()
with paddle.pir_utils.IrGuard():
if in_dygraph_mode:
paddle.disable_static()
ir_outs = fn(*args, **kwargs)
return ir_outs

return impl


def run_in_pir_mode(fn):
@wraps(fn)
def impl(*args, **kwargs):
OpcodeExecutorCache().clear()
pir_fn = to_pir_test(fn)
return pir_fn(*args, **kwargs)

return impl


def run_in_both_default_and_pir(fn):
@wraps(fn)
def impl(*args, **kwargs):
OpcodeExecutorCache().clear()
default_fn = fn
pir_fn = to_pir_test(fn)
default_outs = default_fn(*args, **kwargs)
OpcodeExecutorCache().clear()
# The out of test case should be None, which is not used.
_pir_outs = pir_fn(*args, **kwargs)
return default_outs

return impl
58 changes: 58 additions & 0 deletions test/sot/test_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from test_case_base import (
TestCaseBase,
run_in_both_default_and_pir,
test_instruction_translator_cache_context,
)

import paddle


def tensor_astype(x, y):
z = x.astype(y.dtype)
return z


def tensor_dtype_guard(x):
return x + 1


class TestTensorAstype(TestCaseBase):
@run_in_both_default_and_pir
def test_tensor_astype(self):
x = paddle.ones([2, 3], dtype="float32")
y = paddle.ones([2, 3], dtype="int32")
self.assert_results(tensor_astype, x, y)


class TestTensorDtypeGuard(TestCaseBase):
@run_in_both_default_and_pir
def test_tensor_dtype_guard(self):
x = paddle.ones([2, 3], dtype="float32")
y = paddle.ones([2, 3], dtype="int32")
with test_instruction_translator_cache_context() as ctx:
self.assert_results(tensor_dtype_guard, x)
self.assertEqual(ctx.translate_count, 1)
self.assert_results(tensor_dtype_guard, y)
self.assertEqual(ctx.translate_count, 2)
self.assert_results(tensor_dtype_guard, x)
self.assertEqual(ctx.translate_count, 2)


if __name__ == "__main__":
unittest.main()