Skip to content

Commit

Permalink
[PIR AMP]Adapt auto_cast api for PIR AMP (#61859)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Feb 27, 2024
1 parent 579a12c commit 5d07c26
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 146 deletions.
295 changes: 171 additions & 124 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

import paddle
from paddle.base import core
from paddle.base.framework import _dygraph_tracer, dygraph_only
from paddle.base.framework import (
_dygraph_tracer,
dygraph_only,
in_dynamic_or_pir_mode,
in_pir_mode,
)
from paddle.base.wrapped_decorator import signature_safe_contextmanager

from .amp_lists import black_list, white_list
Expand Down Expand Up @@ -271,7 +276,6 @@ def check_optimizers(optimizers):


@signature_safe_contextmanager
@dygraph_only
def amp_guard(
enable=True,
custom_white_list=None,
Expand Down Expand Up @@ -325,6 +329,10 @@ def amp_guard(
paddle.float32
>>> # doctest: -SKIP
"""
assert (
in_dynamic_or_pir_mode()
), "We only support 'amp_guard' in dynamic or pir mode."

amp_state = locals()
global _g_amp_state_
original_state = _g_amp_state_
Expand All @@ -343,59 +351,6 @@ def amp_guard(
"If enable amp, dtype should be 'float16' or 'bfloat16'."
)

# check tracer
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)

# check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, npu for float16.
# Maybe we will support cpu for bfloat16.
if enable and not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_custom_place()
):
warnings.warn(
'amp_guard can only be enabled on CUDAPlace, XPUPlace, and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place
)
enable = False
if enable:
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
# For custom device:
if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16') and not _is_gpu_float16_supported():
prop = paddle.device.cuda.get_device_capability()
warnings.warn(
"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1])
)
enable = False
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
% (
paddle.device.cuda.get_device_name(),
prop[0],
prop[1],
cuda_version,
)
)
enable = False

amp_dtype = dtype
amp_global_state().amp_dtype = amp_dtype

Expand All @@ -412,87 +367,179 @@ def amp_guard(
custom_white_list, custom_black_list, level, dtype
)

if not enable:
amp_level = AMP_LEVEL.O0
amp_dtype = "float32"

# master_grad_hook will run at the end of backward.
# Since backward_final_hook will be cleared once they have been
# done, we should register the hook every step.
if (
amp_global_state().use_master_grad
and not amp_global_state().already_register_final_backward_hook
):
if in_pir_mode():
if not enable:
amp_level = AMP_LEVEL.O0
amp_dtype = "float32"
amp_attrs = core._get_amp_attrs()
# set amp level
original_amp_level = amp_attrs._amp_level
amp_attrs._amp_level = amp_level
# set amp op list
original_white_list, original_black_list = core._get_amp_op_list()
core._set_amp_op_list(_white_list, _black_list)
# set amp dtype
original_amp_dtype = amp_attrs._amp_dtype
amp_attrs._amp_dtype = amp_dtype
# switch promote
if amp_level == AMP_LEVEL.O2:
original_use_promote = amp_attrs._use_promote
amp_attrs._use_promote = use_promote

def master_grad_hook():
# NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
# classify the params of model into different calsses according to their process_mesh.
# Otherwise, fault will occur.
if not amp_global_state().already_classify_params_meshs:
for param in amp_global_state().model_parameters:
if param is not None and param.process_mesh is not None:
if (
param.process_mesh
not in amp_global_state().mesh2params
):
amp_global_state().mesh2params[
param.process_mesh
] = [param]
else:
amp_global_state().mesh2params[
param.process_mesh
].append(param)
amp_global_state().already_classify_params_meshs = True
try:
yield
finally:
_g_amp_state_ = original_state
amp_attrs._amp_level = original_amp_level
core._set_amp_op_list(original_white_list, original_black_list)
amp_attrs._amp_dtype = original_amp_dtype
if amp_level == AMP_LEVEL.O2:
amp_attrs._use_promote = original_use_promote

if len(amp_global_state().mesh2params):
for _, params in amp_global_state().mesh2params.items():
core.eager.set_master_grads(params)
else:
core.eager.set_master_grads(amp_global_state().model_parameters)
else:
# check tracer
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)
# check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, npu for float16.
# Maybe we will support cpu for bfloat16.
if enable and not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_custom_place()
):
warnings.warn(
'amp_guard can only be enabled on CUDAPlace, XPUPlace, and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place
)
enable = False
if enable:
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
# For custom device:
if tracer._expected_place.is_custom_place() and (
dtype == 'bfloat16'
):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16') and not _is_gpu_float16_supported():
prop = paddle.device.cuda.get_device_capability()
warnings.warn(
"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (
paddle.device.cuda.get_device_name(),
prop[0],
prop[1],
)
)
enable = False
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
% (
paddle.device.cuda.get_device_name(),
prop[0],
prop[1],
cuda_version,
)
)
enable = False

if not enable:
amp_level = AMP_LEVEL.O0
amp_dtype = "float32"

# master_grad_hook will run at the end of backward.
# Since backward_final_hook will be cleared once they have been
# done, we should register the hook every step.
if (
amp_global_state().use_master_grad
and not amp_global_state().already_register_final_backward_hook
):

amp_global_state().already_register_final_backward_hook = False
def master_grad_hook():
# NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
# classify the params of model into different calsses according to their process_mesh.
# Otherwise, fault will occur.
if not amp_global_state().already_classify_params_meshs:
for param in amp_global_state().model_parameters:
if param is not None and param.process_mesh is not None:
if (
param.process_mesh
not in amp_global_state().mesh2params
):
amp_global_state().mesh2params[
param.process_mesh
] = [param]
else:
amp_global_state().mesh2params[
param.process_mesh
].append(param)
amp_global_state().already_classify_params_meshs = True

if len(amp_global_state().mesh2params):
for _, params in amp_global_state().mesh2params.items():
core.eager.set_master_grads(params)
else:
core.eager.set_master_grads(
amp_global_state().model_parameters
)

core.eager._add_backward_final_hook(master_grad_hook)
amp_global_state().already_register_final_backward_hook = True
amp_global_state().already_register_final_backward_hook = False

if tracer:
# enable auto_cast
original_amp_level = tracer._amp_level
tracer._amp_level = amp_level
core.eager._add_backward_final_hook(master_grad_hook)
amp_global_state().already_register_final_backward_hook = True

# set amp op list
original_white_list, original_black_list = tracer._get_amp_op_list()
tracer._set_amp_op_list(_white_list, _black_list)
if tracer:
# enable auto_cast
original_amp_level = tracer._amp_level
tracer._amp_level = amp_level

# TODO(zhiqiu) set amp related flags automatically in this guard
# Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
# batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed inside amp_guard.
# So, users need to set related flags manually.
# set amp op list
original_white_list, original_black_list = tracer._get_amp_op_list()
tracer._set_amp_op_list(_white_list, _black_list)

# original_flags = get_flags(AMP_RELATED_FLAGS)
# set_flags(AMP_RELATED_FLAGS_SETTING)
# TODO(zhiqiu) set amp related flags automatically in this guard
# Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
# batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed inside amp_guard.
# So, users need to set related flags manually.

# set amp dtype
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype
# original_flags = get_flags(AMP_RELATED_FLAGS)
# set_flags(AMP_RELATED_FLAGS_SETTING)

# switch promote
if amp_level == AMP_LEVEL.O2:
original_use_promote = tracer._use_promote
tracer._use_promote = use_promote
# set amp dtype
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype

# restore status
try:
yield
finally:
if tracer:
_g_amp_state_ = original_state
tracer._amp_level = original_amp_level
tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
# switch promote
if amp_level == AMP_LEVEL.O2:
tracer._use_promote = original_use_promote
original_use_promote = tracer._use_promote
tracer._use_promote = use_promote

# restore status
try:
yield
finally:
if tracer:
_g_amp_state_ = original_state
tracer._amp_level = original_amp_level
tracer._set_amp_op_list(
original_white_list, original_black_list
)
# set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
if amp_level == AMP_LEVEL.O2:
tracer._use_promote = original_use_promote


class StateDictHook:
Expand Down
27 changes: 5 additions & 22 deletions test/amp/test_pir_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np

import paddle
from paddle.amp.auto_cast import _update_list
from paddle.base import core


Expand Down Expand Up @@ -48,22 +47,11 @@ def test_linear_amp_o1(self):
with paddle.static.program_guard(main, startup):
x = paddle.static.data('x', [3, 4], 'float32')
linear = paddle.nn.Linear(4, 5)

amp_attrs = core._get_amp_attrs()
amp_attrs._use_promote = True
amp_attrs._amp_level = core.AmpLevel.O1
amp_attrs._amp_dtype = 'float16'
(
original_white_list,
original_black_list,
) = core._get_amp_op_list()
_white_list, _black_list = _update_list(
None, None, 'O1', 'float16'
)
core._set_amp_op_list(_white_list, _black_list)

out1 = linear(x)
out2 = paddle.mean(out1)
with paddle.amp.auto_cast(
level='O1', dtype='float16', use_promote=True
):
out1 = linear(x)
out2 = paddle.mean(out1)

cast_op_count = 0
for op in main.global_block().ops:
Expand All @@ -72,11 +60,6 @@ def test_linear_amp_o1(self):
np.testing.assert_equal(out1.dtype, core.DataType.FLOAT32)
np.testing.assert_equal(out2.dtype, core.DataType.FLOAT32)
np.testing.assert_equal(cast_op_count, 3)

amp_attrs._use_promote = False
amp_attrs._amp_level = core.AmpLevel.O0
amp_attrs._amp_dtype = 'float32'
core._set_amp_op_list(original_white_list, original_black_list)
_white_list, _black_list = core._get_amp_op_list()
np.testing.assert_equal(len(_white_list), 0)
np.testing.assert_equal(len(_black_list), 0)
Expand Down

0 comments on commit 5d07c26

Please sign in to comment.