@@ -385,6 +385,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
385385 " bool silu_activation,"
386386 " int pad_slot_id) -> ()" );
387387 ops.impl (" causal_conv1d_fwd" , torch::kCUDA , &causal_conv1d_fwd);
388+
389+ // Compute NVFP4 block quantized tensor.
390+ ops.def (
391+ " scaled_fp4_quant(Tensor! output, Tensor input,"
392+ " Tensor! output_scale, Tensor input_scale) -> ()" );
393+ ops.impl (" scaled_fp4_quant" , torch::kCUDA , &scaled_fp4_quant);
394+
388395#endif
389396
390397 // Quantized GEMM for GPTQ.
@@ -421,12 +428,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
421428 ops.impl (" dynamic_per_token_scaled_fp8_quant" , torch::kCUDA ,
422429 &dynamic_per_token_scaled_fp8_quant);
423430
424- // Compute NVFP4 block quantized tensor.
425- ops.def (
426- " scaled_fp4_quant(Tensor! output, Tensor input,"
427- " Tensor! output_scale, Tensor input_scale) -> ()" );
428- ops.impl (" scaled_fp4_quant" , torch::kCUDA , &scaled_fp4_quant);
429-
430431 // Compute int8 quantized tensor for given scaling factor.
431432 ops.def (
432433 " static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
0 commit comments