Skip to content

Commit

Permalink
minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Jun 30, 2021
1 parent f690e65 commit 62f2da9
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,13 @@ def test_check_output(self):
self.check_output_with_place(core.CPUPlace())

def test_check_grad(self):
if core.is_compiled_with_cuda():
self.skipTest(
"OneDNN doesn't support bf16 with CUDA, skipping UT" +
self.__class__.__name__)
elif not core.supports_bfloat16():
self.skipTest("Core doesn't support bf16, skipping UT" +
self.__class__.__name__)
else:
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Alpha"],
"Out",
check_dygraph=False,
user_defined_grads=[self.dx, self.dalpha],
user_defined_grad_outputs=[
convert_float_to_uint16(self.dout)
])
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Alpha"],
"Out",
check_dygraph=False,
user_defined_grads=[self.dx, self.dalpha],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])

cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestPReluBF16OneDNNOp.__name__ = cls_name
Expand Down

0 comments on commit 62f2da9

Please sign in to comment.