11import pytest
22import torch
3-
43from torchao .float8 .float8_utils import compute_error
5- from torchao .ops import mx_fp8_bf16
4+ from torchao .ops import mx_fp8_bf16 , mx_fp4_bf16
65from torchao .prototype .mx_formats .mx_tensor import MXTensor
76from torchao .prototype .mx_formats .utils import to_blocked
8- from torchao .utils import (
9- TORCH_VERSION_AT_LEAST_2_4 ,
10- is_sm_at_least_100 ,
11- )
7+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , is_sm_at_least_100
128
139if not TORCH_VERSION_AT_LEAST_2_4 :
1410 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
1511
16-
17- def run_matrix_test (M : int , K : int , N : int ) -> float :
18- """
19- Run matrix multiplication test with given dimensions.
20-
21- Args:
22- M, K, N: Matrix dimensions
23-
24- Returns:
25- float: SQNR (Signal-to-Quantization-Noise Ratio) value
26- """
12+ def run_matrix_test (M : int , K : int , N : int , format : str = "fp8" ) -> float :
2713 dtype = torch .bfloat16
2814 device = torch .device ("cuda" )
29-
30- # Initialize matrices
15+
3116 a = torch .rand ((M , K ), dtype = dtype , device = device )
3217 b = torch .rand ((N , K ), dtype = dtype , device = device )
3318
34- # Convert to MX format
35- a_mx = MXTensor .to_mx (a , torch .float8_e4m3fn , 32 )
36- b_mx = MXTensor .to_mx (b , torch .float8_e4m3fn , 32 )
37-
38- a_fp8 = a_mx ._data
39- b_fp8 = b_mx ._data
40- assert b_fp8 .is_contiguous ()
41- b_fp8 = b_fp8 .transpose (- 1 , - 2 )
42-
43- # Get scales
44- a_scale_e8 = a_mx ._scale_e8m0 .view (M , K // 32 )
45- b_scale_e8 = b_mx ._scale_e8m0 .view (N , K // 32 )
46-
47- a_scale_block = to_blocked (a_scale_e8 )
48- b_scale_block = to_blocked (b_scale_e8 )
19+ fmt = torch .float8_e4m3fn if format == "fp8" else "fp4_e2m1"
20+ mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
21+
22+ a_mx = MXTensor .to_mx (a , fmt , 32 )
23+ b_mx = MXTensor .to_mx (b , fmt , 32 )
4924
50- # Get reference output
51- out_hp = a_mx . to_dtype ( torch . bfloat16 ) @ b_mx .to_dtype ( torch . bfloat16 ). transpose (
52- - 1 , - 2
53- )
25+ a_data = a_mx . _data
26+ b_data = b_mx ._data
27+ assert b_data . is_contiguous ()
28+ b_data = b_data . transpose ( - 1 , - 2 )
5429
55- # Run implementation
56- out_e8_fp8 = mx_fp8_bf16 ( a_fp8 , b_fp8 , a_scale_block , b_scale_block )
30+ a_scale = a_mx . _scale_e8m0 . view ( M , K // 32 )
31+ b_scale = b_mx . _scale_e8m0 . view ( N , K // 32 )
5732
58- # Calculate metrics
59- sqnr = compute_error ( out_hp , out_e8_fp8 )
33+ a_scale_block = to_blocked ( a_scale )
34+ b_scale_block = to_blocked ( b_scale )
6035
61- return sqnr .item ()
36+ out_hp = a_mx .to_dtype (torch .bfloat16 ) @ b_mx .to_dtype (torch .bfloat16 ).transpose (- 1 , - 2 )
37+ out = mx_func (a_data , b_data , a_scale_block , b_scale_block )
6238
39+ return compute_error (out_hp , out ).item ()
6340
6441@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
6542@pytest .mark .skipif (
@@ -68,35 +45,17 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
6845@pytest .mark .parametrize (
6946 "size" ,
7047 [
71- # Small matrices
72- (128 , 128 , 128 ),
73- (256 , 256 , 256 ),
74- (384 , 384 , 384 ),
75- # Medium matrices
76- (512 , 512 , 512 ),
77- (640 , 640 , 640 ),
78- (768 , 768 , 768 ),
79- # Large matrices
80- (896 , 896 , 896 ),
81- (1024 , 1024 , 1024 ),
82- # Very large matrices
83- (8192 , 8192 , 8192 ),
84- # Non-square matrices
85- (128 , 256 , 384 ),
86- (256 , 384 , 512 ),
87- (384 , 512 , 640 ),
88- # Non-aligned matrices
89- (129 , 256 , 384 ),
90- (256 , 384 , 536 ),
91- (133 , 512 , 528 ),
48+ (128 , 128 , 128 ), (256 , 256 , 256 ), (384 , 384 , 384 ), # Small
49+ (512 , 512 , 512 ), (768 , 768 , 768 ), # Medium
50+ (1024 , 1024 , 1024 ), (8192 , 8192 , 8192 ), # Large
51+ (128 , 256 , 384 ), (256 , 384 , 512 ), # Non-square
52+ (129 , 256 , 384 ), (133 , 512 , 528 ), # Non-aligned
9253 ],
9354 ids = lambda x : f"{ x [0 ]} x{ x [1 ]} x{ x [2 ]} " ,
9455)
95- def test_matrix_multiplication (size ):
96- """
97- Test matrix multiplication with various dimensions.
98- Verifies that the SQNR meets minimum quality threshold.
99- """
56+ @pytest .mark .parametrize ("format" , ["fp8" , "fp4" ])
57+ def test_matrix_multiplication (size , format ):
10058 M , K , N = size
101- sqnr = run_matrix_test (M , K , N )
102- assert sqnr >= 80.0 , f"SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
59+ sqnr = run_matrix_test (M , K , N , format )
60+ threshold = 80.0
61+ assert sqnr >= threshold , f"{ format } SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
0 commit comments