Skip to content

Commit

Permalink
Fixed the problem of reporting bf16 warning in auto_cast(enable=False) (
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored Oct 18, 2023
1 parent d392526 commit e8bc6f1
Showing 1 changed file with 30 additions and 29 deletions.
59 changes: 30 additions & 29 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,37 +359,38 @@ def amp_guard(
% tracer._expected_place
)
enable = False
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
# For custom device:
if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16') and not _is_gpu_float16_supported():
prop = paddle.device.cuda.get_device_capability()
warnings.warn(
"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1])
)
if enable:
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
% (
paddle.device.cuda.get_device_name(),
prop[0],
prop[1],
cuda_version,
)
)
# For custom device:
if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16') and not _is_gpu_float16_supported():
prop = paddle.device.cuda.get_device_capability()
warnings.warn(
"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1])
)
enable = False
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
% (
paddle.device.cuda.get_device_name(),
prop[0],
prop[1],
cuda_version,
)
)
enable = False

amp_dtype = dtype
amp_global_state().amp_dtype = amp_dtype
Expand Down

0 comments on commit e8bc6f1

Please sign in to comment.