@@ -154,51 +154,101 @@ def _scaled_dot_product_int8_op_ref(
154154 out = torch .clamp (torch .round (out / o_scale ) + o_zp , min = 0 , max = 255 )
155155 return out .to (torch .uint8 )
156156
157+ def _scaled_dot_product_fp8_op_ref (
158+ self ,
159+ q ,
160+ k ,
161+ v ,
162+ attn_mask = None ,
163+ dropout_p = 0 ,
164+ is_causal = False ,
165+ q_scale = 1.0 ,
166+ k_scale = 1.0 ,
167+ v_scale = 1.0 ,
168+ a_scale = 1.0 ,
169+ o_scale = 1.0 ,
170+ ):
171+ q = q .to (torch .float ) * q_scale
172+ k = k .to (torch .float ) * k_scale
173+ v = v .to (torch .float ) * v_scale
174+ scale_factor = 1 / math .sqrt (q .size (- 1 ))
175+ attn = q @ k .transpose (- 2 , - 1 )
176+
177+ attn = attn * scale_factor
178+ if attn_mask is not None :
179+ attn = attn + attn_mask .to (torch .float )
180+ attn_max = attn .max (dim = - 1 , keepdim = True ).values
181+ attn = attn - attn_max
182+ attn = torch .exp (attn )
183+ attn_sum = torch .sum (attn , dim = - 1 , keepdim = True )
184+ attn = attn / attn_sum
185+ attn = torch .clamp (attn / a_scale , min = - 448 , max = 448 )
186+ attn = attn .to (torch .float8_e4m3fn ).to (torch .float )
187+ attn = attn * a_scale
188+ out = attn @ v
189+ out = torch .clamp (out / o_scale , min = - 448 , max = 448 )
190+ return out .to (torch .float8_e4m3fn )
191+
157192 @pytest .mark .skipif (
158193 not torch_version_at_least ("2.7.0" ),
159- reason = "int8 sdpa requires torch 2.7 or later" ,
194+ reason = "quantized sdpa requires torch 2.7 or later" ,
160195 )
161196 @pytest .mark .skipif (not IS_LINUX , reason = "only support on linux" )
162197 @pytest .mark .skipif (
163198 "CPU" not in torch ._C ._dispatch_dump ("torchao::qscaled_dot_product" ),
164199 reason = "cpp kernels not built" ,
165200 )
201+ @parametrize ("input_dtype" , [torch .uint8 , torch .float8_e4m3fn ])
166202 @parametrize ("batch_size" , [56 , 120 ])
167203 @parametrize ("n_head" , [2 , 16 ])
168204 @parametrize ("q_seq_len" , [18 , 89 ])
169205 @parametrize ("kv_seq_len" , [100 , 253 ])
170206 @parametrize ("head_dim" , [32 , 64 ])
171207 @parametrize ("mask_dtype" , [None , torch .float32 , torch .bfloat16 ])
172- def test_scaled_dot_product_int8_op (
173- self , batch_size , n_head , q_seq_len , kv_seq_len , head_dim , mask_dtype
208+ def test_quantized_scaled_dot_product_op (
209+ self ,
210+ input_dtype ,
211+ batch_size ,
212+ n_head ,
213+ q_seq_len ,
214+ kv_seq_len ,
215+ head_dim ,
216+ mask_dtype ,
174217 ):
175218 torch .manual_seed (1234 )
176219 device = "cpu"
177- q_scale = float (1.7907238006591797 )
178- q_zp = int (127 )
179- k_scale = float (1.8039721250534058 )
180- k_zp = int (125 )
181- v_scale = float (1.839004635810852 )
182- v_zp = int (127 )
183- a_scale = float (0.003919653594493866 )
184- a_zp = int (120 )
185- o_scale = float (1.8191684484481812 )
186- o_zp = int (128 )
220+ if input_dtype == torch .uint8 :
221+ q_scale = float (1.7907238006591797 )
222+ k_scale = float (1.8039721250534058 )
223+ v_scale = float (1.839004635810852 )
224+ a_scale = float (0.003919653594493866 )
225+ o_scale = float (1.8191684484481812 )
226+ q_zp = int (127 )
227+ k_zp = int (125 )
228+ v_zp = int (127 )
229+ a_zp = int (120 )
230+ o_zp = int (128 )
231+ atol , rtol = 1.0 , 5e-6
232+ else :
233+ q_scale = float (5.96875 )
234+ k_scale = float (5.78125 )
235+ v_scale = float (0.98046875 )
236+ a_scale = float (4.84375 )
237+ o_scale = float (3.171875 )
238+ atol , rtol = 0.125 , 5e-6
187239 q_shape = [batch_size , q_seq_len , n_head , head_dim ]
188240 kv_shape = [batch_size , kv_seq_len , n_head , head_dim ]
189241 mask_shape = [batch_size , 1 , 1 , kv_seq_len ]
190- q = torch .randn (q_shape , dtype = torch .float , device = device ).transpose (1 , 2 ) * 100
191- k = (
192- torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
193- * 100
194- )
195- v = (
196- torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
197- * 100
198- )
199- q = q .to (torch .uint8 )
200- k = k .to (torch .uint8 )
201- v = v .to (torch .uint8 )
242+ q = torch .randn (q_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
243+ k = torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
244+ v = torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
245+ if input_dtype == torch .uint8 :
246+ q *= 100
247+ k *= 100
248+ v *= 100
249+ q = q .to (input_dtype )
250+ k = k .to (input_dtype )
251+ v = v .to (input_dtype )
202252 attn_mask = (
203253 torch .randn (mask_shape , dtype = mask_dtype , device = device )
204254 if mask_dtype is not None
@@ -211,44 +261,71 @@ def test_scaled_dot_product_int8_op(
211261 attn_mask .clone () if mask_dtype is not None else None ,
212262 )
213263
214- math_ref = self ._scaled_dot_product_int8_op_ref (
215- q2 ,
216- k2 ,
217- v2 ,
218- attn_mask = attn_mask ,
219- dropout_p = 0.0 ,
220- is_causal = False ,
221- q_scale = q_scale ,
222- q_zp = q_zp ,
223- k_scale = k_scale ,
224- k_zp = k_zp ,
225- v_scale = v_scale ,
226- v_zp = v_zp ,
227- a_scale = a_scale ,
228- a_zp = a_zp ,
229- o_scale = o_scale ,
230- o_zp = o_zp ,
231- )
232- actual = torch .ops .torchao .qscaled_dot_product (
233- q ,
234- k ,
235- v ,
236- attn_mask = attn_mask_2 ,
237- dropout_p = 0.0 ,
238- is_causal = False ,
239- q_scale = q_scale ,
240- q_zp = q_zp ,
241- k_scale = k_scale ,
242- k_zp = k_zp ,
243- v_scale = v_scale ,
244- v_zp = v_zp ,
245- a_scale = a_scale ,
246- a_zp = a_zp ,
247- o_scale = o_scale ,
248- o_zp = o_zp ,
249- )
250-
251- self .assertEqual (actual , math_ref , atol = 1.0 , rtol = 5e-6 )
264+ if input_dtype == torch .uint8 :
265+ math_ref = self ._scaled_dot_product_int8_op_ref (
266+ q2 ,
267+ k2 ,
268+ v2 ,
269+ attn_mask = attn_mask ,
270+ dropout_p = 0.0 ,
271+ is_causal = False ,
272+ q_scale = q_scale ,
273+ q_zp = q_zp ,
274+ k_scale = k_scale ,
275+ k_zp = k_zp ,
276+ v_scale = v_scale ,
277+ v_zp = v_zp ,
278+ a_scale = a_scale ,
279+ a_zp = a_zp ,
280+ o_scale = o_scale ,
281+ o_zp = o_zp ,
282+ )
283+ actual = torch .ops .torchao .qscaled_dot_product (
284+ q ,
285+ k ,
286+ v ,
287+ attn_mask = attn_mask_2 ,
288+ dropout_p = 0.0 ,
289+ is_causal = False ,
290+ q_scale = q_scale ,
291+ q_zp = q_zp ,
292+ k_scale = k_scale ,
293+ k_zp = k_zp ,
294+ v_scale = v_scale ,
295+ v_zp = v_zp ,
296+ a_scale = a_scale ,
297+ a_zp = a_zp ,
298+ o_scale = o_scale ,
299+ o_zp = o_zp ,
300+ )
301+ else :
302+ math_ref = self ._scaled_dot_product_fp8_op_ref (
303+ q2 ,
304+ k2 ,
305+ v2 ,
306+ attn_mask = attn_mask ,
307+ dropout_p = 0.0 ,
308+ is_causal = False ,
309+ q_scale = q_scale ,
310+ k_scale = k_scale ,
311+ v_scale = v_scale ,
312+ a_scale = a_scale ,
313+ o_scale = o_scale ,
314+ )
315+ actual = torch .ops .torchao .qscaled_dot_product (
316+ q ,
317+ k ,
318+ v ,
319+ attn_mask = attn_mask_2 ,
320+ dropout_p = 0.0 ,
321+ is_causal = False ,
322+ q_scale = q_scale ,
323+ k_scale = k_scale ,
324+ v_scale = v_scale ,
325+ a_scale = a_scale ,
326+ o_scale = o_scale ,
327+ )
328+ self .assertEqual (actual .float (), math_ref .float (), atol = atol , rtol = rtol )
252329
253330
254331instantiate_parametrized_tests (TestOps )
0 commit comments