You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
model_str="neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
model = LLM(model=model_str, quantization="fp8",kv_cache_dtype="fp8")
params = SamplingParams(temperature=0)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"New york times is politically sided to ",
"The future holds infinite "
]
result = model.generate(prompts=prompts, sampling_params=params)
for output in result:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"\n\n Prompt: {prompt!r}, \nGenerated text: {generated_text!r}, \ntoken_ids: {output.outputs[0].token_ids}"
)
and the execution:
VLLM_ATTENTION_BACKEND=FLASHINFER /bin/python3 /workspace/vllm_github/test_llm.py
root@s4124-0013:/workspace/vllm_github# VLLM_ATTENTION_BACKEND=FLASHINFER /bin/python3 /workspace/vllm_github/test_llm.py
INFO 08-29 19:17:01 config.py:628] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor
INFO 08-29 19:17:01 llm_engine.py:210] Initializing an LLM engine (v0.5.5) with config: model='neuralmagic/Meta-Llama-3-8B-Instruct-FP8', speculative_config=None, tokenizer='neuralmagic/Meta-Llama-3-8B-Instruct-FP8', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=fp8, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=neuralmagic/Meta-Llama-3-8B-Instruct-FP8, use_v2_block_manager=False, num_scheduler_steps=1, enable_prefix_caching=False, use_async_output_proc=True)
INFO 08-29 19:17:08 selector.py:142] Using Flashinfer backend.
INFO 08-29 19:17:12 model_runner.py:906] Starting to load model neuralmagic/Meta-Llama-3-8B-Instruct-FP8...
WARNING 08-29 19:17:13 fp8.py:47] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
INFO 08-29 19:17:13 selector.py:142] Using Flashinfer backend.
INFO 08-29 19:17:14 weight_utils.py:236] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:00<00:00, 2.46it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 2.49it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 2.49it/s]
WARNING 08-29 19:17:15 utils.py:722] Using KV cache scaling factor 1.0 for fp8_e4m3. This may cause accuracy issues. Please make sure k/v_scale scaling factors are available in the fp8 checkpoint.
INFO 08-29 19:17:15 model_runner.py:917] Loading model weights took 8.4596 GB
INFO 08-29 19:17:15 gpu_executor.py:121] # GPU blocks: 47349, # CPU blocks: 4096
INFO 08-29 19:17:16 model_runner.py:1208] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 08-29 19:17:16 model_runner.py:1212] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 08-29 19:17:25 model_runner.py:1327] Graph capturing finished in 8 secs.
Processed prompts: 0%| | 0/5 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/vllm_github/test_llm.py", line 21, in <module>
[rank0]: result = model.generate(prompts=prompts, sampling_params=params)
[rank0]: File "/workspace/vllm_github/vllm/utils.py", line 1031, in inner
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/workspace/vllm_github/vllm/entrypoints/llm.py", line 347, in generate
[rank0]: outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]: File "/workspace/vllm_github/vllm/entrypoints/llm.py", line 697, in _run_engine
[rank0]: step_outputs = self.llm_engine.step()
[rank0]: File "/workspace/vllm_github/vllm/engine/llm_engine.py", line 1511, in step
[rank0]: output = self.model_executor.execute_model(
[rank0]: File "/workspace/vllm_github/vllm/executor/gpu_executor.py", line 129, in execute_model
[rank0]: output = self.driver_worker.execute_model(execute_model_req)
[rank0]: File "/workspace/vllm_github/vllm/worker/worker_base.py", line 327, in execute_model
[rank0]: output = self.model_runner.execute_model(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/workspace/vllm_github/vllm/worker/model_runner.py", line 1414, in execute_model
[rank0]: self.attn_state.begin_forward(model_input)
[rank0]: File "/workspace/vllm_github/vllm/attention/backends/flashinfer.py", line 251, in begin_forward
[rank0]: model_input.attn_metadata.begin_forward()
[rank0]: File "/workspace/vllm_github/vllm/attention/backends/flashinfer.py", line 346, in begin_forward
[rank0]: self.decode_wrapper.begin_forward(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/flashinfer/decode.py", line 539, in begin_forward
[rank0]: self._wrapper.begin_forward(
[rank0]: RuntimeError: BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(at::Tensor, at::Tensor, at::Tensor, at::Tensor, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float, at::Tensor, at::Tensor)::<lambda()> failed to dispatch data type Byte
Processed prompts: 0%|
**Removing the data_type input to decode_wrapper, seems to temporarily solve the issue until we can consistently predict through vLLM the expected data_type to be uint8 or fp8 while building attn_metadata for Flashinfer **
Before submitting a new issue...
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
Previous reference: https://github.com/vllm-project/vllm/pull/7985/files/26904dd78495ad1b18e43d9e52ee62e05cb71d04#r1736922768
Issue:
With this configuration and test:
and the execution:
**Removing the data_type input to decode_wrapper, seems to temporarily solve the issue until we can consistently predict through vLLM the expected
data_type
to beuint8
orfp8
while building attn_metadata for Flashinfer **Before submitting a new issue...
The text was updated successfully, but these errors were encountered: