Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist

RUN --mount=type=cache,target=/root/.cache/uv \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
uv pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
uv pip install flashinfer-python==0.2.3 -i https://flashinfer.ai/whl/cu124/torch2.6 ; \
fi
COPY examples examples

Expand Down
28 changes: 28 additions & 0 deletions examples/deepseek-basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(
model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True,
#enforce_eager=True,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
99 changes: 99 additions & 0 deletions examples/deepseek-chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0

from vllm import LLM, SamplingParams

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

llm = LLM(
model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True,
)
sampling_params = SamplingParams(temperature=0.5)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# def print_outputs(outputs):
# for output in outputs:
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# print("-" * 80)

# print("=" * 80)

# # In this script, we demonstrate how to pass input to the chat method:

# conversation = [
# {
# "role": "system",
# "content": "You are a helpful assistant"
# },
# {
# "role": "user",
# "content": "Hello"
# },
# {
# "role": "assistant",
# "content": "Hello! How can I assist you today?"
# },
# {
# "role": "user",
# "content": "Write an essay about the importance of higher education.",
# },
# ]
# outputs = llm.chat(conversation,
# sampling_params=sampling_params,
# use_tqdm=False)
# print_outputs(outputs)

# # You can run batch inference with llm.chat API
# conversation = [
# {
# "role": "system",
# "content": "You are a helpful assistant"
# },
# {
# "role": "user",
# "content": "Hello"
# },
# {
# "role": "assistant",
# "content": "Hello! How can I assist you today?"
# },
# {
# "role": "user",
# "content": "Write an essay about the importance of higher education.",
# },
# ]
# conversations = [conversation for _ in range(10)]

# # We turn on tqdm progress bar to verify it's indeed running batch inference
# outputs = llm.chat(messages=conversations,
# sampling_params=sampling_params,
# use_tqdm=True)
# print_outputs(outputs)

# A chat template can be optionally supplied.
# If not, the model will use its default chat template.

# with open('template_falcon_180b.jinja', "r") as f:
# chat_template = f.read()

# outputs = llm.chat(
# conversations,
# sampling_params=sampling_params,
# use_tqdm=False,
# chat_template=chat_template,
# )
93 changes: 10 additions & 83 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
from vllm.attention.backends.utils import (PAD_SLOT_ID, PerLayerParameters,
compute_slot_mapping,
compute_slot_mapping_start_idx,
get_fp8_dtype_for_flashinfer,
infer_global_hyperparameters,
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import get_current_vllm_config
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)

Expand Down Expand Up @@ -96,81 +98,6 @@
def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]

@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")


@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""

window_left: int
logits_soft_cap: Optional[float]
sm_scale: float


def get_per_layer_parameters(
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""

layers = vllm_config.compilation_config.static_forward_context
per_layer_params: Dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
assert isinstance(layer, Attention)

impl = layer.impl
assert isinstance(impl, FlashInferImpl)

# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale

per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)

return per_layer_params


def infer_global_hyperparameters(
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`

So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""

assert len(per_layer_params) > 0, "No attention layers found in the model."

param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")

return global_params


class FlashInferState(AttentionState):

Expand Down Expand Up @@ -274,7 +201,7 @@
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype = get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
else:
kv_cache_dtype = get_kv_cache_torch_dtype(
Expand All @@ -293,8 +220,8 @@
batch_size + 1,
dtype=torch.int32)

global_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
global_params = infer_global_hyperparameters(self.vllm_config,
FlashInferImpl)

attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
Expand Down Expand Up @@ -652,7 +579,7 @@
# - `logits_soft_cap`
# - `sm_scale`
inferred_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
self.vllm_config, FlashInferImpl)
self.global_hyperparameters = inferred_params
self.window_left = inferred_params.window_left
self.logits_soft_cap = inferred_params.logits_soft_cap
Expand Down Expand Up @@ -853,7 +780,7 @@
block_table_bound_tensor = None

if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype = get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
else:
kv_cache_dtype = get_kv_cache_torch_dtype(
Expand Down Expand Up @@ -970,7 +897,7 @@
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(

Check failure on line 900 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

"type[FlashInferBackend]" has no attribute "get_fp8_dtype_for_flashinfer" [attr-defined]
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)

Expand Down
Loading