Skip to content

Commit 86c6239

Browse files
committed
HACK: workaround triton kernel compile failure
With: ``` $ VLLM_LOGGING_LEVEL=DEBUG VLLM_USE_V1=1 VLLM_ENABLE_V1_MULTIPROCESSING=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-model "[ngram]" --num-speculative-tokens 5 --ngram-prompt-lookup-max 4 ``` I'm getting: ``` File "/home/markmc/vllm-project/vllm/vllm/v1/engine/core.py", line 121, in _initialize_kv_caches available_gpu_memory = self.model_executor.determine_available_memory() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm/vllm/v1/executor/abstract.py", line 66, in determine_available_memory output = self.collective_rpc("determine_available_memory") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm/vllm/executor/uniproc_executor.py", line 56, in collective_rpc answer = run_method(self.driver_worker, method, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm/vllm/utils.py", line 2216, in run_method return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm-venv/lib64/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm/vllm/v1/worker/gpu_worker.py", line 157, in determine_available_memory self.model_runner.profile_run() File "/home/markmc/vllm-project/vllm/vllm/v1/worker/gpu_model_runner.py", line 1466, in profile_run sampler_output = self._dummy_sampler_run(hidden_states) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm-venv/lib64/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm/vllm/v1/worker/gpu_model_runner.py", line 1375, in _dummy_sampler_run self.rejection_sampler( File "/home/markmc/vllm-project/vllm-venv/lib64/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm-venv/lib64/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm/vllm/v1/sample/rejection_sampler.py", line 92, in forward output_token_ids = rejection_sample( ^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm/vllm/v1/sample/rejection_sampler.py", line 166, in rejection_sample rejection_greedy_sample_kernel[(batch_size, )]( File "/home/markmc/vllm-project/vllm-venv/lib64/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm-venv/lib64/python3.12/site-packages/triton/runtime/jit.py", line 662, in run kernel = self.compile( ^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm-venv/lib64/python3.12/site-packages/triton/compiler/compiler.py", line 276, in compile module = src.make_ir(options, codegen_fns, context) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/markmc/vllm-project/vllm-venv/lib64/python3.12/site-packages/triton/compiler/compiler.py", line 113, in make_ir return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ triton.compiler.errors.CompilationError: at 17:7: bonus_token_ids_ptr, # [batch_size] is_greedy_ptr, # [batch_size] or None max_spec_len, ): req_idx = tl.program_id(0) # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, # re-compilation may happen during runtime when is_greedy_ptr is None. if is_greedy_ptr is None: is_greedy = True else: is_greedy = tl.load(is_greedy_ptr + req_idx) if not is_greedy: ^ ValueError('Cannot bitcast data-type of size 8 to data-type of size 1') ``` Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 61f4121 commit 86c6239

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vllm/v1/sample/rejection_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def rejection_greedy_sample_kernel(
375375
if is_greedy_ptr is None:
376376
is_greedy = True
377377
else:
378-
is_greedy = tl.load(is_greedy_ptr + req_idx)
378+
is_greedy = tl.load(is_greedy_ptr + req_idx).to(tl.int1)
379379
if not is_greedy:
380380
# Early exit for non-greedy sampling requests.
381381
return

0 commit comments

Comments
 (0)