Skip to content

Commit c9f9d5b

Browse files
authored
[Bugfix][AMD] Update torch_bindings so that scaled_fp4_quant isn't build on ROCm (#13235)
1 parent 0c73026 commit c9f9d5b

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

csrc/ops.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
177177
std::optional<torch::Tensor> const& bias);
178178

179179
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
180+
181+
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
182+
torch::Tensor& output_scale,
183+
torch::Tensor const& input_scale);
180184
#endif
181185

182186
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
@@ -194,10 +198,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
194198

195199
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
196200

197-
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
198-
torch::Tensor& output_scale,
199-
torch::Tensor const& input_scale);
200-
201201
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
202202
torch::Tensor const& scale);
203203

csrc/torch_bindings.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
385385
"bool silu_activation,"
386386
"int pad_slot_id) -> ()");
387387
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
388+
389+
// Compute NVFP4 block quantized tensor.
390+
ops.def(
391+
"scaled_fp4_quant(Tensor! output, Tensor input,"
392+
" Tensor! output_scale, Tensor input_scale) -> ()");
393+
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
394+
388395
#endif
389396

390397
// Quantized GEMM for GPTQ.
@@ -421,12 +428,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
421428
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
422429
&dynamic_per_token_scaled_fp8_quant);
423430

424-
// Compute NVFP4 block quantized tensor.
425-
ops.def(
426-
"scaled_fp4_quant(Tensor! output, Tensor input,"
427-
" Tensor! output_scale, Tensor input_scale) -> ()");
428-
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
429-
430431
// Compute int8 quantized tensor for given scaling factor.
431432
ops.def(
432433
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"

vllm/_custom_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ def scaled_fp4_quant(
774774
two values are packed into a uint8 and float8_e4m3 scaling factors
775775
in the sizzled layout.
776776
"""
777+
assert not current_platform.is_rocm()
777778
assert input.ndim >= 1, (
778779
f'input.ndim needs to be >= 1, but got {input.ndim}.')
779780
other_dims = 1 if input.ndim == 1 else -1

0 commit comments

Comments
 (0)