1616
1717#include < torch/all.h>
1818
19- #if defined ENABLE_NVFP4 && ENABLE_NVFP4
20- void scaled_fp4_quant_sm100a (torch::Tensor const & output,
19+ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
20+ (defined (ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
21+ void scaled_fp4_quant_sm1xxa (torch::Tensor const & output,
2122 torch::Tensor const & input,
2223 torch::Tensor const & output_sf,
2324 torch::Tensor const & input_sf);
2425#endif
2526
26- #if defined ENABLE_NVFP4 && ENABLE_NVFP4
27+ #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
2728void scaled_fp4_experts_quant_sm100a (
2829 torch::Tensor& output, torch::Tensor& output_scale,
2930 torch::Tensor const & input, torch::Tensor const & input_global_scale,
@@ -33,8 +34,9 @@ void scaled_fp4_experts_quant_sm100a(
3334
3435void scaled_fp4_quant (torch::Tensor& output, torch::Tensor const & input,
3536 torch::Tensor& output_sf, torch::Tensor const & input_sf) {
36- #if defined ENABLE_NVFP4 && ENABLE_NVFP4
37- return scaled_fp4_quant_sm100a (output, input, output_sf, input_sf);
37+ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
38+ (defined (ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
39+ return scaled_fp4_quant_sm1xxa (output, input, output_sf, input_sf);
3840#endif
3941 TORCH_CHECK_NOT_IMPLEMENTED (false , " No compiled nvfp4 quantization kernel" );
4042}
@@ -44,7 +46,7 @@ void scaled_fp4_experts_quant(
4446 torch::Tensor const & input, torch::Tensor const & input_global_scale,
4547 torch::Tensor const & input_offset_by_experts,
4648 torch::Tensor const & output_scale_offset_by_experts) {
47- #if defined ENABLE_NVFP4 && ENABLE_NVFP4
49+ #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
4850 return scaled_fp4_experts_quant_sm100a (
4951 output, output_scale, input, input_global_scale, input_offset_by_experts,
5052 output_scale_offset_by_experts);
0 commit comments