Skip to content

[Bug]: RuntimeError (expected mat1 and mat2 to have the same dtype) trying to serve gemma-3-4b-it-q4_0.gguf on NVIDIA DGX Spark #27580

@nhira

Description

@nhira

Your current environment

The output of python collect_env.py
==============================
        System Info
==============================
OS                           : Ubuntu 24.04.3 LTS (aarch64)
GCC version                  : (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version                : Could not collect
CMake version                : version 3.31.6
Libc version                 : glibc-2.39

==============================
       PyTorch Info
==============================
PyTorch version              : 2.9.0a0+50eac811a6.nv25.09
Is debug build               : False
CUDA used to build PyTorch   : 13.0
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] (64-bit runtime)
Python platform              : Linux-6.11.0-1016-nvidia-aarch64-with-glibc2.39

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 13.0.88
CUDA_MODULE_LOADING set to   : LAZY
GPU models and configuration : GPU 0: NVIDIA GB10
Nvidia driver version        : 580.95.05
cuDNN version                : Probably one of the following:
/usr/lib/aarch64-linux-gnu/libcudnn.so.9.13.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv.so.9.13.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn.so.9.13.0
/usr/lib/aarch64-linux-gnu/libcudnn_engines_precompiled.so.9.13.0
/usr/lib/aarch64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.13.0
/usr/lib/aarch64-linux-gnu/libcudnn_graph.so.9.13.0
/usr/lib/aarch64-linux-gnu/libcudnn_heuristic.so.9.13.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops.so.9.13.0
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                         aarch64
CPU op-mode(s):                       64-bit
Byte Order:                           Little Endian
CPU(s):                               20
On-line CPU(s) list:                  0-19
Vendor ID:                            ARM
Model name:                           Cortex-X925
Model:                                1
Thread(s) per core:                   1
Core(s) per socket:                   10
Socket(s):                            1
Stepping:                             r0p1
CPU(s) scaling MHz:                   91%
CPU max MHz:                          4004.0000
CPU min MHz:                          1378.0000
BogoMIPS:                             2000.00
Flags:                                fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
Model name:                           Cortex-A725
Model:                                1
Thread(s) per core:                   1
Core(s) per socket:                   10
Socket(s):                            1
Stepping:                             r0p1
CPU(s) scaling MHz:                   87%
CPU max MHz:                          2860.0000
CPU min MHz:                          338.0000
BogoMIPS:                             2000.00
Flags:                                fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
L1d cache:                            1.3 MiB (20 instances)
L1i cache:                            1.3 MiB (20 instances)
L2 cache:                             25 MiB (20 instances)
L3 cache:                             24 MiB (2 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-19
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; __user pointer sanitization
Vulnerability Spectre v2:             Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

==============================
Versions of relevant libraries
==============================
[pip3] flashinfer-python==0.4.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.1.0
[pip3] nvidia-cudnn-frontend==1.14.0
[pip3] nvidia-cutlass-dsl==4.2.1
[pip3] nvidia-dali-cuda130==1.51.2
[pip3] nvidia-ml-py==13.580.82
[pip3] nvidia-modelopt==0.33.0
[pip3] nvidia-modelopt-core==0.33.0
[pip3] nvidia-nvcomp-cu13==5.0.0.6
[pip3] nvidia-nvimgcodec-cu13==0.6.0.32
[pip3] nvidia-nvjpeg-cu13==0.0.0a0
[pip3] nvidia-nvjpeg2k-cu13==0.9.0.43
[pip3] nvidia-nvtiff-cu13==0.5.1.75
[pip3] nvidia-resiliency-ext==0.4.1+cuda13
[pip3] onnx==1.18.0
[pip3] onnx-ir==0.1.9
[pip3] onnxscript==0.3.1
[pip3] optree==0.17.0
[pip3] pynvml==13.0.1
[pip3] pyzmq==27.0.2
[pip3] torch==2.9.0a0+50eac811a6.nv25.9
[pip3] torch_tensorrt==2.9.0a0
[pip3] torchao==0.13.0+git
[pip3] torchprofile==0.0.4
[pip3] torchvision==0.24.0a0+98f8b375
[pip3] transformers==4.57.1
[pip3] triton==3.4.0+gitc817b9b6
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.11.1rc2.dev165+gd31f7844f.d20251019 (git sha: d31f7844f, date: 20251019)
vLLM Build Flags:
  CUDA Archs: 8.0 8.6 9.0 10.0 11.0 12.0+PTX; ROCm: Disabled
GPU Topology:
  	GPU0	NIC0	NIC1	NIC2	NIC3	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	NODE	NODE	NODE	NODE	0-19	0		N/A
NIC0	NODE	 X 	PIX	NODE	NODE				
NIC1	NODE	PIX	 X 	NODE	NODE				
NIC2	NODE	NODE	NODE	 X 	PIX				
NIC3	NODE	NODE	NODE	PIX	 X 				

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: rocep1s0f0
  NIC1: rocep1s0f1
  NIC2: roceP2p1s0f0
  NIC3: roceP2p1s0f1

==============================
     Environment Variables
==============================
NVIDIA_VISIBLE_DEVICES=all
CUBLAS_VERSION=13.0.2.14
PYTORCH_TRITON_VERSION=3.4.0+gitc817b9b6
NVIDIA_REQUIRE_CUDA=cuda>=9.0
TORCH_CUDA_ARCH_LIST=8.0 8.6 9.0 10.0 11.0 12.0+PTX
NCCL_VERSION=2.27.7
NVIDIA_DRIVER_CAPABILITIES=compute,utility,video
TORCH_NCCL_USE_COMM_NONBLOCKING=0
CUDA_ARCH_LIST=8.0 8.6 9.0 10.0 11.0 12.0
NVIDIA_PRODUCT_NAME=vLLM
CUDA_VERSION=13.0.1.012
PYTORCH_VERSION=2.9.0a0+50eac81
PYTORCH_BUILD_NUMBER=0
CUBLASMP_VERSION=0.5.1.65
CUDNN_FRONTEND_VERSION=1.14.0
NVIDIA_VLLM_VERSION=25.09
MAX_JOBS=12
CUDNN_VERSION=9.13.0.50
PYTORCH_HOME=/opt/pytorch/pytorch
LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/torch/lib:/usr/local/lib/python3.12/dist-packages/torch_tensorrt/lib:/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
NVIDIA_BUILD_ID=214638690
CUDA_DRIVER_VERSION=580.82.07
PYTORCH_BUILD_VERSION=2.9.0a0+50eac81
CUDA_HOME=/usr/local/cuda
CUDA_HOME=/usr/local/cuda
CUDA_MODULE_LOADING=LAZY
NVIDIA_REQUIRE_JETPACK_HOST_MOUNTS=
NVIDIA_PYTORCH_VERSION=25.09
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1

🐛 Describe the bug

I am trying to serve Gemma 3 4B (Q4_0 GGUF) on an NVIDIA DGX Spark using a quantization-aware checkpoint from HuggingFace.
Checkpoint: https://huggingface.co/google/gemma-3-4b-it-qat-q4_0-gguf/blob/main/gemma-3-4b-it-q4_0.gguf

I get an error trying to load the checkpoint using a custom Docker image. It feels related to the bf16/fp16 issue mentioned in PR 26189.
Example response:

  File "/workspace/vllm/vllm/model_executor/models/gemma3.py", line 564, in compute_logits
    logits = self.logits_processor(self.model.embed_tokens, hidden_states)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/layers/logits_processor.py", line 60, in forward
    logits = self._get_logits(hidden_states, lm_head, embedding_bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/layers/logits_processor.py", line 92, in _get_logits
    logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/layers/quantization/gguf.py", line 477, in apply
    out = fused_mul_mat_gguf(x, qweight, qweight_type)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1254, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/layers/quantization/gguf.py", line 130, in _fused_mul_mat_gguf
    return x @ qweight.T
           ~~^~~~~~~~~~~
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != c10::Half

Command used:

# vLLM API server version 0.11.1rc2.dev165+gd31f7844f.d20251019 (see Dockerfile below)
CHECKPOINT=/root/.cache/huggingface/gemma-3-4b-it-q4_0.gguf
MODEL=google/gemma-3-4b-it
vllm serve ${CHECKPOINT} --tokenizer ${MODEL}

Key log entries:

[model.py:1969] Downcasting torch.float32 to torch.bfloat16.
[__init__.py:225] Automatically detected platform cuda.
[cuda.py:403] Using Flash Attention backend on V1 engine.

Dockerfile used

# Start with the official NVIDIA NGC vLLM image.
# This provides the necessary CUDA, PyTorch, and Python environment.
FROM nvcr.io/nvidia/vllm:25.09-py3

WORKDIR /workspace 

RUN git clone https://github.com/vllm-project/vllm.git && \
  cd vllm && \
  python3 use_existing_torch.py && \
  pip3 install -r requirements/build.txt && \
  MAX_JOBS=12 pip3 install -e . --no-build-isolation

# simple entrypoint for interactive use
CMD ["/bin/bash"]

Why not try the NVIDIA/NGC image?

I cannot use nvcr.io/nvidia/vllm:25.09-py3 because this uses an older version of vLLM that does not support GGUF-loading.

Why not use the vLLM:nightly Docker image?

I cannot use the vLLM nightly image because it does not recognize this GPU.

HF_TOKEN=...
DOCKER_IMAGE="vllm/vllm-openai:nightly"
MODEL="google/gemma-3-4b-it"
MODEL_NAME="gemma-3-4b-it-q4_0"
CHECKPOINT="/root/.cache/huggingface/gemma-3-4b-it-q4_0.gguf"

docker run --gpus all --rm -it \
  -v ~/.cache/huggingface:/root/.cache/huggingface  \
  -p 8000:8000 \
  --env "HF_TOKEN=${HF_TOKEN}" \
  --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
  ${DOCKER_IMAGE} --model ${CHECKPOINT} --tokenizer ${MODEL} --hf-overrides '{"architectures": ["Gemma3ForCausalLM"]}' --served-model-name ${MODEL_NAME} 

Error:

  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/wrapper.py", line 1695, in generate_and_run_autotune_block
    raise RuntimeError(f"Failed to run autotuning code block: {e}") from e
torch._inductor.exc.InductorError: RuntimeError: Failed to run autotuning code block: No valid triton configs. PTXASError: PTXAS error: Internal Triton PTX codegen error
`ptxas` stderr:
ptxas fatal   : Value 'sm_121a' is not defined for option 'gpu-name'

Repro command: /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas -lineinfo -v --gpu-name=sm_121a /tmp/tmp_kapypnl.ptx -o /tmp/tmp_kapypnl.ptx.o

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions