forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_flop_counter.py
380 lines (299 loc) · 14.4 KB
/
test_flop_counter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
# Owner(s): ["module: unknown"]
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_TORCHDYNAMO
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
import torch.utils.flop_counter
import torch.nn.functional as F
import unittest
import functools
try:
from torchvision import models as torchvision_models
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
HAS_CUDA = torch.cuda.is_available()
def FlopCounterMode(*args, **kwargs):
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
def get_total_flops(mode):
return str(sum([v for _, v in mode.flop_counts["Global"].items()]))
def T(*shape, requires_grad=False):
return torch.randn(*shape, requires_grad=requires_grad)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now")
class TestFlopCounter(TestCase):
def test_flop_counter_variety(self):
mode = FlopCounterMode()
mod = torch.nn.Linear(9, 10)
with mode:
torch.mm(T(4, 5), T(5, 6))
torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
torch.matmul(T(5, 6), T(6, 7))
torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
mod(T(8, 9))
self.assertExpectedInline(get_total_flops(mode), """3012""")
def test_op(self):
mode = FlopCounterMode()
with mode:
torch.mm(T(4, 5), T(5, 6))
# 4 * 6 * 2 * 5 = 240
self.assertExpectedInline(get_total_flops(mode), """240""")
with mode:
torch.bmm(T(3, 4, 5), T(3, 5, 6))
# 3 * 4 * 6 * 2 * 5 = 720
self.assertExpectedInline(get_total_flops(mode), """720""")
with mode:
torch.addmm(T(4, 6), T(4, 5), T(5, 6))
torch.addmm(T(4, 1), T(4, 5), T(5, 6))
torch.addmm(T(6), T(4, 5), T(5, 6))
# 4 * 6 * 2 * 5 = 240
self.assertExpectedInline(get_total_flops(mode), """720""")
with mode:
torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))
# 3 * 4 * 6 * 2 * 5 = 720
self.assertExpectedInline(get_total_flops(mode), """720""")
with mode:
torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)
# out_image_size = 2 * 5 * 5
# kernel_size = 4 * 4
# c_out = 6
# c_in = 3
# out_image_size * kernel_size * c_out * 2 * c_in
# NB: I don't think this properly accounts for padding?
self.assertExpectedInline(get_total_flops(mode), """28800""")
with mode:
torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)
# out_image_size = 2 * 5
# kernel_size = 4
# c_out = 6
# c_in = 3
# out_image_size * kernel_size * c_out * 2 * c_in
# NB: I don't think this properly accounts for padding?
self.assertExpectedInline(get_total_flops(mode), """1440""")
def test_backward(self):
mode = FlopCounterMode()
with mode:
a = T(4, 5, requires_grad=True)
a = torch.mm(a, T(5, 6))
a = a.unsqueeze(0).expand(7, 4, 6)
a = torch.bmm(a, T(7, 6, 7))
a.sum().backward()
self.assertExpectedInline(get_total_flops(mode), """5184""")
def test_torchscript(self):
def foo(x):
return torch.mm(x, x)
mode = FlopCounterMode()
with mode:
foo(T(5, 5))
unscripted_flops = get_total_flops(mode)
ts_foo = torch.jit.script(foo)
with mode:
ts_foo(T(5, 5))
self.assertEqual(unscripted_flops, get_total_flops(mode))
def test_autograd_op(self):
class _CustomOp(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
return torch.mm(input, input)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return torch.mm(grad_output, grad_output) + torch.mm(grad_output, grad_output)
a = T(5, 5, requires_grad=True)
mode = FlopCounterMode()
with mode:
a = _CustomOp.apply(a)
a.sum().backward()
self.assertExpectedInline(get_total_flops(mode), """750""")
def test_conv_backwards_as_decomposition(self):
# [conv backwards decomposition as conv forwards]
class onlyConvs(torch.autograd.Function):
@staticmethod
def forward(inp, weight, transposed):
if not transposed:
return F.conv1d(inp, weight)
else:
return F.conv_transpose1d(inp, weight)
@staticmethod
def setup_context(ctx, inputs, output):
inp, weight, transposed = inputs
ctx.save_for_backward(inp, weight)
ctx.transposed = transposed
@staticmethod
def backward(ctx, grad_out):
inp, weight = ctx.saved_tensors
if not ctx.transposed:
grad_inp = F.conv_transpose1d(grad_out, weight)
grad_weight = F.conv1d(inp, grad_out)
return grad_inp, grad_weight, None
else:
grad_inp = F.conv1d(grad_out, weight)
grad_weight = F.conv1d(grad_out.transpose(1, 0), inp.transpose(1, 0))
return grad_inp, grad_weight.transpose(1, 0), None
from torch.func import grad
x = torch.randn(2, 3, 16, dtype=torch.float64)
weight = torch.randn(3, 4, 4, dtype=torch.float64)
def boring_conv(x, weight, transposed):
if not transposed:
return F.conv1d(x, weight).pow(2).sum()
else:
return F.conv_transpose1d(x, weight).pow(2).sum()
def only_convs(x, weight, transposed):
return onlyConvs.apply(x, weight, transposed).pow(2).sum()
boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True)
fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True)
self.assertEqual(boring_grads, fun_grads)
def test_convs(self):
def assert_equivalence(f, expected_forward=None):
mode = FlopCounterMode()
with mode:
f()
conv_forward_flops = mode.get_flop_counts()['Global'][torch.ops.aten.convolution]
conv_backward_flops = mode.get_flop_counts()['Global'][torch.ops.aten.convolution_backward]
self.assertEqual(conv_forward_flops * 2, conv_backward_flops)
if expected_forward is not None:
self.assertEqual(conv_forward_flops, expected_forward)
x = torch.rand(1, 1, 2, 2, requires_grad=True)
weight = torch.randn(1, 1, 2, 2, requires_grad=True)
assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32)
x = torch.rand(1, 1, 2, 2, requires_grad=True)
weight = torch.randn(1, 1, 1, 1, requires_grad=True)
assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8)
for in_channels, out_channels, groups in [
(1, 1, 1),
(1, 3, 1),
(3, 1, 1),
(3, 7, 1),
(2, 4, 2),
(4, 2, 2),
]:
x = torch.rand(1, in_channels, 4, 4, requires_grad=True)
weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True)
assert_equivalence(lambda: F.conv2d(x, weight).sum().backward())
transposed_weight = torch.randn(in_channels, out_channels, 2, 2, requires_grad=True)
assert_equivalence(lambda: F.conv_transpose2d(x, transposed_weight).sum().backward())
@skipIfNoTorchVision
def test_module(self):
resnet18 = torchvision_models.resnet18()
mode = FlopCounterMode(resnet18)
with mode:
a = T(1, 3, 224, 224, requires_grad=True)
resnet18(a).sum().backward()
self.assertExpectedInline(get_total_flops(mode), """10884440064""")
layer1_conv_flops = mode.flop_counts['ResNet.layer1'][torch.ops.aten.convolution]
layer1_conv_back_flops = mode.flop_counts['ResNet.layer1'][torch.ops.aten.convolution_backward]
self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
def test_conv_transpose_loop(self):
x = torch.rand(1, 4, 30, 2)
model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2)
mode = FlopCounterMode(model)
with mode:
for i in range(50):
out = model(x)
out.sum().backward()
self.assertExpectedInline(str(mode.get_total_flops()), """1536000""")
def test_custom(self):
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5})
with mode:
a = T(4, 5)
a + a
self.assertExpectedInline(get_total_flops(mode), """5""")
def count(*args, out):
return out.numel()
count._get_raw = True
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
with mode:
a = T(4, 5)
a + a
self.assertExpectedInline(get_total_flops(mode), """20""")
def test_noop(self):
mode = FlopCounterMode()
with mode:
T(4, 5).cos()
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
"Does not support all SDPA backends (pre-SM80 hardware on CUDA)")
def test_sdpa(self):
batch_size = 4
n_heads = 8
seq_len_q = 128
seq_len_k = 256
head_dim = 64
head_dim_v = 64
dtype = torch.float16
torch.manual_seed(0)
def get_flops(batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype, backend, with_backward=False):
query = torch.randn(batch_size, n_heads, seq_len_q, head_dim, device='cuda', dtype=dtype, requires_grad=True)
key = torch.randn(batch_size, n_heads, seq_len_k, head_dim, device='cuda', dtype=dtype, requires_grad=True)
value = torch.randn(batch_size, n_heads, seq_len_k, head_dim_v, device='cuda', dtype=dtype, requires_grad=True)
if backend == "math":
backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False)
elif backend == "flash":
backend = torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
elif backend == "mem_efficient":
backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
mode = FlopCounterMode()
with backend, mode:
out = F.scaled_dot_product_attention(query, key, value, dropout_p=0, is_causal=True)
if with_backward:
out.sum().backward()
return int(get_total_flops(mode))
# Sets seq_len_q == seq_len_k and dim_q == dim_v
run_uniform_flops = functools.partial(get_flops, batch_size, n_heads, seq_len_q, seq_len_q, head_dim, head_dim, dtype)
flops = [run_uniform_flops(backend, with_backward=False) for backend in ["math", "flash", "mem_efficient"]]
flops_fw_math, flops_fw_flash, flops_fw_efficient = flops
self.assertEqual(flops_fw_math, flops_fw_flash)
self.assertEqual(flops_fw_math, flops_fw_efficient)
self.assertExpectedInline(str(flops_fw_math), """134217728""")
flops = [run_uniform_flops(backend, with_backward=True) for backend in ["math", "flash", "mem_efficient"]]
flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops
self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
run_nonuniform_flops = functools.partial(get_flops, batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype)
# Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v"
non_uniform_backends = ["math", "mem_efficient"]
flops = [run_nonuniform_flops(backend, with_backward=False) for backend in non_uniform_backends]
flops_fw_math, flops_fw_efficient = flops
self.assertEqual(flops_fw_math, flops_fw_efficient)
self.assertExpectedInline(str(flops_fw_math), """268435456""")
flops = [run_nonuniform_flops(backend, with_backward=True) for backend in non_uniform_backends]
flops_fw_bw_math, flops_fw_bw_efficient = flops
self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
def test_hook_registration(self):
model = torch.nn.Linear(100, 100)
x = torch.randn(3, 100)
flop_counter = FlopCounterMode(model)
with flop_counter:
self.assertEqual(len(model._forward_pre_hooks), 1)
self.assertEqual(len(model._forward_hooks), 1)
model(x).sum().backward()
self.assertEqual(len(model._forward_pre_hooks), 0)
self.assertEqual(len(model._forward_hooks), 0)
def test_pytrees(self):
class Foo(torch.nn.Module):
def forward(self, x):
x = x['a'].relu_()
return {'a': torch.mm(x, x)}
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = Foo()
self.b = Foo()
def forward(self, x):
return self.b(self.a(x))
mod = Mod()
mode = FlopCounterMode(mod)
with mode:
mod({'a': torch.randn(10, 10, requires_grad=True).clone()})['a'].sum().backward()
self.assertExpectedInline((mode.flop_counts['Mod'][torch.ops.aten.mm]), """12000""")
class Mod2(torch.nn.Module):
def forward(self, x):
return (torch.mm(x, x),)
mod = Mod2()
mode = FlopCounterMode(mod)
with mode:
mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
self.assertExpectedInline((mode.flop_counts['Mod2'][torch.ops.aten.mm]), """6000""")
if __name__ == '__main__':
run_tests()