Skip to content

RuntimeError (expected mat1 and mat2 to have the same dtype) trying to serve gemma-3-4b-it-q4_0.gguf on NVIDIA DGX Spark #105

@nhira

Description

@nhira

I am trying to serve Gemma 3 4B (Q4_0 GGUF) on an NVIDIA DGX Spark using a quantization-aware checkpoint from HuggingFace.
Checkpoint: https://huggingface.co/google/gemma-3-4b-it-qat-q4_0-gguf/blob/main/gemma-3-4b-it-q4_0.gguf

I get an error trying to load the checkpoint using a custom Docker image. It feels related to the bf16/fp16 issue mentioned in PR 26189.
Example response:

  File "/workspace/vllm/vllm/model_executor/models/gemma3.py", line 564, in compute_logits
    logits = self.logits_processor(self.model.embed_tokens, hidden_states)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/layers/logits_processor.py", line 60, in forward
    logits = self._get_logits(hidden_states, lm_head, embedding_bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/layers/logits_processor.py", line 92, in _get_logits
    logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/layers/quantization/gguf.py", line 477, in apply
    out = fused_mul_mat_gguf(x, qweight, qweight_type)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1254, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/layers/quantization/gguf.py", line 130, in _fused_mul_mat_gguf
    return x @ qweight.T
           ~~^~~~~~~~~~~
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != c10::Half

Command used:

# vLLM API server version 0.11.1rc2.dev165+gd31f7844f.d20251019 (see Dockerfile below)
CHECKPOINT=/root/.cache/huggingface/gemma-3-4b-it-q4_0.gguf
MODEL=google/gemma-3-4b-it
vllm serve ${CHECKPOINT} --tokenizer ${MODEL}

Key log entries:

[model.py:1969] Downcasting torch.float32 to torch.bfloat16.
[__init__.py:225] Automatically detected platform cuda.
[cuda.py:403] Using Flash Attention backend on V1 engine.

Dockerfile used

# Start with the official NVIDIA NGC vLLM image.
# This provides the necessary CUDA, PyTorch, and Python environment.
FROM nvcr.io/nvidia/vllm:25.09-py3

WORKDIR /workspace 

RUN git clone https://github.com/vllm-project/vllm.git && \
  cd vllm && \
  python3 use_existing_torch.py && \
  pip3 install -r requirements/build.txt && \
  MAX_JOBS=12 pip3 install -e . --no-build-isolation

# simple entrypoint for interactive use
CMD ["/bin/bash"]

Why not try the NVIDIA/NGC image?

I cannot use nvcr.io/nvidia/vllm:25.09-py3 because this uses an older version of vLLM that does not support GGUF-loading.

Why not use the vLLM:nightly Docker image?

I cannot use the vLLM nightly image because it does not recognize this GPU.

HF_TOKEN=...
DOCKER_IMAGE="vllm/vllm-openai:nightly"
MODEL="google/gemma-3-4b-it"
MODEL_NAME="gemma-3-4b-it-q4_0"
CHECKPOINT="/root/.cache/huggingface/gemma-3-4b-it-q4_0.gguf"

docker run --gpus all --rm -it \
  -v ~/.cache/huggingface:/root/.cache/huggingface  \
  -p 8000:8000 \
  --env "HF_TOKEN=${HF_TOKEN}" \
  --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
  ${DOCKER_IMAGE} --model ${CHECKPOINT} --tokenizer ${MODEL} --hf-overrides '{"architectures": ["Gemma3ForCausalLM"]}' --served-model-name ${MODEL_NAME} 

Error:

  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/wrapper.py", line 1695, in generate_and_run_autotune_block
    raise RuntimeError(f"Failed to run autotuning code block: {e}") from e
torch._inductor.exc.InductorError: RuntimeError: Failed to run autotuning code block: No valid triton configs. PTXASError: PTXAS error: Internal Triton PTX codegen error
`ptxas` stderr:
ptxas fatal   : Value 'sm_121a' is not defined for option 'gpu-name'

Repro command: /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas -lineinfo -v --gpu-name=sm_121a /tmp/tmp_kapypnl.ptx -o /tmp/tmp_kapypnl.ptx.o

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions