diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 574821ab5b342..54712e1047d80 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -623,7 +623,9 @@ def program_id(self): Return current train or eval program hash id. """ if _in_amp_guard() or _in_pure_fp16_guard(): - raise NotImplementedError("not implement error.") + raise NotImplementedError( + "Currently, AMP is not supported in PIR mode" + ) if self.training: return self._train_program_id else: @@ -632,13 +634,17 @@ def program_id(self): @cached_property def train_program(self): if _in_amp_guard() or _in_pure_fp16_guard(): - raise NotImplementedError("not implement error.") + raise NotImplementedError( + "Currently, AMP is not supported in PIR mode" + ) return self._create_program() @cached_property def infer_program(self): if _in_amp_guard() or _in_pure_fp16_guard(): - raise NotImplementedError("not implement error.") + raise NotImplementedError( + "Currently, AMP is not supported in PIR mode" + ) return self._create_program(is_infer_mode=True) def _verify_program(self, main_program): diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index 1d1bd2358f100..2208f61ee3a8b 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -22,6 +22,8 @@ import numpy as np import paddle +from paddle.framework import use_pir_api +from paddle.pir.core import vartype_to_datatype from ....infer_meta import MetaInfo from ....symbolic.statement_ir import Symbol @@ -267,6 +269,20 @@ def make_stringify_guard(self) -> list[StringifyExpression]: else: return object_equal_stringify_guard(self) + def get_py_value(self, allow_tensor=False): + if use_pir_api() and isinstance( + self.value, paddle.base.core.VarDesc.VarType + ): + return vartype_to_datatype[self.value] + return super().get_py_value(allow_tensor) + + def get_py_type(self): + if use_pir_api() and isinstance( + self.value, paddle.base.core.VarDesc.VarType + ): + return paddle.pir.core.DataType + return super().get_py_type() + @property def main_info(self) -> dict[str, Any]: return { diff --git a/test/sot/test_case_base.py b/test/sot/test_case_base.py index f5a57f66c186b..bb6a13b6709b2 100644 --- a/test/sot/test_case_base.py +++ b/test/sot/test_case_base.py @@ -16,10 +16,9 @@ import contextlib import copy -import inspect -import os import types import unittest +from functools import wraps import numpy as np @@ -38,39 +37,7 @@ def test_instruction_translator_cache_context(): cache.clear() -def github_action_error_msg(msg: str): - if 'GITHUB_ACTIONS' in os.environ: - frame = inspect.currentframe() - if frame is not None: - # find the first frame that is in the test folder - while frame.f_back is not None: - filename = frame.f_code.co_filename - if filename.startswith("./"): - filename = f"tests/{filename[2:]}" - lineno = frame.f_lineno - output = f"\n::error file={filename},line={lineno}::{msg}" - return output - frame = frame.f_back - return None - - class TestCaseBase(unittest.TestCase): - def assertIs(self, x, y, msg=None): - super().assertIs(x, y, msg=msg) - if msg is None: - msg = f"Assert Is, x is {x}, y is {y}" - msg = github_action_error_msg(msg) - if msg is not None: - print(msg) - - def assertEqual(self, x, y, msg=None): - super().assertEqual(x, y, msg=msg) - if msg is None: - msg = f"Assert Equal, x is {x}, y is {y}" - msg = github_action_error_msg(msg) - if msg is not None: - print(msg) - def assert_nest_match(self, x, y): cls_x = type(x) cls_y = type(y) @@ -136,3 +103,43 @@ def copy_fn(fn): sym_copied_fn.__globals__[key], paddle_fn.__globals__[key] ) self.assert_nest_match(sym_output, paddle_output) + + +# Some decorators for PIR test +def to_pir_test(fn): + # NOTE(SigureMo): This function should sync with test/dygraph_to_static/dygraph_to_static_utils.py + @wraps(fn) + def impl(*args, **kwargs): + in_dygraph_mode = paddle.in_dynamic_mode() + with paddle.pir_utils.IrGuard(): + if in_dygraph_mode: + paddle.disable_static() + ir_outs = fn(*args, **kwargs) + return ir_outs + + return impl + + +def run_in_pir_mode(fn): + @wraps(fn) + def impl(*args, **kwargs): + OpcodeExecutorCache().clear() + pir_fn = to_pir_test(fn) + return pir_fn(*args, **kwargs) + + return impl + + +def run_in_both_default_and_pir(fn): + @wraps(fn) + def impl(*args, **kwargs): + OpcodeExecutorCache().clear() + default_fn = fn + pir_fn = to_pir_test(fn) + default_outs = default_fn(*args, **kwargs) + OpcodeExecutorCache().clear() + # The out of test case should be None, which is not used. + _pir_outs = pir_fn(*args, **kwargs) + return default_outs + + return impl diff --git a/test/sot/test_dtype.py b/test/sot/test_dtype.py new file mode 100644 index 0000000000000..b58678ca1541e --- /dev/null +++ b/test/sot/test_dtype.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import ( + TestCaseBase, + run_in_both_default_and_pir, + test_instruction_translator_cache_context, +) + +import paddle + + +def tensor_astype(x, y): + z = x.astype(y.dtype) + return z + + +def tensor_dtype_guard(x): + return x + 1 + + +class TestTensorAstype(TestCaseBase): + @run_in_both_default_and_pir + def test_tensor_astype(self): + x = paddle.ones([2, 3], dtype="float32") + y = paddle.ones([2, 3], dtype="int32") + self.assert_results(tensor_astype, x, y) + + +class TestTensorDtypeGuard(TestCaseBase): + @run_in_both_default_and_pir + def test_tensor_dtype_guard(self): + x = paddle.ones([2, 3], dtype="float32") + y = paddle.ones([2, 3], dtype="int32") + with test_instruction_translator_cache_context() as ctx: + self.assert_results(tensor_dtype_guard, x) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(tensor_dtype_guard, y) + self.assertEqual(ctx.translate_count, 2) + self.assert_results(tensor_dtype_guard, x) + self.assertEqual(ctx.translate_count, 2) + + +if __name__ == "__main__": + unittest.main()