forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_autocast.py
286 lines (238 loc) · 11.8 KB
/
test_autocast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# Owner(s): ["module: unknown"]
import collections
import unittest
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfTorchDynamo
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
from torch.utils._python_dispatch import TorchDispatchMode
class TestAutocastCPU(TestCase):
def setUp(self):
super().setUp()
self.autocast_lists = AutocastCPUTestLists(torch.device('cpu'))
def tearDown(self):
del self.autocast_lists
super().tearDown()
def _run_autocast_outofplace(
self,
op,
args,
run_as_type,
out_type=None,
module=torch,
add_kwargs=None,
amp_dtype=torch.bfloat16,
):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
return val.to(to_type) if val.is_floating_point() else val
elif isinstance(val, collections.abc.Iterable):
return type(val)(cast(v, to_type) for v in val)
else:
return val
if add_kwargs is None:
add_kwargs = {}
self.assertFalse(torch.is_autocast_cpu_enabled())
with torch.cpu.amp.autocast(dtype=amp_dtype):
self.assertTrue(torch.is_autocast_cpu_enabled())
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
# Try module.* variant, if requested:
if module is not None and hasattr(module, op):
output = getattr(module, op)(*args, **add_kwargs)
if isinstance(output, torch.Tensor):
self.assertTrue(out_type == output.dtype,
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}")
# Try Tensor.* variant:
if hasattr(torch.Tensor, op):
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
if isinstance(output_method, torch.Tensor):
self.assertTrue(out_type == output_method.dtype,
"autocast for torch.{} produced {}, should produce torch.{}"
.format(op, output_method.dtype, out_type))
self.assertTrue((output is not None) or (output_method is not None),
f"{op} not found as an attribute on either Tensor or the requested module {module}")
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
# For example, lstm_cell returns a tuple and equal returns bool.
def compare(first, second):
if isinstance(first, torch.Tensor):
return torch.equal(first, second)
elif isinstance(first, collections.abc.Iterable):
return all(compare(f, s) for f, s in zip(first, second))
else:
return first == second
# If both torch.* and Tensor.* variants were found, check outputs are identical
if (output is not None) and (output_method is not None):
self.assertTrue(type(output) == type(output_method))
comparison = compare(output, output_method)
self.assertTrue(comparison, f"torch.{op} result did not match Tensor.{op} result")
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
# as the C++-side autocasting, and should be bitwise accurate.
output_to_compare = output if output is not None else output_method
with torch.cpu.amp.autocast(enabled=False):
self.assertFalse(torch.is_autocast_cpu_enabled())
if module is not None and hasattr(module, op):
control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs)
else:
control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs)
self.assertTrue(type(output_to_compare) == type(control))
comparison = compare(output_to_compare, control)
self.assertTrue(comparison, f"torch.{op} result did not match control")
self.assertTrue(torch.is_autocast_cpu_enabled())
self.assertFalse(torch.is_autocast_cpu_enabled())
def args_maybe_kwargs(self, op_with_args):
if len(op_with_args) == 2:
return op_with_args[0], op_with_args[1], {}
else:
return op_with_args[0], op_with_args[1], op_with_args[2]
@skipIfTorchDynamo()
def test_autocast_torch_expect_builtin_promote(self):
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
@skipIfTorchDynamo()
def test_autocast_methods_expect_builtin_promote(self):
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
@skipIfTorchDynamo()
def test_autocast_torch_16(self):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
@skipIfTorchDynamo()
def test_autocast_nn_16(self):
for op_with_args in self.autocast_lists.nn_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
@skipIfTorchDynamo()
def test_autocast_nn_fp32(self):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
def test_autocast_torch_need_autocast_promote(self):
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args1, torch.float32)
self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def test_autocast_rnn(self):
if torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported():
x = torch.randn(1, 2, 1)
hx = torch.randn(2, 2, 1)
cx = torch.randn(2, 2, 1)
m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
# Raise ValueError when autocast is not enabled
with self.assertRaisesRegex(ValueError, "input must have the type"):
m(x, (hx, cx))
# Should be able to run the below case with autocast
with torch.cpu.amp.autocast():
m(x, (hx, cx))
def test_autocast_disabled_with_fp32_dtype(self):
with torch.autocast(device_type='cpu', dtype=torch.float32, enabled=False):
_ = torch.ones(10)
class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w_t):
ctx.save_for_backward(x, w_t)
return torch.nn.functional.linear(x, w_t)
@staticmethod
def backward(ctx, grad_output):
x, w_t = ctx.saved_tensors
with torch.autocast(device_type='cuda'):
dL_dX = torch.matmul(grad_output, w_t)
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
return dL_dX, dL_dW
class WeightDTypeCastCounterMode(TorchDispatchMode):
def __init__(self, weight):
super().__init__()
self.dtype_cast_counter = 0
self.weight = weight
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if (
func is torch.ops.aten._to_copy.default and
args[0] is self.weight and
kwargs['dtype'] is torch.float16
):
self.dtype_cast_counter += 1
return func(*args, **kwargs)
def __enter__(self):
self.old_clear_cache = torch.clear_autocast_cache
torch.clear_autocast_cache = lambda: None
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.clear_autocast_cache = self.old_clear_cache
return super().__exit__(exc_type, exc_val, exc_tb)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
class TestAutocastGPU(TestCase):
def test_cast_cache_is_global(self):
"""
Verifies that the autocast cache is global. This is done by
mocking out cache clearing at the end of the forward pass,
running forward+backward with an explicit call to autocast in the
backward, and verifying that the weight only get cast to float16 once.
"""
data = torch.randn(2, 3).cuda()
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
with WeightDTypeCastCounterMode(weight) as mode:
with torch.autocast(device_type='cuda'):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
self.assertEqual(mode.dtype_cast_counter, 1)
def test_cache_disabled(self):
data = torch.randn(2, 3).cuda()
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
try:
torch._C._set_cached_tensors_enabled(True)
torch._C._add_cached_tensor(weight)
with WeightDTypeCastCounterMode(weight) as mode:
with torch.autocast(device_type='cuda'):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
# we should not have cached the conversion of the weight
self.assertEqual(mode.dtype_cast_counter, 2)
finally:
torch._C._set_cached_tensors_enabled(False)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
self.assertEqual(gpu_fast_dtype, torch.half)
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
def test_invalid_device(self):
dev = 'not a real device'
msg = f'unsupported autocast device_type \'{dev}\''
with self.assertRaisesRegex(RuntimeError, msg):
with torch.autocast(device_type=dev):
_ = torch.tensor(1)
if __name__ == '__main__':
run_tests()