diff --git a/test/amp/test_get_autocast_dtype.py b/test/amp/test_get_autocast_dtype.py index dfd3ea2c91cb73..ef8ef989ec24e3 100644 --- a/test/amp/test_get_autocast_dtype.py +++ b/test/amp/test_get_autocast_dtype.py @@ -44,18 +44,30 @@ def test_amp_autocast_fp16(self): self.do_test(device, "float16") self.do_test(device, self.default_dtype) + @unittest.skipIf( + not paddle.amp.is_bfloat16_supported(), + "Skip BF16 test if BF16 is not supported", + ) def test_amp_autocast_bf16(self): for device in self.device_list: with paddle.amp.auto_cast(True, dtype="bfloat16"): self.do_test(device, "bfloat16") self.do_test(device, self.default_dtype) + @unittest.skipIf( + not paddle.amp.is_bfloat16_supported(), + "Skip BF16 test if BF16 is not supported", + ) def test_amp_autocast_false_bf16(self): for device in self.device_list: with paddle.amp.auto_cast(True, dtype="bfloat16"): self.do_test(device, "bfloat16") self.do_test(device, self.default_dtype) + @unittest.skipIf( + not paddle.amp.is_bfloat16_supported(), + "Skip BF16 test if BF16 is not supported", + ) def test_amp_nested_context(self): for device in self.device_list: with paddle.amp.auto_cast(True, dtype="bfloat16"):