diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 9b6e878f7d..57f8f3bb84 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -151,7 +151,7 @@ def _to_mxfp8_dim1_kernel_wrapper( block_size, elem_dtype, hp_dtype, - gemm_kernel_choice, + kernel_preference, cast_kernel_choice, scale_calculation_mode: ScaleCalculationMode, ): @@ -187,7 +187,7 @@ def _to_mxfp8_dim1_kernel_wrapper( elem_dtype, block_size, hp_dtype, - gemm_kernel_choice, + kernel_preference, None, is_swizzled_scales, ) @@ -206,7 +206,7 @@ def _to_mxfp8_dim1_kernel_wrapper( elem_dtype, block_size, hp_dtype, - gemm_kernel_choice, + kernel_preference, None, is_swizzled_scales, )