Skip to content

Commit 1c3d1a3

Browse files
add ut
1 parent 73f2da0 commit 1c3d1a3

File tree

3 files changed

+23
-18
lines changed

3 files changed

+23
-18
lines changed

python/paddle/amp/auto_cast.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,6 @@ def clear_grad(self, set_to_zero: bool) -> None: ...
9292
}
9393

9494
AMP_LEVEL = core.AmpLevel
95-
_g_amp_state_ = None
96-
97-
98-
def amp_state():
99-
global _g_amp_state_
100-
return _g_amp_state_
10195

10296

10397
class AMPGlobalState:
@@ -520,11 +514,6 @@ def amp_guard(
520514
in_dynamic_or_pir_mode()
521515
), "We only support 'amp_guard' in dynamic or pir mode."
522516

523-
amp_state = locals()
524-
global _g_amp_state_
525-
original_state = _g_amp_state_
526-
_g_amp_state_ = amp_state
527-
528517
# check amp_level: O0-O2
529518
level = level.upper()
530519
if level not in ['O0', 'OD', 'O1', 'O2']:
@@ -576,10 +565,10 @@ def amp_guard(
576565
try:
577566
yield
578567
finally:
579-
_g_amp_state_ = original_state
580568
amp_attrs._amp_level = original_amp_level
581569
core._set_amp_op_list(original_white_list, original_black_list)
582570
amp_attrs._amp_dtype = original_amp_dtype
571+
amp_global_state().amp_dtype = original_amp_dtype
583572
if amp_level == AMP_LEVEL.O2:
584573
amp_attrs._use_promote = original_use_promote
585574

@@ -780,13 +769,13 @@ def param_hook(tmp_grad):
780769
yield
781770
finally:
782771
if tracer:
783-
_g_amp_state_ = original_state
784772
tracer._amp_level = original_amp_level
785773
tracer._set_amp_op_list(
786774
original_white_list, original_black_list
787775
)
788776
# set_flags(original_flags)
789777
tracer._amp_dtype = original_amp_dtype
778+
amp_global_state().amp_dtype = original_amp_dtype
790779
if amp_level == AMP_LEVEL.O2:
791780
tracer._use_promote = original_use_promote
792781

test/amp/test_get_autocast_dtype.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
@unittest.skipIf(paddle.device.get_device() == "cpu", "Skip AMP test on CPU")
2121
class TestAutocast(unittest.TestCase):
2222
def setUp(self) -> None:
23+
paddle.disable_static()
2324
self.device_list = [None, paddle.device.get_device()]
25+
self.default_dtype = "float32"
2426

2527
def do_test(self, device, expected_type):
2628
self.assertTrue(paddle.get_autocast_dtype(device) == expected_type)
@@ -30,25 +32,25 @@ def do_test(self, device, expected_type):
3032

3133
def test_amp_default(self):
3234
for device in self.device_list:
33-
self.do_test(device, "float16")
35+
self.do_test(device, self.default_dtype)
3436

3537
def test_amp_autocast_fp16(self):
3638
for device in self.device_list:
3739
with paddle.amp.auto_cast(True, dtype="float16"):
3840
self.do_test(device, "float16")
39-
self.do_test(device, "float16")
41+
self.do_test(device, self.default_dtype)
4042

4143
def test_amp_autocast_bf16(self):
4244
for device in self.device_list:
4345
with paddle.amp.auto_cast(True, dtype="bfloat16"):
4446
self.do_test(device, "bfloat16")
45-
self.do_test(device, "float16")
47+
self.do_test(device, self.default_dtype)
4648

4749
def test_amp_autocast_false_bf16(self):
4850
for device in self.device_list:
4951
with paddle.amp.auto_cast(True, dtype="bfloat16"):
5052
self.do_test(device, "bfloat16")
51-
self.do_test(device, "float16")
53+
self.do_test(device, self.default_dtype)
5254

5355
def test_amp_nested_context(self):
5456
for device in self.device_list:
@@ -57,7 +59,14 @@ def test_amp_nested_context(self):
5759
with paddle.amp.auto_cast(False, dtype="float16"):
5860
self.do_test(device, "float16")
5961
self.do_test(device, "bfloat16")
60-
self.do_test(device, "float16")
62+
self.do_test(device, self.default_dtype)
63+
64+
65+
class TestAutocastStatic(TestAutocast):
66+
def setUp(self) -> None:
67+
paddle.enable_static()
68+
self.device_list = [None, paddle.device.get_device()]
69+
self.default_dtype = "float32"
6170

6271

6372
if __name__ == "__main__":

test/amp/test_is_autocast_enabled.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
@unittest.skipIf(paddle.device.get_device() == "cpu", "Skip AMP test on CPU")
2121
class TestAutocast(unittest.TestCase):
2222
def setUp(self) -> None:
23+
paddle.disable_static()
2324
self.device_list = [None, paddle.device.get_device()]
2425

2526
def test_amp_default(self):
@@ -61,5 +62,11 @@ def test_amp_nested_context(self):
6162
self.assertFalse(paddle.amp.is_autocast_enabled(device))
6263

6364

65+
class TestAutocastStatic(TestAutocast):
66+
def setUp(self) -> None:
67+
paddle.enable_static()
68+
self.device_list = [None, paddle.device.get_device()]
69+
70+
6471
if __name__ == "__main__":
6572
unittest.main()

0 commit comments

Comments
 (0)