22import torch
33
44from torchao .float8 .float8_utils import compute_error
5- from torchao .ops import mx_fp8_bf16
6- from torchao .prototype .mx_formats .mx_tensor import MXTensor
5+ from torchao .ops import mx_fp4_bf16 , mx_fp8_bf16
6+ from torchao .prototype .mx_formats .mx_tensor import DTYPE_FP4 , MXTensor
77from torchao .prototype .mx_formats .utils import to_blocked
8- from torchao .utils import (
9- TORCH_VERSION_AT_LEAST_2_4 ,
10- is_sm_at_least_100 ,
11- )
8+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , is_sm_at_least_100
129
1310if not TORCH_VERSION_AT_LEAST_2_4 :
1411 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
1512
1613
17- def run_matrix_test (M : int , K : int , N : int ) -> float :
18- """
19- Run matrix multiplication test with given dimensions.
20-
21- Args:
22- M, K, N: Matrix dimensions
23-
24- Returns:
25- float: SQNR (Signal-to-Quantization-Noise Ratio) value
26- """
14+ def run_matrix_test (M : int , K : int , N : int , format ) -> float :
2715 dtype = torch .bfloat16
2816 device = torch .device ("cuda" )
2917
30- # Initialize matrices
3118 a = torch .rand ((M , K ), dtype = dtype , device = device )
3219 b = torch .rand ((N , K ), dtype = dtype , device = device )
3320
34- # Convert to MX format
35- a_mx = MXTensor .to_mx (a , torch .float8_e4m3fn , 32 )
36- b_mx = MXTensor .to_mx (b , torch .float8_e4m3fn , 32 )
21+ fmt = torch .float8_e4m3fn if format == "fp8" else DTYPE_FP4
22+ mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
3723
38- a_fp8 = a_mx ._data
39- b_fp8 = b_mx ._data
40- assert b_fp8 .is_contiguous ()
41- b_fp8 = b_fp8 .transpose (- 1 , - 2 )
24+ a_mx = MXTensor .to_mx (a , fmt , 32 )
25+ b_mx = MXTensor .to_mx (b , fmt , 32 )
4226
43- # Get scales
44- a_scale_e8 = a_mx ._scale_e8m0 .view (M , K // 32 )
45- b_scale_e8 = b_mx ._scale_e8m0 .view (N , K // 32 )
27+ a_data = a_mx ._data
28+ b_data = b_mx ._data
29+ assert b_data .is_contiguous ()
30+ b_data = b_data .transpose (- 1 , - 2 )
4631
47- a_scale_block = to_blocked (a_scale_e8 )
48- b_scale_block = to_blocked (b_scale_e8 )
32+ a_scale = a_mx ._scale_e8m0 .view (M , K // 32 )
33+ b_scale = b_mx ._scale_e8m0 .view (N , K // 32 )
34+
35+ a_scale_block = to_blocked (a_scale )
36+ b_scale_block = to_blocked (b_scale )
4937
50- # Get reference output
5138 out_hp = a_mx .to_dtype (torch .bfloat16 ) @ b_mx .to_dtype (torch .bfloat16 ).transpose (
5239 - 1 , - 2
5340 )
41+ out = mx_func (a_data , b_data , a_scale_block , b_scale_block )
5442
55- # Run implementation
56- out_e8_fp8 = mx_fp8_bf16 (a_fp8 , b_fp8 , a_scale_block , b_scale_block )
57-
58- # Calculate metrics
59- sqnr = compute_error (out_hp , out_e8_fp8 )
60-
61- return sqnr .item ()
43+ return compute_error (out_hp , out ).item ()
6244
6345
6446@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -68,35 +50,25 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
6850@pytest .mark .parametrize (
6951 "size" ,
7052 [
71- # Small matrices
7253 (128 , 128 , 128 ),
7354 (256 , 256 , 256 ),
74- (384 , 384 , 384 ),
75- # Medium matrices
55+ (384 , 384 , 384 ), # Small
7656 (512 , 512 , 512 ),
77- (640 , 640 , 640 ),
78- (768 , 768 , 768 ),
79- # Large matrices
80- (896 , 896 , 896 ),
57+ (768 , 768 , 768 ), # Medium
8158 (1024 , 1024 , 1024 ),
82- # Very large matrices
83- (8192 , 8192 , 8192 ),
84- # Non-square matrices
59+ (8192 , 8192 , 8192 ), # Large
8560 (128 , 256 , 384 ),
86- (256 , 384 , 512 ),
87- (384 , 512 , 640 ),
88- # Non-aligned matrices
61+ (256 , 384 , 512 ), # Non-square
8962 (129 , 256 , 384 ),
90- (256 , 384 , 536 ),
91- (133 , 512 , 528 ),
63+ (133 , 512 , 528 ), # Non-aligned
9264 ],
9365 ids = lambda x : f"{ x [0 ]} x{ x [1 ]} x{ x [2 ]} " ,
9466)
95- def test_matrix_multiplication (size ):
96- """
97- Test matrix multiplication with various dimensions.
98- Verifies that the SQNR meets minimum quality threshold.
99- """
67+ @pytest .mark .parametrize ("format" , ["fp8" , "fp4" ])
68+ def test_matrix_multiplication (size , format ):
10069 M , K , N = size
101- sqnr = run_matrix_test (M , K , N )
102- assert sqnr >= 80.0 , f"SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
70+ sqnr = run_matrix_test (M , K , N , format )
71+ threshold = 80.0
72+ assert (
73+ sqnr >= threshold
74+ ), f"{ format } SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
0 commit comments