@@ -259,11 +259,11 @@ def _(func, types, args, kwargs):
259259 return _float8_linear_impl (input_tensor , weight_tensor , bias )
260260
261261
262- @implements (aten .mm .default )
262+ @implements (aten .matmul .default )
263263@implements_torch_function (torch .matmul )
264264def _ (func , types , args , kwargs ):
265265 input_tensor , weight_tensor = args [0 ], args [1 ]
266- return _float8_mm_impl (input_tensor , weight_tensor )
266+ return _float8_matmul_impl (input_tensor , weight_tensor )
267267
268268
269269@implements (aten .addmm_ .default )
@@ -275,24 +275,110 @@ def _(func, types, args, kwargs):
275275 )
276276 assert kwargs .get ("alpha" , 1 ) == 1 , "only alpha=1 is supported"
277277 assert kwargs .get ("beta" , 1 ) == 1 , "only beta=1 is supported"
278- out = _float8_mm_impl (input_tensor , weight_tensor )
278+ out = _float8_matmul_impl (input_tensor , weight_tensor )
279279 return bias_tensor .add_ (out )
280280
281281
282- def _float8_mm_impl (
283- input_tensor : torch .Tensor ,
284- weight_tensor : torch .Tensor ,
282+ def _get_matmul_kernel_choice (weight_tensor : Float8Tensor ) -> str :
283+ """
284+ Return the kernel choice for matmuls, either "fbgemm" or "torch".
285+ """
286+ if weight_tensor .kernel_preference == KernelPreference .AUTO :
287+ kernel_choice = "torch"
288+ if _is_fbgemm_gpu_genai_available () and is_sm_at_least_90 ():
289+ kernel_choice = "fbgemm"
290+ elif weight_tensor .kernel_preference == KernelPreference .FBGEMM :
291+ kernel_choice = "fbgemm"
292+ else :
293+ assert weight_tensor .kernel_preference == KernelPreference .TORCH , (
294+ f"{ weight_tensor .kernel_preference = } not handled"
295+ )
296+ kernel_choice = "torch"
297+ return kernel_choice
298+
299+
300+ def _call_fbgemm_f8f8bf16_matmul (
301+ input_tensor : Float8Tensor ,
302+ weight_tensor : Float8Tensor ,
303+ bias : Optional [torch .Tensor ] = None ,
304+ weight_is_already_transposed : bool = False ,
285305) -> torch .Tensor :
286- assert isinstance (weight_tensor , Float8Tensor ), (
287- f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )} "
306+ """
307+ Call `torch.ops.fbgemm.f8f8bf16*` ops.
308+
309+ These ops expect the weight tensor to be transposed.
310+ If `weight_is_already_transposed=True` (e.g. in the linear case),
311+ then we avoid unnecessarily double transposing the weight.
312+ """
313+ assert _is_fbgemm_gpu_genai_available (), (
314+ "Expected fbgemm_gpu_genai package to be installed"
288315 )
289- is_transposed = weight_tensor .qdata .stride (- 2 ) < weight_tensor .qdata .stride (- 1 )
290- # For matmul(x, w.t()), just call the linear implementation
291- # For matmul(x, w), just dequantize for now, we can optimize later
292- if is_transposed :
293- return _float8_linear_impl (input_tensor , weight_tensor .t ())
316+ assert is_sm_at_least_90 (), "Expected SM90+ for fbgemm_gpu_genai"
317+ mm_config = weight_tensor .mm_config
318+ assert mm_config is not None
319+
320+ if not weight_is_already_transposed :
321+ weight_tensor = weight_tensor .t ()
322+
323+ out_shape = get_out_shape (input_tensor .shape , weight_tensor .shape )
324+ xq = input_tensor .qdata .reshape (- 1 , input_tensor .qdata .shape [- 1 ])
325+ wq = weight_tensor .qdata
326+ x_scale = input_tensor .scale
327+ w_scale = weight_tensor .scale
328+ if _is_rowwise_scaled (weight_tensor ):
329+ assert _is_rowwise_scaled (input_tensor ), (
330+ "Input tensor must be rowwise block size"
331+ )
332+ res = torch .ops .fbgemm .f8f8bf16_rowwise (
333+ xq ,
334+ wq ,
335+ x_scale ,
336+ w_scale ,
337+ bias = bias ,
338+ use_fast_accum = mm_config .use_fast_accum ,
339+ ).reshape (out_shape )
294340 else :
295- return torch .matmul (input_tensor , weight_tensor .dequantize ())
341+ assert _is_tensorwise_scaled (weight_tensor )
342+ assert _is_tensorwise_scaled (input_tensor )
343+ res = torch .ops .fbgemm .f8f8bf16 (
344+ xq ,
345+ wq ,
346+ x_scale * w_scale ,
347+ use_fast_accum = mm_config .use_fast_accum ,
348+ ).reshape (out_shape )
349+ if bias is not None :
350+ res = res + bias
351+ return res
352+
353+
354+ def _call_float8_scaled_mm (
355+ input_tensor : Float8Tensor ,
356+ weight_tensor : Float8Tensor ,
357+ bias : Optional [torch .Tensor ] = None ,
358+ ) -> torch .Tensor :
359+ scaled_mm_config = weight_tensor .mm_config
360+ assert scaled_mm_config is not None
361+ assert weight_tensor .dim () == 2
362+ out_shape = (* input_tensor .shape [:- 1 ], weight_tensor .shape [1 ])
363+
364+ # Extract tensor data and scales
365+ inpt_data = input_tensor .qdata .reshape (- 1 , input_tensor .qdata .shape [- 1 ])
366+ w_data = weight_tensor .qdata
367+ input_scale = input_tensor .scale
368+ w_scale = weight_tensor .scale
369+
370+ input_scale = preprocess_scale (input_scale , input_tensor .shape )
371+ inpt_data , w_data = preprocess_data (inpt_data , w_data , scaled_mm_config )
372+
373+ return addmm_float8_unwrapped_inference (
374+ inpt_data ,
375+ input_scale ,
376+ w_data ,
377+ w_scale ,
378+ output_dtype = input_tensor .dtype ,
379+ bias = bias ,
380+ use_fast_accum = scaled_mm_config .use_fast_accum ,
381+ ).reshape (out_shape )
296382
297383
298384def _float8_linear_impl (
@@ -312,88 +398,17 @@ def _float8_linear_impl(
312398 )
313399
314400 if isinstance (input_tensor , Float8Tensor ):
315- kernel_choice = None
316-
317- if weight_tensor .kernel_preference == KernelPreference .AUTO :
318- kernel_choice = "torch"
319- if _is_fbgemm_gpu_genai_available () and is_sm_at_least_90 ():
320- kernel_choice = "fbgemm"
321- elif weight_tensor .kernel_preference == KernelPreference .FBGEMM :
322- kernel_choice = "fbgemm"
323- else :
324- assert weight_tensor .kernel_preference == KernelPreference .TORCH , (
325- f"{ weight_tensor .kernel_preference = } not handled"
326- )
327- kernel_choice = "torch"
328-
401+ kernel_choice = _get_matmul_kernel_choice (weight_tensor )
329402 if kernel_choice == "fbgemm" :
330- assert _is_fbgemm_gpu_genai_available (), (
331- "Expected fbgemm_gpu_genai package to be installed"
403+ return _call_fbgemm_f8f8bf16_matmul (
404+ input_tensor ,
405+ weight_tensor ,
406+ bias ,
407+ weight_is_already_transposed = True ,
332408 )
333- assert is_sm_at_least_90 (), "Expected SM90+ for fbgemm_gpu_genai"
334- mm_config = weight_tensor .mm_config
335- assert mm_config is not None
336-
337- out_shape = get_out_shape (input_tensor .shape , weight_tensor .shape )
338- xq = input_tensor .qdata .reshape (- 1 , input_tensor .qdata .shape [- 1 ])
339- wq = weight_tensor .qdata
340- x_scale = input_tensor .scale
341- w_scale = weight_tensor .scale
342- if _is_rowwise_scaled (weight_tensor ):
343- assert _is_rowwise_scaled (input_tensor ), (
344- "Input tensor must be rowwise block size"
345- )
346- res = torch .ops .fbgemm .f8f8bf16_rowwise (
347- xq ,
348- wq ,
349- x_scale ,
350- w_scale ,
351- bias = bias ,
352- use_fast_accum = mm_config .use_fast_accum ,
353- ).reshape (out_shape )
354- else :
355- assert _is_tensorwise_scaled (weight_tensor )
356- assert _is_tensorwise_scaled (input_tensor )
357- res = torch .ops .fbgemm .f8f8bf16 (
358- xq ,
359- wq ,
360- x_scale * w_scale ,
361- use_fast_accum = mm_config .use_fast_accum ,
362- ).reshape (out_shape )
363- if bias is not None :
364- res = res + bias
365- return res
366409 else :
367410 assert kernel_choice == "torch"
368- scaled_mm_config = weight_tensor .mm_config
369- assert scaled_mm_config is not None
370- out_shape = get_out_shape (input_tensor .shape , weight_tensor .shape )
371-
372- # Extract tensor data and scales
373- inpt_data = input_tensor .qdata .reshape (- 1 , input_tensor .qdata .shape [- 1 ])
374- w_data = weight_tensor .qdata
375- input_scale = input_tensor .scale
376- w_scale = weight_tensor .scale
377-
378- # Handle rowwise scaling
379- if _is_rowwise_scaled (weight_tensor ):
380- assert _is_rowwise_scaled (input_tensor ), (
381- "Input tensor must be rowwise block size"
382- )
383- w_scale = w_scale .transpose (- 1 , - 2 )
384-
385- input_scale = preprocess_scale (input_scale , input_tensor .shape )
386- inpt_data , w_data = preprocess_data (inpt_data , w_data .T , scaled_mm_config )
387-
388- return addmm_float8_unwrapped_inference (
389- inpt_data ,
390- input_scale ,
391- w_data ,
392- w_scale ,
393- output_dtype = input_tensor .dtype ,
394- bias = bias ,
395- use_fast_accum = scaled_mm_config .use_fast_accum ,
396- ).reshape (out_shape )
411+ return _call_float8_scaled_mm (input_tensor , weight_tensor .t (), bias )
397412 else :
398413 assert not isinstance (input_tensor , TorchAOBaseTensor ), (
399414 "Expecting input_tensor to be unquantized"
@@ -405,6 +420,37 @@ def _float8_linear_impl(
405420 )
406421
407422
423+ def _float8_matmul_impl (
424+ input_tensor : torch .Tensor ,
425+ weight_tensor : torch .Tensor ,
426+ ) -> torch .Tensor :
427+ assert isinstance (weight_tensor , Float8Tensor ), (
428+ f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )} "
429+ )
430+
431+ act_quant_kwargs = weight_tensor .act_quant_kwargs
432+ # quantizing activation, if `act_quant_kwargs` is specified
433+ if act_quant_kwargs is not None :
434+ input_tensor = _choose_quant_func_and_quantize_tensor (
435+ input_tensor , act_quant_kwargs
436+ )
437+
438+ if isinstance (input_tensor , Float8Tensor ):
439+ kernel_choice = _get_matmul_kernel_choice (weight_tensor )
440+ if kernel_choice == "fbgemm" :
441+ return _call_fbgemm_f8f8bf16_matmul (input_tensor , weight_tensor )
442+ else :
443+ assert kernel_choice == "torch"
444+ return _call_float8_scaled_mm (input_tensor , weight_tensor )
445+ else :
446+ assert not isinstance (input_tensor , TorchAOBaseTensor ), (
447+ "Expecting input_tensor to be unquantized"
448+ )
449+ # when input is not `Float8Tensor`, we expect that it is not quantized
450+ # so this is float8 weight only quantization
451+ return torch .matmul (input_tensor , weight_tensor .dequantize ())
452+
453+
408454@implements_torch_function (torch .bmm )
409455def _ (func , types , args , kwargs ):
410456 input_tensor , weight_tensor = (
0 commit comments