Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Fix PIR Unittest No.6-15】Fix book/* and part of test_comp_* in PIR mode #64124

Merged
merged 4 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import logging
import warnings

import paddle.pir
from paddle.autograd.backward_utils import (
Expand Down Expand Up @@ -166,8 +167,8 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
% (i, str(grad.shape), i, str(output.shape))
)
if output.dtype != grad.dtype:
raise ValueError(
"The dtype of grad_output[%d] %s should be the same as the dtype of output[%d] %s"
warnings.warn(
"The dtype of grad_output[%d] %s is not same as the dtype of output[%d] %s"
% (i, str(grad.dtype), i, str(output.dtype))
)
feedop = grad.get_defining_op()
Expand Down
9 changes: 5 additions & 4 deletions test/deprecated/book/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
set_tests_properties(${src} PROPERTIES FIXTURES_SETUP ${src}_infer_model)
endforeach()
set_tests_properties(test_word2vec_book PROPERTIES TIMEOUT 120)
set_tests_properties(test_recognize_digits PROPERTIES TIMEOUT 120)
set_tests_properties(test_image_classification PROPERTIES TIMEOUT 200)
set_tests_properties(test_fit_a_line PROPERTIES TIMEOUT 120)
set_tests_properties(test_word2vec_book_deprecated PROPERTIES TIMEOUT 120)
set_tests_properties(test_recognize_digits_deprecated PROPERTIES TIMEOUT 120)
set_tests_properties(test_image_classification_deprecated PROPERTIES TIMEOUT
200)
set_tests_properties(test_fit_a_line_deprecated PROPERTIES TIMEOUT 120)
1 change: 0 additions & 1 deletion test/deprecated/prim/prim/vjp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()

add_subdirectory(eager)
add_subdirectory(static)
10 changes: 0 additions & 10 deletions test/deprecated/prim/prim/vjp/eager/CMakeLists.txt

This file was deleted.

1 change: 0 additions & 1 deletion test/deprecated/prim/prim/vjp/static/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()

set_tests_properties(test_comp_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_div_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60)
Expand Down
29 changes: 16 additions & 13 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,7 +1960,9 @@ def check_inplace_output_with_place(
if getattr(self, "no_need_check_inplace", False):
return

if os.getenv("FLAGS_enable_pir_in_executor"):
if os.getenv("FLAGS_enable_pir_in_executor") or os.getenv(
"FLAGS_enable_pir_api"
):
return

has_infer_inplace = base.core.has_infer_inplace(self.op_type)
Expand Down Expand Up @@ -3119,18 +3121,19 @@ def check_grad_with_place(
core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False)
if check_prim:
self._check_grad_helper()
prim_grad_checker = PrimGradChecker(
self,
place,
inputs_to_check,
output_names,
no_grad_set,
user_defined_grad_outputs,
)
prim_grad_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
self.__class__.check_prim = True
with paddle.pir_utils.OldIrGuard():
self._check_grad_helper()
prim_grad_checker = PrimGradChecker(
self,
place,
inputs_to_check,
output_names,
no_grad_set,
user_defined_grad_outputs,
)
prim_grad_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
self.__class__.check_prim = True

if check_prim_pir:
with paddle.pir_utils.IrGuard():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ def desired(primal, cotangent):

actual = actual(self.primal, self.cotangent)
desired = desired(self.primal, self.cotangent)
from paddle.base.data_feeder import _PADDLE_DTYPE_2_NUMPY_DTYPE
from paddle.base.data_feeder import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE

self.assertEqual(
_PADDLE_DTYPE_2_NUMPY_DTYPE[actual[0].dtype], desired.dtype
)
if actual[0].dtype in _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE.keys():
TO_NUMPY_DTYPE = _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE
else:
from paddle.base.data_feeder import _PADDLE_DTYPE_2_NUMPY_DTYPE

TO_NUMPY_DTYPE = _PADDLE_DTYPE_2_NUMPY_DTYPE

self.assertEqual(TO_NUMPY_DTYPE[actual[0].dtype], desired.dtype)
np.testing.assert_allclose(
actual=actual[0],
desired=desired,
Expand Down
1 change: 1 addition & 0 deletions test/prim/prim/vjp/static/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ foreach(TEST_OP ${TEST_OPS})
endforeach()

set_tests_properties(test_comp_sum_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_tanh_grad PROPERTIES TIMEOUT 60)
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import parameterized as param

import paddle
from paddle.base import core, framework
from paddle.base import core


def apply_to_static(net, use_cinn):
Expand Down Expand Up @@ -88,27 +88,6 @@ def train(self, use_prim, use_cinn):

return res

def test_cinn(self):
paddle.disable_static()
use_cinn = True
if isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# TODO(jiabin): CINN will crashed in this case open it when fixed
use_cinn = False

dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=use_cinn)

for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-15,
atol=1e-15,
)
paddle.enable_static()

def test_cast_grad_comp(self):
core._set_prim_backward_enabled(True)

Expand All @@ -124,10 +103,14 @@ def actual(primal, cotangent):
x_cotangent = paddle.static.gradients(y, x, v)
exe = paddle.static.Executor()
exe.run(sp)
if paddle.framework.in_pir_mode():
fetch_list = mp.blocks[0].ops[-1].result(0)
else:
fetch_list = mp.blocks[0].ops[-1].output('Out')[0]
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=mp.blocks[0].ops[-1].output('Out')[0],
fetch_list=fetch_list,
)[0]

def desired(primal, cotangent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,9 @@ def train(self, use_prim, use_cinn):

return res

def test_cinn(self):
paddle.disable_static()
use_cinn = True
if isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# TODO(jiabin): CINN will crashed in this case open it when fixed
use_cinn = False

dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=use_cinn)

for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-7,
atol=1e-7,
)
def test_reshape_grad_comp(self):
paddle.enable_static()

def test_reshape_grad_comp(self):
def actual(primal, shape, cotangent):
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
Expand All @@ -143,7 +124,7 @@ def actual(primal, shape, cotangent):
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
fetch_list=[x_cotangent[0]],
)[0]

def desired(primal, shape, cotangent):
Expand All @@ -162,7 +143,7 @@ def desired(primal, shape, cotangent):
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
fetch_list=[x_cotangent[0]],
)[0]

if (self.dtype == np.float16) and isinstance(
Expand All @@ -178,6 +159,7 @@ def desired(primal, shape, cotangent):
atol=self.rtol,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,9 @@ def train(self, use_prim, use_cinn):

return res

def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)

for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-7,
atol=1e-7,
)
def test_tanh_grad_comp(self):
paddle.enable_static()

def test_tanh_grad_comp(self):
def actual(primal, cotangent):
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
Expand All @@ -99,7 +87,7 @@ def actual(primal, cotangent):
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
fetch_list=[x_cotangent[0]],
)[0]

def desired(primal, cotangent):
Expand All @@ -112,6 +100,7 @@ def desired(primal, cotangent):
atol=0,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()


if __name__ == '__main__':
Expand Down