Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 58 additions & 50 deletions test/amp/test_amp_master_grad_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ def amp_fp16_o2(self, use_master_grad):
)

def test_amp_fp16_o2(self):
use_master_grad_list = [False, True]
for master_grad in use_master_grad_list:
self.amp_fp16_o2(master_grad)
with paddle.pir_utils.OldIrGuard():
use_master_grad_list = [False, True]
for master_grad in use_master_grad_list:
self.amp_fp16_o2(master_grad)


class TestMasterGradAccuracy(AmpTestBase):
Expand Down Expand Up @@ -155,53 +156,60 @@ def _run(
)
return losses

dtype = "float16"
max_iters = 25
x_f32, x_f16 = self._generate_feed_x(dtype)
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
elif paddle.device.is_compiled_with_xpu():
place = paddle.device.XPUPlace(0)
else:
raise ValueError("Only support CUDA or XPU Place.")
exe = paddle.static.Executor(place)
use_grad_clip_list = [False, True]
for use_grad_clip in use_grad_clip_list:
losses_o1 = _run(
place, exe, x_f32, max_iters, 'O1', use_grad_clip, dtype=dtype
)
losses_o2_no_master_grad = _run(
place,
exe,
x_f16,
max_iters,
'O2',
use_grad_clip,
dtype=dtype,
use_master_grad=False,
)
losses_o2_master_grad = _run(
place,
exe,
x_f16,
max_iters,
'O2',
use_grad_clip,
dtype=dtype,
use_master_grad=True,
)

self.assertNotEqual(
losses_o1,
losses_o2_no_master_grad,
f"dtype: {dtype}, loss of o1 and o2-wo-master_grad should not be equal, but received loss o1: {losses_o1}, loss o2: {losses_o2_no_master_grad}",
)

self.assertEqual(
losses_o1,
losses_o2_master_grad,
f"dtype: {dtype}, loss of o1 and o2-w-master_grad should be equal, but received loss o1: {losses_o1}, loss o2: {losses_o2_master_grad}",
)
with paddle.pir_utils.OldIrGuard():
dtype = "float16"
max_iters = 25
x_f32, x_f16 = self._generate_feed_x(dtype)
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
elif paddle.device.is_compiled_with_xpu():
place = paddle.device.XPUPlace(0)
else:
raise ValueError("Only support CUDA or XPU Place.")
exe = paddle.static.Executor(place)
use_grad_clip_list = [False, True]
for use_grad_clip in use_grad_clip_list:
losses_o1 = _run(
place,
exe,
x_f32,
max_iters,
'O1',
use_grad_clip,
dtype=dtype,
)
losses_o2_no_master_grad = _run(
place,
exe,
x_f16,
max_iters,
'O2',
use_grad_clip,
dtype=dtype,
use_master_grad=False,
)
losses_o2_master_grad = _run(
place,
exe,
x_f16,
max_iters,
'O2',
use_grad_clip,
dtype=dtype,
use_master_grad=True,
)

self.assertNotEqual(
losses_o1,
losses_o2_no_master_grad,
f"dtype: {dtype}, loss of o1 and o2-wo-master_grad should not be equal, but received loss o1: {losses_o1}, loss o2: {losses_o2_no_master_grad}",
)

self.assertEqual(
losses_o1,
losses_o2_master_grad,
f"dtype: {dtype}, loss of o1 and o2-w-master_grad should be equal, but received loss o1: {losses_o1}, loss o2: {losses_o2_master_grad}",
)


if __name__ == '__main__':
Expand Down
204 changes: 105 additions & 99 deletions test/amp/test_model_cast_to_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,29 +172,31 @@ def _graph_common(self, _amp_fun, startup_prog=None):
)

def test_graph_rewrite(self):
self._graph_common(
lambda prog: amp.bf16.rewrite_program_bf16(
prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_varnames={'elementwise_add_0.tmp_0'},
),
with paddle.pir_utils.OldIrGuard():
self._graph_common(
lambda prog: amp.bf16.rewrite_program_bf16(
prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_varnames={'elementwise_add_0.tmp_0'},
),
)
)
)

def test_graph_cast(self):
self._graph_common(
lambda prog, startup_prog: amp.bf16.cast_model_to_bf16(
prog,
startup_prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_list={'elementwise_mul'},
with paddle.pir_utils.OldIrGuard():
self._graph_common(
lambda prog, startup_prog: amp.bf16.cast_model_to_bf16(
prog,
startup_prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_list={'elementwise_mul'},
),
use_bf16_guard=True,
),
use_bf16_guard=True,
),
startup_prog=base.default_startup_program(),
)
startup_prog=base.default_startup_program(),
)


@unittest.skipIf(
Expand All @@ -221,48 +223,50 @@ def _check_optimizer(self, program, expected_num_mp):
)

def test_amp_bf16_o1(self):
main_program, startup_program, _, _, _ = build_embedding_model(
True, "bfloat16", "O1"
)
self.assertEqual(main_program.num_blocks, 1)
self._check_optimizer(main_program, 0)

amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 0,
"adamw": 0,
}
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
with paddle.pir_utils.OldIrGuard():
main_program, startup_program, _, _, _ = build_embedding_model(
True, "bfloat16", "O1"
)
self.assertEqual(main_program.num_blocks, 1)
self._check_optimizer(main_program, 0)

amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 0,
"adamw": 0,
}
self._check_op_calls(op_stats_list[0], expected_bf16_calls)

def test_amp_bf16_o2(self):
main_program, startup_program, _, _, _ = build_embedding_model(
True, "bfloat16", "O2"
)
self.assertEqual(main_program.num_blocks, 1)

amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_fp32_calls = {"lookup_table_v2": 1}
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 3,
"adamw": 3,
}
self._check_optimizer(
main_program,
expected_bf16_calls["matmul_v2"]
+ expected_bf16_calls["elementwise_add"]
+ expected_fp32_calls["lookup_table_v2"],
)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
with paddle.pir_utils.OldIrGuard():
main_program, startup_program, _, _, _ = build_embedding_model(
True, "bfloat16", "O2"
)
self.assertEqual(main_program.num_blocks, 1)

amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_fp32_calls = {"lookup_table_v2": 1}
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 3,
"adamw": 3,
}
self._check_optimizer(
main_program,
expected_bf16_calls["matmul_v2"]
+ expected_bf16_calls["elementwise_add"]
+ expected_fp32_calls["lookup_table_v2"],
)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)


@unittest.skipIf(
Expand All @@ -278,47 +282,49 @@ def _generate_feed_x(self):
return x_fp32, x_bf16

def test_compare_o1_o2(self):
def _run(place, exe, x_np, max_iters, level):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", level)

losses = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
"bfloat16",
level,
with paddle.pir_utils.OldIrGuard():

def _run(place, exe, x_np, max_iters, level):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", level)

losses = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
"bfloat16",
level,
)
return losses

max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x()
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
elif paddle.is_compiled_with_xpu():
place = paddle.device.XPUPlace(0)
else:
raise ValueError("Only support CUDA or XPU Place.")
exe = paddle.static.Executor(place)
losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')

self.assertEqual(
losses_o1,
losses_o2,
f"loss of o1 and o2 should be equal, but received loss o1: {losses_o1}, loss o2: {losses_o2}",
)
return losses

max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x()
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
elif paddle.is_compiled_with_xpu():
place = paddle.device.XPUPlace(0)
else:
raise ValueError("Only support CUDA or XPU Place.")
exe = paddle.static.Executor(place)
losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')

self.assertEqual(
losses_o1,
losses_o2,
f"loss of o1 and o2 should be equal, but received loss o1: {losses_o1}, loss o2: {losses_o2}",
)


if __name__ == '__main__':
Expand Down
Loading