-
Notifications
You must be signed in to change notification settings - Fork 172
/
test_integration.py
1609 lines (1383 loc) · 62.7 KB
/
test_integration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# mypy: ignore-errors
import copy
import unittest
import itertools
import torch
import torch.nn as nn
from torch._inductor.utils import run_and_get_code
from torch._dynamo import config
import torchao
from torch.ao.quantization import MinMaxObserver, QConfigMapping
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
quantize_,
_replace_with_custom_fn_if_matches_filter,
)
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
from torchao.quantization.quant_api import (
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
)
from torchao.quantization.quant_primitives import (
safe_int_mm,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
MappingType,
)
from torchao.quantization.utils import (
dequantize_per_channel,
dequantize_per_tensor,
dynamically_quantize_per_channel,
quant_int8_dynamic_per_token_linear,
quantize_activation_per_token_absmax,
)
from torchao.quantization.smoothquant import (
get_scale,
smooth_fq_linear_to_inference,
SmoothFakeDynamicallyQuantizedLinear,
swap_linear_with_smooth_fq_linear,
)
from torchao.quantization.subclass import (
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight
)
from torchao.quantization.utils import (
_apply_logging_hook,
compute_error,
compute_error as SQNR,
_fqn_to_op_to_shape_to_count,
LoggingTensorMode,
)
from torchao.quantization.autoquant import (
AQInt8DynamicallyQuantizedLinearWeight,
AQInt8WeightOnlyQuantizedLinearWeight,
AQInt8WeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
from parameterized import parameterized
import itertools
import logging
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
is_fbcode,
benchmark_model
)
logger = logging.getLogger("INFO")
torch.manual_seed(0)
config.cache_size_limit = 100
COMMON_DEVICES = ["cpu", "cuda"]
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
def _int8wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_weight_only(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)
def _int8wo_groupwise_api(mod):
group_size = 32
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)
def _int8da_int8w_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)
def _int4wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int4_weight_only(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)
# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
_int8wo_api,
_int8da_int8w_api,
_int4wo_api,
]
def undo_recommended_configs():
torch._inductor.config.coordinate_descent_tuning = False
torch._inductor.config.coordinate_descent_check_all_directions = False
torch._inductor.config.force_fuse_int_mm_with_mul = False
torch._inductor.config.fx_graph_cache = False
torch._inductor.config.triton.unique_kernel_names = False
torch.set_float32_matmul_precision("highest")
def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
new_tuples.append(tuple1 + tuple2)
return new_tuples
def run_supported_device_dtype(test_method):
"""Assumes that the 3rd arg (args[2]) of the decorated method is device and
there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing
"""
def wrapper(*args, **kwargs):
if len(args) < 3:
raise unittest.SkipTest(f"Not enough args. Expected more than or equal to 3, but got {len(args)}")
device = args[2]
dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3]
if device == "cuda" and not torch.cuda.is_available():
raise unittest.SkipTest(f"Need CUDA available.")
if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
raise unittest.SkipTest("Need CUDA and SM80+ available.")
return test_method(*args, **kwargs)
return wrapper
class SmoothquantUnitTest(unittest.TestCase):
# first, let's reproduce the graphic from the paper, Figure 4, to ensure
# we are calculating the scales correctly
def test_figure_4(self):
X = torch.FloatTensor([1, -16, 2, 6, -2, 8, -1, -9]).reshape(1, 2, 4)
W = torch.FloatTensor([2, 1, -2, 1, -1, -1, 2, -1, -2, -1, -1, 1]).reshape(4, 3)
X_mul_W = torch.matmul(X, W)
smoothquant_scale = get_scale(
torch.amax(torch.abs(X), dim=(0, 1)),
torch.amax(torch.abs(W), dim=1),
alpha=0.5,
)
# reproduce scaled calculation
X_scaled = X / smoothquant_scale.reshape(1, 1, -1)
W_scaled = torch.matmul(torch.diag(smoothquant_scale), W)
X_scaled_mul_scaled_W = torch.matmul(X_scaled, W_scaled)
assert torch.allclose(X_mul_W, X_scaled_mul_scaled_W), "not close!"
assert X_mul_W.shape == X_scaled_mul_scaled_W.shape
# next, run the above test on a sample of representative inputs
def test_tensors(self):
x_shape = (1, 5, 7)
w_shape = (7, 9)
for i in range(3):
X = torch.randn(x_shape) * 10
W = torch.randn(w_shape)
s = get_scale(
torch.amax(torch.abs(X), dim=(0, 1)),
torch.amax(torch.abs(W), dim=1),
alpha=0.5,
)
Y = torch.matmul(X, W)
Y_ref = torch.matmul(
X / s.reshape(1, 1, -1),
torch.matmul(torch.diag(s), W),
)
assert torch.allclose(Y, Y_ref, atol=1e-3, rtol=1e-3), "not close!"
def _test_smooth_linear_impl(self, x_shape, lin_shape, device):
orig_backend = torch.backends.quantized.engine
# so we can use the full range
torch.backends.quantized.engine = "qnnpack"
x = torch.randn(*x_shape, device=device) * 9 + 10
lin_fp32 = nn.Linear(*lin_shape, device=device) # misc: ignore
lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float(
copy.deepcopy(lin_fp32), alpha=0.25
)
lin_smooth_skip_scaling = SmoothFakeDynamicallyQuantizedLinear.from_float(
copy.deepcopy(lin_fp32), alpha=0.25
)
lin_fp32_copy = copy.deepcopy(lin_fp32) # assignment: ignore
lin_fp32_copy.qconfig = torch.ao.quantization.QConfig( # assignment: ignore
activation=None,
weight=torch.ao.quantization.default_per_channel_weight_observer,
)
lin_dynamic_q = torch.ao.nn.quantized.dynamic.Linear.from_float(
lin_fp32_copy.cpu()
)
y_ref = lin_fp32(x)
# calibrate the smoothquant versions
y_smooth_nocalib = lin_smooth(x)
_ = lin_smooth_skip_scaling(x)
lin_smooth.to_inference()
lin_smooth_skip_scaling.debug_skip_scaling = True
lin_smooth_skip_scaling.to_inference()
# verify that with scaling turned off, numerics match quantized version
y_smooth_fq_only = lin_smooth_skip_scaling(x)
y_smooth_fq = lin_smooth(x)
y_dynamic_q = lin_dynamic_q(x.cpu()).to(device)
# print('y_ref', y_ref)
# print('y_smooth_nocalib', y_smooth_nocalib)
# print('y_smooth_fq', y_smooth_fq)
# print('y_smooth_fq_only', y_smooth_fq_only)
# print('y_dynamic_q', y_dynamic_q)
sqnr_smooth_fq = compute_error(y_ref, y_smooth_fq)
sqnr_dynamic_q = compute_error(y_ref, y_dynamic_q)
sqnr_fq = compute_error(y_smooth_fq_only, y_dynamic_q)
# print('sqnr_smooth', sqnr_smooth_fq, 'sqnr_dynamic', sqnr_dynamic_q, 'sqnr_fq', sqnr_fq)
assert torch.allclose(
y_ref, y_smooth_nocalib
), "y_ref not close to y_smooth_nocalib"
# after https://github.com/pytorch-labs/ao_benchmarks/pull/32,
# numerics do not match exactly between production c++ code
# and this Python code
# assert torch.allclose(
# y_smooth_fq_only, y_dynamic_q,
# atol=torch.max(y_smooth_fq_only).item()*0.01,
# rtol=0.00001), \
# 'y_smooth_fq_only not close to y_dynamic_q'
self.assertTrue(sqnr_smooth_fq.item() >= 40.0, f"got: {sqnr_smooth_fq.item()}")
self.assertTrue(sqnr_dynamic_q.item() >= 40.0, f"got: {sqnr_dynamic_q.item()}")
self.assertTrue(sqnr_fq.item() >= 40.0, f"got: {sqnr_fq.item()}")
# Restore backend
torch.backends.quantized.engine = orig_backend
def test_smooth_linear_cpu(self):
self._test_smooth_linear_impl((1, 5, 3), (3, 4), "cpu")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_smooth_linear_cuda(self):
self._test_smooth_linear_impl((1, 32, 32), (32, 16), "cuda")
def test_smooth_linear_edge_cases(self):
orig_backend = torch.backends.quantized.engine
# so we can use the full range
torch.backends.quantized.engine = "qnnpack"
lin_fp32 = nn.Linear(3, 4)
lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float(
lin_fp32, alpha=0.25
)
# test different ranks
x0 = torch.randn(4, 5, 3)
x1 = torch.randn(1, 8, 5, 3)
x2 = torch.randn(2, 3, 7, 5, 3)
# calibrate
_ = lin_smooth(x0)
_ = lin_smooth(x1)
_ = lin_smooth(x2)
# inference
lin_smooth.to_inference()
_ = lin_smooth(x0)
_ = lin_smooth(x1)
_ = lin_smooth(x2)
# Restore backend
torch.backends.quantized.engine = orig_backend
def test_swap(self):
m = nn.Sequential(
nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 4)),
nn.Linear(4, 4),
)
m_copy = copy.deepcopy(m)
swap_linear_with_smooth_fq_linear(m_copy, skip_fqn_list=["0.2"])
# verify all linears are swapped
assert isinstance(m_copy[0][0], SmoothFakeDynamicallyQuantizedLinear)
assert isinstance(m_copy[0][1], nn.ReLU)
# this one was skipped
assert isinstance(m_copy[0][2], nn.Linear)
assert isinstance(m_copy[1], SmoothFakeDynamicallyQuantizedLinear)
# verify results do not change without smoothing
x = torch.randn(4, 4)
y_ref = m(x)
y = m_copy(x)
assert torch.allclose(y_ref, y)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
def test_weight_t_and_non_t_numerics_match(self):
# verify that numerics match whether weight is stored
# in transposed format (for cuBLAS) vs non-transposed format
# (for torch.compile)
dtype = torch.half
device = "cuda"
lin_ref = nn.Linear(32, 16, dtype=dtype, device=device)
lin_eager_t = copy.deepcopy(lin_ref)
lin_opt_t = copy.deepcopy(lin_eager_t)
lin_opt = copy.deepcopy(lin_eager_t)
lin_eager_t = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_eager_t)
lin_opt_t = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_opt_t)
lin_opt = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_opt)
lin_opt.store_w_int_repr_t = False
x = torch.randn(32, 32, dtype=dtype, device=device)
y_calib_eager_t = lin_eager_t(x)
y_calib_opt_t = lin_opt_t(x)
y_calib_opt = lin_opt(x)
torch.testing.assert_close(y_calib_eager_t, y_calib_opt_t)
torch.testing.assert_close(y_calib_eager_t, y_calib_opt)
lin_eager_t.to_inference()
lin_opt_t.to_inference()
lin_opt.to_inference()
torch.testing.assert_close(lin_eager_t.W_int_repr, lin_opt_t.W_int_repr)
torch.testing.assert_close(lin_eager_t.W_int_repr, lin_opt.W_int_repr)
lin_opt_t = torch.compile(lin_opt_t, mode="max-autotune")
lin_opt = torch.compile(lin_opt, mode="max-autotune")
y_ref = lin_ref(x)
y_eager = lin_eager_t(x)
y_opt_t = lin_opt_t(x)
y_opt = lin_opt(x)
if not torch.any(torch.isinf(y_ref)) and torch.any(torch.isinf(y_eager)):
# eager mode torch._int_mm is sometimes buggy, when this happens
# we can't really compare the compiled version against it properly
print("eager mode torch._int_mm known bad, test is inconclusive")
return
sqnr_ref_eager = compute_error(y_ref, y_eager)
sqnr_eager_opt_t = compute_error(y_eager, y_opt_t)
sqnr_eager_opt = compute_error(y_eager, y_opt)
# since torch.compile for a torch.half model can
# change numerics significantly, we can only test for a high SQNR here
# and not for closeness
self.assertTrue(sqnr_eager_opt_t >= 45.0)
self.assertTrue(sqnr_eager_opt >= 45.0)
# y_opt_t and y_opt should be equivalent
torch.testing.assert_close(y_opt_t, y_opt)
def test_selective_torch_compile(self):
m = nn.Sequential(
nn.Linear(4, 4),
nn.Sequential(
nn.Linear(4, 4),
nn.Linear(4, 4),
),
nn.Linear(4, 4),
)
x = torch.randn(4, 4)
y_ref = m(x)
_replace_with_custom_fn_if_matches_filter(
m,
lambda mod: torch.compile(mod),
lambda mod, fqn: isinstance(mod, nn.Linear) and fqn != "1.0",
)
self.assertTrue(isinstance(m[0], torch._dynamo.eval_frame.OptimizedModule))
self.assertTrue(isinstance(m[1][0], nn.Linear))
self.assertTrue(isinstance(m[1][1], torch._dynamo.eval_frame.OptimizedModule))
self.assertTrue(isinstance(m[2], torch._dynamo.eval_frame.OptimizedModule))
y = m(x)
torch.testing.assert_close(y, y_ref)
def test_debug_x_absmax(self):
m = nn.Sequential(nn.Linear(3, 4))
x0 = torch.randn(4, 5, 3)
y0 = m(x0)
swap_linear_with_smooth_fq_linear(m)
# no calibration, straight to inference, should not crash
smooth_fq_linear_to_inference(m, debug_skip_calibration=True)
y1 = m(x0)
class PythonQuantUtilOpUnitTest(unittest.TestCase):
def _test_dynamic_quant_per_channel_numerics_impl(
self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device
):
# verifies that dynamic quant per channel in plain pytorch matches
# numerics of production AO code
# TODO(future): test this on cpu-half, need to first make
# torch.aminmax support half on cpu
x = torch.randn(16, 32, device=device, dtype=float_dtype)
y_vals, y_scale, y_zero_point = dynamically_quantize_per_channel(
x, qmin, qmax, int_dtype
)
min_val, max_val = torch.aminmax(x, dim=1)
# reference
weight_obs = torch.ao.quantization.MovingAveragePerChannelMinMaxObserver(
dtype=qint_dtype,
quant_min=qmin,
quant_max=qmax,
qscheme=torch.per_channel_symmetric,
averaging_constant=1.0, # make it ignore previous iterations
)
weight_obs(x)
y_ref_scale, y_ref_zp = weight_obs.calculate_qparams()
y_ref_scale = y_ref_scale.to(device)
y_ref_zp = y_ref_zp.to(device)
# quantize_per_channel doesn't work for half, so we cast there and back
x_for_ref = x.half().float() if float_dtype == torch.float16 else x
y_ref = torch.quantize_per_channel(
x_for_ref, y_ref_scale, y_ref_zp, 0, qint_dtype
)
torch.testing.assert_close(
y_scale, y_ref.q_per_channel_scales().to(float_dtype)
)
assert torch.equal(y_zero_point, y_ref.q_per_channel_zero_points())
# this test case has one element where the rounding is off by one
# from Python-only code vs the c++ code, it's easy to repro with
# various shapes.
# Discussion here is relevant: https://github.com/pytorch/pytorch/issues/16498
# TODO(future): figure out what to do about this
# assert torch.equal(int_vals, q_reference.int_repr())
assert torch.max(torch.abs(y_vals - y_ref.int_repr())) <= 1
# dequantize
x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point, out_dtype=float_dtype)
x_ref_dq = y_ref.dequantize().to(float_dtype)
# off-by-one for scale is okay
torch.testing.assert_close(
x_dq, x_ref_dq, atol=torch.max(y_scale).item() * 1.01, rtol=0.0001
)
def test_dynamic_quant_per_channel_numerics_cpu(self):
test_cases = ((-128, 127, torch.int8, torch.qint8, torch.float32, "cpu"),)
for row in test_cases:
self._test_dynamic_quant_per_channel_numerics_impl(*row)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("AssertionError: Tensor-likes are not close!")
def test_dynamic_quant_per_channel_numerics_cuda(self):
test_cases = (
(-128, 127, torch.int8, torch.qint8, torch.float32, "cuda"),
(-128, 127, torch.int8, torch.qint8, torch.float16, "cuda"),
)
for row in test_cases:
self._test_dynamic_quant_per_channel_numerics_impl(*row)
def _test_quantize_per_token_impl(self, device, dtype):
x = torch.randn(3, 3, 3, device=device, dtype=dtype)
xq, scales = quantize_activation_per_token_absmax(x)
block_size = (1, 1, 3)
x_dq = dequantize_affine(xq, block_size, scales, None, torch.int8, output_dtype=x.dtype)
sqnr = compute_error(x, x_dq)
self.assertTrue(sqnr >= 45.0)
def test_quantize_per_token_cpu(self):
for dtype in (torch.float32, torch.float16, torch.bfloat16):
self._test_quantize_per_token_impl("cpu", dtype)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantize_per_token_cuda(self):
for dtype in (torch.float32, torch.float16, torch.bfloat16):
self._test_quantize_per_token_impl("cuda", dtype)
def _test_per_token_linear_impl(self, device, dtype):
x = torch.randn(2, 16, 8, device=device, dtype=dtype)
w = torch.randn(16, 8, device=device, dtype=dtype)
wq, w_scales, _w_zp = dynamically_quantize_per_channel(w, -127, 127, torch.int8)
# Note: need to make the weight contiguous because we are
# testing in eager mode and cuBlas will not give correct results
# for a transposed weight
y = quant_int8_dynamic_per_token_linear(
x, wq.t().contiguous(), w_scales, None, dtype
)
y_ref = torch.matmul(x, w.t())
sqnr = compute_error(y_ref, y)
self.assertTrue(sqnr >= 42.0)
def test_per_token_linear_cpu(self):
for dtype in (torch.float32,):
self._test_per_token_linear_impl("cpu", dtype)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_per_token_linear_cuda(self):
for dtype in (torch.float32, torch.float16, torch.bfloat16):
self._test_per_token_linear_impl("cuda", dtype)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test__int_mm(self):
# TODO(future): figure out what here needs to move to PT core,
# if it's not already tested there
m, k, n = 32, 32, 16
x = torch.randint(-128, 127, (m, k), dtype=torch.int8, device="cuda")
w = torch.randint(-128, 127, (k, n), dtype=torch.int8, device="cuda")
y_ref = torch.matmul(x.float(), w.float()).to(torch.int32)
y_raw = safe_int_mm(x, w)
wrap_in_mm_opt = torch.compile(safe_int_mm, mode="max-autotune")
# note: triton chokes on the line below on k == 8 and n == 8 with
# https://www.internalfb.com/phabricator/paste/view/P683467944
# TODO(future): file an issue
y_opt = wrap_in_mm_opt(x, w)
torch.testing.assert_close(y_ref, y_raw, atol=0, rtol=0)
torch.testing.assert_close(y_ref, y_opt, atol=0, rtol=0)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test__int_mm_eager_and_torch_compile_numerics(self):
def __int_mm_ref(x, w):
x = x.cpu().to(torch.int32)
w = w.cpu().to(torch.int32)
y = torch.matmul(x, w)
return y.cuda()
shapes = (
# minimal test shape
((1, 32, 32), (32, 16)),
# paste of real linear shapes from LLaMa 1.5b
((17, 1, 1536), (1536, 1536)),
((17, 8, 4096), (4096, 1536)),
((17, 1, 1536), (1536, 4096)),
((17, 8, 1536), (1536, 1536)),
((17, 1, 4096), (4096, 1536)),
((17, 8, 1536), (1536, 4096)),
)
for x_shape, w_shape in shapes:
def wrap_torch_int_mm(x, w):
b, n, k = x.shape
k, m = w.shape
x = x.reshape(b * n, k)
res = safe_int_mm(x, w)
res = res.reshape(b, n, m)
return res
wrap_torch_int_mm_opt = torch.compile(
wrap_torch_int_mm, mode="max-autotune"
)
x = torch.randint(-128, 127, x_shape, dtype=torch.int8, device="cuda")
w = torch.randint(-128, 127, w_shape, dtype=torch.int8, device="cuda")
z_ref = __int_mm_ref(x, w)
z_eager = wrap_torch_int_mm(x, w)
z_torch_compile = wrap_torch_int_mm_opt(x, w)
# print(z_ref)
# print(z_eager)
# print(z_torch_compile)
torch.testing.assert_close(z_ref, z_eager, atol=0, rtol=0)
torch.testing.assert_close(z_ref, z_torch_compile, atol=0, rtol=0)
class TestSubclass(unittest.TestCase):
@run_supported_device_dtype
def _test_dequantize_impl(
self,
test_subclass_from_float,
test_device,
min_sqnr=35,
test_dtype=torch.bfloat16,
test_shape=(32, 64, 64),
):
m, k, n = test_shape
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
w = lin.weight.detach()
lin.weight = torch.nn.Parameter(
test_subclass_from_float(lin.weight), requires_grad=False
)
self.assertGreater(
SQNR(w, lin.weight.dequantize()),
min_sqnr,
f"{lin.weight.__class__.__name__} failed dtype={test_dtype}"
)
self.assertGreater(
SQNR(w.t(),
lin.weight.t().dequantize()),
min_sqnr,
f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}"
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_dequantize_int8_dynamic_quant_subclass(self, device, dtype):
self._test_dequantize_impl(
Int8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype,
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
self._test_dequantize_impl(
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 8)] if device=='cuda' else [])):
self._test_dequantize_impl(
Int4WeightOnlyQuantizedLinearWeight.from_float, device, 15, test_shape=test_shape, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
m_shapes = [16, 256] + ([1] if device=="cuda" else [])
n_shapes = [16] + ([8, 13] if device=="cuda" else [])
for groupsize in [256, 128]:
for inner_k_tiles in [8, 4, 2]:
for m in m_shapes:
for n in n_shapes:
self._test_dequantize_impl(
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles),
device,
15,
test_shape=[m, 256, n],
test_dtype=dtype,
)
@run_supported_device_dtype
def _test_lin_weight_subclass_impl(
self,
test_subclass_from_float,
test_device,
min_sqnr=35,
test_dtype=torch.bfloat16,
test_shape=(32, 64, 32),
):
if not "cuda" in test_device:
self.skipTest("test requires cuda")
with torch.no_grad():
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
ref_f = lin(x)
lin.weight = torch.nn.Parameter(
test_subclass_from_float(lin.weight), requires_grad=False
)
test = lin(x)
self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}"
)
lin_comp = torch.compile(lin, mode='max-autotune')
test_comp = lin_comp(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}"
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen")
def test_int8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
Int8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_weight_only_quant_subclass(self, device, dtype):
undo_recommended_configs()
self._test_lin_weight_subclass_impl(
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skip(
"This segfaults in CI cuda only, disable to unblock PR, we can investigate "
"later if needed"
)
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8WeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Fails for {dtype}")
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 8)] if device=='cuda' else [])):
self._test_lin_weight_subclass_impl(
Int4WeightOnlyQuantizedLinearWeight.from_float, device, 10, test_shape=test_shape, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
m_shapes = [16, 256] + ([1] if device=="cuda" else [])
n_shapes = [16] + ([8, 13] if device=="cuda" else [])
for groupsize in [128, 64]:
for inner_k_tiles in [8, 4, 2]:
for m in m_shapes:
for n in n_shapes:
self._test_lin_weight_subclass_impl(
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles),
device,
10,
test_shape=[m, 256, n],
test_dtype=dtype,
)
@torch.no_grad()
@run_supported_device_dtype
def _test_lin_weight_subclass_api_impl(
self,
api,
test_device,
min_sqnr=35,
test_dtype=torch.bfloat16,
test_shape=(32, 64, 32)
):
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
mod = nn.Sequential(
nn.Linear(k, n, device=test_device), nn.ReLU(), nn.Linear(n, n, device=test_device)
).to(test_dtype)
ref_f = mod(x)
api(mod)
test = mod(x)
self.assertGreater(
SQNR(ref_f, test),
min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}"
)
mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp), min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}"
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen")
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
undo_recommended_configs()
self._test_lin_weight_subclass_api_impl(
_int8wo_api, device, 40, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch._inductor.config.patch({"freezing": True})
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after.")
def test_int8_weight_only_quant_with_freeze(self, device, dtype):
torch._dynamo.reset()
self._test_lin_weight_subclass_api_impl(
_int8wo_api, device, 40, test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])):
self._test_lin_weight_subclass_api_impl(
_int4wo_api,
device,
15,
test_shape=test_shape,
test_dtype=dtype
)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
def api(mod):
kwargs_copy = kwargs.copy()
if TORCH_VERSION_AT_LEAST_2_4:
kwargs_copy["group_size"] = groupsize
del kwargs_copy["groupsize"]
quantize_(mod, int4_weight_only(**kwargs_copy))
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
kwargs_copy["inner_k_tiles"] = inner_k_tiles
del kwargs_copy["layout"]
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)
self._test_lin_weight_subclass_api_impl(
api,
device,
15,
test_shape=test_shape,
test_dtype=dtype,
)
class TestDynamicQuant(unittest.TestCase):
def test_dynamic_quant(self):
M, K, N = 8, 16, 8
x = torch.randn(M, K)
m = nn.Sequential(nn.Linear(K, N))
y_ref = m(x)
quantize_(m, int8_dynamic_activation_int8_weight())
y_test = m(x)
sqnr = compute_error(y_ref, y_test)
self.assertGreater(sqnr, 40.0)
# self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear))
class TestWeightOnlyInt8Quant(unittest.TestCase):
def test_weight_only_quant(self):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(4, 5))
y_ref = m(x)
_int8wo_api(m)
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 43.0)
def test_weight_only_groupwise_quant(self):
for x_shape in [[128, 512]]:
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(512, 32))
y_ref = m(x)
_int8wo_groupwise_api(m)
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 45.0)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weight_only_quant_force_mixed_mm(self, device, dtype):
undo_recommended_configs()
if device != "cuda":
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("test requires SM capability of at least (8, 0).")
from torch._inductor import config
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True)
with config.patch({
"epilogue_fusion": True,
mixed_mm_key: mixed_mm_val,
}):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
torch._dynamo.reset()
x = torch.randn(*x_shape).to(device).to(dtype)
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
y_ref = m(x)
_int8wo_api(m)
m(x)
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreaterEqual(sqnr, 38)
if device == "cuda":
self.assertTrue("mixed_mm" in code, f"got code: {code}")
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weight_only_quant_use_mixed_mm(self, device, dtype):
undo_recommended_configs()
if device != "cuda":
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("test requires SM capability of at least (8, 0).")
torch.manual_seed(0)
from torch._inductor import config
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True)
with config.patch({
"epilogue_fusion": False,
mixed_mm_key: mixed_mm_val,
}):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
torch._dynamo.reset()
x = torch.randn(*x_shape).to(device).to(dtype)
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
y_ref = m(x)
_int8wo_api(m)
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 42.75)