From 5a48d417f5f886d823dd93f2229da2566d6bfe14 Mon Sep 17 00:00:00 2001 From: Oleg Mosalov Date: Wed, 20 Nov 2024 15:43:24 +0000 Subject: [PATCH 001/317] A simple test to compare named_modules for a base model before and after loading a LoRA adapter. Signed-off-by: Oleg Mosalov --- tests/lora/test_load_lora_adapter.py | 74 ++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/lora/test_load_lora_adapter.py diff --git a/tests/lora/test_load_lora_adapter.py b/tests/lora/test_load_lora_adapter.py new file mode 100644 index 000000000000..d5958e989173 --- /dev/null +++ b/tests/lora/test_load_lora_adapter.py @@ -0,0 +1,74 @@ +from vllm import LLM +from vllm.lora.request import LoRARequest +import os + +def extract_layer_names(llm): + engine = getattr(llm, "llm_engine") + model_executor = getattr(engine, "model_executor") + driver_worker = getattr(model_executor, "driver_worker") + model_runner = getattr(driver_worker, "model_runner") + return [name for name, _ in model_runner.model.named_modules()] + +def load_base_model(base_model_path): + print(f"Loading base model from {base_model_path}...") + llm = LLM(model=base_model_path, enable_lora=True) + print("Base model loaded.") + return llm + +def load_lora_adapter(llm, lora_path): + print(f"Loading LoRA adapter from {lora_path}...") + lora_request = LoRARequest("lora_adapter", 1, lora_path) + print("LoRA adapter loaded.") + print("Sending a dummy request.") + prompt = "Hi!" + output = llm.generate(prompt, lora_request=lora_request) + print("The request is sent.") + return llm + +def compare_layers(base_layers, lora_layers): + print("Comparing layers...") + base_set = set(base_layers) + lora_set = set(lora_layers) + + added_layers = lora_set - base_set + removed_layers = base_set - lora_set + + #print("Base model layers:") + #for layer in base_set: + # print(f" {layer}") + + if added_layers or removed_layers: + print("Layer differences detected:") + if added_layers: + print(" Layers added by LoRA:") + for layer in added_layers: + print(f" {layer}") + if removed_layers: + print(" Layers removed after LoRA:") + for layer in removed_layers: + print(f" {layer}") + return True + else: + print("No differences in layers detected.") + return False + +def main(): + base_model_path = "/data/llama-3/llama-3-8b" + lora_adapter_path = "/home/oleg/lora_test/Meta-Llama-3-8B-oasst-Adapter" + + if not os.path.exists(base_model_path): + raise FileNotFoundError(f"Base model path not found: {base_model_path}") + if not os.path.exists(lora_adapter_path): + raise FileNotFoundError(f"LoRA adapter path not found: {lora_adapter_path}") + + base_model = load_base_model(base_model_path) + base_layers = extract_layer_names(base_model) + + model_with_lora = load_lora_adapter(base_model, lora_adapter_path) + lora_layers = extract_layer_names(model_with_lora) + + compare_layers(base_layers, lora_layers) + +if __name__ == "__main__": + main() + From 94bfd282fc4cb01af53074ce6a9a27972657e214 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 20 Nov 2024 19:20:21 +0000 Subject: [PATCH 002/317] Added non-triton SGMV and BGMV ops (not kernels yet) Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 1 + vllm/lora/ops/__init__.py | 0 vllm/lora/ops/xla/lora_ops.py | 119 ++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+) delete mode 100644 vllm/lora/ops/__init__.py create mode 100644 vllm/lora/ops/xla/lora_ops.py diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7f68dae9717c..7804780e015f 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1096,6 +1096,7 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits + print("punica", logits.dtype) # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, self.lora_a_stacked, diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/lora/ops/xla/lora_ops.py b/vllm/lora/ops/xla/lora_ops.py new file mode 100644 index 000000000000..8ad32dd4a77b --- /dev/null +++ b/vllm/lora/ops/xla/lora_ops.py @@ -0,0 +1,119 @@ +import torch + +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_expand( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + add_inputs + ) + + +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + inputs = inputs.to(dtype=torch.float16) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:] += outputs[:] + else: + output_tensor[:] = outputs[:] + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_shrink( + inputs, + lora_a_weights, + output_tensor, + exploded_indices, + scaling + ) + +def bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0 +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:] = scaling * outputs[:] + +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_expand_slice( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + slice_offset, + slice_size, + add_inputs + ) + + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + inputs = inputs.to(dtype=torch.float16) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset+slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset+slice_size] = outputs[:] \ No newline at end of file From cf5b5c5748da19df5b9ff75a3171ae3b13b33d93 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 20 Nov 2024 19:21:03 +0000 Subject: [PATCH 003/317] Made a copy of the layer tests for the TPU. TODO: DRY it out Signed-off-by: Akshat Tripathi --- tests/lora/conftest.py | 2 +- tests/lora/test_layers_tpu.py | 1220 +++++++++++++++++++++++++++++++++ 2 files changed, 1221 insertions(+), 1 deletion(-) create mode 100644 tests/lora/test_layers_tpu.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 5ea66518b411..7baa632f5bff 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -70,7 +70,7 @@ def dist_init(): temp_file = tempfile.mkstemp()[1] backend = "nccl" - if current_platform.is_cpu(): + if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" init_distributed_environment(world_size=1, diff --git a/tests/lora/test_layers_tpu.py b/tests/lora/test_layers_tpu.py new file mode 100644 index 000000000000..29f732c621af --- /dev/null +++ b/tests/lora/test_layers_tpu.py @@ -0,0 +1,1220 @@ +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple +from unittest.mock import patch + +import pytest +import torch +import torch.nn.functional as F + +from vllm.config import LoRAConfig +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LogitsProcessorWithLoRA, LoRAMapping, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, + PackedLoRALayerWeights) +from vllm.lora.punica import PunicaWrapper +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) +from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform + +from .utils import DummyLoRAManager + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} +TPU_DEVICES = [ + f"xla:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +# We will launch different triton kernels between the prefill and decode +# stages, so we need to verify this. prefill stage(True) or decode stage(False) +STAGES = [True, False] + + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots, device="cpu")[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + id_to_index: List[Optional[int]], + layer: BaseLayerWithLoRA, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. + sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras: List[LoRALayerWeights] = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager( + layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, + device: torch.device = "xla" +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs: List[torch.Tensor] = [] + index_mapping: List[int] = [] + prompt_mapping: List[int] = [] + + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device=device)) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device=device) * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +@pytest.mark.parametrize("stage", STAGES) +def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: + torch.set_default_device(device) + max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[vocab_size:, :] = 0 + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + lora_embedding.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + + lora_result = lora_embedding(torch.cat(inputs)) + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = embedding(input_) + after_a = F.embedding( + input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + + lora_result = lora_embedding(torch.cat(inputs)) + expected_result = embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +# @pytest.mark.skip( +# reason="Fails when loras are in any slot other than the first.") +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +@pytest.mark.parametrize("stage", STAGES) +def test_embeddings_with_new_embeddings(dist_init, num_loras, device, + vocab_size, stage) -> None: + + torch.set_default_device(device) + max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding_data = torch.rand_like(embedding.weight.data) + embedding.weight.data = embedding_data + embedding.weight.data[vocab_size:, :] = 0 + expanded_embedding = VocabParallelEmbedding( + vocab_size + lora_config.lora_extra_vocab_size * max_loras, + 256, + org_num_embeddings=vocab_size) + expanded_embedding.weight.data[:vocab_size, :] = embedding_data + # We need to deepcopy the embedding as it will be modified + # in place + lora_embedding = VocabParallelEmbeddingWithLoRA( + deepcopy(expanded_embedding)) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return expanded_embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + expanded_embedding, lora_embedding = create_random_embedding_layer() + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=torch.zeros( + (256, vocab_size + lora_config.lora_extra_vocab_size)), + generate_embeddings_tensor=256, + ) + + lora_embedding.set_mapping(punica_wrapper) + # All embeddings tensors have the same shape. + embeddings_tensors = [ + lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) + ] + embeddings_tensor_len = embeddings_tensors[0].shape[0] + + # Add empty embeddings_tensors for unoccupied lora slots. + for _ in range(max_loras - len(embeddings_tensors)): + embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + original_inputs = deepcopy(inputs) + + # Force some of the inputs to be in the extended embeddings range + # to guarantee that their behavior is tested. + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + embedding_id = lora_id - 1 + input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) + original_input_[-1] = vocab_size + input_[-2] = vocab_size + ( + (embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = vocab_size + embeddings_tensor_len - 1 + + expanded_embedding.weight[vocab_size:vocab_size + + (embeddings_tensor_len * + max_loras)] = torch.cat(embeddings_tensors) + + lora_result = lora_embedding(torch.cat(original_inputs)) + + expected_results: List[torch.Tensor] = [] + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + lora = lora_dict[lora_id] + result = expanded_embedding(input_) + after_a = F.embedding( + original_input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + original_inputs = deepcopy(inputs) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + lora_result = lora_embedding(torch.cat(original_inputs)) + expected_result = expanded_embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) +@pytest.mark.parametrize("stage", STAGES) +def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, + stage) -> None: + + torch.set_default_device(device) + max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def _pretest(): + linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, + 1024, + vocab_size, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + linear.weight.data[:, vocab_size:] = 0 + logits_processor = LogitsProcessor( + vocab_size + lora_config.lora_extra_vocab_size, vocab_size) + lora_logits_processor = LogitsProcessorWithLoRA( + logits_processor, 1024, linear.weight.dtype, linear.weight.device, + None) + lora_logits_processor.create_lora_weights(max_loras, lora_config) + + return linear, logits_processor, lora_logits_processor + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, logits_processor, lora_logits_processor = _pretest() + lora_logits_processor.set_mapping(punica_wrapper) + # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_logits_processor, + layer_weights=linear.weight, + generate_embeddings_tensor=1024, + ) + embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor + embeddings_tensor_len = embeddings_tensor.shape[0] + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=8 * num_loras, # * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) + input_ = torch.rand(20, 1024, dtype=torch.float16) + + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=linear, + embedding_bias=None) + + original_lm_head = deepcopy(linear) + + linear.weight[logits_processor. + org_vocab_size:logits_processor.org_vocab_size + + embeddings_tensor_len] = embeddings_tensor + + logits_processor.org_vocab_size = (vocab_size + + lora_config.lora_extra_vocab_size) + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = logits_processor._get_logits(hidden_states=input_, + lm_head=linear, + embedding_bias=None) + result[:, vocab_size + embeddings_tensor_len:] = float("-inf") + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + logits_processor.org_vocab_size = vocab_size + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_logits_processor.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=8 * num_loras * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=original_lm_head, + embedding_bias=None)[:, :vocab_size] + expected_result = logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=original_lm_head, + embedding_bias=None) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_linear_replicated(dist_init, num_loras, device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_linear_replicated_layer(): + + linear = ReplicatedLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = ReplicatedLinearWithLoRA(linear) + + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_replicated_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("orientation", ["row", "column"]) +@pytest.mark.parametrize("fully_shard", [True, False]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, + device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16) + + def create_random_linear_parallel_layer(): + if orientation == "row": + linear = RowParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard + else RowParallelLinearWithShardedLoRA(linear)) + else: + linear = ColumnParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = (ColumnParallelLinearWithLoRA(linear) + if not fully_shard else + ColumnParallelLinearWithShardedLoRA(linear)) + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_parallel_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("repeats", [1, 2, 3]) +@pytest.mark.parametrize("fully_shard", [True, False]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16) + + def create_column_parallel_packed_layer(): + if repeats == 2: + linear = MergedColumnParallelLinear(4096, [4096] * repeats, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = (MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard else + MergedColumnParallelLinearWithShardedLoRA(linear)) + elif repeats == 3: + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = (MergedQKVParallelLinearWithLora(linear) + if not fully_shard else + MergedQKVParallelLinearWithShardedLora(linear)) + else: + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = QKVParallelLinearWithLora( + linear + ) if not fully_shard else QKVParallelLinearWithShardedLora(linear) + + @dataclass + class FakeConfig: + hidden_size = 4096 + num_key_value_heads = 32 + num_attention_heads = 32 + + lora_linear.create_lora_weights(max_loras, + lora_config, + model_config=FakeConfig()) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + + linear, lora_linear = create_column_parallel_packed_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, sublora_dict = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + repeats=repeats, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * + (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * + sublora.scaling) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + # lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 8]) +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0), + (6.0, 1.0)]) +@pytest.mark.parametrize("max_position", [11, 4096, 32768]) +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("rotary_dim", [None, 32]) +@pytest.mark.parametrize("head_size", [32, 108]) +@pytest.mark.parametrize("seq_len", [11, 1024]) +def test_rotary_embedding_long_context(dist_init, num_loras, device, + scaling_factors, max_position, + is_neox_style, rotary_dim, head_size, + seq_len) -> None: + dtype = torch.float16 + seed = 0 + current_platform.seed_everything(seed) + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + long_lora_scaling_factors=scaling_factors, + lora_dtype=dtype) + + if rotary_dim is None: + rotary_dim = head_size + base = 10000 + batch_size = 5 * num_loras + num_heads = 7 + + # Verify lora is equivalent to linear scaling rotary embedding. + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + ) + lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) + lora_rope.set_mapping(punica_wrapper) + lora_rope.create_lora_weights(max_loras, lora_config) + linear_rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, { + "rope_type": "linear", + "factor": scaling_factors + }) + linear_rope = linear_rope.to(dtype=dtype) + id_to_index = get_random_id_to_index(num_loras, max_loras) + _, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=batch_size, + input_size=(1, max_position), + input_range=(0, lora_config.lora_extra_vocab_size), + input_type=torch.float16, + device=device) + + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + long_lora_context = LongContextLoRAContext(list(scaling_factors), + rotary_dim) + + next_expected_offset = 0 + # Make sure the offset is correct. + scaling_factor_to_offset = lora_rope.scaling_factor_to_offset + for scaling_factor, offset in scaling_factor_to_offset.items(): + assert offset == next_expected_offset + next_expected_offset += scaling_factor * max_position + + for i in range(len(scaling_factors)): + long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( + scaling_factors[i], 0) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + long_lora_context=long_lora_context, + ) + # lora_rope.set_mapping(*mapping_info) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + ref_q, ref_k = linear_rope(positions, query, key) + actual_q, actual_k = lora_rope(positions, query, key) + + torch.allclose(ref_q, actual_q) + torch.allclose(ref_k, actual_k) + + +@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("seed", list(range(256))) +def test_vocab_parallel_embedding_indices(tp_size, seed): + random.seed(seed) + vocab_size = random.randint(4000, 64000) + added_vocab_size = random.randint(0, 1024) + org_vocab_size = vocab_size - added_vocab_size + last_org_vocab_end_index = 0 + last_added_vocab_end_index = org_vocab_size + computed_vocab_size = 0 + computed_org_vocab_size = 0 + computed_added_vocab_size = 0 + vocab_size_padded = -1 + + all_org_tokens: List[int] = [] + all_added_tokens: List[int] = [] + token_ids: List[int] = [] + + for tp_rank in range(tp_size): + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", + return_value=tp_rank + ), patch( + "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", + return_value=tp_size): + vocab_embedding = VocabParallelEmbedding( + vocab_size, 1, org_num_embeddings=org_vocab_size) + vocab_size_padded = vocab_embedding.num_embeddings_padded + shard_indices = vocab_embedding.shard_indices + # Assert that the ranges are contiguous + assert shard_indices.org_vocab_start_index == last_org_vocab_end_index + assert (shard_indices.added_vocab_start_index == + last_added_vocab_end_index) + + # Ensure that we are not exceeding the vocab size + computed_vocab_size += shard_indices.num_elements_padded + computed_org_vocab_size += shard_indices.num_org_elements + computed_added_vocab_size += shard_indices.num_added_elements + + # Ensure that the ranges are not overlapping + all_org_tokens.extend( + range(shard_indices.org_vocab_start_index, + shard_indices.org_vocab_end_index)) + all_added_tokens.extend( + range(shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index)) + + token_ids.extend( + range(shard_indices.org_vocab_start_index, + shard_indices.org_vocab_end_index)) + token_ids.extend([-1] * (shard_indices.num_org_elements_padded - + shard_indices.num_org_elements)) + token_ids.extend( + range(shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index)) + token_ids.extend([-1] * (shard_indices.num_added_elements_padded - + shard_indices.num_added_elements)) + + last_org_vocab_end_index = shard_indices.org_vocab_end_index + last_added_vocab_end_index = shard_indices.added_vocab_end_index + + assert computed_vocab_size == vocab_size_padded + assert computed_org_vocab_size == org_vocab_size + assert computed_added_vocab_size == added_vocab_size + + # Ensure that the ranges are not overlapping + assert len(all_org_tokens) == len(set(all_org_tokens)) + assert len(all_added_tokens) == len(set(all_added_tokens)) + assert not set(all_org_tokens).intersection(set(all_added_tokens)) + + token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) + reindex_mapping = vocab_embedding.get_sharded_to_full_mapping() + assert reindex_mapping is not None or tp_size == 1 + if reindex_mapping is not None: + reindexed_token_ids = token_ids_tensor[reindex_mapping] + expected = torch.tensor(list(range(0, vocab_size))) + assert reindexed_token_ids[:vocab_size].equal(expected) + assert torch.all(reindexed_token_ids[vocab_size:] == -1) + + +def test_get_masked_input_and_mask(): + x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + + # base tp 1 case, no padding + modified_x, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(x, modified_x) + + # tp 2 case, no padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=0) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=8, + added_vocab_start_index=10, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) + + # tp 4 case, no padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=0) + modified_x_rank_1, _ = get_masked_input_and_mask(x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=0) + modified_x_rank_2, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=6, + added_vocab_start_index=10, + added_vocab_end_index=11, + num_org_vocab_padding=0) + modified_x_rank_3, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=6, + org_vocab_end_index=8, + added_vocab_start_index=11, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) + assert torch.equal(modified_x_rank_2, + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) + assert torch.equal(modified_x_rank_3, + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) + + # base tp 1 case, with padding + modified_x, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x, + torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) + + # tp 2 case, with padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=2) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=8, + added_vocab_start_index=10, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) + + # tp 4 case, with padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=2) + modified_x_rank_1, _ = get_masked_input_and_mask(x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=2) + modified_x_rank_2, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=6, + added_vocab_start_index=10, + added_vocab_end_index=11, + num_org_vocab_padding=2) + modified_x_rank_3, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=6, + org_vocab_end_index=8, + added_vocab_start_index=11, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) + assert torch.equal(modified_x_rank_2, + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) + assert torch.equal(modified_x_rank_3, + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) From 1c97a908191c3da4f8f51c45e6c5baa35766d9ca Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 21 Nov 2024 11:40:53 +0000 Subject: [PATCH 004/317] Removed extra print Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7804780e015f..7f68dae9717c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1096,7 +1096,6 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - print("punica", logits.dtype) # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, self.lora_a_stacked, From e1cdb1d94774e749afe3fc843851dbe9e276728e Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 22 Nov 2024 12:27:02 +0000 Subject: [PATCH 005/317] Made some minor shape-based fixes to the kernels Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla/lora_ops.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/lora/ops/xla/lora_ops.py b/vllm/lora/ops/xla/lora_ops.py index 8ad32dd4a77b..cd12f3659f47 100644 --- a/vllm/lora/ops/xla/lora_ops.py +++ b/vllm/lora/ops/xla/lora_ops.py @@ -30,14 +30,18 @@ def bgmv_expand( lora_indices_tensor: torch.Tensor, add_inputs: bool = True ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() - inputs = inputs.to(dtype=torch.float16) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + if add_inputs: - output_tensor[:] += outputs[:] + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] else: - output_tensor[:] = outputs[:] + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] def sgmv_shrink( inputs: torch.Tensor, @@ -68,10 +72,10 @@ def bgmv_shrink( lora_indices_tensor: torch.Tensor, scaling: float = 1.0 ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - output_tensor[:] = scaling * outputs[:] + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] def sgmv_expand_slice( inputs: torch.Tensor, @@ -109,8 +113,8 @@ def bgmv_expand_slice( slice_size: int, add_inputs: bool = True ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() - inputs = inputs.to(dtype=torch.float16) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) if add_inputs: From a59451ff0fc202db2a84903d81869123f7e52aa8 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 22 Nov 2024 15:11:44 +0000 Subject: [PATCH 006/317] Added basic lora execution code Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 86 ++++++++++++++++++++++++++++++--- vllm/worker/tpu_worker.py | 20 ++++++-- 2 files changed, 96 insertions(+), 10 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index ecdf7aa88896..aa376529d050 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import enum import time from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Set, Type, Union) from unittest.mock import patch @@ -17,8 +17,12 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_lora from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) @@ -62,6 +66,8 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int n: List[int] seq_groups: List[List[int]] + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None is_first_multi_step: bool = True is_last_step: bool = True virtual_engine: int = 0 @@ -72,6 +78,8 @@ def as_broadcastable_tensor_dict( tensor_dict = { "token_ids": self.token_ids, "position_ids": self.position_ids, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, "input_lens": self.input_lens, "t": self.t, "p": self.p, @@ -122,6 +130,9 @@ def __init__( False, ) self.cached_step_outputs: List[torch.Tensor] = [] + + # LoRA support + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None smem_size = 512 * 1024 block_table_size = 4 * self.block_tables.size @@ -154,16 +165,37 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - model = ModelWrapper(model) + self.model = model + + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + max_pos_embeddings = self.model.config.max_position_embeddings + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.model_config.get_vocab_size(), + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + self.model = ModelWrapper(self.model) self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) + backend="openxla", + fullgraph=True, + dynamic=False) def get_model(self) -> nn.Module: return self.model.model - def _dummy_run( + def _dummy_run( # KRAI-TODO: Add lora config here self, batch_size: int, seq_len: int, @@ -600,6 +632,15 @@ def execute_model( num_steps: int = 1, ) -> List[SamplerOutput]: assert intermediate_tensors is None + + print(f"\e[0;31m SELF LORA CONFIG {self.lora_config} \033[0m") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + if not model_input.is_first_multi_step: if not model_input.is_last_step: return [] @@ -765,7 +806,38 @@ def execute_model( sampler_output = _make_decode_output(next_token_ids, model_input.seq_groups) return [sampler_output] - + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + print("\e[0;31mSetting active loras\033[0m") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() class ModelWrapper(nn.Module): diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 12f10169f2db..a698040d98e4 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Set import torch import torch_xla.core.xla_model as xm @@ -12,18 +12,19 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerBase, + WorkerBase, WorkerInput) logger = init_logger(__name__) -class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): +class TPUWorker(LocalOrDistributedWorkerBase): def __init__( self, @@ -84,6 +85,7 @@ def init_device(self) -> None: # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and # 30-40 graphs for decode. 128 is an arbitrary safe number. torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.reorderable_logging_functions = set([print]) # Use persistent cache to avoid XLA recompilation. # NOTE(woosuk): Set per-rank cache path since different ranks # can have slightly different XLA graphs. @@ -265,6 +267,18 @@ def execute_worker(self, worker_input: WorkerInput) -> None: if src_indices.numel() > 0: attn_backend.copy_blocks(self.tpu_cache, (src_indices, dst_indices)) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() def _make_src_to_dst( From 968ae739d6708e421257fdd4ab95ec75c74f5f5e Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:00:04 +0000 Subject: [PATCH 007/317] Replaced einsums with matmuls+reshaping for better xla compilation Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla/lora_ops.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/xla/lora_ops.py b/vllm/lora/ops/xla/lora_ops.py index cd12f3659f47..51167ddf1b6b 100644 --- a/vllm/lora/ops/xla/lora_ops.py +++ b/vllm/lora/ops/xla/lora_ops.py @@ -32,8 +32,11 @@ def bgmv_expand( ): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1 @@ -73,7 +76,10 @@ def bgmv_shrink( scaling: float = 1.0 ): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -115,7 +121,11 @@ def bgmv_expand_slice( ): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + if add_inputs: output_tensor[:, slice_offset:slice_offset+slice_size] += outputs[:] From 8945217cb4e2060706bc595ededfbc817ae6cbc1 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:02:26 +0000 Subject: [PATCH 008/317] Replaced inf/-inf with max/min since XLA doesn't allow `nan_to_num_()` to be called with infinities Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7f68dae9717c..f34262d9f4b1 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1081,12 +1081,18 @@ def _get_logits( lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded + + # KRAI: Temporary change + neg_inf = torch.finfo(lora_logits.dtype).min + pos_inf = torch.finfo(lora_logits.dtype).max + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), - posinf=float("inf"), - neginf=float("-inf"))) + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) + print(f"AKSHAT - After index select: {lora_logits.shape}, {indices_padded.shape}") # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): From 29be82d95be54f77394ac39e04eb64770de5c53e Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:03:37 +0000 Subject: [PATCH 009/317] Added lora config to `_dummy_run()` Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index aa376529d050..c31d7e896ad1 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -44,7 +44,7 @@ # FIXME(woosuk): A temporary hack to support `n > 1`. # This can significantly affect the performance if too large. _MAX_NUM_SAMPLES = 128 - +LORA_WARMUP_RANK = 8 # KRAI: TODO: Should this not be max rank - so we have better startup times? class ExecutionMode(enum.Enum): PREFILL = enum.auto() @@ -54,7 +54,6 @@ class ExecutionMode(enum.Enum): def is_prefill(self) -> bool: return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) - @dataclass(frozen=True) class ModelInputForTPU(ModelRunnerInputBase): token_ids: torch.Tensor @@ -282,6 +281,27 @@ def _dummy_run( # KRAI-TODO: Add lora config here t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 + + # Create a series of dummy loras and requests for them. Make to fill all lora slots. + if self.lora_config: + dummy_lora_requests: Set[LoRARequest] = set() + dummy_lora_mapping: LoRAMapping + + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for lora_id in range(1, self.lora_config.max_loras + 1): + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.add(dummy_lora_request) + dummy_lora_mapping = LoRAMapping( + [lora_id] * seq_len, [lora_id], is_prefill=exec_mode.is_prefill() + ) + self.set_active_loras(dummy_lora_requests, dummy_lora_mapping) # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile From f992620f4fa5ee9bde7117cfc95dfa522bc61195 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:04:04 +0000 Subject: [PATCH 010/317] Changed torch._dynamo config Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index c31d7e896ad1..7352fb5e4c43 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -310,10 +310,9 @@ def _dummy_run( # KRAI-TODO: Add lora config here # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). + torch._dynamo.config.capture_dynamic_output_shape_ops = True if exec_mode.is_prefill(): - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) + # Prefill torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode From b3abfc369c63a3b5709bf2364ef6cc66f2d6b03e Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:55:33 +0000 Subject: [PATCH 011/317] Quick patch to allow non lora code to run Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 7352fb5e4c43..efd2f79dc9f1 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -310,9 +310,13 @@ def _dummy_run( # KRAI-TODO: Add lora config here # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). - torch._dynamo.config.capture_dynamic_output_shape_ops = True if exec_mode.is_prefill(): # Prefill + if self.lora_config is not None: + torch._dynamo.config.capture_dynamic_output_shape_ops = True + else: + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode From 00f1b4a0304f1b782efe6af8eee1193e1c784d37 Mon Sep 17 00:00:00 2001 From: Oleg Mosalov Date: Fri, 22 Nov 2024 15:14:17 +0000 Subject: [PATCH 012/317] Updated the test for loading a LoRA adapter, now it better shows when the adapter and its weights are loaded. Signed-off-by: Oleg Mosalov --- tests/lora/test_load_lora_adapter.py | 45 ++++++++++++++++++---------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/tests/lora/test_load_lora_adapter.py b/tests/lora/test_load_lora_adapter.py index d5958e989173..2c5088e7be6d 100644 --- a/tests/lora/test_load_lora_adapter.py +++ b/tests/lora/test_load_lora_adapter.py @@ -7,7 +7,15 @@ def extract_layer_names(llm): model_executor = getattr(engine, "model_executor") driver_worker = getattr(model_executor, "driver_worker") model_runner = getattr(driver_worker, "model_runner") - return [name for name, _ in model_runner.model.named_modules()] + list_adapters = list(model_runner.model.lora_manager.list_adapters().values()) + list_layers = [] + for adapter in list_adapters: + loras = adapter.loras + adapter_layers = [] + for k in loras: + adapter_layers.append(loras[k].module_name) + list_layers.append(adapter_layers) + return list_layers def load_base_model(base_model_path): print(f"Loading base model from {base_model_path}...") @@ -19,6 +27,9 @@ def load_lora_adapter(llm, lora_path): print(f"Loading LoRA adapter from {lora_path}...") lora_request = LoRARequest("lora_adapter", 1, lora_path) print("LoRA adapter loaded.") + return llm, lora_request + +def send_request(llm, lora_request): print("Sending a dummy request.") prompt = "Hi!" output = llm.generate(prompt, lora_request=lora_request) @@ -27,26 +38,21 @@ def load_lora_adapter(llm, lora_path): def compare_layers(base_layers, lora_layers): print("Comparing layers...") - base_set = set(base_layers) - lora_set = set(lora_layers) + print(f"There are {len(base_layers)} LoRA layers in the base model.") + print(f"There are {len(lora_layers)} LoRA layers in the LoRA adapter.") + + base_set = set(name for adapter in base_layers for name in adapter) + lora_set = set(name for adapter in lora_layers for name in adapter) added_layers = lora_set - base_set removed_layers = base_set - lora_set - #print("Base model layers:") - #for layer in base_set: - # print(f" {layer}") - if added_layers or removed_layers: print("Layer differences detected:") if added_layers: - print(" Layers added by LoRA:") - for layer in added_layers: - print(f" {layer}") + print(f" Added {len(added_layers)} LoRA layers.") if removed_layers: - print(" Layers removed after LoRA:") - for layer in removed_layers: - print(f" {layer}") + print(f" Removed {len(removed_layers)} LoRA layers.") return True else: print("No differences in layers detected.") @@ -64,10 +70,17 @@ def main(): base_model = load_base_model(base_model_path) base_layers = extract_layer_names(base_model) - model_with_lora = load_lora_adapter(base_model, lora_adapter_path) - lora_layers = extract_layer_names(model_with_lora) + model_with_lora, lora_request = load_lora_adapter(base_model, lora_adapter_path) + lora_layers_before_request = extract_layer_names(model_with_lora) + + model_with_lora_after_request = send_request(model_with_lora, lora_request) + lora_layers_after_request = extract_layer_names(model_with_lora_after_request) + + print("Compare the base model and the model with a loaded LoRA adapter...") + compare_layers(base_layers, lora_layers_before_request) - compare_layers(base_layers, lora_layers) + print("Compare the model with a loaded LoRA adapter before and after sending a request...") + compare_layers(lora_layers_before_request, lora_layers_after_request) if __name__ == "__main__": main() From ca39aec4c2310c4b751e0ad29fbb9e10b6b606ef Mon Sep 17 00:00:00 2001 From: Oleg Mosalov Date: Fri, 22 Nov 2024 15:21:40 +0000 Subject: [PATCH 013/317] Better wording. Signed-off-by: Oleg Mosalov --- tests/lora/test_load_lora_adapter.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_load_lora_adapter.py b/tests/lora/test_load_lora_adapter.py index 2c5088e7be6d..fbd9093cdb67 100644 --- a/tests/lora/test_load_lora_adapter.py +++ b/tests/lora/test_load_lora_adapter.py @@ -36,16 +36,16 @@ def send_request(llm, lora_request): print("The request is sent.") return llm -def compare_layers(base_layers, lora_layers): +def compare_layers(first_model_layers, second_model_layers): print("Comparing layers...") - print(f"There are {len(base_layers)} LoRA layers in the base model.") - print(f"There are {len(lora_layers)} LoRA layers in the LoRA adapter.") + print(f"There are {len(first_model_layers)} LoRA adapters in the first model.") + print(f"There are {len(second_model_layers)} LoRA adapters in the second model.") - base_set = set(name for adapter in base_layers for name in adapter) - lora_set = set(name for adapter in lora_layers for name in adapter) + first_set = set(name for adapter in first_model_layers for name in adapter) + second_set = set(name for adapter in second_model_layers for name in adapter) - added_layers = lora_set - base_set - removed_layers = base_set - lora_set + added_layers = second_set - first_set + removed_layers = first_set - second_set if added_layers or removed_layers: print("Layer differences detected:") From 484602d294346651654392cbb7b2aa96cc8ca039 Mon Sep 17 00:00:00 2001 From: Oleg Mosalov Date: Mon, 16 Dec 2024 12:55:37 +0100 Subject: [PATCH 014/317] Added arg_parser to test_load_lora_adapter.py. Signed-off-by: Oleg Mosalov --- tests/lora/test_load_lora_adapter.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_load_lora_adapter.py b/tests/lora/test_load_lora_adapter.py index fbd9093cdb67..3616d2c7e1f9 100644 --- a/tests/lora/test_load_lora_adapter.py +++ b/tests/lora/test_load_lora_adapter.py @@ -1,6 +1,7 @@ from vllm import LLM from vllm.lora.request import LoRARequest import os +import argparse def extract_layer_names(llm): engine = getattr(llm, "llm_engine") @@ -17,9 +18,9 @@ def extract_layer_names(llm): list_layers.append(adapter_layers) return list_layers -def load_base_model(base_model_path): +def load_base_model(base_model_path, enable_lora, max_model_len, max_num_seqs, max_loras): print(f"Loading base model from {base_model_path}...") - llm = LLM(model=base_model_path, enable_lora=True) + llm = LLM(model=base_model_path, enable_lora=enable_lora, max_model_len=max_model_len, max_num_seqs=max_num_seqs, max_loras=max_loras) print("Base model loaded.") return llm @@ -58,16 +59,14 @@ def compare_layers(first_model_layers, second_model_layers): print("No differences in layers detected.") return False -def main(): - base_model_path = "/data/llama-3/llama-3-8b" - lora_adapter_path = "/home/oleg/lora_test/Meta-Llama-3-8B-oasst-Adapter" +def main(base_model_path, lora_adapter_path, enable_lora, max_model_len, max_num_seqs, max_loras): if not os.path.exists(base_model_path): raise FileNotFoundError(f"Base model path not found: {base_model_path}") if not os.path.exists(lora_adapter_path): raise FileNotFoundError(f"LoRA adapter path not found: {lora_adapter_path}") - base_model = load_base_model(base_model_path) + base_model = load_base_model(base_model_path, enable_lora, max_model_len, max_num_seqs, max_loras) base_layers = extract_layer_names(base_model) model_with_lora, lora_request = load_lora_adapter(base_model, lora_adapter_path) @@ -83,5 +82,16 @@ def main(): compare_layers(lora_layers_before_request, lora_layers_after_request) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + + parser.add_argument('-m', '--base-model-path', dest='base_model_path', type=str, required=True, help="The path of the base model") + parser.add_argument('-l', '--lora-adapter-path', dest='lora_adapter_path', type=str, required=True, help="The path of the base model") + parser.add_argument('--enable-lora', dest='enable_lora', action='store_true', default=True) + parser.add_argument('--max-model-len', dest='max_model_len', type=int, default=2048) + parser.add_argument('--max-num-seqs', dest='max_num_seqs', type=int, default=16) + parser.add_argument('--max-loras', dest='max_loras', type=int, default=4) + + args = parser.parse_args() + + main(base_model_path=args.base_model_path, lora_adapter_path=args.lora_adapter_path, enable_lora=args.enable_lora, max_model_len=args.max_model_len, max_num_seqs=args.max_num_seqs, max_loras=args.max_loras) From 98d7c34fedef597deb576c9b6930e28a39ae40cd Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 17 Jan 2025 15:23:58 +0000 Subject: [PATCH 015/317] Minor fixes Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 346 +++++++++++++++++++++++++ vllm/platforms/tpu.py | 9 + 2 files changed, 355 insertions(+) create mode 100644 vllm/lora/punica_wrapper/punica_tpu.py diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py new file mode 100644 index 000000000000..ffac5b2c362e --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -0,0 +1,346 @@ +from typing import Callable, Optional, Tuple, Union + +import torch + +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +from .punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperTPU(PunicaWrapperBase): + """ + PunicaWrapperTPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index cdf835a52c0c..2032a77d8221 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -57,6 +57,15 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return not envs.VLLM_USE_V1 + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on TPU.") + return False + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + @classmethod def inference_mode(cls): return torch.no_grad() From b39550539453406ce3f6267710bddf94e2ce81e0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 22 Jan 2025 11:11:26 +0000 Subject: [PATCH 016/317] Replaced einsums with matmuls to allow xla compilation Signed-off-by: Akshat Tripathi --- vllm/lora/ops/torch_ops/lora_ops.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index af79f98415cb..30240c5e0bc9 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -30,7 +30,9 @@ def bgmv_expand(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -71,7 +73,9 @@ def bgmv_shrink(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -107,7 +111,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) if add_inputs: output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] From b64f7005104c950e76bf60f0978898c54dfa01de Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 12:01:16 +0000 Subject: [PATCH 017/317] Removed xla ops for torch ops Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla/lora_ops.py | 133 ---------------------------------- 1 file changed, 133 deletions(-) delete mode 100644 vllm/lora/ops/xla/lora_ops.py diff --git a/vllm/lora/ops/xla/lora_ops.py b/vllm/lora/ops/xla/lora_ops.py deleted file mode 100644 index 51167ddf1b6b..000000000000 --- a/vllm/lora/ops/xla/lora_ops.py +++ /dev/null @@ -1,133 +0,0 @@ -import torch - -def sgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - add_inputs: bool = False -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_expand( - inputs, - lora_b_weights, - output_tensor, - exploded_indices, - add_inputs - ) - - -def bgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True -): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) - inputs = inputs.to(dtype=output_tensor.dtype) - - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - limit = output_tensor.shape[0] - if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: - limit = 1 - - if add_inputs: - output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] - else: - output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] - -def sgmv_shrink( - inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - scaling: float, -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_shrink( - inputs, - lora_a_weights, - output_tensor, - exploded_indices, - scaling - ) - -def bgmv_shrink( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0 -): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) - - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - -def sgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_expand_slice( - inputs, - lora_b_weights, - output_tensor, - exploded_indices, - slice_offset, - slice_size, - add_inputs - ) - - -def bgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True -): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) - inputs = inputs.to(dtype=output_tensor.dtype) - - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - - if add_inputs: - output_tensor[:, slice_offset:slice_offset+slice_size] += outputs[:] - else: - output_tensor[:, slice_offset:slice_offset+slice_size] = outputs[:] \ No newline at end of file From 2812d20db1ed540522f5d6c7c327b25a1a65ab8c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 12:02:11 +0000 Subject: [PATCH 018/317] Removed old debug log points Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 1 - vllm/worker/tpu_model_runner.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index f34262d9f4b1..0b2da870291d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1092,7 +1092,6 @@ def _get_logits( ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf)) - print(f"AKSHAT - After index select: {lora_logits.shape}, {indices_padded.shape}") # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index efd2f79dc9f1..25e45782eb99 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -656,8 +656,6 @@ def execute_model( ) -> List[SamplerOutput]: assert intermediate_tensors is None - print(f"\e[0;31m SELF LORA CONFIG {self.lora_config} \033[0m") - if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None @@ -839,7 +837,6 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - print("\e[0;31mSetting active loras\033[0m") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: From 1a08d2757297d5484c789107506b170c5657f48c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 12:02:35 +0000 Subject: [PATCH 019/317] Fixed bgmv/sgmv shape error Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index ffac5b2c362e..cd8349889ffd 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -270,7 +270,7 @@ def add_lora_linear(self, Args: y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor + x (torch.Tensor): Input tensor (B, S, E) lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. @@ -289,10 +289,11 @@ def add_lora_linear(self, r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default, consistent with the # triton op - buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) + buffer = torch.zeros( + (len(output_slices), x.size(1), r), + dtype=torch.float32, + device=x.device, + ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_expand(y, buffer, From 6a75cab72b2cc1e4a7ff370bb2b28c06b90b24db Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 14:41:22 +0000 Subject: [PATCH 020/317] Fixed lora batching crash in warmup Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 25e45782eb99..fdb5e38e7935 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -298,9 +298,9 @@ def _dummy_run( # KRAI-TODO: Add lora config here self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK) dummy_lora_requests.add(dummy_lora_request) - dummy_lora_mapping = LoRAMapping( - [lora_id] * seq_len, [lora_id], is_prefill=exec_mode.is_prefill() - ) + dummy_lora_mapping = LoRAMapping( + [lora_id] * batch_size * seq_len, [lora_id] * batch_size, is_prefill=exec_mode.is_prefill() + ) self.set_active_loras(dummy_lora_requests, dummy_lora_mapping) # NOTE(woosuk): There are two stages of compilation: torch.compile and @@ -384,7 +384,7 @@ def warmup_model( # Decode start = time.time() seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() + batch_size = _get_padded_batch_size(1) while True: self._dummy_run(batch_size, seq_len, From 013a038ad76d1892d5257704cea62630a7a5e405 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 15:02:13 +0000 Subject: [PATCH 021/317] Fixed shape issue in add_lora_linear() Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index cd8349889ffd..b40da63a3517 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -289,8 +289,9 @@ def add_lora_linear(self, r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default, consistent with the # triton op + batch_size, seq_len, _ = x.shape buffer = torch.zeros( - (len(output_slices), x.size(1), r), + (len(output_slices), batch_size * seq_len, r), dtype=torch.float32, device=x.device, ) From 82595f55ccb95cf072fa883a70fb0704ca1cf61d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 15:02:59 +0000 Subject: [PATCH 022/317] Fixed dynamic lora tensor shapes Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index fdb5e38e7935..0c358115d3d6 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -320,13 +320,17 @@ def _dummy_run( # KRAI-TODO: Add lora config here torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(input_lens, 0) + if self.lora_config is not None: + torch._dynamo.config.capture_dynamic_output_shape_ops = True + else: + pass + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. with set_forward_context(attn_metadata, self.vllm_config, 0): From 3fd5e48d7b9874b253a74e9a2553e2fb4154fdfd Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 16:15:23 +0000 Subject: [PATCH 023/317] Fixed lora_input preparation for actual execution Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 101 ++++++++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 12 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 0c358115d3d6..807db4c45d35 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -65,8 +65,7 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int n: List[int] seq_groups: List[List[int]] - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None + lora_inputs: List[Tuple[Set[LoRARequest], LoRAMapping]] is_first_multi_step: bool = True is_last_step: bool = True virtual_engine: int = 0 @@ -77,8 +76,7 @@ def as_broadcastable_tensor_dict( tensor_dict = { "token_ids": self.token_ids, "position_ids": self.position_ids, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, + "lora_inputs": self.lora_inputs, "input_lens": self.input_lens, "t": self.t, "p": self.p, @@ -641,8 +639,81 @@ def prepare_model_input( list(metadata.seq_data.keys()) for metadata in seq_group_metadata_list ] - return ModelInputForTPU(input_tokens, input_positions, attn_metadata, - input_lens, t, p, num_samples, n, seq_groups) + + lora_inputs = [] + if self.load_config is not None: + lora_inputs = self._prepare_lora_input(seq_group_metadata_list, is_prompt, padded_batch_size) + + return ModelInputForTPU( + token_ids=input_tokens, + position_ids=input_positions, + attn_metadata=attn_metadata, + input_lens=input_lens, + t=t, + p=p, + num_samples=num_samples, + n=n, + seq_groups=seq_groups, + lora_inputs=lora_inputs + ) + + def _prepare_lora_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + is_prefill: bool, + padded_batch_size: int) -> List[Tuple[Set[LoRARequest], LoRAMapping]]: + """ + Prepares a list of LoRA inputs. If we're decoding then the list will only have 1 item, + otherwise there'll be an item for each sequence + """ + + lora_input = [] + if is_prefill: + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + query_len = seq.token_chunk_size + padded_query_len = _get_padded_prefill_len(query_len) + + index_mapping = [lora_id] * padded_query_len + prompt_mapping = [lora_id] + + lora_request = set() + if seq.lora_request is not None: + lora_request.add(seq.lora_request) + + lora_input.append(( + lora_request, + LoRAMapping( + index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=True + ) + )) + else: + lora_request = set() + index_mapping = [] + prompt_mapping = [] + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + + index_mapping += [lora_id] + prompt_mapping += [lora_id] + + if seq.lora_request is not None: + lora_request.add(seq.lora_request) + + index_mapping += [0] * (padded_batch_size - len(seq_group_metadata_list)) + prompt_mapping += [0] * (padded_batch_size - len(seq_group_metadata_list)) + + lora_input.append(( + lora_request, + LoRAMapping( + index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=False + ) + )) + + return lora_input def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: @@ -660,12 +731,6 @@ def execute_model( ) -> List[SamplerOutput]: assert intermediate_tensors is None - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - if not model_input.is_first_multi_step: if not model_input.is_last_step: return [] @@ -741,6 +806,12 @@ def execute_model( input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) + + if self.lora_config is not None: + assert len(model_input.lora_inputs) == batch_size + lora_requests, lora_mapping = model_input.lora_inputs[i] + self.set_active_loras(lora_requests, lora_mapping) + with set_forward_context(model_input.attn_metadata, self.vllm_config, model_input.virtual_engine): @@ -790,6 +861,12 @@ def execute_model( t = model_input.t.to(self.device) p = model_input.p.to(self.device) input_lens = model_input.input_lens.to(self.device) + + if self.lora_config is not None: + assert len(model_input.lora_inputs) == 1 + lora_requests, lora_mapping = model_input.lora_inputs[0] + self.set_active_loras(lora_requests, lora_mapping) + for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping with set_forward_context(model_input.attn_metadata, From 1d89e0d72b4b97f9ba73a92137d8fa1b70bac6db Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 24 Jan 2025 16:42:22 +0000 Subject: [PATCH 024/317] Fixed wrong model bug Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 807db4c45d35..db6117485dc6 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -184,7 +184,7 @@ def load_model(self) -> None: self.model = self.lora_manager.create_lora_manager(self.model) self.model = ModelWrapper(self.model) - self.model = torch.compile(model, + self.model = torch.compile(self.model, backend="openxla", fullgraph=True, dynamic=False) @@ -321,7 +321,6 @@ def _dummy_run( # KRAI-TODO: Add lora config here if self.lora_config is not None: torch._dynamo.config.capture_dynamic_output_shape_ops = True else: - pass torch._dynamo.mark_dynamic(token_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(input_lens, 0) From 6acc4953d53451c7fcfad7bdfe87f473c41451d9 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 24 Jan 2025 16:51:49 +0000 Subject: [PATCH 025/317] Moved if statements outside of for loops in PunicaWrapperTPU Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 54 +++++++------------------- 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index b40da63a3517..b0a8149d5b7d 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -111,43 +111,6 @@ def _expand_slice_decode( bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) - def _apply_expand( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: int, - y_slice_size: int, - add_inputs: bool = True, - ): - """ - Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` - computation, which is suitable for the - GEMM of lora'b. - """ - - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - - def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, scale: float): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) - shrink_fun(y, x, w_t_all, scale) - y = y.view_as(y_org) - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs): @@ -170,10 +133,19 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], """ x = x.view(-1, x.shape[-1]) + + shrink_fun: Callable = (self._shrink_prefill if self.is_prefill else self._shrink_decode) + # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + y_s = y[slice_idx] + lora_s = lora_a_stacked[slice_idx] + + y_org = y_s + y_s = y_s.view(-1, y_s.shape[-1]) + + shrink_fun(y_s, x, lora_s, scale) + y_s = y_s.view_as(y_org) def add_expand(self, y: torch.Tensor, @@ -203,6 +175,8 @@ def add_expand(self, output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ + expand_slice_fun: Callable = (self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode) + y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start @@ -210,7 +184,7 @@ def add_expand(self, self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( + expand_slice_fun( y, x[slice_idx], lora_b_stacked[slice_idx], From 6b14c458b4ff0c359a0bf87f5fa16a8f34223c26 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 28 Jan 2025 14:50:58 +0000 Subject: [PATCH 026/317] Added early exits to PunicaWrapperTPU lora functions Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index b0a8149d5b7d..64bd8fa16917 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -48,6 +48,8 @@ def _shrink_decode( w_t_all: torch.Tensor, scale: float, ): + if self.no_lora: + return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) def _expand_prefill( @@ -75,6 +77,8 @@ def _expand_decode( w_t_all: torch.Tensor, add_inputs: bool, ): + if self.no_lora: + return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def _expand_slice_prefill( @@ -108,6 +112,8 @@ def _expand_slice_decode( y_slice_size: int, add_inputs: bool, ): + if self.no_lora: + return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) From 0ae5d851c7c4afea1ec322c554b0909c73382d48 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 30 Jan 2025 12:33:39 +0000 Subject: [PATCH 027/317] Added torch ops for tpu (Static prefill sizes) Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 13 +++ vllm/lora/ops/xla_ops/lora_ops.py | 118 +++++++++++++++++++++++++ vllm/lora/punica_wrapper/punica_tpu.py | 2 +- 3 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 vllm/lora/ops/xla_ops/__init__.py create mode 100644 vllm/lora/ops/xla_ops/lora_ops.py diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py new file mode 100644 index 000000000000..4785af8520d3 --- /dev/null +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -0,0 +1,13 @@ +from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401 +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py new file mode 100644 index 000000000000..5dc0c98bbb48 --- /dev/null +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -0,0 +1,118 @@ +import torch + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + inputs.size(0)) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + inputs.size(0)) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + inputs.size(0)) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 64bd8fa16917..b6739bd97bdb 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -2,7 +2,7 @@ import torch -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, +from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) From d51c151b27b8bae50cad32f76fd79581c83528b9 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 30 Jan 2025 17:34:42 +0000 Subject: [PATCH 028/317] XLA bgmv operations are now imported from the default torch_ops Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 65 +------------------------------ 1 file changed, 1 insertion(+), 64 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 5dc0c98bbb48..d6c630880644 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,5 +1,5 @@ import torch - +from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink def sgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -17,31 +17,6 @@ def sgmv_expand(inputs: torch.Tensor, bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs) - -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) - inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - limit = output_tensor.shape[0] - if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: - limit = 1 - - if add_inputs: - output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] - else: - output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] - - def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, @@ -61,23 +36,6 @@ def sgmv_shrink( scaling) -def bgmv_shrink(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) - inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - - def sgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -95,24 +53,3 @@ def sgmv_expand_slice(inputs: torch.Tensor, bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, slice_offset, slice_size, add_inputs) - - -def bgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - inputs = inputs.to(dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - if add_inputs: - output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] - else: - output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] From be3dfecf926fe6519211cc13813b31cdc753d1bf Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 31 Jan 2025 16:15:22 +0000 Subject: [PATCH 029/317] Removed TODOs Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 10 +++++++--- vllm/worker/tpu_model_runner.py | 5 ++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0b2da870291d..1971ebe6c238 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1082,9 +1082,13 @@ def _get_logits( lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded - # KRAI: Temporary change - neg_inf = torch.finfo(lora_logits.dtype).min - pos_inf = torch.finfo(lora_logits.dtype).max + if current_platform.is_tpu(): + # Because nan_to_num_ doesn't work with actual -inf values on TPU + neg_inf = torch.finfo(lora_logits.dtype).min + pos_inf = torch.finfo(lora_logits.dtype).max + else: + neg_inf = float("-inf") + pos_inf = float("inf") lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index db6117485dc6..85cc6f4dedae 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -44,7 +44,6 @@ # FIXME(woosuk): A temporary hack to support `n > 1`. # This can significantly affect the performance if too large. _MAX_NUM_SAMPLES = 128 -LORA_WARMUP_RANK = 8 # KRAI: TODO: Should this not be max rank - so we have better startup times? class ExecutionMode(enum.Enum): PREFILL = enum.auto() @@ -192,7 +191,7 @@ def load_model(self) -> None: def get_model(self) -> nn.Module: return self.model.model - def _dummy_run( # KRAI-TODO: Add lora config here + def _dummy_run( self, batch_size: int, seq_len: int, @@ -294,7 +293,7 @@ def _dummy_run( # KRAI-TODO: Add lora config here lora_path="/not/a/real/path", ) self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) + rank=self.lora_config.max_lora_rank) dummy_lora_requests.add(dummy_lora_request) dummy_lora_mapping = LoRAMapping( [lora_id] * batch_size * seq_len, [lora_id] * batch_size, is_prefill=exec_mode.is_prefill() From c5ce23912b58311a24c47236b367f06438b39795 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 31 Jan 2025 17:44:21 +0000 Subject: [PATCH 030/317] Removed old code Signed-off-by: Akshat Tripathi --- tests/lora/test_layers_tpu.py | 1220 -------------------------- tests/lora/test_load_lora_adapter.py | 97 -- 2 files changed, 1317 deletions(-) delete mode 100644 tests/lora/test_layers_tpu.py delete mode 100644 tests/lora/test_load_lora_adapter.py diff --git a/tests/lora/test_layers_tpu.py b/tests/lora/test_layers_tpu.py deleted file mode 100644 index 29f732c621af..000000000000 --- a/tests/lora/test_layers_tpu.py +++ /dev/null @@ -1,1220 +0,0 @@ -import random -from copy import deepcopy -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple -from unittest.mock import patch - -import pytest -import torch -import torch.nn.functional as F - -from vllm.config import LoRAConfig -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, - RowParallelLinearWithShardedLoRA) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LinearScalingRotaryEmbeddingWithLora, - LogitsProcessorWithLoRA, LoRAMapping, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLora, - QKVParallelLinearWithLora, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA, - VocabParallelEmbeddingWithLoRA) -# yapf: enable -from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, - PackedLoRALayerWeights) -from vllm.lora.punica import PunicaWrapper -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) -from vllm.model_executor.utils import set_random_seed -from vllm.platforms import current_platform - -from .utils import DummyLoRAManager - -TOLERANCES = { - torch.float16: (5e-3, 5e-3), - torch.float32: (5e-3, 5e-3), - torch.bfloat16: (3e-2, 2e-2), -} -TPU_DEVICES = [ - f"xla:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -# We will launch different triton kernels between the prefill and decode -# stages, so we need to verify this. prefill stage(True) or decode stage(False) -STAGES = [True, False] - - -def get_random_id_to_index(num_loras: int, - num_slots: int, - log: bool = True) -> List[Optional[int]]: - """Creates a random lora_id_to_index mapping. - - Args: - num_loras: The number of active loras in the mapping. - num_slots: The number of slots in the mapping. Must be larger - than num_loras. - log: Whether to log the output. - """ - - if num_loras > num_slots: - raise ValueError( - f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " - "num_loras must be less than or equal to num_slots.") - - slots: List[Optional[int]] = [None] * num_slots - random_slot_selections = (torch.randperm(num_slots, device="cpu")[:num_loras]).tolist() - for lora_id, slot_idx in enumerate(random_slot_selections, start=1): - slots[slot_idx] = lora_id - - if log: - print(f"Created lora_id_to_index mapping: {slots}.") - - return slots - - -def populate_loras( - id_to_index: List[Optional[int]], - layer: BaseLayerWithLoRA, - layer_weights: torch.Tensor, - generate_embeddings_tensor: int = 0, - repeats: int = 1, -) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: - """This method populates the lora layers with lora weights. - - Args: - id_to_index: a list of lora ids. The index of the lora id - represents which memory slot the lora matrices are - stored in. A None value indicates a free slot. - layer: the LoRAlayer to populate. - layer_weights: the PyTorch tensor containing the layer's - weights. - generate_embeddings_tensor: whether to generate an - embeddings tensor for each LoRA. - repeats: must only be set for column parallel packed - layers. Indicates the number of loras to compose - together to create a single lora layer. - """ - - # Dictionary that maps the lora ID to the - # corresponding lora weights. - lora_dict: Dict[int, LoRALayerWeights] = dict() - - # Dictionary that maps the lora ID to the - # corresponding subloras. - sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() - - for slot_idx, lora_id in enumerate(id_to_index): - if lora_id is not None: - subloras: List[LoRALayerWeights] = [] - sublora_len = layer_weights.shape[0] // repeats - for i in range(repeats): - sublora = DummyLoRAManager( - layer_weights.device).init_random_lora( - module_name=f"fake_{i}", - weight=layer_weights, - generate_embeddings_tensor=generate_embeddings_tensor, - ) - sublora.lora_b = sublora.lora_b[:, (sublora_len * - i):(sublora_len * (i + 1))] - sublora.optimize() - subloras.append(sublora) - - lora = PackedLoRALayerWeights.pack( - subloras) if repeats > 1 else subloras[0] - - layer.set_lora( - slot_idx, - lora_a=lora.lora_a, - lora_b=lora.lora_b, - embeddings_tensor=lora.embeddings_tensor, - ) - - lora_dict[lora_id] = lora - sublora_dict[lora_id] = subloras - - return lora_dict, sublora_dict - - -def create_random_inputs( - active_lora_ids: List[int], - num_inputs: int, - input_size: Tuple[int, ...], - input_range: Tuple[float, float], - input_type: torch.dtype = torch.int, - device: torch.device = "xla" -) -> Tuple[List[torch.Tensor], List[int], List[int]]: - """Creates random inputs. - - Args: - active_lora_ids: lora IDs of active lora weights. - num_inputs: the number of inputs to create. - input_size: the size of each individual input. - input_range: the range of values to include in the input. - input_range[0] <= possible input values < input_range[1] - input_type: the type of values in the input. - """ - - low, high = input_range - - inputs: List[torch.Tensor] = [] - index_mapping: List[int] = [] - prompt_mapping: List[int] = [] - - for _ in range(num_inputs): - if input_type == torch.int: - inputs.append( - torch.randint(low=int(low), - high=int(high), - size=input_size, - device=device)) - else: - inputs.append( - torch.rand(size=input_size, dtype=input_type, device=device) * - high + low) - - lora_id = random.choice(active_lora_ids) - index_mapping += [lora_id] * input_size[0] - prompt_mapping += [lora_id] - - return inputs, index_mapping, prompt_mapping - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -@pytest.mark.parametrize("stage", STAGES) -def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) - - def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(vocab_size, 256) - embedding.weight.data = torch.rand_like(embedding.weight.data) - embedding.weight.data[vocab_size:, :] = 0 - lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) - lora_embedding.create_lora_weights(max_loras, lora_config) - - return embedding, lora_embedding - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - embedding, lora_embedding = create_random_embedding_layer() - lora_embedding.set_mapping(punica_wrapper) - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_embedding, - layer_weights=embedding.weight.T, - ) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=num_loras * 3, - input_size=(200, ), - input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - - lora_result = lora_embedding(torch.cat(inputs)) - - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] - result = embedding(input_) - after_a = F.embedding( - input_, - lora.lora_a, - ) - result += (after_a @ lora.lora_b) - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_embedding.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=num_loras * 3, - input_size=(200, ), - input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - - lora_result = lora_embedding(torch.cat(inputs)) - expected_result = embedding(torch.cat(inputs)) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -# @pytest.mark.skip( -# reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -@pytest.mark.parametrize("stage", STAGES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device, - vocab_size, stage) -> None: - - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) - - def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(vocab_size, 256) - embedding_data = torch.rand_like(embedding.weight.data) - embedding.weight.data = embedding_data - embedding.weight.data[vocab_size:, :] = 0 - expanded_embedding = VocabParallelEmbedding( - vocab_size + lora_config.lora_extra_vocab_size * max_loras, - 256, - org_num_embeddings=vocab_size) - expanded_embedding.weight.data[:vocab_size, :] = embedding_data - # We need to deepcopy the embedding as it will be modified - # in place - lora_embedding = VocabParallelEmbeddingWithLoRA( - deepcopy(expanded_embedding)) - lora_embedding.create_lora_weights(max_loras, lora_config) - - return expanded_embedding, lora_embedding - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - expanded_embedding, lora_embedding = create_random_embedding_layer() - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_embedding, - layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size)), - generate_embeddings_tensor=256, - ) - - lora_embedding.set_mapping(punica_wrapper) - # All embeddings tensors have the same shape. - embeddings_tensors = [ - lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) - ] - embeddings_tensor_len = embeddings_tensors[0].shape[0] - - # Add empty embeddings_tensors for unoccupied lora slots. - for _ in range(max_loras - len(embeddings_tensors)): - embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=num_loras * 3, - input_size=(200, ), - input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - original_inputs = deepcopy(inputs) - - # Force some of the inputs to be in the extended embeddings range - # to guarantee that their behavior is tested. - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): - embedding_id = lora_id - 1 - input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) - original_input_[-1] = vocab_size - input_[-2] = vocab_size + ( - (embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - - expanded_embedding.weight[vocab_size:vocab_size + - (embeddings_tensor_len * - max_loras)] = torch.cat(embeddings_tensors) - - lora_result = lora_embedding(torch.cat(original_inputs)) - - expected_results: List[torch.Tensor] = [] - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): - lora = lora_dict[lora_id] - result = expanded_embedding(input_) - after_a = F.embedding( - original_input_, - lora.lora_a, - ) - result += (after_a @ lora.lora_b) - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_embedding.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=num_loras * 3, - input_size=(200, ), - input_range=(1, vocab_size), - device=device) - original_inputs = deepcopy(inputs) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - lora_result = lora_embedding(torch.cat(original_inputs)) - expected_result = expanded_embedding(torch.cat(inputs)) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) -@pytest.mark.parametrize("stage", STAGES) -def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, - stage) -> None: - - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) - - def _pretest(): - linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, - 1024, - vocab_size, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - linear.weight.data[:, vocab_size:] = 0 - logits_processor = LogitsProcessor( - vocab_size + lora_config.lora_extra_vocab_size, vocab_size) - lora_logits_processor = LogitsProcessorWithLoRA( - logits_processor, 1024, linear.weight.dtype, linear.weight.device, - None) - lora_logits_processor.create_lora_weights(max_loras, lora_config) - - return linear, logits_processor, lora_logits_processor - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, logits_processor, lora_logits_processor = _pretest() - lora_logits_processor.set_mapping(punica_wrapper) - # NOTE: all the generated loras share the same embeddings tensor. - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_logits_processor, - layer_weights=linear.weight, - generate_embeddings_tensor=1024, - ) - embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor - embeddings_tensor_len = embeddings_tensor.shape[0] - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=8 * num_loras, # * 3, - input_size=(1, 1024), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - input_ = torch.rand(20, 1024, dtype=torch.float16) - - lora_result = lora_logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=linear, - embedding_bias=None) - - original_lm_head = deepcopy(linear) - - linear.weight[logits_processor. - org_vocab_size:logits_processor.org_vocab_size + - embeddings_tensor_len] = embeddings_tensor - - logits_processor.org_vocab_size = (vocab_size + - lora_config.lora_extra_vocab_size) - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] - result = logits_processor._get_logits(hidden_states=input_, - lm_head=linear, - embedding_bias=None) - result[:, vocab_size + embeddings_tensor_len:] = float("-inf") - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling - expected_results.append(result) - expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = vocab_size - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_logits_processor.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=8 * num_loras * 3, - input_size=(1, 1024), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - - lora_result = lora_logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=original_lm_head, - embedding_bias=None)[:, :vocab_size] - expected_result = logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=original_lm_head, - embedding_bias=None) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("stage", STAGES) -def test_linear_replicated(dist_init, num_loras, device, stage) -> None: - - torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) - max_loras = 8 - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) - - def create_random_linear_replicated_layer(): - - linear = ReplicatedLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = ReplicatedLinearWithLoRA(linear) - - lora_linear.create_lora_weights(max_loras, lora_config) - - return linear, lora_linear - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, lora_linear = create_random_linear_replicated_layer() - lora_linear.set_mapping(punica_wrapper) - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_linear, - layer_weights=linear.weight, - ) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) - - lora_result = lora_linear(torch.cat(inputs))[0] - - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] - result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_linear.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) - - lora_result = lora_linear(torch.cat(inputs))[0] - expected_result = linear(torch.cat(inputs))[0] - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("orientation", ["row", "column"]) -@pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("stage", STAGES) -def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage) -> None: - - torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) - max_loras = 8 - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) - - def create_random_linear_parallel_layer(): - if orientation == "row": - linear = RowParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard - else RowParallelLinearWithShardedLoRA(linear)) - else: - linear = ColumnParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (ColumnParallelLinearWithLoRA(linear) - if not fully_shard else - ColumnParallelLinearWithShardedLoRA(linear)) - lora_linear.create_lora_weights(max_loras, lora_config) - - return linear, lora_linear - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, lora_linear = create_random_linear_parallel_layer() - lora_linear.set_mapping(punica_wrapper) - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_linear, - layer_weights=linear.weight, - ) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) - - lora_result = lora_linear(torch.cat(inputs))[0] - - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] - result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_linear.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) - - lora_result = lora_linear(torch.cat(inputs))[0] - expected_result = linear(torch.cat(inputs))[0] - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("repeats", [1, 2, 3]) -@pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("stage", STAGES) -def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage) -> None: - - torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) - max_loras = 8 - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) - - def create_column_parallel_packed_layer(): - if repeats == 2: - linear = MergedColumnParallelLinear(4096, [4096] * repeats, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedColumnParallelLinearWithLoRA(linear) - if not fully_shard else - MergedColumnParallelLinearWithShardedLoRA(linear)) - elif repeats == 3: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedQKVParallelLinearWithLora(linear) - if not fully_shard else - MergedQKVParallelLinearWithShardedLora(linear)) - else: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = QKVParallelLinearWithLora( - linear - ) if not fully_shard else QKVParallelLinearWithShardedLora(linear) - - @dataclass - class FakeConfig: - hidden_size = 4096 - num_key_value_heads = 32 - num_attention_heads = 32 - - lora_linear.create_lora_weights(max_loras, - lora_config, - model_config=FakeConfig()) - - return linear, lora_linear - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - - linear, lora_linear = create_column_parallel_packed_layer() - lora_linear.set_mapping(punica_wrapper) - lora_dict, sublora_dict = populate_loras( - id_to_index, - layer=lora_linear, - layer_weights=linear.weight, - repeats=repeats, - ) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) - - lora_result = lora_linear(torch.cat(inputs))[0] - - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - result = linear(input_)[0] - subloras = sublora_dict[lora_id] - for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * - (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * - sublora.scaling) - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - for slot_idx in range(max_loras): - lora_linear.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) - # lora_linear.set_mapping(*mapping_info) - - lora_result = lora_linear(torch.cat(inputs))[0] - expected_result = linear(torch.cat(inputs))[0] - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 8]) -@pytest.mark.parametrize("device", ["cuda"]) -@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0), - (6.0, 1.0)]) -@pytest.mark.parametrize("max_position", [11, 4096, 32768]) -@pytest.mark.parametrize("is_neox_style", [True, False]) -@pytest.mark.parametrize("rotary_dim", [None, 32]) -@pytest.mark.parametrize("head_size", [32, 108]) -@pytest.mark.parametrize("seq_len", [11, 1024]) -def test_rotary_embedding_long_context(dist_init, num_loras, device, - scaling_factors, max_position, - is_neox_style, rotary_dim, head_size, - seq_len) -> None: - dtype = torch.float16 - seed = 0 - current_platform.seed_everything(seed) - torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) - max_loras = 8 - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - long_lora_scaling_factors=scaling_factors, - lora_dtype=dtype) - - if rotary_dim is None: - rotary_dim = head_size - base = 10000 - batch_size = 5 * num_loras - num_heads = 7 - - # Verify lora is equivalent to linear scaling rotary embedding. - rope = get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - ) - lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) - lora_rope.set_mapping(punica_wrapper) - lora_rope.create_lora_weights(max_loras, lora_config) - linear_rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_style, { - "rope_type": "linear", - "factor": scaling_factors - }) - linear_rope = linear_rope.to(dtype=dtype) - id_to_index = get_random_id_to_index(num_loras, max_loras) - _, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=batch_size, - input_size=(1, max_position), - input_range=(0, lora_config.lora_extra_vocab_size), - input_type=torch.float16, - device=device) - - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - long_lora_context = LongContextLoRAContext(list(scaling_factors), - rotary_dim) - - next_expected_offset = 0 - # Make sure the offset is correct. - scaling_factor_to_offset = lora_rope.scaling_factor_to_offset - for scaling_factor, offset in scaling_factor_to_offset.items(): - assert offset == next_expected_offset - next_expected_offset += scaling_factor * max_position - - for i in range(len(scaling_factors)): - long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( - scaling_factors[i], 0) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - long_lora_context=long_lora_context, - ) - # lora_rope.set_mapping(*mapping_info) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) - key = torch.randn_like(query) - ref_q, ref_k = linear_rope(positions, query, key) - actual_q, actual_k = lora_rope(positions, query, key) - - torch.allclose(ref_q, actual_q) - torch.allclose(ref_k, actual_k) - - -@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) -@pytest.mark.parametrize("seed", list(range(256))) -def test_vocab_parallel_embedding_indices(tp_size, seed): - random.seed(seed) - vocab_size = random.randint(4000, 64000) - added_vocab_size = random.randint(0, 1024) - org_vocab_size = vocab_size - added_vocab_size - last_org_vocab_end_index = 0 - last_added_vocab_end_index = org_vocab_size - computed_vocab_size = 0 - computed_org_vocab_size = 0 - computed_added_vocab_size = 0 - vocab_size_padded = -1 - - all_org_tokens: List[int] = [] - all_added_tokens: List[int] = [] - token_ids: List[int] = [] - - for tp_rank in range(tp_size): - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", - return_value=tp_rank - ), patch( - "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", - return_value=tp_size): - vocab_embedding = VocabParallelEmbedding( - vocab_size, 1, org_num_embeddings=org_vocab_size) - vocab_size_padded = vocab_embedding.num_embeddings_padded - shard_indices = vocab_embedding.shard_indices - # Assert that the ranges are contiguous - assert shard_indices.org_vocab_start_index == last_org_vocab_end_index - assert (shard_indices.added_vocab_start_index == - last_added_vocab_end_index) - - # Ensure that we are not exceeding the vocab size - computed_vocab_size += shard_indices.num_elements_padded - computed_org_vocab_size += shard_indices.num_org_elements - computed_added_vocab_size += shard_indices.num_added_elements - - # Ensure that the ranges are not overlapping - all_org_tokens.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) - all_added_tokens.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) - - token_ids.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_org_elements_padded - - shard_indices.num_org_elements)) - token_ids.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_added_elements_padded - - shard_indices.num_added_elements)) - - last_org_vocab_end_index = shard_indices.org_vocab_end_index - last_added_vocab_end_index = shard_indices.added_vocab_end_index - - assert computed_vocab_size == vocab_size_padded - assert computed_org_vocab_size == org_vocab_size - assert computed_added_vocab_size == added_vocab_size - - # Ensure that the ranges are not overlapping - assert len(all_org_tokens) == len(set(all_org_tokens)) - assert len(all_added_tokens) == len(set(all_added_tokens)) - assert not set(all_org_tokens).intersection(set(all_added_tokens)) - - token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) - reindex_mapping = vocab_embedding.get_sharded_to_full_mapping() - assert reindex_mapping is not None or tp_size == 1 - if reindex_mapping is not None: - reindexed_token_ids = token_ids_tensor[reindex_mapping] - expected = torch.tensor(list(range(0, vocab_size))) - assert reindexed_token_ids[:vocab_size].equal(expected) - assert torch.all(reindexed_token_ids[vocab_size:] == -1) - - -def test_get_masked_input_and_mask(): - x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) - - # base tp 1 case, no padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(x, modified_x) - - # tp 2 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=0) - modified_x_rank_1, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=4, - org_vocab_end_index=8, - added_vocab_start_index=10, - added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) - - # tp 4 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=0) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=0) - modified_x_rank_2, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=4, - org_vocab_end_index=6, - added_vocab_start_index=10, - added_vocab_end_index=11, - num_org_vocab_padding=0) - modified_x_rank_3, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=6, - org_vocab_end_index=8, - added_vocab_start_index=11, - added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) - - # base tp 1 case, with padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x, - torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) - - # tp 2 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=2) - modified_x_rank_1, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=4, - org_vocab_end_index=8, - added_vocab_start_index=10, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) - - # tp 4 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=2) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=2) - modified_x_rank_2, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=4, - org_vocab_end_index=6, - added_vocab_start_index=10, - added_vocab_end_index=11, - num_org_vocab_padding=2) - modified_x_rank_3, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=6, - org_vocab_end_index=8, - added_vocab_start_index=11, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) diff --git a/tests/lora/test_load_lora_adapter.py b/tests/lora/test_load_lora_adapter.py deleted file mode 100644 index 3616d2c7e1f9..000000000000 --- a/tests/lora/test_load_lora_adapter.py +++ /dev/null @@ -1,97 +0,0 @@ -from vllm import LLM -from vllm.lora.request import LoRARequest -import os -import argparse - -def extract_layer_names(llm): - engine = getattr(llm, "llm_engine") - model_executor = getattr(engine, "model_executor") - driver_worker = getattr(model_executor, "driver_worker") - model_runner = getattr(driver_worker, "model_runner") - list_adapters = list(model_runner.model.lora_manager.list_adapters().values()) - list_layers = [] - for adapter in list_adapters: - loras = adapter.loras - adapter_layers = [] - for k in loras: - adapter_layers.append(loras[k].module_name) - list_layers.append(adapter_layers) - return list_layers - -def load_base_model(base_model_path, enable_lora, max_model_len, max_num_seqs, max_loras): - print(f"Loading base model from {base_model_path}...") - llm = LLM(model=base_model_path, enable_lora=enable_lora, max_model_len=max_model_len, max_num_seqs=max_num_seqs, max_loras=max_loras) - print("Base model loaded.") - return llm - -def load_lora_adapter(llm, lora_path): - print(f"Loading LoRA adapter from {lora_path}...") - lora_request = LoRARequest("lora_adapter", 1, lora_path) - print("LoRA adapter loaded.") - return llm, lora_request - -def send_request(llm, lora_request): - print("Sending a dummy request.") - prompt = "Hi!" - output = llm.generate(prompt, lora_request=lora_request) - print("The request is sent.") - return llm - -def compare_layers(first_model_layers, second_model_layers): - print("Comparing layers...") - print(f"There are {len(first_model_layers)} LoRA adapters in the first model.") - print(f"There are {len(second_model_layers)} LoRA adapters in the second model.") - - first_set = set(name for adapter in first_model_layers for name in adapter) - second_set = set(name for adapter in second_model_layers for name in adapter) - - added_layers = second_set - first_set - removed_layers = first_set - second_set - - if added_layers or removed_layers: - print("Layer differences detected:") - if added_layers: - print(f" Added {len(added_layers)} LoRA layers.") - if removed_layers: - print(f" Removed {len(removed_layers)} LoRA layers.") - return True - else: - print("No differences in layers detected.") - return False - -def main(base_model_path, lora_adapter_path, enable_lora, max_model_len, max_num_seqs, max_loras): - - if not os.path.exists(base_model_path): - raise FileNotFoundError(f"Base model path not found: {base_model_path}") - if not os.path.exists(lora_adapter_path): - raise FileNotFoundError(f"LoRA adapter path not found: {lora_adapter_path}") - - base_model = load_base_model(base_model_path, enable_lora, max_model_len, max_num_seqs, max_loras) - base_layers = extract_layer_names(base_model) - - model_with_lora, lora_request = load_lora_adapter(base_model, lora_adapter_path) - lora_layers_before_request = extract_layer_names(model_with_lora) - - model_with_lora_after_request = send_request(model_with_lora, lora_request) - lora_layers_after_request = extract_layer_names(model_with_lora_after_request) - - print("Compare the base model and the model with a loaded LoRA adapter...") - compare_layers(base_layers, lora_layers_before_request) - - print("Compare the model with a loaded LoRA adapter before and after sending a request...") - compare_layers(lora_layers_before_request, lora_layers_after_request) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument('-m', '--base-model-path', dest='base_model_path', type=str, required=True, help="The path of the base model") - parser.add_argument('-l', '--lora-adapter-path', dest='lora_adapter_path', type=str, required=True, help="The path of the base model") - parser.add_argument('--enable-lora', dest='enable_lora', action='store_true', default=True) - parser.add_argument('--max-model-len', dest='max_model_len', type=int, default=2048) - parser.add_argument('--max-num-seqs', dest='max_num_seqs', type=int, default=16) - parser.add_argument('--max-loras', dest='max_loras', type=int, default=4) - - args = parser.parse_args() - - main(base_model_path=args.base_model_path, lora_adapter_path=args.lora_adapter_path, enable_lora=args.enable_lora, max_model_len=args.max_model_len, max_num_seqs=args.max_num_seqs, max_loras=args.max_loras) - From 1c2246fb68f1890ca81330fc55e9d8de4c80a90d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 31 Jan 2025 17:58:01 +0000 Subject: [PATCH 031/317] Linting Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 4 +- vllm/lora/ops/torch_ops/lora_ops.py | 9 +- vllm/lora/ops/xla_ops/__init__.py | 4 +- vllm/lora/ops/xla_ops/lora_ops.py | 3 + vllm/lora/punica_wrapper/punica_tpu.py | 22 ++-- vllm/platforms/tpu.py | 2 +- vllm/worker/tpu_model_runner.py | 137 +++++++++++++------------ vllm/worker/tpu_worker.py | 9 +- 8 files changed, 99 insertions(+), 91 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 1971ebe6c238..5a1e3c60f43a 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1081,7 +1081,7 @@ def _get_logits( lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded - + if current_platform.is_tpu(): # Because nan_to_num_ doesn't work with actual -inf values on TPU neg_inf = torch.finfo(lora_logits.dtype).min @@ -1089,7 +1089,7 @@ def _get_logits( else: neg_inf = float("-inf") pos_inf = float("inf") - + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index 30240c5e0bc9..1a43f22215e2 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -32,7 +32,8 @@ def bgmv_expand(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -75,7 +76,8 @@ def bgmv_shrink(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -112,7 +114,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) if add_inputs: output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 4785af8520d3..632a5d0274b0 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,7 +1,7 @@ from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401 from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, - sgmv_shrink) + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) __all__ = [ "bgmv_expand", diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index d6c630880644..a52ac51b43c9 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,6 +1,8 @@ import torch + from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink + def sgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -17,6 +19,7 @@ def sgmv_expand(inputs: torch.Tensor, bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs) + def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index b6739bd97bdb..b831b4878b02 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -2,9 +2,8 @@ import torch -from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) from .punica_base import PunicaWrapperBase @@ -139,17 +138,18 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], """ x = x.view(-1, x.shape[-1]) - - shrink_fun: Callable = (self._shrink_prefill if self.is_prefill else self._shrink_decode) - + + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): y_s = y[slice_idx] lora_s = lora_a_stacked[slice_idx] - + y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - + shrink_fun(y_s, x, lora_s, scale) y_s = y_s.view_as(y_org) @@ -181,8 +181,10 @@ def add_expand(self, output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ - expand_slice_fun: Callable = (self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode) - + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 2032a77d8221..9c5aba463e91 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -61,7 +61,7 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on TPU.") return False - + @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 85cc6f4dedae..1051fa1b74c7 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -3,8 +3,8 @@ import enum import time from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Set, - Type, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, Union) from unittest.mock import patch import numpy as np @@ -45,6 +45,7 @@ # This can significantly affect the performance if too large. _MAX_NUM_SAMPLES = 128 + class ExecutionMode(enum.Enum): PREFILL = enum.auto() DECODE = enum.auto() @@ -53,6 +54,7 @@ class ExecutionMode(enum.Enum): def is_prefill(self) -> bool: return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + @dataclass(frozen=True) class ModelInputForTPU(ModelRunnerInputBase): token_ids: torch.Tensor @@ -126,7 +128,7 @@ def __init__( False, ) self.cached_step_outputs: List[torch.Tensor] = [] - + # LoRA support self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None @@ -162,14 +164,14 @@ def load_model(self) -> None: model = model.eval() xm.wait_device_ops() self.model = model - + if self.lora_config: assert supports_lora( self.model ), f"{self.model.__class__.__name__} does not support LoRA yet." max_pos_embeddings = self.model.config.max_position_embeddings - + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, @@ -181,12 +183,12 @@ def load_model(self) -> None: max_position_embeddings=max_pos_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) - + self.model = ModelWrapper(self.model) self.model = torch.compile(self.model, - backend="openxla", - fullgraph=True, - dynamic=False) + backend="openxla", + fullgraph=True, + dynamic=False) def get_model(self) -> nn.Module: return self.model.model @@ -278,12 +280,13 @@ def _dummy_run( t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 - - # Create a series of dummy loras and requests for them. Make to fill all lora slots. + + # Create a series of dummy loras and requests for them. + # Make to fill all lora slots. if self.lora_config: dummy_lora_requests: Set[LoRARequest] = set() dummy_lora_mapping: LoRAMapping - + assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for lora_id in range(1, self.lora_config.max_loras + 1): @@ -292,12 +295,13 @@ def _dummy_run( lora_int_id=lora_id, lora_path="/not/a/real/path", ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=self.lora_config.max_lora_rank) + self.lora_manager.add_dummy_lora( + dummy_lora_request, + rank=self.lora_config.max_lora_rank) dummy_lora_requests.add(dummy_lora_request) dummy_lora_mapping = LoRAMapping( - [lora_id] * batch_size * seq_len, [lora_id] * batch_size, is_prefill=exec_mode.is_prefill() - ) + [lora_id] * batch_size * seq_len, [lora_id] * batch_size, + is_prefill=exec_mode.is_prefill()) self.set_active_loras(dummy_lora_requests, dummy_lora_mapping) # NOTE(woosuk): There are two stages of compilation: torch.compile and @@ -637,55 +641,52 @@ def prepare_model_input( list(metadata.seq_data.keys()) for metadata in seq_group_metadata_list ] - + lora_inputs = [] if self.load_config is not None: - lora_inputs = self._prepare_lora_input(seq_group_metadata_list, is_prompt, padded_batch_size) - - return ModelInputForTPU( - token_ids=input_tokens, - position_ids=input_positions, - attn_metadata=attn_metadata, - input_lens=input_lens, - t=t, - p=p, - num_samples=num_samples, - n=n, - seq_groups=seq_groups, - lora_inputs=lora_inputs - ) - + lora_inputs = self._prepare_lora_input(seq_group_metadata_list, + is_prompt, + padded_batch_size) + + return ModelInputForTPU(token_ids=input_tokens, + position_ids=input_positions, + attn_metadata=attn_metadata, + input_lens=input_lens, + t=t, + p=p, + num_samples=num_samples, + n=n, + seq_groups=seq_groups, + lora_inputs=lora_inputs) + def _prepare_lora_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - is_prefill: bool, - padded_batch_size: int) -> List[Tuple[Set[LoRARequest], LoRAMapping]]: + is_prefill: bool, padded_batch_size: int + ) -> List[Tuple[Set[LoRARequest], LoRAMapping]]: """ - Prepares a list of LoRA inputs. If we're decoding then the list will only have 1 item, - otherwise there'll be an item for each sequence + Prepares a list of LoRA inputs. If we're decoding then the list will + only have 1 item, otherwise there'll be an item for each sequence """ - + lora_input = [] if is_prefill: for seq in seq_group_metadata_list: lora_id = seq.lora_int_id query_len = seq.token_chunk_size padded_query_len = _get_padded_prefill_len(query_len) - + index_mapping = [lora_id] * padded_query_len prompt_mapping = [lora_id] - + lora_request = set() if seq.lora_request is not None: lora_request.add(seq.lora_request) - - lora_input.append(( - lora_request, - LoRAMapping( - index_mapping=tuple(index_mapping), - prompt_mapping=tuple(prompt_mapping), - is_prefill=True - ) - )) + + lora_input.append( + (lora_request, + LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=True))) else: lora_request = set() index_mapping = [] @@ -695,22 +696,21 @@ def _prepare_lora_input( index_mapping += [lora_id] prompt_mapping += [lora_id] - + if seq.lora_request is not None: lora_request.add(seq.lora_request) - - index_mapping += [0] * (padded_batch_size - len(seq_group_metadata_list)) - prompt_mapping += [0] * (padded_batch_size - len(seq_group_metadata_list)) - - lora_input.append(( - lora_request, - LoRAMapping( - index_mapping=tuple(index_mapping), - prompt_mapping=tuple(prompt_mapping), - is_prefill=False - ) - )) - + + index_mapping += [0] * (padded_batch_size - + len(seq_group_metadata_list)) + prompt_mapping += [0] * (padded_batch_size - + len(seq_group_metadata_list)) + + lora_input.append( + (lora_request, + LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=False))) + return lora_input def make_model_input_from_broadcasted_tensor_dict( @@ -728,7 +728,7 @@ def execute_model( num_steps: int = 1, ) -> List[SamplerOutput]: assert intermediate_tensors is None - + if not model_input.is_first_multi_step: if not model_input.is_last_step: return [] @@ -804,12 +804,12 @@ def execute_model( input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - + if self.lora_config is not None: assert len(model_input.lora_inputs) == batch_size lora_requests, lora_mapping = model_input.lora_inputs[i] self.set_active_loras(lora_requests, lora_mapping) - + with set_forward_context(model_input.attn_metadata, self.vllm_config, model_input.virtual_engine): @@ -859,12 +859,12 @@ def execute_model( t = model_input.t.to(self.device) p = model_input.p.to(self.device) input_lens = model_input.input_lens.to(self.device) - + if self.lora_config is not None: assert len(model_input.lora_inputs) == 1 lora_requests, lora_mapping = model_input.lora_inputs[0] self.set_active_loras(lora_requests, lora_mapping) - + for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping with set_forward_context(model_input.attn_metadata, @@ -906,7 +906,7 @@ def execute_model( sampler_output = _make_decode_output(next_token_ids, model_input.seq_groups) return [sampler_output] - + def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") @@ -938,6 +938,7 @@ def list_loras(self) -> Set[int]: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_adapters() + class ModelWrapper(nn.Module): def __init__(self, model: nn.Module): diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index a698040d98e4..8e892bfe232c 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import List, Optional, Tuple, Union, Set +from typing import List, Optional, Set, Tuple, Union import torch import torch_xla.core.xla_model as xm @@ -17,8 +17,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - WorkerBase, +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) @@ -267,10 +266,10 @@ def execute_worker(self, worker_input: WorkerInput) -> None: if src_indices.numel() > 0: attn_backend.copy_blocks(self.tpu_cache, (src_indices, dst_indices)) - + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) - + def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) From abab4df806f3087edeb56e03d0471595a0683bc7 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 3 Feb 2025 17:51:52 +0000 Subject: [PATCH 032/317] Fixed import error Signed-off-by: Akshat Tripathi --- vllm/lora/ops/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/lora/ops/__init__.py diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From b2b3dadc13a7ffac494628bdc684cc06750ca5d3 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 4 Feb 2025 14:15:54 +0000 Subject: [PATCH 033/317] lint Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 2 ++ vllm/lora/ops/xla_ops/lora_ops.py | 3 +++ vllm/lora/punica_wrapper/punica_tpu.py | 4 +++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 632a5d0274b0..67ffde460755 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401 from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index a52ac51b43c9..b664b93fbf6f 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import torch from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink @@ -35,6 +37,7 @@ def sgmv_shrink( exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) + print("SGMV", lora_indices_tensor, lora_a_weights) bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index b831b4878b02..84245e82eb8a 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + from typing import Callable, Optional, Tuple, Union import torch @@ -222,7 +224,7 @@ def add_lora_embedding(self, add_inputs (bool): Default to True. """ - # Embedding layer only need expand op + # Embedding layer only needs the expand op expand_fun: Callable = (self._expand_prefill if self.is_prefill else self._expand_decode) expand_fun(y, x, lora_b_stacked, add_inputs) From 62b7f4b0227d8af6544c7e5b0077b2c3143b1212 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 3 Mar 2025 18:04:32 +0000 Subject: [PATCH 034/317] Abstracted out infinity values Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 14 +++++--------- vllm/platforms/interface.py | 7 +++++++ vllm/platforms/tpu.py | 6 +++++- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5a1e3c60f43a..1e7d3e410cef 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1078,18 +1078,14 @@ def _get_logits( torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) - lora_logits[-1] = float("-inf") + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded - if current_platform.is_tpu(): - # Because nan_to_num_ doesn't work with actual -inf values on TPU - neg_inf = torch.finfo(lora_logits.dtype).min - pos_inf = torch.finfo(lora_logits.dtype).max - else: - neg_inf = float("-inf") - pos_inf = float("inf") - lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d6dae2e526dc..f4497b14a5e7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -322,6 +322,13 @@ def get_punica_wrapper(cls) -> str: """ raise NotImplementedError + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + """ + Return the platform specific values for (-inf, inf) + """ + return float("-inf"), float("inf") + @classmethod def get_device_communicator_cls(cls) -> str: """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 9c5aba463e91..c9b18b000031 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple import torch @@ -66,6 +66,10 @@ def is_pin_memory_available(cls): def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + return torch.finfo(dtype).min, torch.finfo(dtype).max + @classmethod def inference_mode(cls): return torch.no_grad() From 9e95c6636dd43ec1decf1aeedae76b103e0c3720 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 18:48:08 +0000 Subject: [PATCH 035/317] Moved and modified bgmv ops from the cpu backend to the tpu backend, because xla doesn't allow partial updates Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 90 ++++++++++++++++++++++++-- vllm/lora/punica_wrapper/punica_tpu.py | 9 ++- 2 files changed, 92 insertions(+), 7 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index b664b93fbf6f..308d361fe7eb 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -2,8 +2,6 @@ import torch -from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink - def sgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -22,6 +20,37 @@ def sgmv_expand(inputs: torch.Tensor, add_inputs) +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + outputs = torch.cat( + (outputs, + torch.zeros((batch_size, output_tensor.shape[1] - outputs.shape[1]), + device=outputs.device)), + dim=1) + + if add_inputs: + output_tensor += outputs[:limit, :] + else: + output_tensor = outputs[:limit, :] + + def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, @@ -37,11 +66,28 @@ def sgmv_shrink( exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) - print("SGMV", lora_indices_tensor, lora_a_weights) bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling) +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + output_tensor = scaling * outputs[:] + + def sgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -53,9 +99,45 @@ def sgmv_expand_slice(inputs: torch.Tensor, token_nums: int, slice_offset: int, slice_size: int, + total_size: int, add_inputs: bool = False): exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, - slice_offset, slice_size, add_inputs) + slice_offset, slice_size, total_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + total_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + + inputs = inputs.to(dtype=output_tensor.dtype) + + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + + batch_size, output_size, input_size = selected_loras.shape + + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + outputs = torch.cat(( + torch.zeros((batch_size, slice_offset), device=outputs.device), + outputs, + torch.zeros((batch_size, total_size - (slice_offset + slice_size)), + device=outputs.device), + ), + dim=1) + + if add_inputs: + output_tensor += outputs + else: + output_tensor = outputs diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 84245e82eb8a..920aacfbf8e8 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -89,6 +89,7 @@ def _expand_slice_prefill( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool, ): #No LoRA request, so return directly @@ -101,6 +102,7 @@ def _expand_slice_prefill( *self.prefill_metadata, y_offset, y_slice_size, + y_total_size, add_inputs, ) @@ -111,12 +113,13 @@ def _expand_slice_decode( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool, ): if self.no_lora: return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_inputs) + y_slice_size, y_total_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -161,7 +164,6 @@ def add_expand(self, lora_b_stacked: Tuple[torch.Tensor, ...], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], - offset_start: int = 0, add_inputs=True, **kwargs) -> None: """ @@ -189,7 +191,7 @@ def add_expand(self, y_org = y y = y.view(-1, y.shape[-1]) - offset_left = offset_start + offset_left = 0 if lora_bias_stacked is not None: self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) @@ -200,6 +202,7 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], + y_total_size=sum(output_slices), add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] From a76f6bdce78c89569dc7e24cf4ec8738932a9663 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 18:55:49 +0000 Subject: [PATCH 036/317] Removed total_size for linting Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 11 +++++------ vllm/lora/punica_wrapper/punica_tpu.py | 6 +----- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 308d361fe7eb..e494a2fed52d 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -85,7 +85,7 @@ def bgmv_shrink(inputs: torch.Tensor, outputs = (selected_loras @ inputs.reshape( (batch_size, input_size, 1))).reshape((batch_size, output_size)) - output_tensor = scaling * outputs[:] + output_tensor = scaling * outputs def sgmv_expand_slice(inputs: torch.Tensor, @@ -99,13 +99,12 @@ def sgmv_expand_slice(inputs: torch.Tensor, token_nums: int, slice_offset: int, slice_size: int, - total_size: int, add_inputs: bool = False): exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, - slice_offset, slice_size, total_size, add_inputs) + slice_offset, slice_size, add_inputs) def bgmv_expand_slice(inputs: torch.Tensor, @@ -114,7 +113,6 @@ def bgmv_expand_slice(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, slice_offset: int, slice_size: int, - total_size: int, add_inputs: bool = True): selected_loras = lora_b_weights[lora_indices_tensor].to( dtype=output_tensor.dtype) @@ -132,8 +130,9 @@ def bgmv_expand_slice(inputs: torch.Tensor, outputs = torch.cat(( torch.zeros((batch_size, slice_offset), device=outputs.device), outputs, - torch.zeros((batch_size, total_size - (slice_offset + slice_size)), - device=outputs.device), + torch.zeros( + (batch_size, output_tensor.shape[1] - (slice_offset + slice_size)), + device=outputs.device), ), dim=1) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 920aacfbf8e8..4b5642033ff7 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -89,7 +89,6 @@ def _expand_slice_prefill( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - y_total_size: int, add_inputs: bool, ): #No LoRA request, so return directly @@ -102,7 +101,6 @@ def _expand_slice_prefill( *self.prefill_metadata, y_offset, y_slice_size, - y_total_size, add_inputs, ) @@ -113,13 +111,12 @@ def _expand_slice_decode( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - y_total_size: int, add_inputs: bool, ): if self.no_lora: return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, y_total_size, add_inputs) + y_slice_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -202,7 +199,6 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], - y_total_size=sum(output_slices), add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] From d3e5ce0004f24da5e8e4f98eae6239ba5bfdc415 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 19:04:58 +0000 Subject: [PATCH 037/317] Reverted changes to torch_ops Signed-off-by: Akshat Tripathi --- vllm/lora/ops/torch_ops/lora_ops.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index 1a43f22215e2..af79f98415cb 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -30,10 +30,7 @@ def bgmv_expand(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -74,10 +71,7 @@ def bgmv_shrink(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -113,9 +107,7 @@ def bgmv_expand_slice(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) if add_inputs: output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] From 355c621634f5c4d1522ba8bce7d677d8b29dcf6f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 19:11:20 +0000 Subject: [PATCH 038/317] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 4b5642033ff7..3b7a6dad035d 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -161,6 +161,7 @@ def add_expand(self, lora_b_stacked: Tuple[torch.Tensor, ...], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], + offset_start: int = 0, add_inputs=True, **kwargs) -> None: """ From 64902b300e0c291faa132268ef86dff74c8515ec Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 3 Mar 2025 18:04:49 +0000 Subject: [PATCH 039/317] Replaced in-place buffer updates with direct returning Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 32 +++++++---- vllm/lora/ops/xla_ops/lora_ops.py | 25 +++++---- vllm/lora/punica_wrapper/punica_base.py | 10 ++-- vllm/lora/punica_wrapper/punica_tpu.py | 74 +++++++++++++------------ vllm/platforms/interface.py | 5 ++ vllm/platforms/tpu.py | 4 ++ 6 files changed, 88 insertions(+), 62 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 1e7d3e410cef..5e700d2e10d2 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -258,10 +258,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[1], -1, ) - self.punica_wrapper.add_lora_embedding(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) + + lora_output = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + if not current_platform.can_update_inplace(): + full_output = lora_output + return full_output.view_as(full_output_org) @classmethod @@ -395,10 +400,12 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, 1.0, - self.output_slices) + lora_output = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + return output @@ -1102,9 +1109,12 @@ def _get_logits( lora_logits.shape[1]] = lora_logits # LogitsProcessorWithLoRA always using bgmv - self.punica_wrapper.add_lora_logits(logits, hidden_states, - self.lora_a_stacked, - self.lora_b_stacked, 1.0) + lora_output = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, + 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index e494a2fed52d..7ac7d16fbf88 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -16,8 +16,8 @@ def sgmv_expand(inputs: torch.Tensor, exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) - bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, - add_inputs) + return bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) def bgmv_expand(inputs: torch.Tensor, @@ -46,9 +46,9 @@ def bgmv_expand(inputs: torch.Tensor, dim=1) if add_inputs: - output_tensor += outputs[:limit, :] + return output_tensor + outputs[:limit, :] else: - output_tensor = outputs[:limit, :] + return outputs[:limit, :] def sgmv_shrink( @@ -66,8 +66,8 @@ def sgmv_shrink( exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) - bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, - scaling) + return bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) def bgmv_shrink(inputs: torch.Tensor, @@ -75,6 +75,7 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: @@ -85,7 +86,7 @@ def bgmv_shrink(inputs: torch.Tensor, outputs = (selected_loras @ inputs.reshape( (batch_size, input_size, 1))).reshape((batch_size, output_size)) - output_tensor = scaling * outputs + return scaling * outputs def sgmv_expand_slice(inputs: torch.Tensor, @@ -103,8 +104,9 @@ def sgmv_expand_slice(inputs: torch.Tensor, exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) - bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, - slice_offset, slice_size, add_inputs) + return bgmv_expand_slice(inputs, lora_b_weights, output_tensor, + exploded_indices, slice_offset, slice_size, + add_inputs) def bgmv_expand_slice(inputs: torch.Tensor, @@ -114,6 +116,7 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( dtype=output_tensor.dtype) @@ -137,6 +140,6 @@ def bgmv_expand_slice(inputs: torch.Tensor, dim=1) if add_inputs: - output_tensor += outputs + return output_tensor + outputs else: - output_tensor = outputs + return outputs diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index dad98f8e2122..d160b2739bc7 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -48,7 +48,7 @@ def add_shrink( lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. """ @@ -66,7 +66,7 @@ def add_expand( offset_start: int = 0, add_inputs=True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. """ @@ -80,7 +80,7 @@ def add_lora_embedding( lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. @@ -98,7 +98,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. """ @@ -114,7 +114,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. """ diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 3b7a6dad035d..602ec824853b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -34,7 +34,7 @@ def _shrink_prefill( #No LoRA request, so return directly if self.no_lora: return - sgmv_shrink( + return sgmv_shrink( x, w_t_all, y, @@ -51,7 +51,7 @@ def _shrink_decode( ): if self.no_lora: return - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) def _expand_prefill( self, @@ -63,7 +63,7 @@ def _expand_prefill( #No LoRA request, so return directly if self.no_lora: return - sgmv_expand( + return sgmv_expand( x, w_t_all, y, @@ -80,7 +80,7 @@ def _expand_decode( ): if self.no_lora: return - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def _expand_slice_prefill( self, @@ -90,11 +90,11 @@ def _expand_slice_prefill( y_offset: int, y_slice_size: int, add_inputs: bool, - ): + ) -> torch.Tensor: #No LoRA request, so return directly if self.no_lora: return - sgmv_expand_slice( + return sgmv_expand_slice( x, w_t_all, y, @@ -112,15 +112,15 @@ def _expand_slice_decode( y_offset: int, y_slice_size: int, add_inputs: bool, - ): + ) -> torch.Tensor: if self.no_lora: return - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_inputs) + return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, + y_offset, y_slice_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs): + scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the @@ -144,6 +144,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], shrink_fun: Callable = (self._shrink_prefill if self.is_prefill else self._shrink_decode) + new_y = [] # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): y_s = y[slice_idx] @@ -152,8 +153,10 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - shrink_fun(y_s, x, lora_s, scale) + y_s = shrink_fun(y_s, x, lora_s, scale) y_s = y_s.view_as(y_org) + new_y.append(y_s) + return tuple(new_y) def add_expand(self, y: torch.Tensor, @@ -163,7 +166,7 @@ def add_expand(self, output_slices: Tuple[int, ...], offset_start: int = 0, add_inputs=True, - **kwargs) -> None: + **kwargs) -> torch.Tensor: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -191,10 +194,10 @@ def add_expand(self, y = y.view(-1, y.shape[-1]) offset_left = 0 if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - expand_slice_fun( + y = expand_slice_fun( y, x[slice_idx], lora_b_stacked[slice_idx], @@ -203,14 +206,14 @@ def add_expand(self, add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] - y = y.view_as(y_org) + return y.view_as(y_org) def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, - **kwargs) -> None: + **kwargs) -> torch.Tensor: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -227,7 +230,7 @@ def add_lora_embedding(self, # Embedding layer only needs the expand op expand_fun: Callable = (self._expand_prefill if self.is_prefill else self._expand_decode) - expand_fun(y, x, lora_b_stacked, add_inputs) + return expand_fun(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -239,7 +242,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> torch.Tensor: """ Applicable to linear-related lora. @@ -279,14 +282,14 @@ def add_lora_linear(self, dtype=torch.float32, device=x.device, ) - self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) + buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + return self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) def add_lora_logits(self, y: torch.Tensor, @@ -296,7 +299,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> torch.Tensor: """ Applies lora specifically for LogitsProcessorWithLoRA. @@ -323,10 +326,11 @@ def add_lora_logits(self, dtype=torch.float32, device=x.device) # LogitsProcessorWithLoRA always using bgmv. - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) - y = y.view_as(y_org) + buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, + scale) + y = bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + return y.view_as(y_org) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f4497b14a5e7..d764000c363c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -329,6 +329,11 @@ def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: """ return float("-inf"), float("inf") + @classmethod + def can_update_inplace(cls) -> bool: + """Checks if the platform allows inplace memory updates""" + return True + @classmethod def get_device_communicator_cls(cls) -> str: """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index c9b18b000031..4864173b2f0e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -70,6 +70,10 @@ def get_punica_wrapper(cls) -> str: def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: return torch.finfo(dtype).min, torch.finfo(dtype).max + @classmethod + def can_update_inplace(cls): + return False + @classmethod def inference_mode(cls): return torch.no_grad() From 60fe398deb1e1443ffe88ca09859a1d05c93faf9 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 11 Feb 2025 14:51:29 +0000 Subject: [PATCH 040/317] PunicaWrapperTPU now returns unchanged buffer if no loras are needed Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 602ec824853b..90058cd404d0 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -33,7 +33,7 @@ def _shrink_prefill( ): #No LoRA request, so return directly if self.no_lora: - return + return y return sgmv_shrink( x, w_t_all, @@ -50,7 +50,7 @@ def _shrink_decode( scale: float, ): if self.no_lora: - return + return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) def _expand_prefill( @@ -62,7 +62,7 @@ def _expand_prefill( ): #No LoRA request, so return directly if self.no_lora: - return + return y return sgmv_expand( x, w_t_all, @@ -79,7 +79,7 @@ def _expand_decode( add_inputs: bool, ): if self.no_lora: - return + return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def _expand_slice_prefill( @@ -93,7 +93,7 @@ def _expand_slice_prefill( ) -> torch.Tensor: #No LoRA request, so return directly if self.no_lora: - return + return y return sgmv_expand_slice( x, w_t_all, @@ -114,7 +114,7 @@ def _expand_slice_decode( add_inputs: bool, ) -> torch.Tensor: if self.no_lora: - return + return y return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) From 40fa7e3f790d76feb0f19ce845be0e52570a1978 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Feb 2025 17:46:03 +0000 Subject: [PATCH 041/317] Simplified TPU prefill Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 15 ---------- vllm/lora/punica_wrapper/punica_tpu.py | 39 +++++++++++++++++++------- vllm/worker/tpu_model_runner.py | 1 + 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 7ac7d16fbf88..69449981b89a 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -6,12 +6,7 @@ def sgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, add_inputs: bool = False): exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) @@ -55,12 +50,7 @@ def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, scaling: float, ): exploded_indices = torch.repeat_interleave(lora_indices_tensor, @@ -92,12 +82,7 @@ def bgmv_shrink(inputs: torch.Tensor, def sgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, slice_offset: int, slice_size: int, add_inputs: bool = False): diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 90058cd404d0..847cdd75a76c 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -141,8 +141,8 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x = x.view(-1, x.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) + # shrink_fun: Callable = (self._shrink_prefill + # if self.is_prefill else self._shrink_decode) new_y = [] # TODO fuse these kernels @@ -153,7 +153,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - y_s = shrink_fun(y_s, x, lora_s, scale) + y_s = self._shrink_decode(y_s, x, lora_s, scale) y_s = y_s.view_as(y_org) new_y.append(y_s) return tuple(new_y) @@ -186,9 +186,9 @@ def add_expand(self, output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) + # expand_slice_fun: Callable = (self._expand_slice_prefill + # if self.is_prefill else + # self._expand_slice_decode) y_org = y y = y.view(-1, y.shape[-1]) @@ -197,7 +197,7 @@ def add_expand(self, y = self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - y = expand_slice_fun( + y = self._expand_slice_decode( y, x[slice_idx], lora_b_stacked[slice_idx], @@ -228,9 +228,9 @@ def add_lora_embedding(self, """ # Embedding layer only needs the expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) - return expand_fun(y, x, lora_b_stacked, add_inputs) + # expand_fun: Callable = (self._expand_prefill + # if self.is_prefill else self._expand_decode) + return self._expand_decode(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -334,3 +334,22 @@ def add_lora_logits(self, self.sampler_indices, add_inputs=True) return y.view_as(y_org) + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + self.batch_size = 1 + self._lora_indices_per_batch[:self.batch_size].copy_(token_lora_tensor[:self.batch_size]) + + def set_no_lora(self, no_lora: bool): + self.no_lora = no_lora + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + """ + return (self._lora_indices_per_batch[:self.batch_size],) \ No newline at end of file diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 1051fa1b74c7..46266375d45e 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -917,6 +917,7 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + self.lora_manager._adapter_manager.punica_wrapper.set_no_lora(len(lora_requests) == 0) def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From 4d99b5a69227f8fbc27661aff650281f53567d33 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Feb 2025 17:48:13 +0000 Subject: [PATCH 042/317] Removed sgmv kernels from TPU implementation Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 44 ------------ vllm/lora/punica_wrapper/punica_tpu.py | 95 +++----------------------- 2 files changed, 8 insertions(+), 131 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 69449981b89a..483bef186185 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -2,19 +2,6 @@ import torch - -def sgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = False): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - inputs.size(0)) - - return bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, - add_inputs) - - def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -45,21 +32,6 @@ def bgmv_expand(inputs: torch.Tensor, else: return outputs[:limit, :] - -def sgmv_shrink( - inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float, -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - inputs.size(0)) - - return bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, - scaling) - - def bgmv_shrink(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -78,22 +50,6 @@ def bgmv_shrink(inputs: torch.Tensor, return scaling * outputs - -def sgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = False): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - inputs.size(0)) - - return bgmv_expand_slice(inputs, lora_b_weights, output_tensor, - exploded_indices, slice_offset, slice_size, - add_inputs) - - def bgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 847cdd75a76c..1b8e8ed30e5b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -4,8 +4,7 @@ import torch -from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) from .punica_base import PunicaWrapperBase @@ -24,25 +23,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - def _shrink_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - #No LoRA request, so return directly - if self.no_lora: - return y - return sgmv_shrink( - x, - w_t_all, - y, - *self.prefill_metadata, - scale, - ) - - def _shrink_decode( + def shrink( self, y: torch.Tensor, x: torch.Tensor, @@ -53,25 +34,7 @@ def _shrink_decode( return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - def _expand_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_inputs: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return y - return sgmv_expand( - x, - w_t_all, - y, - *self.prefill_metadata, - add_inputs, - ) - - def _expand_decode( + def expand( self, y: torch.Tensor, x: torch.Tensor, @@ -82,29 +45,8 @@ def _expand_decode( return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - def _expand_slice_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: int, - y_slice_size: int, - add_inputs: bool, - ) -> torch.Tensor: - #No LoRA request, so return directly - if self.no_lora: - return y - return sgmv_expand_slice( - x, - w_t_all, - y, - *self.prefill_metadata, - y_offset, - y_slice_size, - add_inputs, - ) - def _expand_slice_decode( + def expand_slice( self, y: torch.Tensor, x: torch.Tensor, @@ -141,9 +83,6 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x = x.view(-1, x.shape[-1]) - # shrink_fun: Callable = (self._shrink_prefill - # if self.is_prefill else self._shrink_decode) - new_y = [] # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): @@ -153,7 +92,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - y_s = self._shrink_decode(y_s, x, lora_s, scale) + y_s = self.shrink(y_s, x, lora_s, scale) y_s = y_s.view_as(y_org) new_y.append(y_s) return tuple(new_y) @@ -186,10 +125,6 @@ def add_expand(self, output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ - # expand_slice_fun: Callable = (self._expand_slice_prefill - # if self.is_prefill else - # self._expand_slice_decode) - y_org = y y = y.view(-1, y.shape[-1]) offset_left = 0 @@ -197,7 +132,7 @@ def add_expand(self, y = self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - y = self._expand_slice_decode( + y = self.expand_slice( y, x[slice_idx], lora_b_stacked[slice_idx], @@ -228,9 +163,7 @@ def add_lora_embedding(self, """ # Embedding layer only needs the expand op - # expand_fun: Callable = (self._expand_prefill - # if self.is_prefill else self._expand_decode) - return self._expand_decode(y, x, lora_b_stacked, add_inputs) + return self.expand(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -340,16 +273,4 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self._lora_indices_per_batch[:self.batch_size].copy_(token_lora_tensor[:self.batch_size]) def set_no_lora(self, no_lora: bool): - self.no_lora = no_lora - - @property - def prefill_metadata( - self - ) -> Tuple[torch.Tensor, int, int]: - """ - This property provides a convenient way to access the necessary - metadata for prefill-related kernel computations. - 1. lora_indices_per_batch: Tensor of lora indices, and an index of - -1 means no lora should be applied. - """ - return (self._lora_indices_per_batch[:self.batch_size],) \ No newline at end of file + self.no_lora = no_lora \ No newline at end of file From 726699b5efd1699ed3495f1e6f177d2d6cb4bd91 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Feb 2025 17:51:19 +0000 Subject: [PATCH 043/317] Fix bug Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 10 ++-------- vllm/worker/tpu_model_runner.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 67ffde460755..04c399954d14 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,15 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401 -from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, - sgmv_shrink) +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) __all__ = [ "bgmv_expand", "bgmv_expand_slice", - "bgmv_shrink", - "sgmv_expand", - "sgmv_expand_slice", - "sgmv_shrink", + "bgmv_shrink" ] diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 46266375d45e..7e9ca916b2aa 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -917,7 +917,7 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - self.lora_manager._adapter_manager.punica_wrapper.set_no_lora(len(lora_requests) == 0) + self.lora_manager._adapter_manager.punica_wrapper.set_no_lora(len(lora_requests) == 0) # TODO: Cleanup def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From 6c39a3177d91addd7b5a6f1e3d7d583cc0960dae Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Feb 2025 17:55:00 +0000 Subject: [PATCH 044/317] Added torch.compiles to PunicaWrapperTPU functions Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 1b8e8ed30e5b..f29ac59c5c4b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -22,7 +22,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - + @torch.compile(backend="openxla") def shrink( self, y: torch.Tensor, @@ -34,6 +34,7 @@ def shrink( return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + @torch.compile(backend="openxla") def expand( self, y: torch.Tensor, @@ -45,7 +46,7 @@ def expand( return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - + @torch.compile(backend="openxla") def expand_slice( self, y: torch.Tensor, @@ -60,6 +61,7 @@ def expand_slice( return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) + @torch.compile(backend="openxla") def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: @@ -97,6 +99,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], new_y.append(y_s) return tuple(new_y) + @torch.compile(backend="openxla") def add_expand(self, y: torch.Tensor, x: Union[Tuple[torch.Tensor, ...], torch.Tensor], @@ -143,6 +146,7 @@ def add_expand(self, offset_left += output_slices[slice_idx] return y.view_as(y_org) + @torch.compile(backend="openxla") def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, @@ -165,6 +169,7 @@ def add_lora_embedding(self, # Embedding layer only needs the expand op return self.expand(y, x, lora_b_stacked, add_inputs) + @torch.compile(backend="openxla") def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, @@ -223,7 +228,8 @@ def add_lora_linear(self, output_slices, add_inputs=True, **kwargs) - + + @torch.compile(backend="openxla") def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, From e71a3abfc81c9228cb28439c08becd5e875b61ed Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 11:02:58 +0000 Subject: [PATCH 045/317] Replaced "x[x==-1] = y" with "x = torch.where(x == - 1, y)" Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index dbc2d27c597f..00c3689ef462 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -125,11 +125,11 @@ def convert_mapping( indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size), ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 + embeddings_indices = torch.where(embeddings_indices == -1, embeddings_indices, max_loras - 1) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.where(sampler_indices_padded == -1, sampler_indices_padded, max_loras - 1) sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) From abc3aa3d22fe6238921a85158ac87c09bccb9d1d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 11:06:34 +0000 Subject: [PATCH 046/317] Revert "Added torch.compiles to PunicaWrapperTPU functions" This reverts commit b78b08898dddcb592480d4179e8d346f78eaabd5. Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index f29ac59c5c4b..1b8e8ed30e5b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -22,7 +22,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - @torch.compile(backend="openxla") + def shrink( self, y: torch.Tensor, @@ -34,7 +34,6 @@ def shrink( return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - @torch.compile(backend="openxla") def expand( self, y: torch.Tensor, @@ -46,7 +45,7 @@ def expand( return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - @torch.compile(backend="openxla") + def expand_slice( self, y: torch.Tensor, @@ -61,7 +60,6 @@ def expand_slice( return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) - @torch.compile(backend="openxla") def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: @@ -99,7 +97,6 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], new_y.append(y_s) return tuple(new_y) - @torch.compile(backend="openxla") def add_expand(self, y: torch.Tensor, x: Union[Tuple[torch.Tensor, ...], torch.Tensor], @@ -146,7 +143,6 @@ def add_expand(self, offset_left += output_slices[slice_idx] return y.view_as(y_org) - @torch.compile(backend="openxla") def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, @@ -169,7 +165,6 @@ def add_lora_embedding(self, # Embedding layer only needs the expand op return self.expand(y, x, lora_b_stacked, add_inputs) - @torch.compile(backend="openxla") def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, @@ -228,8 +223,7 @@ def add_lora_linear(self, output_slices, add_inputs=True, **kwargs) - - @torch.compile(backend="openxla") + def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, From 61fd5ade8dcf2da15b201e9661663bc9502a47fa Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 11:14:25 +0000 Subject: [PATCH 047/317] Fix linting Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 9 +++------ vllm/lora/punica_wrapper/punica_tpu.py | 12 ++++++------ vllm/worker/tpu_model_runner.py | 3 ++- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 04c399954d14..94062b05d916 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink) -__all__ = [ - "bgmv_expand", - "bgmv_expand_slice", - "bgmv_shrink" -] +__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 1b8e8ed30e5b..fdbf9cb96ddb 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch -from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) +from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink from .punica_base import PunicaWrapperBase @@ -45,7 +45,6 @@ def expand( return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - def expand_slice( self, y: torch.Tensor, @@ -270,7 +269,8 @@ def add_lora_logits(self, def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 - self._lora_indices_per_batch[:self.batch_size].copy_(token_lora_tensor[:self.batch_size]) - + self._lora_indices_per_batch[:self.batch_size].copy_( + token_lora_tensor[:self.batch_size]) + def set_no_lora(self, no_lora: bool): - self.no_lora = no_lora \ No newline at end of file + self.no_lora = no_lora diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 7e9ca916b2aa..ec4aa8f38c24 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -917,7 +917,8 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - self.lora_manager._adapter_manager.punica_wrapper.set_no_lora(len(lora_requests) == 0) # TODO: Cleanup + self.lora_manager._adapter_manager.punica_wrapper.set_no_lora( + len(lora_requests) == 0) # TODO: Cleanup def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From 63ffc3ee538c005e395168da2a85a70a3aa61400 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 12:21:58 +0000 Subject: [PATCH 048/317] Added lora hotswapping test Signed-off-by: Akshat Tripathi --- tests/tpu/test_lora.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/tpu/test_lora.py diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py new file mode 100644 index 000000000000..ed1553fbb6f2 --- /dev/null +++ b/tests/tpu/test_lora.py @@ -0,0 +1,29 @@ +import vllm +import sys + +from vllm.lora.request import LoRARequest + +def test_lora_hotswapping(): + lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_requests = [ + LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) + for i in range(1, 5) + ] + + llm = vllm.LLM( + model="Qwen/Qwen2.5-3B-Instruct", + num_scheduler_steps=1, + max_model_len=256, + max_seq_len_to_capture=256, + max_num_seqs=8, + enable_lora=True, + max_loras=2, + max_lora_rank=8 + ) + + prompt = "What is 1+1?" + + for _ in range(10): + for i, req in enumerate(lora_requests): + output = llm.generate(prompt, sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), lora_request=req)[0].outputs[0].text + assert output.strip()[0] == i + 1 From f9085f324291e76812153bfd54453282952ae69b Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 12:45:43 +0000 Subject: [PATCH 049/317] Fixed hotswapping test prompt Signed-off-by: Akshat Tripathi --- tests/tpu/test_lora.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py index ed1553fbb6f2..d32adda9fe48 100644 --- a/tests/tpu/test_lora.py +++ b/tests/tpu/test_lora.py @@ -1,5 +1,4 @@ import vllm -import sys from vllm.lora.request import LoRARequest @@ -21,7 +20,7 @@ def test_lora_hotswapping(): max_lora_rank=8 ) - prompt = "What is 1+1?" + prompt = "What is 1+1? \n" for _ in range(10): for i, req in enumerate(lora_requests): From db785c8d42630216fcb3dbe94a8d06f95565d35f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 13:12:40 +0000 Subject: [PATCH 050/317] Fixed bug in tpu lora test Signed-off-by: Akshat Tripathi --- tests/tpu/test_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py index d32adda9fe48..d3d2c7eb2e1d 100644 --- a/tests/tpu/test_lora.py +++ b/tests/tpu/test_lora.py @@ -25,4 +25,4 @@ def test_lora_hotswapping(): for _ in range(10): for i, req in enumerate(lora_requests): output = llm.generate(prompt, sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), lora_request=req)[0].outputs[0].text - assert output.strip()[0] == i + 1 + assert int(output.strip()[0]) == i + 1 \ No newline at end of file From 886340165e197ed73ea763fcef5421546ed5f91a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 17:21:36 +0000 Subject: [PATCH 051/317] Merged set_no_lora() functionality with _udpate_prefill_metada Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 5 ++--- vllm/worker/tpu_model_runner.py | 2 -- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index fdbf9cb96ddb..1fcedfb61a93 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -271,6 +271,5 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 self._lora_indices_per_batch[:self.batch_size].copy_( token_lora_tensor[:self.batch_size]) - - def set_no_lora(self, no_lora: bool): - self.no_lora = no_lora + # TODO: .item() is extremely inefficient on TPU, so find a way around it + self.no_lora = torch.all(token_lora_tensor == -1).item() diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index ec4aa8f38c24..1051fa1b74c7 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -917,8 +917,6 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - self.lora_manager._adapter_manager.punica_wrapper.set_no_lora( - len(lora_requests) == 0) # TODO: Cleanup def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From 93372c0901aa721234bc1635deb60e4fac49aeca Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 17:22:50 +0000 Subject: [PATCH 052/317] Added Multi-LoRA functionality to TPU V1 Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 77 ++++++++++++++++++++++-------- vllm/v1/worker/tpu_worker.py | 4 ++ 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8635ffce7027..18362d86cf3b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from copy import deepcopy import enum import time from dataclasses import dataclass @@ -29,6 +30,7 @@ from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -70,7 +72,7 @@ class DecodeData: attn_metadata: Optional[PallasMetadata] = None -class TPUModelRunner: +class TPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -394,6 +396,17 @@ def _get_prompts_and_decodes( return PromptDecodeInfo(prompt_req_ids, decode_req_ids, prompt_scheduled_tokens) + + def _get_input_batch_subset(self, req_idxs: List[int]) -> InputBatch: + req_idxs = set(req_idxs) + all_req_idxs = set(self.input_batch.req_id_to_index.values()) + + req_idxs_to_remove = all_req_idxs.difference(req_idxs) + + subset_batch = deepcopy(self.input_batch) + subset_batch.condense(list(req_idxs_to_remove)) + return subset_batch + def _prepare_prompt(self, req_index: int, num_scheduled_tokens: int) -> PromptData: @@ -469,6 +482,10 @@ def _prepare_prompt(self, req_index: int, self.device) effective_query_lens = self.prompt_effective_query_lens_cpu[ self.cur_swap_id].to(self.device) + + if self.lora_config: + prompt_input_batch = self._get_input_batch_subset(req_idxs=[req_index]) + self.set_active_loras(prompt_input_batch, np.array([padded_prompt_len], dtype=np.int32)) self.swap_step() @@ -559,6 +576,11 @@ def _prepare_decode( block_table = block_table_cpu.to(self.device) context_lens = self.decode_context_lens_cpu[ self.cur_swap_id][:padded_batch_size].to(self.device) + + if self.lora_config: + req_idxs = list(map(self.input_batch.req_id_to_index.get, decode_req_ids)) + decode_input_batch = self._get_input_batch_subset(req_idxs) + self.set_active_loras(decode_input_batch, np.array([padded_batch_size], dtype=np.int32)) self.swap_step() @@ -720,6 +742,12 @@ def load_model(self) -> None: "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + model = self.load_lora_model(model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) model = model.eval() xm.mark_step() xm.wait_device_ops() @@ -825,13 +853,19 @@ def dummy_run( # graphs in the disk (VLLM_XLA_CACHE_PATH). if exec_mode.is_prefill(): # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) + if self.lora_config is not None: # TODO: Remove this condition + torch._dynamo.config.capture_dynamic_output_shape_ops = True + else: + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) + if self.lora_config is not None: # TODO: Remove this condition + torch._dynamo.config.capture_dynamic_output_shape_ops = True + else: + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) @@ -850,11 +884,12 @@ def capture_model(self) -> None: for batch_size in [1]: seq_len = 16 while seq_len <= self.model_config.max_model_len: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFILL) - xm.wait_device_ops() + with self.maybe_profile_with_lora(self.lora_config, np.array([seq_len] * batch_size, dtype=np.int32)): + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.PREFILL) + xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) num_tokens = batch_size * seq_len @@ -874,11 +909,12 @@ def capture_model(self) -> None: for batch_size in [1]: seq_len = 16 while seq_len <= self.model_config.max_model_len: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFIX_PREFILL) - xm.wait_device_ops() + with self.maybe_profile_with_lora(self.lora_config, np.array([seq_len] * batch_size, dtype=np.int32)): + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.PREFIX_PREFILL) + xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) num_tokens = batch_size * seq_len @@ -898,11 +934,12 @@ def capture_model(self) -> None: seq_len = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.DECODE) - xm.wait_device_ops() + with self.maybe_profile_with_lora(self.lora_config, np.array([seq_len] * batch_size, dtype=np.int32)): + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.DECODE) + xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) if batch_size >= self.scheduler_config.max_num_seqs: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index f29edd34ede3..af614cfa2843 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -14,6 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.scheduler import SchedulerOutput @@ -153,6 +154,9 @@ def execute_model( output = self.model_runner.execute_model(scheduler_output) return output if self.rank == 0 else None + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + def load_model(self) -> None: self.model_runner.load_model() From 5d8f22da2cff976dffb26148cabfb02b17986ed2 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 17 Feb 2025 16:49:27 +0000 Subject: [PATCH 053/317] Added test that verifies switching Signed-off-by: Akshat Tripathi --- test_switching.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 test_switching.py diff --git a/test_switching.py b/test_switching.py new file mode 100644 index 000000000000..ad84d47d3b8e --- /dev/null +++ b/test_switching.py @@ -0,0 +1,36 @@ +import vllm + +import torch_xla.debug.profiler as xp + +from vllm.lora.request import LoRARequest + +lora_paths = ["/mnt/ssd0/adapters/1", "/mnt/ssd0/adapters/2", "/mnt/ssd0/adapters/3", "/mnt/ssd0/adapters/4"] + +lora_requests = [ + LoRARequest("lora_adapter", i+1, lora_path) + for i, lora_path in enumerate(lora_paths) +] + +llm = vllm.LLM( + model="/mnt/ssd0/work_collection/downloaded_Qwen2.5-3b-Instruct_model/", + num_scheduler_steps=1, + swap_space=16, + max_model_len=256, + max_seq_len_to_capture=256, + max_num_seqs=8, + enable_lora=True, + # enforce_eager=True, + max_loras=2, + max_lora_rank=8 +) + +for _ in range(2): + for i, req in enumerate(lora_requests): + print(i, llm.generate( + "What's 1+1?", + sampling_params=vllm.SamplingParams( + max_tokens=256, + temperature=0 + ), + lora_request=req + )) \ No newline at end of file From 1c49acb2e94b52ebaf77abeaa2bb891343f1aa99 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 4 Feb 2025 15:56:55 +0000 Subject: [PATCH 054/317] Added bgmv kernel test code Signed-off-by: Akshat Tripathi --- bgmv.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 bgmv.py diff --git a/bgmv.py b/bgmv.py new file mode 100644 index 000000000000..a72448485769 --- /dev/null +++ b/bgmv.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 + +import jax +from jax import numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +def create_tensors(T, D, L, N): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: jax.Array - shape (T, D) + lora: jax.Array - shape (L, D) + + ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T + + Ignored: + idxs: jax.Array - shape (T, ) - all values must be in [0, N) + loras: jax.Array - shape (N, 1, L, D) + """ + inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) + lora = jax.random.normal(jax.random.PRNGKey(1), (L, D)) + ref_output = inputs @ lora.T + + return inputs, lora, ref_output + + +def bgmv_kernel(inp_ref, lora_ref, out_ref, acc_ref): + + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + acc_ref[...] += jax.lax.dot_general(inp_ref[...], + lora_ref[...], + (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@jax.jit +def bgmv(inputs: jax.Array, lora: jax.Array): + T, D = inputs.shape + L, _ = lora.shape + + # TODO: Tune + # Also figure out how to make bT % 128 instead of bL, + # or pick block sizes based off dims + bT = 8 + bL = 128 + bD = 128 + + return pl.pallas_call( + kernel=bgmv_kernel, + out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + grid=(T // bT, L // bL, D // bD), + in_specs=[ + pl.BlockSpec((bT, bD), lambda i, j, k: (i, k)), + pl.BlockSpec((bL, bD), lambda i, j, k: (j, k)), + ], + out_specs=pl.BlockSpec((bT, bL), lambda i, j, k: (i, j)), + scratch_shapes=[pltpu.VMEM((bT, bL), jnp.float32)]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + interpret=True)(inputs, lora) + + +if __name__ == "__main__": + T, D, L, N = 128, 3072, 128, 8 + inputs, lora, ref_output = create_tensors(T, D, L, N) + + print(lora.shape, inputs.shape, ref_output.shape) + + output1 = bgmv(inputs, lora) + + print(jnp.isnan(output1).sum(), "NaN values") + + # np.testing.assert_allclose(ref_output, output1) + # print("Success") From f68f69e5b5774201a10d0f2528cd41dc6285e484 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 4 Feb 2025 16:09:48 +0000 Subject: [PATCH 055/317] Readded 1 lora dim Signed-off-by: Akshat Tripathi --- bgmv.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/bgmv.py b/bgmv.py index a72448485769..09ac60bfbd00 100644 --- a/bgmv.py +++ b/bgmv.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import jax +import numpy as np from jax import numpy as jnp from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu @@ -25,8 +26,8 @@ def create_tensors(T, D, L, N): loras: jax.Array - shape (N, 1, L, D) """ inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) - lora = jax.random.normal(jax.random.PRNGKey(1), (L, D)) - ref_output = inputs @ lora.T + lora = jax.random.normal(jax.random.PRNGKey(1), (1, L, D)) + ref_output = inputs @ lora.squeeze(0).T return inputs, lora, ref_output @@ -38,7 +39,7 @@ def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) acc_ref[...] += jax.lax.dot_general(inp_ref[...], - lora_ref[...], + lora_ref[0, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) @@ -50,7 +51,7 @@ def _(): @jax.jit def bgmv(inputs: jax.Array, lora: jax.Array): T, D = inputs.shape - L, _ = lora.shape + _, L, _ = lora.shape # TODO: Tune # Also figure out how to make bT % 128 instead of bL, @@ -67,7 +68,7 @@ def bgmv(inputs: jax.Array, lora: jax.Array): grid=(T // bT, L // bL, D // bD), in_specs=[ pl.BlockSpec((bT, bD), lambda i, j, k: (i, k)), - pl.BlockSpec((bL, bD), lambda i, j, k: (j, k)), + pl.BlockSpec((1, bL, bD), lambda i, j, k: (0, j, k)), ], out_specs=pl.BlockSpec((bT, bL), lambda i, j, k: (i, j)), scratch_shapes=[pltpu.VMEM((bT, bL), jnp.float32)]), @@ -86,5 +87,5 @@ def bgmv(inputs: jax.Array, lora: jax.Array): print(jnp.isnan(output1).sum(), "NaN values") - # np.testing.assert_allclose(ref_output, output1) - # print("Success") + np.testing.assert_allclose(ref_output, output1, rtol=1e-2) + print("Success") From 68f4b40e23241a4c778b0c19a809bca5ab3c245d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Feb 2025 16:29:35 +0000 Subject: [PATCH 056/317] Added scratchpad for debugging Signed-off-by: Akshat Tripathi --- scratch.py | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 scratch.py diff --git a/scratch.py b/scratch.py new file mode 100644 index 000000000000..6abcea8767d4 --- /dev/null +++ b/scratch.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +from itertools import product + +import numpy as np + + +class Blocked: + + def __init__(self, arr: np.ndarray, sizes, index_map): + assert len(arr.shape) == len(sizes) + self.arr = arr + self.sizes = sizes + self.index_map = index_map + + def get_block(self, idxs): + return eval(f"self.arr[{self._get_index_str(idxs)}]") + + def set_block(self, idxs, val): + exec(f"self.arr[{self._get_index_str(idxs)}] = val") + + def _get_index_str(self, idxs): + return ", ".join(f"{idx*self.sizes[i]}:{(idx+1)*self.sizes[i]}" + for i, idx in enumerate(self.index_map(*idxs))) + + +# np.random.seed(4) +T, D, L, N = 128, 3072, 16, 8 + +D1 = 8 +D2 = D // D1 + +bT = 1 +bL = 16 +bD1 = 8 +bD2 = 128 +bD = bD1 * bD2 + +inputs = np.random.randn(T, D) +loras = np.random.randn(1, L, D) + +print("ref1", (inputs @ loras.squeeze(0).T).sum()) + +inputs_1 = inputs.reshape((T, D1, D2)) +loras_1 = loras.reshape((1, L, D1, D2)) + +print("ref2", np.einsum("tdD,ondD->tn", inputs_1, loras_1).sum()) + + +def fast_bgmv(inputs, loras): + out = np.zeros((T, L)) + grid = (T // bT, L // bL, D1 // bD1, D2 // bD2) + x_b = Blocked(inputs, (bT, bD1, bD2), lambda i, j, k1, k2: (i, k1, k2)) + l_b = Blocked(loras, (1, bL, bD1, bD2), lambda i, j, k1, k2: + (0, j, k1, k2)) + out_b = Blocked(out, (bT, bL), lambda i, j, k1, k2: (i, j)) + acc_ref = np.zeros((bT, bL)) + + for idxs in product(*list(map(range, grid))): + x_ref = x_b.get_block(idxs) + l_ref = l_b.get_block(idxs) + + if idxs[2] == 0 and idxs[3] == 0: + acc_ref = np.zeros_like(acc_ref) + + acc_ref += (x_ref * l_ref[0]).sum(-1).sum(-1) + + if idxs[2] == grid[2] - 1 and idxs[3] == grid[3] - 1: + out_b.set_block(idxs, acc_ref) + return out + + +def slow_bgmv(inputs, loras): + out = np.zeros((T, L)) + grid = (T // bT, L // bL, D // bD) + x_b = Blocked(inputs, (bT, bD), lambda i, j, k: (i, k)) + l_b = Blocked(loras, (1, bL, bD), lambda i, j, k: (0, j, k)) + out_b = Blocked(out, (bT, bL), lambda i, j, k: (i, j)) + acc_ref = np.zeros((bT, bL)) + + for idxs in product(*list(map(range, grid))): + x_ref = x_b.get_block(idxs) + l_ref = l_b.get_block(idxs) + + if idxs[2] == 0: + acc_ref = np.zeros_like(acc_ref) + + acc_ref += x_ref @ l_ref[0].T + + if idxs[2] == grid[2] - 1: + out_b.set_block(idxs, acc_ref) + return out + + +# result = np.zeros((5, 64)) + +# for t in range(5): + +# result[t] = (x1[t] * l1[:]).sum(axis=1).sum(axis=1) +# # for n in range(64): +# # print((x1[t] * l1[n]).shape) +# # result[t, n] = (x1[t] * l1[n]).sum() + +print("test slow", slow_bgmv(inputs, loras).sum()) +print("test fast", fast_bgmv(inputs_1, loras_1).sum()) From 4601d10176d8ed842098d84e26e74b3d71f1aecf Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Feb 2025 17:37:30 +0000 Subject: [PATCH 057/317] Added some dynamic lora selection Signed-off-by: Akshat Tripathi --- bgmv.py | 85 +++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 23 deletions(-) diff --git a/bgmv.py b/bgmv.py index 09ac60bfbd00..e33c356a46c6 100644 --- a/bgmv.py +++ b/bgmv.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import jax -import numpy as np from jax import numpy as jnp from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu @@ -17,29 +16,59 @@ def create_tensors(T, D, L, N): Outputs: inputs: jax.Array - shape (T, D) - lora: jax.Array - shape (L, D) + loras: jax.Array - shape (N, 1, L, D) + idxs: jax.Array - shape (T, ) - all values must be in [0, N) ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T - - Ignored: - idxs: jax.Array - shape (T, ) - all values must be in [0, N) - loras: jax.Array - shape (N, 1, L, D) """ inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) - lora = jax.random.normal(jax.random.PRNGKey(1), (1, L, D)) - ref_output = inputs @ lora.squeeze(0).T + loras = jax.random.normal(jax.random.PRNGKey(1), (N, 1, L, D)) + idxs = jax.random.randint(jax.random.PRNGKey(2), + shape=(T, ), + minval=0, + maxval=N) + + ref_output = jnp.einsum("td,__ld->tl", inputs, loras[idxs]) + + return inputs, loras, idxs, ref_output + + +def create_debug_tensors(T, D, L, N): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: jax.Array - shape (T, D) + loras: jax.Array - shape (N, 1, L, D) + idxs: jax.Array - shape (T, ) - all values must be in [0, N) + + ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T + """ + inputs = jnp.ones((T, D)) + loras = jnp.ones((N, 1, L, D)) * jnp.arange(0, N)[:, None, None, None] + idxs = jax.random.randint(jax.random.PRNGKey(2), + shape=(T, ), + minval=0, + maxval=N) - return inputs, lora, ref_output + ref_output = jnp.einsum("td,t_ld->tl", inputs, loras[idxs]) + return inputs, loras, idxs, ref_output -def bgmv_kernel(inp_ref, lora_ref, out_ref, acc_ref): + +def bgmv_kernel(idx_ref, inp_ref, lora_ref, out_ref, acc_ref): + del idx_ref @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) acc_ref[...] += jax.lax.dot_general(inp_ref[...], - lora_ref[0, ...], + lora_ref[0, 0, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) @@ -49,9 +78,9 @@ def _(): @jax.jit -def bgmv(inputs: jax.Array, lora: jax.Array): +def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): T, D = inputs.shape - _, L, _ = lora.shape + N, _, L, _ = lora.shape # TODO: Tune # Also figure out how to make bT % 128 instead of bL, @@ -64,28 +93,38 @@ def bgmv(inputs: jax.Array, lora: jax.Array): kernel=bgmv_kernel, out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=0, + num_scalar_prefetch=1, grid=(T // bT, L // bL, D // bD), in_specs=[ - pl.BlockSpec((bT, bD), lambda i, j, k: (i, k)), - pl.BlockSpec((1, bL, bD), lambda i, j, k: (0, j, k)), + pl.BlockSpec((bT, bD), lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((1, 1, bL, bD), lambda i, j, k, block_idx: + (block_idx[i * bT], 0, j, k)), ], - out_specs=pl.BlockSpec((bT, bL), lambda i, j, k: (i, j)), + out_specs=pl.BlockSpec((bT, bL), lambda i, j, k, block_idx: + (i, j)), scratch_shapes=[pltpu.VMEM((bT, bL), jnp.float32)]), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), - interpret=True)(inputs, lora) + interpret=True)(idxs, inputs, lora) if __name__ == "__main__": T, D, L, N = 128, 3072, 128, 8 - inputs, lora, ref_output = create_tensors(T, D, L, N) + inputs, lora, idxs, ref_output = create_debug_tensors(T, D, L, N) + print(idxs) + # breakpoint() print(lora.shape, inputs.shape, ref_output.shape) - output1 = bgmv(inputs, lora) + output = bgmv(inputs, lora, idxs) + + print(jnp.isnan(output).sum(), "NaN values") + + print("Err", jnp.max(jnp.abs(ref_output - output))) - print(jnp.isnan(output1).sum(), "NaN values") + output_idxs = (output / D)[:, 0] + print(output_idxs) + print(output_idxs == idxs) - np.testing.assert_allclose(ref_output, output1, rtol=1e-2) - print("Success") + breakpoint() + # np.testing.assert_allclose(ref_output, output1, rtol=1e-2) From 26c6aabf0c9241f7d3a0a558f6797192e47d242f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 18:48:08 +0000 Subject: [PATCH 058/317] Moved and modified bgmv ops from the cpu backend to the tpu backend, because xla doesn't allow partial updates Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 1fcedfb61a93..88458ed433f8 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -52,6 +52,7 @@ def expand_slice( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool, ) -> torch.Tensor: if self.no_lora: @@ -102,7 +103,6 @@ def add_expand(self, lora_b_stacked: Tuple[torch.Tensor, ...], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], - offset_start: int = 0, add_inputs=True, **kwargs) -> torch.Tensor: """ @@ -137,6 +137,7 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], + y_total_size=sum(output_slices), add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] From 2faad5acefb882a2f9e900b34f7d082f5f5c3b38 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Feb 2025 15:55:23 +0000 Subject: [PATCH 059/317] Added bgmv kernel test Signed-off-by: Akshat Tripathi --- tests/lora/tpu/__init__.py | 0 tests/lora/tpu/test_pallas_kernels.py | 58 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 tests/lora/tpu/__init__.py create mode 100644 tests/lora/tpu/test_pallas_kernels.py diff --git a/tests/lora/tpu/__init__.py b/tests/lora/tpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py new file mode 100644 index 000000000000..27be3be804e5 --- /dev/null +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from bgmv import bgmv + +N_TOKENS = [ + 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, + 131072 +] +HIDDEN_SIZES = [128, 256, 512, 896, 1024, 2048, 4096, 8192, 8320] + +DTYPES = [jnp.float16, jnp.bfloat16] +NUM_LORA = [1, 2, 4, 8, 16, 32] +RANKS = [8, 16, 32, 64, 128] + + +def generate_test_data(T, D, L, N, seed, dtype=jnp.float32): + """ + Generates debug tensors for testing. + """ + inputs = jax.random.normal(jax.random.PRNGKey(seed), (T, D)) + loras = jax.random.normal(jax.random.PRNGKey(seed), (N, 1, L, D)) + idxs = jax.random.randint(jax.random.PRNGKey(seed), + shape=(T, ), + minval=0, + maxval=N) + + ref_output = jnp.einsum("td,t_ld->tl", inputs, loras[idxs]) + return inputs, loras, idxs, ref_output + + +# Parameterize tests with various shapes and dtypes +@pytest.mark.parametrize("T", N_TOKENS) +@pytest.mark.parametrize("D", HIDDEN_SIZES) +@pytest.mark.parametrize("L", RANKS) +@pytest.mark.parametrize("N", NUM_LORA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", [0]) +def test_bgmv(T, D, L, N, dtype, op_type, seed): + inputs, loras, idxs, ref_output = generate_test_data( + T, D, L, N, seed, dtype) + + # Run bgmv + match op_type: + case "expand": + output = bgmv(inputs, loras, idxs) # TODO: Specialise + case "shrink": + output = bgmv(inputs, loras, idxs) + + # Make sure we have no NaNs + assert jnp.isnan(output).sum() == 0 + + # Compare with reference output + np.testing.assert_allclose(output, ref_output, rtol=1e-3, atol=1e-3) From 01f8c8f289145b615f535f64967ba13f4381d7cf Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Feb 2025 16:59:20 +0000 Subject: [PATCH 060/317] Made bgmv kernel fully functional (WIP on supporting smaller ranks) (WIP on perf) Signed-off-by: Akshat Tripathi --- bgmv.py | 70 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/bgmv.py b/bgmv.py index e33c356a46c6..0959dae351a8 100644 --- a/bgmv.py +++ b/bgmv.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import functools + import jax from jax import numpy as jnp from jax.experimental import pallas as pl @@ -22,8 +24,8 @@ def create_tensors(T, D, L, N): ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T """ inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) - loras = jax.random.normal(jax.random.PRNGKey(1), (N, 1, L, D)) - idxs = jax.random.randint(jax.random.PRNGKey(2), + loras = jax.random.normal(jax.random.PRNGKey(0), (N, 1, L, D)) + idxs = jax.random.randint(jax.random.PRNGKey(0), shape=(T, ), minval=0, maxval=N) @@ -50,7 +52,7 @@ def create_debug_tensors(T, D, L, N): """ inputs = jnp.ones((T, D)) loras = jnp.ones((N, 1, L, D)) * jnp.arange(0, N)[:, None, None, None] - idxs = jax.random.randint(jax.random.PRNGKey(2), + idxs = jax.random.randint(jax.random.PRNGKey(0), shape=(T, ), minval=0, maxval=N) @@ -60,17 +62,24 @@ def create_debug_tensors(T, D, L, N): return inputs, loras, idxs, ref_output -def bgmv_kernel(idx_ref, inp_ref, lora_ref, out_ref, acc_ref): - del idx_ref +def bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, + mask_ref): @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - acc_ref[...] += jax.lax.dot_general(inp_ref[...], - lora_ref[0, 0, ...], - (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) + t = pl.program_id(0) + + for i in range(bT): + idx = idx_ref[i + bT * t] + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) + + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[idx, 0, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): @@ -89,23 +98,30 @@ def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): bL = 128 bD = 128 - return pl.pallas_call( - kernel=bgmv_kernel, - out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=(T // bT, L // bL, D // bD), - in_specs=[ - pl.BlockSpec((bT, bD), lambda i, j, k, block_idx: (i, k)), - pl.BlockSpec((1, 1, bL, bD), lambda i, j, k, block_idx: - (block_idx[i * bT], 0, j, k)), - ], - out_specs=pl.BlockSpec((bT, bL), lambda i, j, k, block_idx: - (i, j)), - scratch_shapes=[pltpu.VMEM((bT, bL), jnp.float32)]), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", "arbitrary")), - interpret=True)(idxs, inputs, lora) + return pl.pallas_call(kernel=functools.partial(bgmv_kernel, bT, bL), + out_shape=jax.ShapeDtypeStruct((T, L), + dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // bT, L // bL, D // bD), + in_specs=[ + pl.BlockSpec((bT, bD), + lambda i, j, k, block_idx: + (i, k)), + pl.BlockSpec((N, 1, bL, bD), + lambda i, j, k, block_idx: + (0, 0, j, k)), + ], + out_specs=pl.BlockSpec( + (bT, bL), lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((bT, bL), jnp.float32), + pltpu.VMEM((bT, bL), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", + "arbitrary")), + interpret=True)(idxs, inputs, lora) if __name__ == "__main__": @@ -126,5 +142,5 @@ def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): print(output_idxs) print(output_idxs == idxs) - breakpoint() + # breakpoint() # np.testing.assert_allclose(ref_output, output1, rtol=1e-2) From db859b0c335220e57883f9d467748247c434a836 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 17 Feb 2025 16:48:54 +0000 Subject: [PATCH 061/317] Updated bgmv_kernel to work with ranks that aren't exact multiples of 128 Signed-off-by: Akshat Tripathi --- bgmv.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/bgmv.py b/bgmv.py index 0959dae351a8..ef2125263d4b 100644 --- a/bgmv.py +++ b/bgmv.py @@ -90,20 +90,24 @@ def _(): def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): T, D = inputs.shape N, _, L, _ = lora.shape + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register + L1 = L + if L < 128 or L % 128 != 0: + L1 = (L // 128 + 1) * 128 + lora = jnp.pad(lora, ((0,0), (0,0), (0,L1-L), (0,0))) - # TODO: Tune - # Also figure out how to make bT % 128 instead of bL, - # or pick block sizes based off dims + # TODO: Tune these bT = 8 bL = 128 bD = 128 return pl.pallas_call(kernel=functools.partial(bgmv_kernel, bT, bL), - out_shape=jax.ShapeDtypeStruct((T, L), + out_shape=jax.ShapeDtypeStruct((T, L1), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, - grid=(T // bT, L // bL, D // bD), + grid=(T // bT, L1 // bL, D // bD), in_specs=[ pl.BlockSpec((bT, bD), lambda i, j, k, block_idx: @@ -121,11 +125,11 @@ def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): compiler_params=pltpu.TPUCompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), - interpret=True)(idxs, inputs, lora) + interpret=True)(idxs, inputs, lora)[:, :L] if __name__ == "__main__": - T, D, L, N = 128, 3072, 128, 8 + T, D, L, N = 16, 3072, 8, 8 inputs, lora, idxs, ref_output = create_debug_tensors(T, D, L, N) print(idxs) # breakpoint() From 8c34edf7d5d54d4d608899bade8ea4eb6d1a78f9 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 12:00:37 +0000 Subject: [PATCH 062/317] Removed interpreted mode on kernel Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 12 ++--- vllm/lora/ops/xla_ops/pallas.py | 84 +++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 9 deletions(-) create mode 100644 vllm/lora/ops/xla_ops/pallas.py diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 483bef186185..cc541a8a8de5 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +from .pallas import bgmv def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -38,17 +39,10 @@ def bgmv_shrink(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) - return scaling * outputs + return scaling * bgmv(inputs, lora_b_weights, lora_indices_tensor) def bgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py new file mode 100644 index 000000000000..2889d21e774e --- /dev/null +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -0,0 +1,84 @@ +import functools +from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas +jax_import_guard() + +import jax +from jax.experimental import pallas as pl +import jax.numpy as jnp +from jax.experimental.pallas import tpu as pltpu + + +def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, + mask_ref): + + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + t = pl.program_id(0) + + for i in range(bT): + idx = idx_ref[i + bT * t] + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) + + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[idx, 0, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@jax.jit +def _bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): + T, D = inputs.shape + N, _, L, _ = loras.shape + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register + L1 = L + if L < 128 or L % 128 != 0: + L1 = (L // 128 + 1) * 128 + loras = jnp.pad(loras, ((0,0), (0,0), (0,L1-L), (0,0))) + + # TODO: Tune these + bT = 8 + bL = 128 + bD = 128 + + return pl.pallas_call(kernel=functools.partial(_bgmv_kernel, bT, bL), + out_shape=jax.ShapeDtypeStruct((T, L1), + dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // bT, L1 // bL, D // bD), + in_specs=[ + pl.BlockSpec((bT, bD), + lambda i, j, k, block_idx: + (i, k)), + pl.BlockSpec((N, 1, bL, bD), + lambda i, j, k, block_idx: + (0, 0, j, k)), + ], + out_specs=pl.BlockSpec( + (bT, bL), lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((bT, bL), jnp.float32), + pltpu.VMEM((bT, bL), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", + "arbitrary")))(idxs, inputs, loras)[:, :L] + +def bgmv_shape_function(inputs, loras, idxs): + T, _ = inputs.shape + _, _, L, _ = loras.shape + + return [((T, L), inputs.dtype)] + +def bgmv(inputs, loras, idxs): + kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) + + return kernel(inputs, loras, idxs) \ No newline at end of file From b8f1a5749e4bd0c199f289ca8c635f6159ea1657 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 13:14:28 +0000 Subject: [PATCH 063/317] Added pallas kernel benchmarking script Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 bmark_kernels.py diff --git a/bmark_kernels.py b/bmark_kernels.py new file mode 100644 index 000000000000..744c335dc449 --- /dev/null +++ b/bmark_kernels.py @@ -0,0 +1,47 @@ +import itertools +import pytest + +import jax +from jax import numpy as jnp +from vllm.lora.ops.xla_ops.pallas import _bgmv + +def create_tensors(T, D, L, N): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: jax.Array - shape (T, D) + loras: jax.Array - shape (N, 1, L, D) + idxs: jax.Array - shape (T, ) - all values must be in [0, N) + + ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T + """ + inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) + loras = jax.random.normal(jax.random.PRNGKey(0), (N, 1, L, D)) + idxs = jax.random.randint(jax.random.PRNGKey(0), + shape=(T, ), + minval=0, + maxval=N) + + + return inputs, loras, idxs + +# SEQ_LENS = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +# HIDDEN_DIM = [1024, 2048, 3072, 4096] +# LORA_RANKS = [8, 16, 32, 64, 128, 256] +# N_LORAS = [1, 2, 4, 8, 16, 32] +SEQ_LENS = [16, 8192] +HIDDEN_DIM = [1024, 4096] +LORA_RANKS = [8, 256] +N_LORAS = [1, 32] + +@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) +def test_bgmv_benchmark(benchmark, T, D, L, N): + inputs, loras, idxs = create_tensors(T, D, L, N) + + benchmark(_bgmv, inputs, loras, idxs) + From c32ed71759f5d0cfd28f1020be8f69512651a93b Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 13:34:13 +0000 Subject: [PATCH 064/317] Fixed mosaic kernel compilation issue Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 2889d21e774e..9c91f4b0df7d 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -33,7 +33,11 @@ def _(): @jax.jit -def _bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): +def _bgmv( + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array # (N, 1, L, D) model dtype +) -> jax.Array: # (T, L) model dtype T, D = inputs.shape N, _, L, _ = loras.shape @@ -72,7 +76,7 @@ def _bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): dimension_semantics=("parallel", "parallel", "arbitrary")))(idxs, inputs, loras)[:, :L] -def bgmv_shape_function(inputs, loras, idxs): +def bgmv_shape_function(idxs, inputs, loras): T, _ = inputs.shape _, _, L, _ = loras.shape @@ -81,4 +85,4 @@ def bgmv_shape_function(inputs, loras, idxs): def bgmv(inputs, loras, idxs): kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - return kernel(inputs, loras, idxs) \ No newline at end of file + return kernel(idxs, inputs, loras) \ No newline at end of file From d8860a35d2bcbcd1e4e6ce172ac25d27637e42b0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 13:43:17 +0000 Subject: [PATCH 065/317] Added reference kernel benchmarking Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bmark_kernels.py b/bmark_kernels.py index 744c335dc449..23356ac152a1 100644 --- a/bmark_kernels.py +++ b/bmark_kernels.py @@ -3,7 +3,7 @@ import jax from jax import numpy as jnp -from vllm.lora.ops.xla_ops.pallas import _bgmv +from vllm.lora.ops.xla_ops.pallas import bgmv def create_tensors(T, D, L, N): """ @@ -30,18 +30,18 @@ def create_tensors(T, D, L, N): return inputs, loras, idxs -# SEQ_LENS = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] -# HIDDEN_DIM = [1024, 2048, 3072, 4096] -# LORA_RANKS = [8, 16, 32, 64, 128, 256] -# N_LORAS = [1, 2, 4, 8, 16, 32] -SEQ_LENS = [16, 8192] -HIDDEN_DIM = [1024, 4096] -LORA_RANKS = [8, 256] -N_LORAS = [1, 32] +def ref_bgmv(inputs, loras, idxs): + return jnp.einsum("td,__ld->tl", inputs, loras[idxs]) + +SEQ_LENS = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +HIDDEN_DIM = [1024, 2048, 3072, 4096] +LORA_RANKS = [8, 16, 32, 64, 128, 256] +N_LORAS = [1, 2, 4, 8, 16, 32] + @pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) +@pytest.mark.parametrize("func", [bgmv, ref_bgmv]) def test_bgmv_benchmark(benchmark, T, D, L, N): inputs, loras, idxs = create_tensors(T, D, L, N) - benchmark(_bgmv, inputs, loras, idxs) - + benchmark.pedantic(ref_bgmv, args=(inputs, loras, idxs), rounds=10, warmup_rounds=5, iterations=10) From 0663384a9fc5bf236ce57ff357d0c71130f012de Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 14:46:40 +0000 Subject: [PATCH 066/317] Registered the custom op Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 4 ++-- vllm/lora/ops/xla_ops/pallas.py | 22 ++++++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index cc541a8a8de5..8473180108fc 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from .pallas import bgmv +import vllm.lora.ops.xla_ops.pallas # Required to register the custom ops def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -42,7 +42,7 @@ def bgmv_shrink(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) - return scaling * bgmv(inputs, lora_b_weights, lora_indices_tensor) + return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) def bgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 9c91f4b0df7d..f7abbe5e187c 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -1,6 +1,7 @@ import functools -from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas -jax_import_guard() +import torch +from torch.library import impl +from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas, XLA_LIB import jax from jax.experimental import pallas as pl @@ -82,7 +83,20 @@ def bgmv_shape_function(idxs, inputs, loras): return [((T, L), inputs.dtype)] -def bgmv(inputs, loras, idxs): +XLA_LIB.define( + "bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", +) + +@impl(XLA_LIB, "bgmv", "XLA") +def bgmv_xla(inputs, loras, idxs): + jax_import_guard() kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - return kernel(idxs, inputs, loras) \ No newline at end of file + return kernel(idxs, inputs, loras) + +@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") +def bgmv_non_xla(inputs, loras, idxs): + T, _ = inputs.shape + _, _, L, _ = loras.shape + + return torch.empty((T, L), device=inputs.device) From 76317b65ccdf20a28554e387e277082f55add4dc Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 14:53:06 +0000 Subject: [PATCH 067/317] Integrated bgmv kernel Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 8473180108fc..aced0aa34c69 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -8,15 +8,10 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + batch_size = outputs.size(0) + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -52,18 +47,10 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_size: int, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - inputs = inputs.to(dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) - - batch_size, output_size, input_size = selected_loras.shape - - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + batch_size = outputs.size(0) + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) outputs = torch.cat(( torch.zeros((batch_size, slice_offset), device=outputs.device), From 742dad0496b42e014839a8284b5cd7675e097933 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 15:31:36 +0000 Subject: [PATCH 068/317] Fixed model compilation bugs Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 4 ++-- vllm/lora/punica_wrapper/punica_tpu.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index aced0aa34c69..3ef86ea0854f 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -10,8 +10,8 @@ def bgmv_expand(inputs: torch.Tensor, add_inputs: bool = True): inputs = inputs.to(dtype=output_tensor.dtype) - batch_size = outputs.size(0) outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + batch_size = outputs.size(0) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -49,8 +49,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) - batch_size = outputs.size(0) outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + batch_size = outputs.size(0) outputs = torch.cat(( torch.zeros((batch_size, slice_offset), device=outputs.device), diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 88458ed433f8..6c244913e5dd 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -22,6 +22,11 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) + + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which isn't supported by the TPU. + # So convert those tensors to int32. + # Not all of them are used by the TPU so only convert the useful ones. + self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32) def shrink( self, From 43efe69c7fd1601a0b2163547fcd815ebed14947 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 25 Feb 2025 16:20:00 +0000 Subject: [PATCH 069/317] Minor changes Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 2 +- vllm/lora/punica_wrapper/punica_tpu.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index f7abbe5e187c..7c4716505743 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -91,7 +91,7 @@ def bgmv_shape_function(idxs, inputs, loras): def bgmv_xla(inputs, loras, idxs): jax_import_guard() kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - + return kernel(idxs, inputs, loras) @impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 6c244913e5dd..2037b131488a 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -69,11 +69,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: """ - Performs GEMM for multiple slices of lora_a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. + Performs GEMM for multiple slices of lora_a. Semantics: for i in range(len(lora_a_stacked)): From 2eae3844299125d804402249977e1631c8d1babd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Feb 2025 11:01:07 -0800 Subject: [PATCH 070/317] [V1] Get input tokens from scheduler (#13339) Signed-off-by: Woosuk Kwon --- tests/v1/worker/test_gpu_model_runner.py | 1 + vllm/v1/core/scheduler.py | 43 +++-- vllm/v1/core/scheduler_output.py | 15 +- vllm/v1/worker/gpu_model_runner.py | 219 +++++++++++------------ 4 files changed, 139 insertions(+), 139 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 576d906fa749..c655b0fded6e 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner): cached_req_data = CachedRequestData( req_id=req_id, resumed_from_preemption=False, + new_token_ids=[], new_block_ids=[], num_computed_tokens=0, ) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 82c4b307d48b..e5c60afeb492 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -121,6 +121,8 @@ def schedule(self) -> "SchedulerOutput": encoder_budget = self.max_num_encoder_input_tokens # Spec decode-related. scheduled_spec_decode_tokens: Dict[str, List[int]] = {} + + # For logging. scheduled_timestamp = time.monotonic() # First, schedule the RUNNING requests. @@ -187,6 +189,15 @@ def schedule(self) -> "SchedulerOutput": token_budget -= num_new_tokens req_index += 1 + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids[:num_scheduled_spec_tokens]) + # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( @@ -196,11 +207,6 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget - # Speculative decode related. - if request.spec_token_ids: - scheduled_spec_decode_tokens[ - request.request_id] = request.spec_token_ids - # Record the LoRAs in scheduled_running_reqs requested_loras: Set[int] = set() if self.lora_config: @@ -324,23 +330,24 @@ def schedule(self) -> "SchedulerOutput": # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, - req_to_new_block_ids[req.request_id], - req.num_computed_tokens) + req_to_new_block_ids[req.request_id]) for req in scheduled_new_reqs ] resumed_reqs_data = [ self._make_cached_request_data( req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), req_to_new_block_ids[req.request_id], - req.num_computed_tokens, resumed_from_preemption=True, ) for req in scheduled_resumed_reqs ] running_reqs_data = [ self._make_cached_request_data( req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), req_to_new_block_ids[req.request_id], - req.num_computed_tokens, resumed_from_preemption=False, ) for req in scheduled_running_reqs ] @@ -349,8 +356,8 @@ def schedule(self) -> "SchedulerOutput": scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, - scheduled_encoder_inputs=scheduled_encoder_inputs, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. @@ -366,22 +373,28 @@ def schedule(self) -> "SchedulerOutput": def _make_cached_request_data( self, request: Request, + num_scheduled_tokens: int, + num_scheduled_spec_tokens: int, new_block_ids: List[int], - num_computed_tokens: int, resumed_from_preemption: bool, ) -> "CachedRequestData": # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - if request.request_id in self._cached_reqs_data: - req_data = self._cached_reqs_data[request.request_id] + num_computed_tokens = request.num_computed_tokens + num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens + new_token_ids = request.all_token_ids[ + num_computed_tokens:num_computed_tokens + num_regular_tokens] + req_data = self._cached_reqs_data.get(request.request_id) + if req_data is not None: req_data.resumed_from_preemption = resumed_from_preemption + req_data.new_token_ids = new_token_ids req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: req_data = CachedRequestData.from_request(request, resumed_from_preemption, - new_block_ids, - num_computed_tokens) + new_token_ids, + new_block_ids) self._cached_reqs_data[request.request_id] = req_data return req_data diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 2ca8526936e6..47413527c32f 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -30,7 +30,6 @@ def from_request( cls, request: "Request", block_ids: List[int], - num_computed_tokens: int, ) -> "NewRequestData": return cls( req_id=request.request_id, @@ -41,7 +40,7 @@ def from_request( mm_positions=request.mm_positions, sampling_params=request.sampling_params, block_ids=block_ids, - num_computed_tokens=num_computed_tokens, + num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, ) @@ -54,6 +53,7 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool + new_token_ids: List[int] new_block_ids: List[int] num_computed_tokens: int @@ -62,14 +62,15 @@ def from_request( cls, request: "Request", resumed_from_preemption: bool, + new_token_ids: List[int], new_block_ids: List[int], - num_computed_tokens: int, ) -> "CachedRequestData": return cls( req_id=request.request_id, resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, + num_computed_tokens=request.num_computed_tokens, ) @@ -91,9 +92,9 @@ class SchedulerOutput: # Total number of tokens scheduled for all requests. # Equal to sum(num_scheduled_tokens.values()) total_num_scheduled_tokens: int - # req_id -> spec_decode_tokens - # If a request does not have any spec decode tokens, it will - # not be included in the dictionary. + # req_id -> spec_token_ids + # If a request does not have any spec decode tokens, it will not be + # included in the dictionary. scheduled_spec_decode_tokens: Dict[str, List[int]] # req_id -> encoder input indices that need processing. # E.g., if a request has [0, 1], it could mean the vision encoder needs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1212c3554b6..e1d1e43427b8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,7 +2,7 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -184,7 +184,6 @@ def __init__( self.max_model_len, self.max_num_tokens), dtype=np.int32) - self.arange_cpu = torch.from_numpy(self.arange_np) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. @@ -327,7 +326,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_state = self.requests[req_id] # Update the cached states. - req_state.num_computed_tokens = req_data.num_computed_tokens + num_computed_tokens = req_data.num_computed_tokens + req_state.num_computed_tokens = num_computed_tokens + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec decode tokens. + num_new_tokens = (num_computed_tokens + + len(req_data.new_token_ids) - + req_state.num_tokens) + new_token_ids = (req_data.new_token_ids[-num_new_tokens:] + if num_new_tokens > 0 else []) + req_state.output_token_ids.extend(new_token_ids) + # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. req_state.block_ids.extend(req_data.new_block_ids) @@ -346,12 +355,30 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - start_index = len(req_state.block_ids) - len( - req_data.new_block_ids) + num_computed_tokens) + start_index = (len(req_state.block_ids) - + len(req_data.new_block_ids)) self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(req_data.new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = req_data.new_token_ids + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, []) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec decode tokens. + self.input_batch.num_tokens[req_index] = end_token_index + # Check if the batch has changed. If not, we can skip copying the + # sampling metadata from CPU to GPU. batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 # Add the new or resumed requests to the persistent batch. @@ -374,7 +401,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: return batch_changed def _prepare_inputs( - self, scheduler_output: "SchedulerOutput" + self, + scheduler_output: "SchedulerOutput", ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -387,24 +415,14 @@ def _prepare_inputs( # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens_list: List[int] = [] + num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) max_num_scheduled_tokens = 0 - all_spec_token_ids: List[int] = [] - num_spec_tokens_list: List[int] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens_list.append(num_tokens) + num_scheduled_tokens[i] = num_tokens max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, []) - all_spec_token_ids.extend(spec_token_ids) - num_spec_tokens_list.append(len(spec_token_ids)) - - num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, - dtype=np.int32) - assert max_num_scheduled_tokens > 0 # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] @@ -441,78 +459,6 @@ def _prepare_inputs( token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - use_spec_decode = len(all_spec_token_ids) > 0 - if use_spec_decode: - - # 1. Write spec_token_ids to input batch. - # Step 1. Get req indices that perform spec decode and repeat - # the req indices by the number of spec tokens. Note - # for requests that don't perform spec decode, the - # number of spec tokens is 0 and the req index is - # repeated 0 times. - # E.g., num_spec_tokens_list: [3, 0, 2, 0, 1] - # spec_req_indices: [0, 0, 0, 2, 2, 4] - spec_req_indices = np.repeat(self.arange_np[:num_reqs], - num_spec_tokens_list) - # spec_offsets: offsets within each spec token list. - # E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here - spec_offsets = np.concatenate( - [self.arange_np[1:val + 1] for val in num_spec_tokens_list]) - # spec_seq_offsets: offsets within each sequence. - # E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2] - # after repeating: [1, 1, 1, 3, 3, 2] - # spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1] - # = [2, 3, 4, 4, 5, 3] - spec_seq_offsets = np.repeat( - self.input_batch.num_computed_tokens_cpu[:num_reqs], - num_spec_tokens_list) + spec_offsets - # cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3] - cumsums_spec_offsets = ( - spec_seq_offsets + - spec_req_indices * self.input_batch.token_ids_cpu.shape[1]) - cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to( - torch.int64) - all_spec_token_ids = torch.tensor(all_spec_token_ids, - device="cpu", - dtype=self.input_ids_cpu.dtype) - - # Step 2. Write spec token ids to input_ids_cpu. - self.input_batch.token_ids_cpu_tensor.flatten().scatter_( - 0, cumsums_spec_offsets, all_spec_token_ids) - - # 2. Get spec decode logits indices. - # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] - # cu_num_tokens: [4, 104, 107, 207, 209] - # num_spec_tokens_list: [3, 0, 2, 0, 1] - # num_sampled_tokens: [4, 1, 3, 1, 2] - # spec_decode_logits_indices: - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32) - num_sampled_tokens = num_spec_tokens_np + 1 - # logits_start_loc: [0, 103, 104, 206, 207] - logits_start_loc = cu_num_tokens - num_sampled_tokens - # [0, 103, 104, 206, 207] -> - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] - logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) - # The following three lines: - # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] - cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) - # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] - # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_sampled_offsets = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) - # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - total_num_sampled_tokens = num_sampled_tokens.sum() - sampled_arange = (self.arange_np[:total_num_sampled_tokens] - - cumsums_sampled_offsets) - - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - spec_decode_logits_indices = logits_start_loc + sampled_arange - # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. @@ -606,9 +552,11 @@ def _prepare_inputs( suffix_kv_lens=suffix_kv_lens, ) + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 if use_spec_decode: - logits_indices = torch.from_numpy(spec_decode_logits_indices).to( - self.device, non_blocking=True) + logits_indices = self._calc_spec_decode_metadata( + scheduler_output, cu_num_tokens) else: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -762,6 +710,53 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr += completion_part_len + def _calc_spec_decode_metadata( + self, + scheduler_output: "SchedulerOutput", + cu_num_tokens: np.ndarray, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get the number of spec decode tokens for each request. + num_reqs = self.input_batch.num_reqs + num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): + assert req_id is not None + num_spec_decode_tokens[i] = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + + # Get spec decode logits indices. + # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] + # cu_num_tokens: [4, 104, 107, 207, 209] + # num_spec_tokens_list: [3, 0, 2, 0, 1] + # num_sampled_tokens: [4, 1, 3, 1, 2] + # spec_decode_logits_indices: + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + num_sampled_tokens = num_spec_decode_tokens + 1 + # logits_start_loc: [0, 103, 104, 206, 207] + logits_start_loc = cu_num_tokens - num_sampled_tokens + # [0, 103, 104, 206, 207] -> + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) + # The following three lines: + # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) + # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] + # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_sampled_offsets = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + total_num_sampled_tokens = num_sampled_tokens.sum() + sampled_arange = (self.arange_np[:total_num_sampled_tokens] - + cumsums_sampled_offsets) + + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + spec_decode_logits_indices = logits_start_loc + sampled_arange + return torch.from_numpy(spec_decode_logits_indices).to( + self.device, non_blocking=True) + def _prepare_sampling( self, batch_changed: bool, @@ -773,7 +768,9 @@ def _prepare_sampling( for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, req_to_spec_token_ids, not batch_changed) + req_id_output_token_ids, + req_to_spec_token_ids, + skip_copy=not batch_changed) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -960,28 +957,24 @@ def execute_model( # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs - request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] + req_ids: List[str] = [] + # Because `input_batch.req_ids` is a list of length `max_num_reqs`, + # we need to stop at `num_reqs`. + # FIXME(woosuk): This is hacky. Refactor. for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None + req_ids.append(req_id) req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - if seq_len >= req_state.num_tokens: - request_seq_lens.append((i, req_state, seq_len)) - else: - # Ignore the sampled token from the partial request. + if seq_len < req_state.num_tokens: + # Ignore the sampled token. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) if generator is not None: # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) - # num_reqs entries should be non-None - assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -994,29 +987,21 @@ def execute_model( scheduler_output, ) - # Update batch with the valid generated tokens. + # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: + # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() - for i, req_state, seq_len in request_seq_lens: - token_id = valid_sampled_token_ids[i][0] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - self.input_batch.num_tokens[i] += 1 else: + # Includes spec decode tokens. valid_mask = sampled_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() + # TODO(woosuk): Optimize this. valid_sampled_token_ids = [ seq.tolist() for seq in sampled_token_ids[valid_mask].split(gen_lens) ] - self.input_batch.num_tokens[:num_reqs] += gen_lens - for i, req_state, seq_len in request_seq_lens: - target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[ - i, target_slice] = valid_sampled_token_ids[i] - req_state.output_token_ids.extend(valid_sampled_token_ids[i]) model_runner_output = ModelRunnerOutput( req_ids=req_ids, From 70837d670ed58d7cbcb4f5ac597e279613bf1384 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 17 Feb 2025 13:37:45 -0800 Subject: [PATCH 071/317] [V1][PP] Fix intermediate tensor values (#13417) Signed-off-by: Cody Yu --- vllm/sequence.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 98578ee04d58..45d0e5bc7680 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1137,6 +1137,9 @@ def __getitem__(self, key: Union[str, slice]): def __setitem__(self, key: str, value: torch.Tensor): self.tensors[key] = value + def items(self): + return self.tensors.items() + def __len__(self): return len(self.tensors) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e1d1e43427b8..1119d53b493c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -151,7 +151,8 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - # self.intermediate_tensors # Set after load_model + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: Optional[IntermediateTensors] = None # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -922,6 +923,11 @@ def execute_model( if get_pp_group().is_first_rank: intermediate_tensors = None else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_( + v[:num_input_tokens], non_blocking=True) intermediate_tensors = IntermediateTensors({ k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items() @@ -1120,7 +1126,7 @@ def _dummy_run( if get_pp_group().is_first_rank: intermediate_tensors = None else: - if not hasattr(self, "intermediate_tensors"): + if self.intermediate_tensors is None: self.intermediate_tensors = ( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, From 0bbf7dbe734c334631290c52e988ca2f060a9b39 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Feb 2025 15:40:12 -0800 Subject: [PATCH 072/317] [V1][Spec decode] Move drafter to model runner (#13363) Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 7 ++++ vllm/v1/core/scheduler.py | 11 +++---- vllm/v1/engine/core.py | 30 ----------------- vllm/v1/outputs.py | 3 ++ vllm/v1/request.py | 12 ------- vllm/v1/spec_decode/ngram_proposer.py | 23 ++++++++----- vllm/v1/worker/gpu_input_batch.py | 7 ++++ vllm/v1/worker/gpu_model_runner.py | 47 +++++++++++++++++++++++++++ vllm/v1/worker/tpu_model_runner.py | 1 + 9 files changed, 84 insertions(+), 57 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index e39a7f9f40bd..eb730973c946 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -203,6 +203,7 @@ def test_schedule_partial_requests(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) @@ -259,6 +260,7 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -307,6 +309,7 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -354,6 +357,7 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -394,6 +398,7 @@ def test_stop_via_update_from_output(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -434,6 +439,7 @@ def test_schedule_concurrent_batches(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) @@ -450,6 +456,7 @@ def test_schedule_concurrent_batches(): req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index e5c60afeb492..8f10834251c1 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -474,6 +474,7 @@ def update_from_output( model_runner_output: "ModelRunnerOutput", ) -> EngineCoreOutputs: sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -530,13 +531,9 @@ def update_from_output( self.encoder_cache_manager.free_encoder_input( request, input_id) - if request.num_computed_tokens >= request.num_tokens: - # Clear the spec tokens as the request has generated - # a new token. Here, We assume all spec tokens are verified - # if we perform speculative decoding for this request. - # Therefore, we can clear all spec tokens after - # the generation step. - request.clear_spec_tokens() + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + request.spec_token_ids = spec_token_ids[req_index] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c7ea7b1a94d8..6718a5f7b02d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -27,7 +27,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder -from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -86,15 +85,6 @@ def __init__( self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) - # Setup speculative decode. - # TODO: find a better way to check if we are using ngram. - self.use_spec_decode = False - if self.scheduler.speculative_config: - assert self.scheduler.speculative_config.ngram_prompt_lookup_min \ - , "Only ngram spec decode is supported in V1." - self.proposer = NgramProposer() - self.use_spec_decode = True - def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() @@ -158,9 +148,6 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - if self.use_spec_decode: - self.propose_tokens() - scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( @@ -221,23 +208,6 @@ def shutdown(self): def profile(self, is_start: bool = True): self.model_executor.profile(is_start) - def propose_tokens(self): - assert self.scheduler.speculative_config is not None - for req in self.scheduler.running: - # Ignore requests that are doing chunked prefill. - if req.num_computed_tokens < req.num_tokens - 1: - continue - # Ignore requests that already have spec tokens. - if req.spec_token_ids: - continue - spec_tokens = self.proposer.propose( - req.all_token_ids, - self.scheduler.speculative_config.ngram_prompt_lookup_min, - self.scheduler.speculative_config.num_speculative_tokens, - ) - if spec_tokens: - req.append_spec_token_ids(spec_tokens) - def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index fb6c4051e9a6..0c8eca38ade7 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -67,6 +67,9 @@ class ModelRunnerOutput: # each request due to speculative/jump decoding. sampled_token_ids: List[List[int]] + # num_reqs x num_spec_tokens + spec_token_ids: Optional[List[List[int]]] + # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index a1bcc2d0393c..52d7faeeb066 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -104,18 +104,6 @@ def append_output_token_ids( self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) - def append_spec_token_ids( - self, - token_ids: Union[int, List[int]], - ) -> None: - if isinstance(token_ids, int): - self.spec_token_ids.append(token_ids) - else: - self.spec_token_ids.extend(token_ids) - - def clear_spec_tokens(self) -> None: - self.spec_token_ids.clear() - @property def num_tokens(self) -> int: return len(self._all_token_ids) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 8eee99506b1f..9b116e00af97 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import List, Optional -from vllm.v1.utils import ConstantList +import numpy as np class NgramProposer: @@ -9,8 +9,12 @@ class NgramProposer: def __init__(self): pass - def propose(self, context_token_ids: ConstantList[int], n: int, - k: int) -> Optional[List[int]]: + def propose( + self, + context_token_ids: np.ndarray, + n: int, + k: int, + ) -> Optional[np.ndarray]: """Proposes the next sequence of tokens based on n-gram pattern matching in the context. The function finds matches of the last n tokens in the previous context, and returns k tokens that followed @@ -25,8 +29,8 @@ def propose(self, context_token_ids: ConstantList[int], n: int, the maximum amount of tokens until the end. Returns: - List[int]: The sequence of tokens that followed - the matched n-gram in the context. + np.ndarray: The sequence of tokens that followed + the matched n-gram in the context. None: If no matching n-gram pattern is found. Example: @@ -66,9 +70,12 @@ def _kmp_lps_array(pattern: List[int]) -> List[int]: return lps @staticmethod - def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int, - k: int) -> Optional[List[int]]: - context_len = len(context_token_ids) + def _find_subarray_kmp( + context_token_ids: np.ndarray, + n: int, + k: int, + ) -> Optional[np.ndarray]: + context_len = context_token_ids.shape[0] assert n > 0 pattern = context_token_ids[-n:] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 805d8f618d2e..cb7411a44e2f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -78,6 +78,7 @@ def __init__( ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) @@ -217,7 +218,11 @@ def add_request( end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids + # Number of token ids in token_ids_cpu. + # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens + # Number of tokens without spec decode tokens. + self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(req_index, request.block_ids) @@ -356,6 +361,8 @@ def condense(self, empty_req_indices: List[int]) -> None: self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens + self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ + last_req_index] self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ last_req_index] self.num_computed_tokens_cpu[ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1119d53b493c..5754422cb1f7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -33,6 +33,7 @@ from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID +from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -117,6 +118,15 @@ def __init__( # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + # Set up speculative decoding. + self.use_spec_decode = False + if self.speculative_config: + # TODO: find a better way to check if we are using ngram. + assert self.speculative_config.ngram_prompt_lookup_min, \ + "Currently, only ngram spec decode is supported in V1." + self.drafter = NgramProposer() + self.use_spec_decode = True + # Request states. self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. @@ -367,6 +377,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.token_ids_cpu[ req_index, start_token_index:end_token_index] = req_data.new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, []) @@ -1009,15 +1020,51 @@ def execute_model( for seq in sampled_token_ids[valid_mask].split(gen_lens) ] + if not self.use_spec_decode: + spec_token_ids = None + else: + spec_token_ids = self.generate_draft_token_ids( + valid_sampled_token_ids) + model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, ) return model_runner_output + def generate_draft_token_ids( + self, + sampled_token_ids: List[List[int]], + ) -> List[List[int]]: + # TODO(woosuk): Optimize. + num_reqs = len(sampled_token_ids) + draft_token_ids: List[List[int]] = [] + for i in range(num_reqs): + if len(sampled_token_ids[i]) == 0: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Add sampled_token_ids to token_ids_cpu. + start_idx = self.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + len(sampled_token_ids[i]) + self.input_batch.token_ids_cpu[ + i, start_idx:end_idx] = sampled_token_ids[i] + drafter_output = self.drafter.propose( + self.input_batch.token_ids_cpu[i, :end_idx], + self.speculative_config.ngram_prompt_lookup_min, + self.speculative_config.num_speculative_tokens, + ) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 18362d86cf3b..255c6cef2f30 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -718,6 +718,7 @@ def execute_model( req_ids=all_req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[[token_id] for token_id in sampled_token_ids], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] ) From b7b9248f3113443214f1045c9a59d63b7d48975d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 17 Feb 2025 19:32:48 -0500 Subject: [PATCH 073/317] [Bugfix][CI][V1] Work around V1 + CUDA Graph + torch._scaled_mm fallback issue (#13425) Signed-off-by: Tyler Michael Smith --- .../schemes/compressed_tensors_w8a8_fp8.py | 6 ++++-- .../layers/quantization/fbgemm_fp8.py | 4 +++- vllm/model_executor/layers/quantization/fp8.py | 6 ++++-- .../layers/quantization/utils/w8a8_utils.py | 14 ++++++++------ 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 5dcc41a9e5da..32072e9fa570 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -9,8 +9,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, - requantize_with_max_scale) + apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -93,6 +93,8 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + maybe_create_device_identity() + output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3bb8188f725c..20f2c3da600d 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -17,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz) + apply_fp8_linear, maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter) from vllm.platforms import current_platform @@ -84,6 +85,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + maybe_create_device_identity() weight_loader = extra_weight_attrs.get("weight_loader") del input_size, output_size output_size_per_partition = sum(output_partition_sizes) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f928ea7e23ca..fe8ff7ca5e12 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported, - normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, - requantize_with_max_scale) + maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, requantize_with_max_scale) from vllm.model_executor.parameter import (BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -162,6 +162,8 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + maybe_create_device_identity() + output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index bea6390f71ff..0f93b7f6c45b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -9,7 +9,7 @@ # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +TORCH_DEVICE_IDENTITY = None # The condition to determine if it is on a platform that supports # torch._scaled_mm rowwise feature. @@ -113,6 +113,13 @@ def requantize_with_max_scale( return max_w_scale, weight +def maybe_create_device_identity(): + # Allocate dummy ones tensor for torch._scaled_mm + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) + + def apply_fp8_linear( input: torch.Tensor, weight: torch.Tensor, @@ -215,11 +222,6 @@ def apply_fp8_linear( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. - # Making sure the dummy tensor is on the same device as the weight - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY.device != weight.device: - TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) - # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place From f551ab584f855b82c82e9e418d8692813544841e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 18 Feb 2025 03:52:35 +0000 Subject: [PATCH 074/317] [Misc] Remove dangling references to `SamplingType.BEAM` (#13402) --- vllm/model_executor/layers/sampler.py | 78 --------------------------- 1 file changed, 78 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 0fcb78691325..07ee75593f7b 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -68,7 +68,6 @@ class SampleResultArgsType: sample_results_dict: SampleResultsDictType sampling_metadata: SamplingMetadata greedy_samples: Optional[torch.Tensor] - beam_search_logprobs: Optional[torch.Tensor] # Union of non-deferred (single-step scheduling) @@ -510,74 +509,6 @@ def _random_sample( return results -def _beam_search_sample( - selected_seq_groups: List[SequenceGroupToSample], - logprobs: torch.Tensor, -) -> SampleResultType: - """Run beam sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - logprobs: (num_selected_samples, vocab_size,) A tensor of logprob - on selected sample indices. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - # We sample 2 * beam_width candidates to make sure that with high - # probability we can get `beam_width` candidates in addition to - # the finished sequences for the next iteration. See - # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 - # for details. See also HF reference: - # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 - # - # NOTE: Beam search is not vectorized, so its speed can be slower than - # other sampling methods. - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - is_prompt = seq_group.is_prompt - seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params - num_parent_seqs = len(seq_ids) - beam_width = sampling_params.n - seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] - if is_prompt: - # Prompt phase. - assert num_parent_seqs == 1, ( - "Prompt input should have only one seq.") - parent_ids = [0] * (2 * beam_width) - _, next_token_ids = torch.topk(seq_group_logprobs[0], - 2 * beam_width) - next_token_ids = next_token_ids.tolist() - else: - # Generation phase. - cumulative_logprobs: List[float] = [ - seq_group.seq_data[seq_id].cumulative_logprob - for seq_id in seq_ids - ] - cumulative_logprobs_tensor = torch.tensor( - cumulative_logprobs, - dtype=torch.float, - device=seq_group_logprobs.device) - seq_group_logprobs = (seq_group_logprobs + - cumulative_logprobs_tensor.unsqueeze(dim=1)) - _, topk_ids = torch.topk(seq_group_logprobs.flatten(), - 2 * beam_width) - topk_ids = topk_ids.tolist() - vocab_size = seq_group_logprobs.size(-1) - parent_ids = [i // vocab_size for i in topk_ids] - next_token_ids = [i % vocab_size for i in topk_ids] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) - return results - - # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead. # Note that we always sample with replacement. @@ -666,14 +597,12 @@ def get_pythonized_sample_results( sampling_metadata, greedy_samples, multinomial_samples, - beam_search_logprobs, sample_results_dict, ) = ( sample_result_args.sample_metadata, sample_result_args.sampling_metadata, sample_result_args.greedy_samples, sample_result_args.multinomial_samples, - sample_result_args.beam_search_logprobs, sample_result_args.sample_results_dict, ) @@ -686,9 +615,6 @@ def get_pythonized_sample_results( elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample(seq_groups, multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) sample_results_dict.update(zip(seq_group_id, sample_results)) return [ @@ -731,7 +657,6 @@ def _sample_with_torch( sample_metadata: SampleMetadataType = {} multinomial_samples: MultinomialSamplesType = {} greedy_samples: Optional[torch.Tensor] = None - beam_search_logprobs: Optional[torch.Tensor] = None # Create output tensor for sampled token ids. if include_gpu_probs_tensor: @@ -800,8 +725,6 @@ def _sample_with_torch( sampled_token_ids_tensor[long_sample_indices] = \ multinomial_samples[sampling_type].to(torch.long) - elif sampling_type == SamplingType.BEAM: - beam_search_logprobs = logprobs[sample_indices] else: raise ValueError(f"Unsupported sampling type: {sampling_type}") @@ -812,7 +735,6 @@ def _sample_with_torch( sample_metadata=sample_metadata, multinomial_samples=multinomial_samples, greedy_samples=greedy_samples, - beam_search_logprobs=beam_search_logprobs, sample_results_dict=sample_results_dict) if not sampling_metadata.skip_sampler_cpu_output: From d32dd01467f82f7d2823c6f6c8f8731d4a9b71a9 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 18 Feb 2025 11:52:47 +0800 Subject: [PATCH 075/317] [Model] Enable quantization support for `transformers` backend (#12960) --- docs/source/models/supported_models.md | 10 ++-- tests/models/test_transformers.py | 54 ++++++++++++++++++++-- vllm/model_executor/models/transformers.py | 25 ++++------ 3 files changed, 66 insertions(+), 23 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index b046ccfd1555..a1a28986b8a9 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -42,7 +42,7 @@ Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project ### Transformers fallback -After the merge of , `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned! +`vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned! To check if the backend is `transformers`, you can simply do this: @@ -56,9 +56,13 @@ If it is `TransformersModel` then it means it's based on `transformers`! #### Supported features -##### LORA and quantization +##### Quantization -Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team! +Transformers fallback has supported most of available quantization in vLLM (except GGUF). See [Quantization page](#quantization-index) for more information about supported quantization in vllm. + +##### LoRA + +LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team! Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly. diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 1d5d9729df85..31e3c1f7b987 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -45,10 +45,14 @@ def check_implementation( ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("openai-community/gpt2", "transformers"), ("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE - ("meta-llama/Llama-3.2-1B-Instruct", "auto"), ]) # trust_remote_code=True by default -def test_models(hf_runner, vllm_runner, example_prompts, model, - model_impl) -> None: +def test_models( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + example_prompts: list[str], + model: str, + model_impl: str, +) -> None: maybe_raises = nullcontext() if model == "openai-community/gpt2" and model_impl == "transformers": @@ -67,10 +71,50 @@ def test_models(hf_runner, vllm_runner, example_prompts, model, @multi_gpu_test(num_gpus=2) def test_distributed( - hf_runner, - vllm_runner, + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], example_prompts, ): kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} check_implementation(hf_runner, vllm_runner, example_prompts, "meta-llama/Llama-3.2-1B-Instruct", **kwargs) + + +@pytest.mark.parametrize("model, quantization_kwargs", [ + ( + "meta-llama/Llama-3.2-1B-Instruct", + { + "quantization": "bitsandbytes", + "load_format": "bitsandbytes", + }, + ), +]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_quantization( + vllm_runner: Type[VllmRunner], + example_prompts: list[str], + model: str, + quantization_kwargs: dict[str, str], + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner( + model, model_impl="auto", enforce_eager=True, + **quantization_kwargs) as vllm_model: # type: ignore[arg-type] + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs) + + with vllm_runner( + model, + model_impl="transformers", + enforce_eager=True, + **quantization_kwargs) as vllm_model: # type: ignore[arg-type] + transformers_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs) + check_logprobs_close( + outputs_0_lst=transformers_outputs, + outputs_1_lst=vllm_outputs, + name_0="transformers", + name_1="vllm", + ) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 1605467bc3dd..9b456b248952 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -28,6 +28,7 @@ from vllm.distributed.utils import divide from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -37,6 +38,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsQuant from .utils import maybe_prefix logger = init_logger(__name__) @@ -50,10 +52,10 @@ def vllm_flash_attention_forward( value: torch.Tensor, attention_mask: torch.Tensor, # Transformers kwargs - scaling: float = None, + scaling: Optional[float] = None, # vLLM kwargs - attn_metadata: AttentionMetadata = None, - attention_instances: list[Attention] = None, + attn_metadata: Optional[AttentionMetadata] = None, + attention_instances: Optional[list[Attention]] = None, **kwargs): self_attn = attention_instances[module.layer_idx] if scaling is not None: @@ -99,13 +101,7 @@ def replace_linear_class( vllm_linear_cls = { "colwise": ColumnParallelLinear, "rowwise": RowParallelLinear, - }.get(style) - - if vllm_linear_cls is None: - logger.warning( - "Unsupported parallel style value: %s. " - "This layer will not be tensor parallelized.", style) - return linear + }.get(style, ReplicatedLinear) class HFCompatibleLinear(vllm_linear_cls): """ @@ -119,10 +115,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input_size=linear.in_features, output_size=linear.out_features, bias=linear.bias is not None, + quant_config=quant_config, ) -class TransformersModel(nn.Module): +class TransformersModel(nn.Module, SupportsQuant): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" ] # TODO transformers will have a util to get it @@ -133,10 +130,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.config = config - self.quant_config = quant_config self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size @@ -162,7 +157,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: scale=config.head_dim**-0.5, num_kv_heads=divide(config.num_key_value_heads, tp_size), cache_config=cache_config, - quant_config=None, + quant_config=self.quant_config, prefix=f"{i}.attn") for i in range(config.num_hidden_layers) ] @@ -172,7 +167,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # ForCausalLM modifications self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=None, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head")) if config.tie_word_embeddings: self.lm_head.weight = self.model.get_input_embeddings().weight From cfcb3f28bb3200d676d15e54375c9fa127544256 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Mon, 17 Feb 2025 22:07:12 -0600 Subject: [PATCH 076/317] [ROCm] fix get_device_name for rocm (#13438) Signed-off-by: Divakar Verma --- vllm/platforms/rocm.py | 49 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 393b8a18527f..e506689dc33c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from functools import lru_cache +import os +from functools import lru_cache, wraps from typing import TYPE_CHECKING, Dict, List, Optional import torch +from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles, + amdsmi_init, amdsmi_shut_down) import vllm.envs as envs from vllm.logger import init_logger @@ -53,6 +56,41 @@ "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } +# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`` +if "HIP_VISIBLE_DEVICES" in os.environ: + val = os.environ["HIP_VISIBLE_DEVICES"] + if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None): + assert val == cuda_val + else: + os.environ["CUDA_VISIBLE_DEVICES"] = val + +# AMDSMI utils +# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using AMDSMI is that it will not initialize CUDA + + +def with_amdsmi_context(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + amdsmi_init() + try: + return fn(*args, **kwargs) + finally: + amdsmi_shut_down() + + return wrapper + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM @@ -96,13 +134,12 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: return DeviceCapability(major=major, minor=minor) @classmethod + @with_amdsmi_context @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - # NOTE: When using V1 this function is called when overriding the - # engine args. Calling torch.cuda.get_device_name(device_id) here - # will result in the ROCm context being initialized before other - # processes can be created. - return "AMD" + physical_device_id = device_id_to_physical_device_id(device_id) + handle = amdsmi_get_processor_handles()[physical_device_id] + return amdsmi_get_gpu_asic_info(handle)["market_name"] @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: From cabcc6e09aa8163300bbac8b92ec7d2835fa4df6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 18 Feb 2025 12:33:45 +0800 Subject: [PATCH 077/317] [v1] fix parallel config rank (#13445) Signed-off-by: youkaichao --- vllm/v1/worker/worker_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index bc7e76c38aed..51d2da2344b8 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -41,6 +41,7 @@ def __init__( # Configuration storage super().__init__(vllm_config=vllm_config) + self.parallel_config.rank = rank self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method From 79dd067631051fc8a3b50bce707c149d55298fef Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 00:34:47 -0500 Subject: [PATCH 078/317] [Quant] Molmo SupportsQuant (#13336) --- vllm/model_executor/models/molmo.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index feb585022317..b2154ef54af3 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -52,7 +52,8 @@ from vllm.sequence import IntermediateTensors from vllm.utils import JSONTree, json_map_leaves -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, + SupportsQuant) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -633,7 +634,8 @@ def forward( return hidden_states, residual -class MolmoVisionBackbone(nn.Module): +class MolmoVisionBackbone(nn.Module, SupportsQuant): + packed_modules_mapping = {"merged_linear": ["gate_proj", "up_proj"]} def __init__( self, @@ -794,7 +796,7 @@ def load_weights(self, weights: Iterable[Tuple[str, @support_torch_compile -class MolmoModel(nn.Module): +class MolmoModel(nn.Module, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1402,8 +1404,8 @@ def get_replacement_molmo(item_idx: int): @MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor, info=MolmoProcessingInfo, dummy_inputs=MolmoDummyInputsBuilder) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): +class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, + SupportsQuant): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ # vision backbone mapping From 682099f5f907c8c5f82b74fbe9f4a7f9e8600b0b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 00:35:09 -0500 Subject: [PATCH 079/317] [Quant] Arctic SupportsQuant (#13366) --- vllm/model_executor/models/arctic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index d015682aab47..27df448e63f7 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -33,7 +33,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.arctic import ArcticConfig -from .interfaces import SupportsPP +from .interfaces import SupportsPP, SupportsQuant from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -423,7 +423,8 @@ def forward( return hidden_states -class ArcticForCausalLM(nn.Module, SupportsPP): +class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() From ec2ec124c073e1680328ddab6cc90d311e6681fa Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 18 Feb 2025 00:43:31 -0500 Subject: [PATCH 080/317] [Bugfix] Only print out chat template when supplied (#13444) --- vllm/entrypoints/openai/api_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ad391d6737bf..da5383e790f5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -797,7 +797,9 @@ async def init_app_state( state.log_stats = not args.disable_log_stats resolved_chat_template = load_chat_template(args.chat_template) - logger.info("Using supplied chat template:\n%s", resolved_chat_template) + if resolved_chat_template is not None: + logger.info("Using supplied chat template:\n%s", + resolved_chat_template) state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, From f87ec52ef55144b010c8bcf6f8e87aeeea4dc818 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 18 Feb 2025 13:48:10 +0800 Subject: [PATCH 081/317] [core] fix sleep mode in pytorch 2.6 (#13456) Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index f74ad9ac3385..7f63fc143787 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -9,7 +9,7 @@ # the only successful approach is to call cuda driver API in C. import dataclasses from contextlib import contextmanager -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -97,7 +97,7 @@ def use_memory_pool_with_allocator( new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) with torch.cuda.memory.use_mem_pool(mem_pool): - yield mem_pool + yield mem_pool, new_alloc class CuMemAllocator: @@ -142,6 +142,7 @@ def get_instance() -> "CuMemAllocator": def __init__(self): self.pointer_to_data: Dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag + self.allocator_and_pools: Dict[str, Any] = {} def python_malloc_callback(self, allocation_handle: HandleType) -> None: """ @@ -231,7 +232,13 @@ def use_memory_pool(self, tag: Optional[str] = None): old_tag = self.current_tag self.current_tag = tag with use_memory_pool_with_allocator(self.python_malloc_callback, - self.python_free_callback): + self.python_free_callback) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator and + # the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data yield # PyTorch's bug, calling torch.cuda.empty_cache() will error # when using pluggable allocator, see From 6329db4a972f0bd9ada8e912046da401a840689e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 00:51:09 -0500 Subject: [PATCH 082/317] [Quant] Aria SupportsQuant (#13416) --- vllm/model_executor/models/aria.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 98df532aa0a8..df73a3b76b1f 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -36,7 +36,7 @@ from .idefics2_vision_model import ( Idefics2VisionTransformer as Idefics3VisionTransformer) # yapf: enable -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsQuant from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, is_pp_missing_parameter, maybe_prefix, @@ -53,7 +53,8 @@ class AriaImagePixelInputs(TypedDict): """ -class AriaVisionTransformer(Idefics3VisionTransformer): +class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} def __init__( self, @@ -304,11 +305,17 @@ def __init__( self.mlp = AriaTextMoELayer(config, quant_config=quant_config) -class AriaTextModel(LlamaModel): +class AriaTextModel(LlamaModel, SupportsQuant): """ Custom LlamaModel for the AriaMoE model which modifies the standard LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. """ + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "experts.w13_weight": ["experts.fc1.weight"], + "experts.w2_weight": ["experts.fc2.weight"], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, From 770eab0295650788ec5367164dd2db921fa8ec75 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Feb 2025 21:58:06 -0800 Subject: [PATCH 083/317] [V1][PP] Fix & Pin Ray version in requirements-cuda.txt (#13436) Signed-off-by: Woosuk Kwon --- requirements-cuda.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 0e7217fb3769..44b56422e3ab 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -2,7 +2,7 @@ -r requirements-common.txt # Dependencies for NVIDIA GPUs -ray[default] >= 2.9 +ray[adag] == 2.41.0 # Required for pipeline parallelism in V1. torch == 2.5.1 torchaudio==2.5.1 # These must be updated alongside torch From 49d081902b66ef239f144ea27a1d584bc31b85d8 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 18 Feb 2025 01:49:41 -0500 Subject: [PATCH 084/317] Add outlines fallback when JSON schema has enum (#13449) Signed-off-by: mgoin --- tests/entrypoints/conftest.py | 41 +++++++++++++++++++ tests/entrypoints/llm/test_guided_generate.py | 41 +++++++++++++++++++ vllm/model_executor/guided_decoding/utils.py | 4 ++ 3 files changed, 86 insertions(+) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index b00e168db9d3..3b596ea3e6a0 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -141,6 +141,47 @@ def sample_definition_json_schema(): } +@pytest.fixture +def sample_enum_json_schema(): + return { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", + "pending"] # Literal values using enum + }, + "priority": { + "type": "string", + "enum": ["low", "medium", "high", "critical"] + }, + "category": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["bug", "feature", "improvement"] + }, + "severity": { + "type": "integer", + "enum": [1, 2, 3, 4, + 5] # Enum can also contain numbers + } + }, + "required": ["type", "severity"] + }, + "flags": { + "type": "array", + "items": { + "type": "string", + "enum": ["urgent", "blocked", "needs_review", "approved"] + } + } + }, + "required": ["status", "priority", "category", "flags"] + } + + @pytest.fixture def sample_guided_choice(): return [ diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 932a35a9950e..01d2c1709b49 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -146,6 +146,47 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, schema=sample_definition_json_schema) +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_enum_json_completion(sample_enum_json_schema, llm, + guided_decoding_backend: str): + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_enum_json_schema, + backend=guided_decoding_backend)) + outputs = llm.generate(prompts=[ + "Create a bug report JSON that fits this schema: " + f"{sample_enum_json_schema}. Make it for a high priority critical bug." + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_enum_json_schema) + + # Additional assertions to verify enum values + assert output_json["status"] in ["active", "inactive", "pending"] + assert output_json["priority"] in ["low", "medium", "high", "critical"] + assert output_json["category"]["type"] in [ + "bug", "feature", "improvement" + ] + assert output_json["category"]["severity"] in [1, 2, 3, 4, 5] + for flag in output_json["flags"]: + assert flag in ["urgent", "blocked", "needs_review", "approved"] + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) def test_guided_choice_completion(sample_guided_choice, llm, diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py index 87ef45358457..c3c0378ea952 100644 --- a/vllm/model_executor/guided_decoding/utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -14,6 +14,10 @@ def check_object(obj: dict) -> bool: if "pattern" in obj: return True + # Check for enum restrictions + if "enum" in obj: + return True + # Check for numeric ranges if obj.get("type") in ("integer", "number") and any( key in obj for key in [ From 13c152e2afc2418c5ecf3ca7b4d9101843fb5a46 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 18 Feb 2025 03:19:15 -0500 Subject: [PATCH 085/317] [Bugfix] Ensure LoRA path from the request can be included in err msg (#13450) Signed-off-by: Yuan Tang --- vllm/lora/worker_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index f33a7b88cc35..b103acefe4aa 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -133,7 +133,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # For NotFoundError raise ValueError( f"Loading lora {lora_request.lora_name} failed: No adapter " - f"found for {lora_path}") from e + f"found for {lora_request.lora_path}") from e except Exception as e: # For BadRequestError raise e From 64ebfa364b650a0cd8f6689f2665459f2b64450c Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 18 Feb 2025 18:25:53 +0800 Subject: [PATCH 086/317] [Bugfix] Fix failing transformers dynamic module resolving with spawn multiproc method (#13403) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/utils.py | 25 +++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index dc620d4984a7..9686231fb4bd 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -47,22 +47,31 @@ def resolve_transformers_fallback(model_config: ModelConfig, for i, arch in enumerate(architectures): if arch == "TransformersModel": continue - custom_module = None - auto_map = getattr(model_config.hf_config, "auto_map", None) - if auto_map is not None and "AutoModel" in auto_map: - custom_module = get_class_from_dynamic_module( - model_config.hf_config.auto_map["AutoModel"], - model_config.model) + auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", + None) or dict() + # Make sure that config class is always initialized before model class, + # otherwise the model class won't be able to access the config class, + # the expected auto_map should have correct order like: + # "auto_map": { + # "AutoConfig": "--", + # "AutoModel": "--", + # "AutoModelFor": "--", + # }, + auto_modules = { + name: get_class_from_dynamic_module(module, model_config.model) + for name, module in sorted(auto_map.items(), key=lambda x: x[0]) + } + custom_model_module = auto_modules.get("AutoModel") # TODO(Isotr0py): Further clean up these raises. # perhaps handled them in _ModelRegistry._raise_for_unsupported? if model_config.model_impl == ModelImpl.TRANSFORMERS: - if not is_transformers_impl_compatible(arch, custom_module): + if not is_transformers_impl_compatible(arch, custom_model_module): raise ValueError( f"The Transformers implementation of {arch} is not " "compatible with vLLM.") architectures[i] = "TransformersModel" if model_config.model_impl == ModelImpl.AUTO: - if not is_transformers_impl_compatible(arch, custom_module): + if not is_transformers_impl_compatible(arch, custom_model_module): raise ValueError( f"{arch} has no vLLM implementation and the Transformers " "implementation is not compatible with vLLM.") From 70fc779b33dd0bbccfa291a1cc8449a9a7c7b666 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 18 Feb 2025 10:52:39 +0000 Subject: [PATCH 087/317] [Doc]: Improve feature tables (#13224) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/source/_static/custom.css | 8 + docs/source/conf.py | 9 +- docs/source/features/compatibility_matrix.md | 151 +++++++++--------- .../quantization/supported_hardware.md | 94 +++++------ docs/source/models/pooling_models.md | 8 +- 5 files changed, 142 insertions(+), 128 deletions(-) create mode 100644 docs/source/_static/custom.css diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css new file mode 100644 index 000000000000..79bd2082b49e --- /dev/null +++ b/docs/source/_static/custom.css @@ -0,0 +1,8 @@ +.vertical-table-header th.head:not(.stub) { + writing-mode: sideways-lr; + white-space: nowrap; + max-width: 0; + p { + margin: 0; + } +} diff --git a/docs/source/conf.py b/docs/source/conf.py index f4e8c8b94910..84c9a27be3bf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -78,8 +78,12 @@ 'use_repository_button': True, 'use_edit_page_button': True, } +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] html_js_files = ["custom.js"] +html_css_files = ["custom.css"] myst_url_schemes = { 'http': None, @@ -121,11 +125,6 @@ if os.path.exists(header_file): os.remove(header_file) -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ['_static'] - # Generate additional rst documentation here. def setup(app): diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md index ee5db70c7d5c..6056ca0d366b 100644 --- a/docs/source/features/compatibility_matrix.md +++ b/docs/source/features/compatibility_matrix.md @@ -4,8 +4,14 @@ The tables below show mutually exclusive features and the support on some hardware. +The symbols used have the following meanings: + +- ✅ = Full compatibility +- 🟠 = Partial compatibility +- ❌ = No compatibility + :::{note} -Check the '✗' with links to see tracking issue for unsupported feature/hardware combination. +Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/hardware combination. ::: ## Feature x Feature @@ -29,6 +35,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar :header-rows: 1 :stub-columns: 1 :widths: auto +:class: vertical-table-header - * Feature * [CP](#chunked-prefill) @@ -48,7 +55,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * beam-search * guided dec - * [CP](#chunked-prefill) - * + * ✅ * * * @@ -66,7 +73,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * - * [APC](#automatic-prefix-caching) * ✅ - * + * ✅ * * * @@ -82,9 +89,9 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * * - * [LoRA](#lora-adapter) - * [✗](gh-pr:9057) * ✅ - * + * ✅ + * ✅ * * * @@ -102,7 +109,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * + * ✅ * * * @@ -118,9 +125,9 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar - * [SD](#spec_decode) * ✅ * ✅ - * ✗ + * ❌ + * ✅ * ✅ - * * * * @@ -138,7 +145,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * + * ✅ * * * @@ -150,13 +157,13 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * * - * pooling - * ✗ - * ✗ - * ✗ - * ✗ - * ✗ - * ✗ - * + * ❌ + * ❌ + * ❌ + * ❌ + * ❌ + * ❌ + * ✅ * * * @@ -167,14 +174,14 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * * - * enc-dec - * ✗ - * [✗](gh-issue:7366) - * ✗ - * ✗ - * [✗](gh-issue:7366) + * ❌ + * [❌](gh-issue:7366) + * ❌ + * ❌ + * [❌](gh-issue:7366) + * ✅ * ✅ * ✅ - * * * * @@ -190,9 +197,9 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * ✗ + * ❌ + * ✅ * ✅ - * * * * @@ -205,12 +212,12 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * [✗](gh-pr:8199) * ✅ - * ✗ + * ✅ + * ❌ + * ✅ * ✅ * ✅ - * * * * @@ -222,49 +229,49 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * ✗ + * ❌ + * ✅ + * ❌ + * ❌ * ✅ - * ✗ - * ✗ * ✅ * ✅ - * * * * * * - * multi-step - * ✗ + * ❌ * ✅ - * ✗ + * ❌ + * ✅ + * ❌ + * ✅ + * ❌ + * ❌ * ✅ - * ✗ * ✅ - * ✗ - * ✗ * ✅ - * [✗](gh-issue:8198) * ✅ - * * * * * - * mm * ✅ - * [✗](gh-pr:8348) - * [✗](gh-pr:7199) - * ? - * ? + * [🟠](gh-pr:8348) + * [🟠](gh-pr:4194) + * ❔ + * ❔ * ✅ * ✅ * ✅ * ✅ * ✅ * ✅ - * ? - * + * ❔ + * ✅ * * * @@ -273,16 +280,16 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * [✗](gh-issue:6137) + * [❌](gh-issue:6137) * ✅ - * ✗ + * ❌ * ✅ * ✅ * ✅ - * ? - * [✗](gh-issue:7968) + * ❔ + * [❌](gh-issue:7968) + * ✅ * ✅ - * * * - * beam-search @@ -290,35 +297,35 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * [✗](gh-issue:6137) + * [❌](gh-issue:6137) * ✅ - * ✗ + * ❌ * ✅ * ✅ * ✅ - * ? - * [✗](gh-issue:7968) - * ? + * ❔ + * [❌](gh-issue:7968) + * ❔ + * ✅ * ✅ - * * - * guided dec * ✅ * ✅ - * ? - * ? - * [✗](gh-issue:11484) + * ❔ + * ❔ + * [❌](gh-issue:11484) * ✅ - * ✗ - * ? + * ❌ + * ❔ * ✅ * ✅ * ✅ - * [✗](gh-issue:9893) - * ? + * [❌](gh-issue:9893) + * ❔ + * ✅ * ✅ * ✅ - * ::: (feature-x-hardware)= @@ -339,7 +346,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * CPU * AMD - * [CP](#chunked-prefill) - * [✗](gh-issue:2729) + * [❌](gh-issue:2729) * ✅ * ✅ * ✅ @@ -347,7 +354,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ - * [APC](#automatic-prefix-caching) - * [✗](gh-issue:3687) + * [❌](gh-issue:3687) * ✅ * ✅ * ✅ @@ -368,7 +375,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * [✗](gh-issue:8475) + * [❌](gh-issue:8475) * ✅ - * [SD](#spec_decode) * ✅ @@ -384,7 +391,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * ✗ + * ❌ * ✅ - * pooling * ✅ @@ -393,7 +400,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * ? + * ❔ - * enc-dec * ✅ * ✅ @@ -401,7 +408,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * ✗ + * ❌ - * mm * ✅ * ✅ @@ -432,15 +439,15 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ✅ - * ✗ - * ✗ + * ❌ + * ❌ - * multi-step * ✅ * ✅ * ✅ * ✅ * ✅ - * [✗](gh-issue:8477) + * [❌](gh-issue:8477) * ✅ - * best-of * ✅ diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index 555ed4ce4c8d..a5bd8caf77cd 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -20,93 +20,93 @@ The table below shows the compatibility of various quantization implementations * AWS Inferentia * Google TPU - * AWQ - * ✗ + * ❌ * ✅︎ * ✅︎ * ✅︎ * ✅︎ - * ✗ + * ❌ * ✅︎ * ✅︎ - * ✗ - * ✗ + * ❌ + * ❌ - * GPTQ * ✅︎ * ✅︎ * ✅︎ * ✅︎ * ✅︎ - * ✗ + * ❌ * ✅︎ * ✅︎ - * ✗ - * ✗ + * ❌ + * ❌ - * Marlin (GPTQ/AWQ/FP8) - * ✗ - * ✗ + * ❌ + * ❌ * ✅︎ * ✅︎ * ✅︎ - * ✗ - * ✗ - * ✗ - * ✗ - * ✗ + * ❌ + * ❌ + * ❌ + * ❌ + * ❌ - * INT8 (W8A8) - * ✗ + * ❌ * ✅︎ * ✅︎ * ✅︎ * ✅︎ - * ✗ - * ✗ + * ❌ + * ❌ * ✅︎ - * ✗ - * ✗ + * ❌ + * ❌ - * FP8 (W8A8) - * ✗ - * ✗ - * ✗ + * ❌ + * ❌ + * ❌ * ✅︎ * ✅︎ * ✅︎ - * ✗ - * ✗ - * ✗ - * ✗ + * ❌ + * ❌ + * ❌ + * ❌ - * AQLM * ✅︎ * ✅︎ * ✅︎ * ✅︎ * ✅︎ - * ✗ - * ✗ - * ✗ - * ✗ - * ✗ + * ❌ + * ❌ + * ❌ + * ❌ + * ❌ - * bitsandbytes * ✅︎ * ✅︎ * ✅︎ * ✅︎ * ✅︎ - * ✗ - * ✗ - * ✗ - * ✗ - * ✗ + * ❌ + * ❌ + * ❌ + * ❌ + * ❌ - * DeepSpeedFP * ✅︎ * ✅︎ * ✅︎ * ✅︎ * ✅︎ - * ✗ - * ✗ - * ✗ - * ✗ - * ✗ + * ❌ + * ❌ + * ❌ + * ❌ + * ❌ - * GGUF * ✅︎ * ✅︎ @@ -114,16 +114,16 @@ The table below shows the compatibility of various quantization implementations * ✅︎ * ✅︎ * ✅︎ - * ✗ - * ✗ - * ✗ - * ✗ + * ❌ + * ❌ + * ❌ + * ❌ ::: - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. -- "✅︎" indicates that the quantization method is supported on the specified hardware. -- "✗" indicates that the quantization method is not supported on the specified hardware. +- ✅︎ indicates that the quantization method is supported on the specified hardware. +- ❌ indicates that the quantization method is not supported on the specified hardware. :::{note} This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. diff --git a/docs/source/models/pooling_models.md b/docs/source/models/pooling_models.md index 9704ccee745c..764b67241999 100644 --- a/docs/source/models/pooling_models.md +++ b/docs/source/models/pooling_models.md @@ -28,10 +28,10 @@ The selected option sets the default pooler used to extract the final hidden sta - * Embedding (`embed`) * `LAST` * ✅︎ - * ✗ + * ❌ - * Classification (`classify`) * `LAST` - * ✗ + * ❌ * ✅︎ - * Sentence Pair Scoring (`score`) * \* @@ -39,8 +39,8 @@ The selected option sets the default pooler used to extract the final hidden sta * \* - * Reward Modeling (`reward`) * `ALL` - * ✗ - * ✗ + * ❌ + * ❌ ::: \*The default pooler is always defined by the model. From 053e304e924a61b805b8067d9bc535c05a9d5b7d Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 18 Feb 2025 19:15:48 +0800 Subject: [PATCH 088/317] [Bugfix] Remove noisy error logging during local model loading (#13458) --- vllm/transformers_utils/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 360b457a19a8..2fed5d743e8e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -504,8 +504,7 @@ def get_sentence_transformer_tokenizer_config(model: str, repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN) - except Exception as e: - logger.error("Error getting repo files", e) + except Exception: repo_files = [] for config_name in sentence_transformer_config_files: From 151fdfc427517f6898e0f673ededaf5ee6787daf Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 18 Feb 2025 19:15:56 +0800 Subject: [PATCH 089/317] [ROCm] Make amdsmi import optional for other platforms (#13460) --- vllm/platforms/rocm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e506689dc33c..a4f18cbfc587 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -5,8 +5,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional import torch -from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles, - amdsmi_init, amdsmi_shut_down) import vllm.envs as envs from vllm.logger import init_logger @@ -20,6 +18,12 @@ logger = init_logger(__name__) +try: + from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles, + amdsmi_init, amdsmi_shut_down) +except ImportError as e: + logger.warning("Failed to import from amdsmi with %r", e) + try: import vllm._C # noqa: F401 except ImportError as e: From 7482ce2c9485a63f788bcb9447bf7f0238fc85ce Mon Sep 17 00:00:00 2001 From: zifeitong Date: Tue, 18 Feb 2025 03:29:13 -0800 Subject: [PATCH 090/317] [Bugfix] Handle content type with optional parameters (#13383) Signed-off-by: Zifei Tong --- vllm/entrypoints/openai/api_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index da5383e790f5..0de7e2392691 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -258,7 +258,8 @@ def _cleanup_ipc_path(): async def validate_json_request(raw_request: Request): content_type = raw_request.headers.get("content-type", "").lower() - if content_type != "application/json": + media_type = content_type.split(";", maxsplit=1)[0] + if media_type != "application/json": raise HTTPException( status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, detail="Unsupported Media Type: Only 'application/json' is allowed" From bbfee9a5a1a65081371fad9145bb928853ec5d3a Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Tue, 18 Feb 2025 03:52:03 -0800 Subject: [PATCH 091/317] [Bugfix] Fix invalid rotary embedding unit test (#13431) Signed-off-by: Liangfu Chen --- tests/kernels/test_rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/test_rotary_embedding.py index 362bcb35ceab..c497dd90edda 100644 --- a/tests/kernels/test_rotary_embedding.py +++ b/tests/kernels/test_rotary_embedding.py @@ -41,7 +41,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, is_neox_style, rotary_dim, head_size, seq_len): batch_size = 1 - base = 0 + base = 10000 num_heads = 7 rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, torch.float32) From 168a9ff6f3fa9ac068a26820c3515facb5c18be4 Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Tue, 18 Feb 2025 17:02:49 +0100 Subject: [PATCH 092/317] [CI/Build] migrate static project metadata from setup.py to pyproject.toml (#8772) --- pyproject.toml | 36 ++++++++++++++++++++++++++++++++++- setup.py | 51 ++++---------------------------------------------- 2 files changed, 39 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 96d4aa149abf..ac155116ccde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,42 @@ requires = [ ] build-backend = "setuptools.build_meta" +[project] +name = "vllm" +authors = [{name = "vLLM Team"}] +license = { "file"= "LICENSE" } +readme = "README.md" +description = "A high-throughput and memory-efficient inference and serving engine for LLMs" +classifiers = [ + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: Apache Software License", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Information Analysis", +] +requires-python = ">=3.9" +dynamic = [ "version", "dependencies", "optional-dependencies"] + +[project.urls] +Homepage="https://github.com/vllm-project/vllm" +Documentation="https://vllm.readthedocs.io/en/latest/" +Slack="http://slack.vllm.ai/" + +[project.scripts] +vllm = "vllm.entrypoints.cli.main:main" + [tool.setuptools_scm] -# version_file = "vllm/_version.py" # currently handled by `setup.py:get_version()` +version_file = "vllm/_version.py" + +[tool.setuptools.packages.find] +where = ["."] +exclude = ["benchmarks", "csrc", "docs", "examples", "tests*"] +namespaces = false [tool.yapfignore] ignore_patterns = [ diff --git a/setup.py b/setup.py index 7243a2ab30aa..d09ae4b3810d 100755 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ import torch from packaging.version import Version, parse -from setuptools import Extension, find_packages, setup +from setuptools import Extension, setup from setuptools.command.build_ext import build_ext from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME @@ -499,9 +499,7 @@ def get_gaudi_sw_version(): def get_vllm_version() -> str: - version = get_version( - write_to="vllm/_version.py", # TODO: move this to pyproject.toml - ) + version = get_version() sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): @@ -549,16 +547,6 @@ def get_vllm_version() -> str: return version -def read_readme() -> str: - """Read the README file if present.""" - p = get_path("README.md") - if os.path.isfile(p): - with open(get_path("README.md"), encoding="utf-8") as f: - return f.read() - else: - return "" - - def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" @@ -649,36 +637,10 @@ def _read_requirements(filename: str) -> List[str]: } setup( - name="vllm", + # static metadata should rather go in pyproject.toml version=get_vllm_version(), - author="vLLM Team", - license="Apache 2.0", - description=("A high-throughput and memory-efficient inference and " - "serving engine for LLMs"), - long_description=read_readme(), - long_description_content_type="text/markdown", - url="https://github.com/vllm-project/vllm", - project_urls={ - "Homepage": "https://github.com/vllm-project/vllm", - "Documentation": "https://vllm.readthedocs.io/en/latest/", - }, - classifiers=[ - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "License :: OSI Approved :: Apache Software License", - "Intended Audience :: Developers", - "Intended Audience :: Information Technology", - "Intended Audience :: Science/Research", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Scientific/Engineering :: Information Analysis", - ], - packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples", - "tests*")), - python_requires=">=3.9", - install_requires=get_requirements(), ext_modules=ext_modules, + install_requires=get_requirements(), extras_require={ "tensorizer": ["tensorizer>=2.9.0"], "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], @@ -687,9 +649,4 @@ def _read_requirements(filename: str) -> List[str]: }, cmdclass=cmdclass, package_data=package_data, - entry_points={ - "console_scripts": [ - "vllm=vllm.entrypoints.cli.main:main", - ], - }, ) From 80f3dc341fb7d902a19ad3382faf808938402b0d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Feb 2025 09:15:32 -0800 Subject: [PATCH 093/317] [V1][PP] Enable true PP with Ray executor (#13472) Signed-off-by: Woosuk Kwon --- vllm/v1/executor/ray_distributed_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 53548610adf6..320ebfd37ae3 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -32,7 +32,7 @@ def max_concurrent_batches(self) -> int: """Ray distributed executor supports pipeline parallelism, meaning that it allows PP size batches to be executed concurrently. """ - return 1 #self.vllm_config.parallel_config.pipeline_parallel_size + return self.parallel_config.pipeline_parallel_size def execute_model( self, From 5b612c4123336aa4d502e25eaf6fe2070a2dc8cd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 19 Feb 2025 01:37:11 +0800 Subject: [PATCH 094/317] [misc] fix debugging code (#13487) Signed-off-by: youkaichao --- docs/source/getting_started/troubleshooting.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/getting_started/troubleshooting.md b/docs/source/getting_started/troubleshooting.md index 2f41fa3b6b19..92103e65bbbb 100644 --- a/docs/source/getting_started/troubleshooting.md +++ b/docs/source/getting_started/troubleshooting.md @@ -94,20 +94,20 @@ pynccl.disabled = False s = torch.cuda.Stream() with torch.cuda.stream(s): data.fill_(1) - pynccl.all_reduce(data, stream=s) - value = data.mean().item() + out = pynccl.all_reduce(data, stream=s) + value = out.mean().item() assert value == world_size, f"Expected {world_size}, got {value}" print("vLLM NCCL is successful!") g = torch.cuda.CUDAGraph() with torch.cuda.graph(cuda_graph=g, stream=s): - pynccl.all_reduce(data, stream=torch.cuda.current_stream()) + out = pynccl.all_reduce(data, stream=torch.cuda.current_stream()) data.fill_(1) g.replay() torch.cuda.current_stream().synchronize() -value = data.mean().item() +value = out.mean().item() assert value == world_size, f"Expected {world_size}, got {value}" print("vLLM NCCL with cuda graph is successful!") From d15cbbe2d29166f8b9d4e265b3640c8e8e13f820 Mon Sep 17 00:00:00 2001 From: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com> Date: Tue, 18 Feb 2025 09:53:14 -0800 Subject: [PATCH 095/317] [V1][Tests] Adding additional testing for multimodal models to V1 (#13308) Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- tests/v1/engine/test_async_llm.py | 60 ++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 05197f44f93b..d864cb2af23e 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -8,7 +8,9 @@ from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import SamplingParams +from vllm.assets.image import ImageAsset from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs import PromptType from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM @@ -17,13 +19,32 @@ pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) -ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - disable_log_requests=True) +TEXT_ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + disable_log_requests=True) + +VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct", + enforce_eager=True, + disable_log_requests=True) + +TEXT_PROMPT = "Hello my name is Robert and" + +VISION_PROMPT_TEMPLATE = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" + "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + "What is in the image?<|im_end|>\n" + "<|im_start|>assistant\n") +VISION_PROMPT = { + "prompt": VISION_PROMPT_TEMPLATE, + "multi_modal_data": { + "image": ImageAsset("stop_sign").pil_image + } +} async def generate(engine: AsyncLLM, request_id: str, + prompt: PromptType, output_kind: RequestOutputKind, max_tokens: int, prompt_logprobs: Optional[int] = None) -> Tuple[int, str]: @@ -32,11 +53,12 @@ async def generate(engine: AsyncLLM, count = 0 sampling_params = SamplingParams(max_tokens=max_tokens, + ignore_eos=True, output_kind=output_kind, temperature=0, prompt_logprobs=prompt_logprobs) async for out in engine.generate(request_id=request_id, - prompt="Hello my name is Robert and", + prompt=prompt, sampling_params=sampling_params): num_tokens = len(out.outputs[0].token_ids) @@ -74,6 +96,7 @@ async def test_async_llm_refuses_prompt_logprobs_with_apc( await asyncio.create_task( generate(engine, "request-0", + TEXT_PROMPT, output_kind, 10, prompt_logprobs=5)) @@ -86,18 +109,24 @@ async def test_async_llm_refuses_prompt_logprobs_with_apc( @pytest.mark.parametrize( "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.parametrize("engine_args_and_prompt", + [(TEXT_ENGINE_ARGS, TEXT_PROMPT), + (VISION_ENGINE_ARGS, VISION_PROMPT)]) @pytest.mark.asyncio -async def test_load(monkeypatch, output_kind: RequestOutputKind): +async def test_load(monkeypatch, output_kind: RequestOutputKind, + engine_args_and_prompt: Tuple[AsyncEngineArgs, + PromptType]): # TODO(rickyx): Remove monkeypatch once we have a better way to test V1 # so that in the future when we switch, we don't have to change all the # tests. with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") + engine_args, prompt = engine_args_and_prompt - engine = AsyncLLM.from_engine_args(ENGINE_ARGS) + engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) - NUM_REQUESTS = 10000 + NUM_REQUESTS = 100 NUM_EXPECTED_TOKENS = 10 request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] @@ -107,7 +136,7 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind): for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, output_kind, + generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS))) # Confirm that we got all the EXPECTED tokens from the requests. @@ -126,13 +155,19 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind): @pytest.mark.parametrize( "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.parametrize("engine_args_and_prompt", + [(TEXT_ENGINE_ARGS, TEXT_PROMPT), + (VISION_ENGINE_ARGS, VISION_PROMPT)]) @pytest.mark.asyncio -async def test_abort(monkeypatch, output_kind: RequestOutputKind): +async def test_abort(monkeypatch, output_kind: RequestOutputKind, + engine_args_and_prompt: Tuple[AsyncEngineArgs, + PromptType]): with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") + engine_args, prompt = engine_args_and_prompt - engine = AsyncLLM.from_engine_args(ENGINE_ARGS) + engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) NUM_REQUESTS = 100 @@ -146,7 +181,7 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind): for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, output_kind, + generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS))) # API server cancels requests when they disconnect. @@ -172,7 +207,8 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind): # Confirm we can do another generation. request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}" task = asyncio.create_task( - generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS)) + generate(engine, request_id, prompt, output_kind, + NUM_EXPECTED_TOKENS)) num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS assert not engine.output_processor.has_unfinished_requests() From d9b7062a1107171da96c014ddf55415b30742ed0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 18 Feb 2025 12:15:33 -0800 Subject: [PATCH 096/317] [V1] Optimize handling of sampling metadata and req_ids list (#13244) Signed-off-by: Nick Hill --- tests/v1/sample/test_rejection_sampler.py | 9 +- tests/v1/sample/test_sampler.py | 44 ++--- tests/v1/worker/test_gpu_input_batch.py | 47 +++-- tests/v1/worker/test_gpu_model_runner.py | 33 ++-- vllm/model_executor/layers/utils.py | 6 +- vllm/v1/core/scheduler.py | 6 +- vllm/v1/sample/metadata.py | 21 +-- vllm/v1/sample/ops/penalties.py | 13 +- vllm/v1/sample/ops/topk_topp_sampler.py | 48 ++--- vllm/v1/sample/rejection_sampler.py | 2 + vllm/v1/sample/sampler.py | 13 +- vllm/v1/utils.py | 11 ++ vllm/v1/worker/gpu_input_batch.py | 213 +++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 85 +++------ vllm/v1/worker/tpu_model_runner.py | 2 - 15 files changed, 255 insertions(+), 298 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 8bc33e84194c..3e810e525e1c 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int], def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: batch_size = len(spec_tokens) return SamplingMetadata( - temperature=0.0, + temperature=torch.tensor([]), all_greedy=True, all_random=False, - rejection_sampling=True, spec_token_ids=spec_tokens, top_p=None, top_k=None, - no_top_p=False, - no_top_k=False, min_p=torch.empty(batch_size, ), - no_min_p=True, generators={}, max_num_logprobs=0, no_penalties=False, @@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: presence_penalties=torch.tensor([]), repetition_penalties=torch.tensor([]), output_token_ids=[], - min_tokens=[], - stop_token_ids=[], + min_tokens={}, logit_bias=[None] * batch_size, ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index a4bd651f8224..3f6301c54267 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -77,25 +77,20 @@ def _create_default_sampling_metadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, all_random=False, - rejection_sampling=False, - top_p=torch.empty(batch_size, ), - top_k=torch.empty(batch_size, ), - no_top_p=True, - no_top_k=True, - min_p=torch.empty(batch_size, ), - no_min_p=True, + top_p=None, + top_k=None, + min_p=None, generators={}, max_num_logprobs=0, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, - spec_token_ids=[], + spec_token_ids=None, frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), no_penalties=True, - min_tokens=[], - stop_token_ids=[], + min_tokens={}, logit_bias=[None] * batch_size, ) return fake_sampling_metadata @@ -104,10 +99,10 @@ def _create_default_sampling_metadata( def _generate_min_token_penalties_and_stop_tokens( num_output_tokens: int, batch_size: int, vocab_size: int, batch_indices_for_min_token_penalty: List[int] -) -> Tuple[List[int], List[Set[int]]]: +) -> Dict[int, Tuple[int, Set[int]]]: """ - Generates and returns a list of minimum token penalties (`min_tokens`) - and a corresponding list of stop token IDs (`stop_token_ids`) for each + Generates and returns a dict of minimum token penalties and + corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each batch. If a batch index is included in `batch_indices_for_min_token_penalty`, @@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens( and a random set of stop token IDs is created. Otherwise, a lower `min_tokens` value is assigned, and the stop token IDs set is empty. """ - stop_token_ids: List[Set[int]] = [] - min_tokens: List[int] = [] + min_tokens: Dict[int, Tuple[int, Set[int]]] = {} for index in range(batch_size): if index in batch_indices_for_min_token_penalty: - min_tokens.append( + min_tokens[index] = ( np.random.randint(num_output_tokens + 1, - 2 * num_output_tokens)) - stop_token_ids.append( + 2 * num_output_tokens), set( np.random.randint(0, vocab_size - 1) for _ in range(np.random.randint(0, vocab_size)))) - else: - min_tokens.append(np.random.randint(0, num_output_tokens)) - stop_token_ids.append(set()) - return (min_tokens, stop_token_ids) + min_tokens[index] = (np.random.randint(0, + num_output_tokens), set()) + return min_tokens def _create_weighted_output_token_list( @@ -165,7 +157,7 @@ def _create_weighted_output_token_list( output_token_ids_for_batch.extend( [token_id for _ in range(index + 1)]) output_token_ids.append(output_token_ids_for_batch) - return (output_token_ids, sorted_token_ids_in_output) + return output_token_ids, sorted_token_ids_in_output @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) batch_indices_for_min_token_penalty = np.random.randint( 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() - min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens( + min_tokens = _generate_min_token_penalties_and_stop_tokens( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, batch_indices_for_min_token_penalty) sampling_metadata.min_tokens = min_tokens - sampling_metadata.stop_token_ids = stop_token_ids sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): for token_id in range(VOCAB_SIZE): - if token_id in stop_token_ids[batch_idx]: + _, stop_token_ids = min_tokens.get(batch_idx, (0, set())) + if token_id in stop_token_ids: assert logits[batch_idx][token_id] == -float("inf") else: assert logits[batch_idx][token_id] != -float("inf") diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index c0ab356f5c93..cb3b3d21fbb3 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import pytest @@ -41,7 +41,7 @@ def _remove_requests( for index in req_indices_to_remove: input_batch.remove_request(reqs[index].req_id) req_ids_to_remove.add(reqs[index].req_id) - return (req_ids_to_remove, req_indices_to_remove_list) + return req_ids_to_remove, req_indices_to_remove_list def _construct_expected_sampling_metadata( @@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata( top_p = [0.0 for _ in range(num_reqs)] min_p = [0.0 for _ in range(num_reqs)] temperature = [0.0 for _ in range(num_reqs)] - stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] - min_tokens = [0 for _ in range(num_reqs)] + min_tokens = {} logit_bias = [None] * num_reqs for req in reqs: if req.req_id not in req_ids_retained: @@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata( top_p[index_in_input_batch] = req.sampling_params.top_p min_p[index_in_input_batch] = req.sampling_params.min_p temperature[index_in_input_batch] = req.sampling_params.temperature - stop_token_ids[ - index_in_input_batch] = req.sampling_params.all_stop_token_ids - min_tokens[index_in_input_batch] = req.sampling_params.min_tokens + min_tokens[index_in_input_batch] = ( + req.sampling_params.min_tokens, + req.sampling_params.all_stop_token_ids) logit_bias[index_in_input_batch] = req.sampling_params.logit_bias return SamplingMetadata( temperature=torch.tensor(temperature, dtype=torch.float, device=device), all_greedy=False, all_random=True, - rejection_sampling=False, - top_p=torch.tensor(top_p, dtype=torch.float, device=device), - top_k=torch.tensor(top_k, dtype=torch.int, device=device), - no_top_p=all(x == 1.0 for x in top_p), - no_top_k=all(x == 0 for x in top_k), - min_p=torch.tensor(min_p, dtype=torch.float, device=device), - no_min_p=all(x == 0.0 for x in min_p), + top_p=None if all(x == 1.0 for x in top_p) else torch.tensor( + top_p, dtype=torch.float, device=device), + top_k=None if all(x == 0 for x in top_k) else torch.tensor( + top_k, dtype=torch.int, device=device), + min_p=None if all(x == 0.0 for x in min_p) else torch.tensor( + min_p, dtype=torch.float, device=device), generators={}, max_num_logprobs=0, prompt_token_ids=make_tensor_with_pad( @@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata( dtype=torch.float, device=device), output_token_ids=output_token_ids, - spec_token_ids=[], + spec_token_ids=None, min_tokens=min_tokens, - stop_token_ids=stop_token_ids, no_penalties=(all(x == 0 for x in presence_penalties) and all(x == 0 for x in frequency_penalties) and all(x == 1 for x in repetition_penalties)), @@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch.condense(req_indices_to_remove) # Generate the sampling metadata - sampling_metadata = input_batch.make_sampling_metadata( - req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False) + sampling_metadata = input_batch._make_sampling_metadata() # Create expected output. expected_sampling_metadata = _construct_expected_sampling_metadata( @@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch.req_id_to_index, device=torch.device(device)) + def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: + return (t1 is None + and t2 is None) or (t1 is not None and t2 is not None + and torch.allclose(t1, t2)) + # Assert the actual and expected output. assert torch.allclose(expected_sampling_metadata.temperature, sampling_metadata.temperature) - assert torch.allclose(expected_sampling_metadata.top_p, - sampling_metadata.top_p) - assert torch.allclose(expected_sampling_metadata.top_k, - sampling_metadata.top_k) + assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p) + assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k) assert torch.allclose( expected_sampling_metadata.frequency_penalties, sampling_metadata.frequency_penalties, @@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): assert (expected_sampling_metadata.output_token_ids == sampling_metadata.output_token_ids) assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens - assert expected_sampling_metadata.stop_token_ids == \ - sampling_metadata.stop_token_ids assert expected_sampling_metadata.no_penalties == \ sampling_metadata.no_penalties - assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p - assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index c655b0fded6e..973efcbf8e50 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -5,6 +5,7 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, SchedulerOutput) +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool: return req_id in model_runner.requests +def _is_sampling_metadata_changed(model_runner, + sampling_metadata_before: SamplingMetadata): + return model_runner.input_batch.sampling_metadata is not ( + sampling_metadata_before) + + def test_update_states_new_request(model_runner): req_id = "req_0" # new req scheduler_output = _schedule_new_request(req_id) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert not _is_req_added(model_runner, req_id) assert not _is_req_scheduled(model_runner, req_id) @@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner): scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, - finished_req_ids={}, + finished_req_ids=set(), free_encoder_input_ids=[], ) @@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is False + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert not _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_ids[0]) assert _is_req_scheduled(model_runner, req_ids[0]) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index dfe71028c1bc..a9ef973917e1 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, vocab_size, num_seqs) output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat( + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( 1, vocab_size) logits[logits > 0] /= torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)[logits > 0] @@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, repetition_penalties, 1.0)[logits <= 0] # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 8f10834251c1..535aa644c53c 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -195,8 +195,10 @@ def schedule(self) -> "SchedulerOutput": request.num_computed_tokens - request.num_tokens) if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids[:num_scheduled_spec_tokens]) + request.spec_token_ids) # Encoder-related. if encoder_inputs_to_schedule: @@ -567,7 +569,7 @@ def update_from_output( outputs.append( EngineCoreOutput( request_id=req_id, - new_token_ids=new_token_ids or [], + new_token_ids=new_token_ids, finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index ea64181c0aeb..2184a1866ff5 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple import torch @@ -12,15 +12,13 @@ class SamplingMetadata: temperature: torch.Tensor all_greedy: bool all_random: bool - rejection_sampling: bool - spec_token_ids: List[List[int]] - top_p: torch.Tensor - top_k: torch.Tensor - no_top_p: bool - no_top_k: bool - min_p: torch.Tensor - no_min_p: bool + # None when there are no speculated tokens. + spec_token_ids: Optional[List[List[int]]] + + top_p: Optional[torch.Tensor] + top_k: Optional[torch.Tensor] + min_p: Optional[torch.Tensor] generators: Dict[int, torch.Generator] @@ -34,7 +32,8 @@ class SamplingMetadata: repetition_penalties: torch.Tensor output_token_ids: List[List[int]] - min_tokens: List[int] - stop_token_ids: List[Set[int]] + + # req_index -> (min_tokens, stop_token_ids) + min_tokens: Dict[int, Tuple[int, Set[int]]] logit_bias: List[Optional[Dict[int, float]]] diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index ba368b44ab9c..8d9f6529fa0b 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Set, Tuple +from typing import Dict, List, Set, Tuple import torch @@ -8,18 +8,17 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad -def apply_min_token_penalties(logits: torch.Tensor, - output_token_ids: List[List[int]], - stop_token_ids: List[Set[int]], - min_tokens: List[int]) -> None: +def apply_min_token_penalties( + logits: torch.Tensor, output_token_ids: List[List[int]], + min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None: """ Applies minimum token penalty by setting the logits of the stop tokens to -inf. """ min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] - for index, min_token in enumerate(min_tokens): + for index, (min_token, stop_token_ids) in min_tokens.items(): if len(output_token_ids[index]) < min_token: - for stop_token_id in stop_token_ids[index]: + for stop_token_id in stop_token_ids: min_tokens_logits_to_penalize.append((index, stop_token_id)) if min_tokens_logits_to_penalize: logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 27431001e3e7..78c88ad8b830 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict +from typing import Dict, Optional import torch import torch.nn as nn @@ -55,13 +55,11 @@ def forward_native( self, logits: torch.Tensor, generators: Dict[int, torch.Generator], - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """PyTorch-native implementation of top-k and top-p sampling.""" - logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) + logits = apply_top_k_top_p(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) @@ -69,37 +67,33 @@ def forward_cuda( self, logits: torch.Tensor, generators: Dict[int, torch.Generator], - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """More optimized implementation for top-k and top-p sampling.""" probs = logits.softmax(dim=-1, dtype=torch.float32) - if no_top_k and no_top_p: + if k is None and p is None: # We prefer `random_sample` over `flashinfer_sample` when sorting is # not needed. This is because `random_sample` does not require # CPU-GPU synchronization while `flashinfer_sample` does. return random_sample(probs, generators) - return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) + return flashinfer_sample(probs, k, p, generators) def apply_top_k_top_p( logits: torch.Tensor, - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. This function sorts the logits tensor, which can be slow for large batches. """ - if no_top_k and no_top_p: + if k is None and p is None: return logits logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - if not no_top_k: + if k is not None: # Apply top-k. top_k_mask = logits_sort.size(1) - k.to(torch.long) # Get all the top_k values. @@ -107,7 +101,7 @@ def apply_top_k_top_p( top_k_mask = logits_sort < top_k_mask logits_sort.masked_fill_(top_k_mask, -float("inf")) - if not no_top_p: + if p is not None: # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) probs_sum = probs_sort.cumsum(dim=-1) @@ -147,10 +141,8 @@ def random_sample( def flashinfer_sample( probs: torch.Tensor, - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], generators: Dict[int, torch.Generator], ) -> torch.Tensor: """Sample from the probabilities using FlashInfer. @@ -167,7 +159,7 @@ def flashinfer_sample( does not. Call this function at the end of the forward pass to minimize the synchronization overhead. """ - assert not (no_top_k and no_top_p) + assert not (k is None and p is None) max_top_k_round = 32 batch_size = probs.shape[0] uniform_samples = torch.empty((max_top_k_round, batch_size), @@ -178,11 +170,11 @@ def flashinfer_sample( for i, generator in generators.items(): uniform_samples[:, i].uniform_(generator=generator) - if no_top_k: + if k is None: # Top-p only. next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( probs, uniform_samples, p, deterministic=True) - elif no_top_p: + elif p is None: # Top-k only. next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( probs, uniform_samples, k, deterministic=True) @@ -194,9 +186,9 @@ def flashinfer_sample( # NOTE: CPU-GPU synchronization happens here. if not success.all(): - if not no_top_k: + if k is not None: probs = flashinfer.sampling.top_k_renorm_prob(probs, k) - if not no_top_p: + if p is not None: probs = flashinfer.sampling.top_p_renorm_prob(probs, p) next_token_ids = flashinfer.sampling.sampling_from_probs( probs, uniform_samples[0], deterministic=True) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index df1da8930211..580ad44297aa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -68,6 +68,7 @@ def flashinfer_sample( # NOTE: The following input preparationg can be moved # to the model runner with a persistent manner for better # performance. + assert sampling_metadata.spec_token_ids is not None spec_token_ids = sampling_metadata.spec_token_ids max_spec_len = max(len(s) for s in spec_token_ids) batch_size = len(spec_token_ids) @@ -119,6 +120,7 @@ def forward_native( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + assert sampling_metadata.spec_token_ids is not None spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] # Add 1 to include the 'bonus' token. sample_lens = [x + 1 for x in spec_lens] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index ec6374d12b17..8e2533eefab0 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -26,7 +26,7 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - if sampling_metadata.rejection_sampling: + if sampling_metadata.spec_token_ids: if sampling_metadata.max_num_logprobs: raise NotImplementedError( "Rejection sampling does not support logprobs.") @@ -104,16 +104,14 @@ def sample( logits = self.apply_temperature(logits, sampling_metadata.temperature) # Apply min_p. - if not sampling_metadata.no_min_p: + if sampling_metadata.min_p is not None: logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. random_sampled = self.topk_topp_sampler( logits, sampling_metadata.generators, - sampling_metadata.no_top_k, sampling_metadata.top_k, - sampling_metadata.no_top_p, sampling_metadata.top_p, ) @@ -179,9 +177,10 @@ def apply_penalties( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - apply_min_token_penalties(logits, sampling_metadata.output_token_ids, - sampling_metadata.stop_token_ids, - sampling_metadata.min_tokens) + if sampling_metadata.min_tokens: + apply_min_token_penalties(logits, + sampling_metadata.output_token_ids, + sampling_metadata.min_tokens) if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None logits = apply_all_penalties( diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5494542c181d..5be465014242 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -188,3 +188,14 @@ def bind_kv_cache( for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] + + +def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, + length: int) -> None: + """ + Copy the first length elements of a tensor into another tensor in a + non-blocking manner. + + Used to copy pinned CPU tensor data to pre-allocated GPU tensors. + """ + to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index cb7411a44e2f..ccafc325b53f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 - # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import numpy as np import torch @@ -12,6 +11,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import BlockTable _SAMPLING_EPS = 1e-5 @@ -63,7 +63,7 @@ def __init__( self.pin_memory = pin_memory self.vocab_size = vocab_size - self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self._req_ids: List[Optional[str]] = [] self.req_id_to_index: Dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. @@ -171,11 +171,8 @@ def __init__( self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() - self.min_tokens: List[int] = [0] * max_num_reqs - self.stop_token_ids: List[Set[int]] = [ - set() for _ in range(max_num_reqs) - ] - self.prompt_token_ids: Optional[torch.Tensor] = None + # req_index -> (min_tokens, stop_token_ids) + self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {} # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), @@ -196,6 +193,17 @@ def __init__( self.logit_bias: List[Optional[Dict[int, float]]] = [None] * max_num_reqs + self.req_output_token_ids: List[Optional[List[int]]] = [] + + # This is updated each time the batch constituents change. + self.sampling_metadata = self._make_sampling_metadata() + + @property + def req_ids(self) -> List[str]: + # None elements should only be present transiently + # while performing state updates to the batch. + return cast(List[str], self._req_ids) + def add_request( self, request: "CachedRequestState", @@ -206,7 +214,13 @@ def add_request( assert req_index < self.max_num_reqs req_id = request.req_id - self.req_ids[req_index] = req_id + if req_index == len(self._req_ids): + self._req_ids.append(req_id) + self.req_output_token_ids.append(request.output_token_ids) + else: + self._req_ids[req_index] = req_id + self.req_output_token_ids[req_index] = request.output_token_ids + self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. @@ -255,8 +269,9 @@ def add_request( req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) - self.min_tokens[req_index] = sampling_params.min_tokens - self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids + if sampling_params.min_tokens: + self.min_tokens[req_index] = (sampling_params.min_tokens, + sampling_params.all_stop_token_ids) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -284,16 +299,20 @@ def add_request( self.request_lora_mapping[req_index] = 0 def remove_request(self, req_id: str) -> Optional[int]: + """This method must always be followed by a call to condense().""" + req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None - self.req_ids[req_index] = None + self._req_ids[req_index] = None + self.req_output_token_ids[req_index] = None self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) self.min_p_reqs.discard(req_id) + self.min_tokens.pop(req_index, None) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) @@ -313,33 +332,17 @@ def remove_request(self, req_id: str) -> Optional[int]: self.logit_bias[req_index] = None return req_index - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.min_p_reqs.clear() - self.frequency_penalties_reqs.clear() - self.presence_penalties_reqs.clear() - self.repetition_penalties_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.num_prompt_logprobs.clear() - self.request_lora_mapping.fill(0) - self.lora_id_to_lora_request.clear() - self.lora_id_to_request_ids.clear() - self.logit_bias = [None] * self.max_num_reqs - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: + num_reqs = self.num_reqs + if num_reqs == 0: # The batched states are empty. + self._req_ids.clear() + self.req_output_token_ids.clear() return # NOTE(woosuk): This function assumes that the empty_req_indices # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 + last_req_index = num_reqs + len(empty_req_indices) - 1 while empty_req_indices: # Find the largest non-empty index. while last_req_index in empty_req_indices: @@ -351,10 +354,13 @@ def condense(self, empty_req_indices: List[int]) -> None: break # Swap the states. - req_id = self.req_ids[last_req_index] + req_id = self._req_ids[last_req_index] + output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None + self._req_ids[empty_index] = req_id + self._req_ids[last_req_index] = None + self.req_output_token_ids[empty_index] = output_token_ids + self.req_output_token_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index num_tokens = self.num_tokens[last_req_index] @@ -379,13 +385,14 @@ def condense(self, empty_req_indices: List[int]) -> None: self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] - self.min_tokens[empty_index] = self.min_tokens[last_req_index] - self.stop_token_ids[empty_index] = self.stop_token_ids[ - last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator + min_token = self.min_tokens.pop(last_req_index, None) + if min_token is not None: + self.min_tokens[empty_index] = min_token + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ last_req_index] @@ -394,87 +401,71 @@ def condense(self, empty_req_indices: List[int]) -> None: # Decrement last_req_index since it is now empty. last_req_index -= 1 - def make_sampling_metadata( - self, - req_id_output_token_ids: Dict[str, List[int]], - req_id_to_spec_token_ids: Dict[str, List[int]], - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - self.min_p[:self.num_reqs].copy_( - self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True) - if not self.no_penalties: - # Since syncing these tensors is expensive only copy them - # if necessary i.e. if there are requests which require - # penalties to be applied during sampling. - self.frequency_penalties[:self.num_reqs].copy_( - self.frequency_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - self.presence_penalties[:self.num_reqs].copy_( - self.presence_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - self.repetition_penalties[:self.num_reqs].copy_( - self.repetition_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - # The prompt tokens are used only for applying penalties during - # the sampling process. Hence copy these tensors only when - # there are requests which need penalties to be applied. - self.prompt_token_ids = self._make_prompt_token_ids_tensor() - - output_token_ids: List[List[int]] = [] - spec_token_ids: List[List[int]] = [] - rejection_sampling = False - for req_id in self.req_ids[:self.num_reqs]: - assert req_id is not None - # Currently we create a tensor for output_token_ids from scratch - # at each step. However, for the penalties computation what we - # need is stats about the token ids present in the output. This - # stats can be maintained incrementally instead of computing it - # from scratch at each step. - # TODO - Replace this with incremental update to output token - # statistics. - output_token_ids.append(req_id_output_token_ids[req_id]) - req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, []) - spec_token_ids.append(req_spec_token_ids) - if req_spec_token_ids: - # If any of the requests require speculative decoding, set the - # flag to True. - rejection_sampling = True + # Trim lists to the batch size. + del self._req_ids[self.num_reqs:] + del self.req_output_token_ids[self.num_reqs:] + + def refresh_sampling_metadata(self): + self.sampling_metadata = self._make_sampling_metadata() + + def _make_sampling_metadata(self) -> SamplingMetadata: + num_reqs = self.num_reqs + copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs) + if not self.no_top_p: + copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) + if not self.no_top_k: + copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) + if not self.no_min_p: + copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs) + + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + copy_slice(self.frequency_penalties_cpu_tensor, + self.frequency_penalties, num_reqs) + copy_slice(self.presence_penalties_cpu_tensor, + self.presence_penalties, num_reqs) + copy_slice(self.repetition_penalties_cpu_tensor, + self.repetition_penalties, num_reqs) + + # The prompt tokens are used only for applying penalties during + # the sampling process. Hence copy these tensors only when + # there are requests which need penalties to be applied. + prompt_token_ids = self._make_prompt_token_ids_tensor() + else: + prompt_token_ids = None return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], + temperature=self.temperature[:num_reqs], all_greedy=self.all_greedy, all_random=self.all_random, - rejection_sampling=rejection_sampling, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - min_p=self.min_p[:self.num_reqs], - no_min_p=self.no_min_p, - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, + top_p=None if self.no_top_p else self.top_p[:num_reqs], + top_k=None if self.no_top_k else self.top_k[:num_reqs], + min_p=None if self.no_min_p else self.min_p[:num_reqs], generators=self.generators, max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=self.prompt_token_ids, - frequency_penalties=self.frequency_penalties[:self.num_reqs], - presence_penalties=self.presence_penalties[:self.num_reqs], - repetition_penalties=self.repetition_penalties[:self.num_reqs], - output_token_ids=output_token_ids, - spec_token_ids=spec_token_ids, - min_tokens=self.min_tokens[:self.num_reqs], - stop_token_ids=self.stop_token_ids[:self.num_reqs], + prompt_token_ids=prompt_token_ids, + frequency_penalties=self.frequency_penalties[:num_reqs], + presence_penalties=self.presence_penalties[:num_reqs], + repetition_penalties=self.repetition_penalties[:num_reqs], + output_token_ids=cast(List[List[int]], self.req_output_token_ids), + spec_token_ids=None, + min_tokens=self.min_tokens, no_penalties=self.no_penalties, - logit_bias=self.logit_bias[:self.num_reqs], + logit_bias=self.logit_bias[:num_reqs], ) + def get_sampling_metadata( + self, + req_id_to_spec_token_ids: Dict[str, List[int]], + ) -> SamplingMetadata: + # Set the new spec token ids in the cached sampling metadata. + self.sampling_metadata.spec_token_ids = [ + req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids + ] if req_id_to_spec_token_ids else None + return self.sampling_metadata + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5754422cb1f7..0ecc00acc790 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,7 +31,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput -from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache @@ -224,16 +223,15 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. The updated states are used by the `_prepare_inputs` function to create the input GPU tensors for the model. - Returns: - True if there is a new/resumed/paused/finished request in the batch. - If False, we can skip copying SamplingMetadata to the GPU. + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: @@ -344,9 +342,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: num_new_tokens = (num_computed_tokens + len(req_data.new_token_ids) - req_state.num_tokens) - new_token_ids = (req_data.new_token_ids[-num_new_tokens:] - if num_new_tokens > 0 else []) - req_state.output_token_ids.extend(new_token_ids) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(req_data.new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + req_data.new_token_ids[-num_new_tokens:]) # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. @@ -380,7 +381,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, []) + req_id, ()) if spec_token_ids: start_index = end_token_index end_token_index += len(spec_token_ids) @@ -410,7 +411,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if removed_req_indices: self.input_batch.condense(removed_req_indices) - return batch_changed + if batch_changed: + self.input_batch.refresh_sampling_metadata() def _prepare_inputs( self, @@ -429,8 +431,7 @@ def _prepare_inputs( # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) max_num_scheduled_tokens = 0 - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None + for i, req_id in enumerate(self.input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens[i] = num_tokens max_num_scheduled_tokens = max(max_num_scheduled_tokens, @@ -669,10 +670,7 @@ def _compute_cascade_attn_prefix_len( def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 - num_reqs = self.input_batch.num_reqs - for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - assert req_id is not None - + for index, req_id in enumerate(self.input_batch.req_ids): req = self.requests[req_id] assert req.mrope_positions is not None @@ -726,12 +724,11 @@ def _calc_spec_decode_metadata( self, scheduler_output: "SchedulerOutput", cu_num_tokens: np.ndarray, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: # Get the number of spec decode tokens for each request. num_reqs = self.input_batch.num_reqs num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None + for i, req_id in enumerate(self.input_batch.req_ids): num_spec_decode_tokens[i] = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) @@ -769,22 +766,6 @@ def _calc_spec_decode_metadata( return torch.from_numpy(spec_decode_logits_indices).to( self.device, non_blocking=True) - def _prepare_sampling( - self, - batch_changed: bool, - req_to_spec_token_ids: Dict[str, List[int]], - ) -> SamplingMetadata: - # Create the sampling metadata. - req_id_output_token_ids: Dict[str, List[int]] = \ - {req_id: req.output_token_ids \ - for req_id, req in self.requests.items()} - - sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, - req_to_spec_token_ids, - skip_copy=not batch_changed) - return sampling_metadata - def _execute_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: @@ -838,9 +819,7 @@ def _gather_encoder_outputs( scheduler_output: "SchedulerOutput", ) -> List[torch.Tensor]: encoder_outputs: List[torch.Tensor] = [] - num_reqs = self.input_batch.num_reqs - for req_id in self.input_batch.req_ids[:num_reqs]: - assert req_id is not None + for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] @@ -882,7 +861,7 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - batch_changed = self._update_states(scheduler_output) + self._update_states(scheduler_output) if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -964,8 +943,8 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Sample the next token and get logprobs if needed. - sampling_metadata = self._prepare_sampling( - batch_changed, scheduler_output.scheduled_spec_decode_tokens) + sampling_metadata = self.input_batch.get_sampling_metadata( + scheduler_output.scheduled_spec_decode_tokens) sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, @@ -973,14 +952,7 @@ def execute_model( # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. - num_reqs = self.input_batch.num_reqs - req_ids: List[str] = [] - # Because `input_batch.req_ids` is a list of length `max_num_reqs`, - # we need to stop at `num_reqs`. - # FIXME(woosuk): This is hacky. Refactor. - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None - req_ids.append(req_id) + for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -1027,7 +999,7 @@ def execute_model( valid_sampled_token_ids) model_runner_output = ModelRunnerOutput( - req_ids=req_ids, + req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, spec_token_ids=spec_token_ids, @@ -1041,19 +1013,18 @@ def generate_draft_token_ids( sampled_token_ids: List[List[int]], ) -> List[List[int]]: # TODO(woosuk): Optimize. - num_reqs = len(sampled_token_ids) draft_token_ids: List[List[int]] = [] - for i in range(num_reqs): - if len(sampled_token_ids[i]) == 0: + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: # Skip speculative decoding. draft_token_ids.append([]) continue # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] - end_idx = start_idx + len(sampled_token_ids[i]) - self.input_batch.token_ids_cpu[ - i, start_idx:end_idx] = sampled_token_ids[i] + end_idx = start_idx + num_sampled_ids + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx], self.speculative_config.ngram_prompt_lookup_min, @@ -1204,7 +1175,7 @@ def profile_run(self) -> None: # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. dummy_kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) + torch.tensor((), dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 255c6cef2f30..9aa74ddee81b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1085,8 +1085,6 @@ def swap_positions(b: InputBatch, id_1, id_2): b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ id_1] - b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ - id_2], b.stop_token_ids[id_1] gen_1 = b.generators.pop(id_1, None) gen_2 = b.generators.pop(id_2, None) From 15906174d69abe47428655f101d8b7b138e89b6b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Feb 2025 12:50:31 -0800 Subject: [PATCH 097/317] Pin Ray version to 2.40.0 (#13490) Signed-off-by: Woosuk Kwon --- requirements-cuda.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 44b56422e3ab..bc670b8511fd 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -2,7 +2,7 @@ -r requirements-common.txt # Dependencies for NVIDIA GPUs -ray[adag] == 2.41.0 # Required for pipeline parallelism in V1. +ray[adag] == 2.40.0 # Required for pipeline parallelism in V1. torch == 2.5.1 torchaudio==2.5.1 # These must be updated alongside torch From 1104f29a9796f2f0a9a6e76bc1f8738e1a9db7a1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Feb 2025 13:19:58 -0800 Subject: [PATCH 098/317] [V1][Spec Decode] Optimize N-gram matching with Numba (#13365) Signed-off-by: Woosuk Kwon --- requirements-common.txt | 1 + vllm/v1/spec_decode/ngram_proposer.py | 113 +++++++++++++------------- vllm/v1/worker/gpu_model_runner.py | 13 ++- 3 files changed, 67 insertions(+), 60 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index b7c94cbdba8b..c52980bc7df7 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -1,6 +1,7 @@ psutil sentencepiece # Required for LLaMA tokenizer. numpy < 2.0.0 +numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding. requests >= 2.26.0 tqdm blake3 diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 9b116e00af97..33289d05dabd 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import Optional import numpy as np +from numba import jit class NgramProposer: - def __init__(self): - pass - def propose( self, context_token_ids: np.ndarray, @@ -21,7 +19,7 @@ def propose( that match. Args: - context_token_ids: List of token IDs representing the + context_token_ids: Numpy array of token IDs representing the context sequence. n: Length of the n-gram to match. k: Number of tokens follow the match. If there are less @@ -41,66 +39,65 @@ def propose( followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ - # TODO: Use c++ to implement the _find_subarray_kmp to - # improve the efficiency - return self._find_subarray_kmp(context_token_ids, n, k) + return _find_subarray_kmp(context_token_ids, n, k) - @staticmethod - def _kmp_lps_array(pattern: List[int]) -> List[int]: - """ - Build the lps (longest proper prefix which is also suffix) - array for the pattern. - """ - lps = [0] * len(pattern) - prev_lps = 0 # length of the previous longest prefix suffix - i = 1 - while i < len(pattern): - if pattern[i] == pattern[prev_lps]: - prev_lps += 1 - lps[i] = prev_lps - i += 1 +@jit(nopython=True) +def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray: + """ + Build the lps (longest proper prefix which is also suffix) + array for the pattern. + """ + lps = np.zeros(len(pattern), dtype=np.int32) + prev_lps = 0 # length of the previous longest prefix suffix + i = 1 + + while i < len(pattern): + if pattern[i] == pattern[prev_lps]: + prev_lps += 1 + lps[i] = prev_lps + i += 1 + else: + if prev_lps != 0: + prev_lps = lps[prev_lps - 1] else: - if prev_lps != 0: - prev_lps = lps[prev_lps - 1] - else: - lps[i] = 0 - i += 1 + lps[i] = 0 + i += 1 + return lps - return lps - @staticmethod - def _find_subarray_kmp( - context_token_ids: np.ndarray, - n: int, - k: int, - ) -> Optional[np.ndarray]: - context_len = context_token_ids.shape[0] - assert n > 0 +@jit(nopython=True) +def _find_subarray_kmp( + context_token_ids: np.ndarray, + n: int, + k: int, +) -> Optional[np.ndarray]: + context_len = context_token_ids.shape[0] + assert n > 0 - pattern = context_token_ids[-n:] - # Precompute lps array for Y - lps = NgramProposer._kmp_lps_array(pattern) + pattern = context_token_ids[-n:] + # Precompute lps array for Y + lps = _kmp_lps_array(pattern) - i = 0 - j = 0 - # -n because the last n tokens are used as pattern - while i < context_len - n: - if context_token_ids[i] == pattern[j]: - i += 1 - j += 1 + i = 0 + j = 0 + # -n because the last n tokens are used as pattern + while i < context_len - n: + if context_token_ids[i] == pattern[j]: + i += 1 + j += 1 - # If we have matched the entire Y - if j == n: - # Found pattern in context, gather the next K elements - return context_token_ids[i:i + k] + # If we have matched the entire Y + if j == n: + # Found pattern in context, gather the next K elements + return context_token_ids[i:i + k] + else: + # Mismatch + if j != 0: + # Use the lps array to avoid re-checking elements + j = lps[j - 1] else: - # Mismatch - if j != 0: - # Use the lps array to avoid re-checking elements - j = lps[j - 1] - else: - i += 1 + i += 1 - # Y not found - return None + # Y not found + return None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0ecc00acc790..31fe095a91bc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -120,11 +120,20 @@ def __init__( # Set up speculative decoding. self.use_spec_decode = False if self.speculative_config: + self.use_spec_decode = True + # TODO: find a better way to check if we are using ngram. assert self.speculative_config.ngram_prompt_lookup_min, \ "Currently, only ngram spec decode is supported in V1." - self.drafter = NgramProposer() - self.use_spec_decode = True + if get_pp_group().is_last_rank: + self.drafter = NgramProposer() + # Trigger Numba JIT compilation for N-gram proposer. + # This usually takes less than 1 second. + self.drafter.propose( + np.zeros(1024, dtype=np.int32), + self.speculative_config.ngram_prompt_lookup_min, + self.speculative_config.num_speculative_tokens, + ) # Request states. self.requests: Dict[str, CachedRequestState] = {} From ca4b020da63f9eec8d568469fe1ee0ae06bf3346 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 19 Feb 2025 03:37:26 +0000 Subject: [PATCH 099/317] [Misc] Remove dangling references to `--use-v2-block-manager` (#13492) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .buildkite/nightly-benchmarks/tests/serving-tests.json | 3 +-- docs/source/features/spec_decode.md | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests.json b/.buildkite/nightly-benchmarks/tests/serving-tests.json index facb0eac749c..415171e268b0 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests.json @@ -66,8 +66,7 @@ "swap_space": 16, "speculative_model": "turboderp/Qwama-0.5B-Instruct", "num_speculative_tokens": 4, - "speculative_draft_tensor_parallel_size": 1, - "use_v2_block_manager": "" + "speculative_draft_tensor_parallel_size": 1 }, "client_parameters": { "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index d2255eff608b..cc8d6fceb7d6 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -45,7 +45,7 @@ To perform the same with an online mode launch the server: ```bash python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \ - --seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \ + --seed 42 -tp 1 --speculative_model facebook/opt-125m \ --num_speculative_tokens 5 --gpu_memory_utilization 0.8 ``` From 8e711a69f7da3ae47c444e3e943cc8affd78df45 Mon Sep 17 00:00:00 2001 From: Yu-Zhou Date: Wed, 19 Feb 2025 11:40:19 +0800 Subject: [PATCH 100/317] [Hardware][Gaudi][Feature] Support Contiguous Cache Fetch (#12139) Signed-off-by: yuzhou Signed-off-by: zhouyu5 Co-authored-by: Cody Yu --- vllm/attention/backends/hpu_attn.py | 6 +- vllm/attention/ops/hpu_paged_attn.py | 1 + vllm/envs.py | 8 ++ vllm/worker/hpu_model_runner.py | 114 +++++++++++++++++---------- 4 files changed, 81 insertions(+), 48 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 1ad5e6e8e4e1..9eb533685dbd 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -118,12 +118,8 @@ def __init__( self.matmul_av = Matmul() self.batch2block_matmul = Matmul() self.block2batch_matmul = Matmul() - # NOTE(kzawora): Contiguous PA is off until model runner supports it self.k_cache = VLLMKVCache() - self.k_cache.use_contiguous_pa = False self.v_cache = VLLMKVCache() - self.v_cache.use_contiguous_pa = False - # NOTE(kzawora): Pipelined PA is off until model runner supports it ops.pa_impl = ops.pa self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads @@ -249,7 +245,7 @@ def forward( block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_scales=attn_metadata.block_scales, - block_groups=None, + block_groups=attn_metadata.block_groups, scale=self.scale, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 8bb536343ed8..49ea420d092c 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -23,6 +23,7 @@ class HPUPagedAttentionMetadata: block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] block_scales: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/envs.py b/vllm/envs.py index f8a18cc662ab..45547416314f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -89,6 +89,7 @@ VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None + VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True def get_default_cache_root(): @@ -585,6 +586,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # specify the path through environment variable VLLM_CUDART_SO_PATH. "VLLM_CUDART_SO_PATH": lambda: os.getenv("VLLM_CUDART_SO_PATH", None), + + # Contiguous cache fetching to avoid using costly gather operation on + # Gaudi3. This is only applicable to HPU contiguous cache. If set to true, + # contiguous cache fetch will be used. + "VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH": + lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in + ("1", "true"), } # end-env-vars-definition diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 774049a5281e..fe7c776d0a23 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -25,9 +25,11 @@ import torch import torch.nn as nn from vllm_hpu_extension.ops import LoraMask as LoraMask +from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) +import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import DeviceConfig, VllmConfig from vllm.distributed.parallel_state import get_world_group @@ -260,10 +262,19 @@ def setup_profiler(): return profiler -def pad_list(list, k, v): - target_len = round_up(len(list), k) - padding = target_len - len(list) - return list + [v] * padding +def pad_list(input, k, v): + input_len = len(input) + target_len = round_up(input_len, k) + padding = target_len - input_len + return input + [v] * padding + + +def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + +def flatten(in_list): + return list(itertools.chain(*in_list)) def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): @@ -334,13 +345,23 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, + block_mapping = torch.nn.functional.one_hot(metadata.block_groups, num_classes=batch_size) block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) return metadata + def _set_block_scales(self, metadata, device): + block_mapping = metadata.block_mapping + ones = torch.ones((block_mapping.size(0), ), + device=device, + dtype=block_mapping.dtype) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + metadata = metadata._replace(block_scales=block_scales) + return metadata + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: @@ -351,6 +372,7 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, meta = attn_metadata attn_metadata = self._set_block_mapping(meta, batch_size, device, dtype) + attn_metadata = self._set_block_scales(attn_metadata, device) return attn_metadata def forward(self, *args, **kwargs): @@ -586,6 +608,7 @@ def __init__( self.bucketing_global_state = HPUBucketingGlobalState() self._setup_buckets() self._set_gc_threshold() + self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -911,6 +934,7 @@ def _prepare_prompt( block_indices=block_indices, block_offsets=block_offsets, block_scales=None, + block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, num_prefills=real_num_seqs, @@ -1008,65 +1032,69 @@ def _prepare_decode( num_decode_tokens = sum(seq_lens) - blocks_used = [len(bt) for bt in block_tables if bt] - block_list = [] - block_scales = [] - for i, bt in enumerate(block_tables): - block_list.extend(bt) - blocks_in_group = len(bt) - if blocks_in_group > 0: - scale = 1.0 / blocks_in_group - block_scales.extend([scale] * blocks_in_group) - - block_mapping_nested: List[List[int]] = [ - [i] * b_u for i, b_u in enumerate(blocks_used) + last_block_usage = [ + slot[0] % self.block_size + 1 for slot in slot_mapping ] - block_mapping: List[int] = list( - itertools.chain.from_iterable(block_mapping_nested)) + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [[self.block_size] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + + padding_fn = None + if self.use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + block_bucket_size = find_bucket( + block_bucket_size, + self.bucketing_global_state.decode_block_bucket_cfg) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + padding_fn = lambda tensor, pad_value: gather_list( + tensor, indices, pad_value) + else: + block_bucket_size = find_bucket( + len(block_list), + self.bucketing_global_state.decode_block_bucket_cfg) + padding_fn = lambda tensor, pad_value: pad_list( + tensor, block_bucket_size, pad_value) - last_block = [ - sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) - ] - block_usage = [[self.block_size] * (b_u - 1) + [lb] - for b_u, lb in zip(blocks_used, last_block)] - block_usage = list(itertools.chain(*block_usage)) - - block_bucket_size = find_bucket( - len(block_list), - self.bucketing_global_state.decode_block_bucket_cfg) - block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) - block_mapping = pad_list(block_mapping, block_bucket_size, -1) - block_usage = pad_list(block_usage, block_bucket_size, 1) - block_scales = pad_list(block_scales, block_bucket_size, 0.0) + block_list = padding_fn(block_list, _PAD_BLOCK_ID) + block_groups = padding_fn(block_groups, -1) + block_usage = padding_fn(block_usage, 1) block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) - block_mapping = torch.tensor(block_mapping, - dtype=torch.long, - device=self.device) + block_groups = torch.tensor(block_groups, + dtype=torch.int, + device=self.device) block_usage = torch.tensor(block_usage, dtype=self.model_config.dtype, device=self.device) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, False) - block_scales = torch.tensor(block_scales, - dtype=self.model_config.dtype, - device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, - block_mapping=block_mapping, + block_mapping=None, block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, - block_scales=block_scales, + block_scales=None, + block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, num_prefills=0, @@ -1280,7 +1308,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_scales' + 'block_offsets', 'block_scales', 'block_groups' ]) return attention_metadata From 52b319209509a7b70c6bb45c9d815736aca3718f Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Tue, 18 Feb 2025 21:13:41 -0800 Subject: [PATCH 101/317] [perf-benchmark] Allow premerge ECR (#13509) Signed-off-by: <> Co-authored-by: EC2 Default User --- .buildkite/nightly-benchmarks/benchmark-pipeline.yaml | 6 +++--- .buildkite/nightly-benchmarks/scripts/wait-for-image.sh | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index df95e46d6dd6..d1c08de7c47c 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -21,7 +21,7 @@ steps: podSpec: priorityClassName: perf-benchmark containers: - - image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT + - image: public.ecr.aws/q9t5s3a7/${BUILDKITE_BRANCH:-main} == "main" && "vllm-ci-postmerge-repo" || "vllm-ci-test-repo"}:$BUILDKITE_COMMIT command: - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh resources: @@ -52,7 +52,7 @@ steps: depends_on: wait-for-container-image plugins: - docker#v5.12.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT + image: public.ecr.aws/q9t5s3a7/${BUILDKITE_BRANCH:-main} == "main" && "vllm-ci-postmerge-repo" || "vllm-ci-test-repo"}:$BUILDKITE_COMMIT command: - bash - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -83,7 +83,7 @@ steps: depends_on: wait-for-container-image plugins: - docker#v5.12.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT + image: public.ecr.aws/q9t5s3a7/${BUILDKITE_BRANCH:-main} == "main" && "vllm-ci-postmerge-repo" || "vllm-ci-test-repo"}:$BUILDKITE_COMMIT command: - bash - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh index aa0f7ade808e..50e1ab024220 100644 --- a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh @@ -1,6 +1,10 @@ #!/bin/sh TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-postmerge-repo:pull" | jq -r .token) -URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-postmerge-repo/manifests/$BUILDKITE_COMMIT" +if [[ "$BUILDKITE_BRANCH" == "main" ]]; then + URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-postmerge-repo/manifests/$BUILDKITE_COMMIT" +else + URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" +fi TIMEOUT_SECONDS=10 From ec0057f5a3197d189020f33d784f68ac4121b391 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Wed, 19 Feb 2025 00:23:24 -0600 Subject: [PATCH 102/317] [ROCm][MoE configs] mi325 mixtral & mi300 qwen_moe (#13503) --- ...=1408,device_name=AMD_Instinct_MI300X.json | 200 ++++++++++++++++++ ...N=176,device_name=AMD_Instinct_MI300X.json | 200 ++++++++++++++++++ ...N=352,device_name=AMD_Instinct_MI300X.json | 200 ++++++++++++++++++ ...N=704,device_name=AMD_Instinct_MI300X.json | 200 ++++++++++++++++++ ...me=AMD_Instinct_MI325X,dtype=fp8_w8a8.json | 164 ++++++++++++++ ...14336,device_name=AMD_Instinct_MI325X.json | 200 ++++++++++++++++++ ...me=AMD_Instinct_MI325X,dtype=fp8_w8a8.json | 164 ++++++++++++++ ...16384,device_name=AMD_Instinct_MI325X.json | 200 ++++++++++++++++++ ...me=AMD_Instinct_MI325X,dtype=fp8_w8a8.json | 164 ++++++++++++++ ...=1792,device_name=AMD_Instinct_MI325X.json | 200 ++++++++++++++++++ ...me=AMD_Instinct_MI325X,dtype=fp8_w8a8.json | 164 ++++++++++++++ ...=2048,device_name=AMD_Instinct_MI325X.json | 200 ++++++++++++++++++ ...me=AMD_Instinct_MI325X,dtype=fp8_w8a8.json | 164 ++++++++++++++ ...=3584,device_name=AMD_Instinct_MI325X.json | 200 ++++++++++++++++++ ...me=AMD_Instinct_MI325X,dtype=fp8_w8a8.json | 164 ++++++++++++++ ...=4096,device_name=AMD_Instinct_MI325X.json | 200 ++++++++++++++++++ ...me=AMD_Instinct_MI325X,dtype=fp8_w8a8.json | 164 ++++++++++++++ ...=7168,device_name=AMD_Instinct_MI325X.json | 200 ++++++++++++++++++ ...me=AMD_Instinct_MI325X,dtype=fp8_w8a8.json | 164 ++++++++++++++ ...=8192,device_name=AMD_Instinct_MI325X.json | 200 ++++++++++++++++++ 20 files changed, 3712 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..d09508b31729 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..746463af4d56 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..bbdb9ad09645 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..43584b1eb6b6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..f245285bd821 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 000000000000..3918c93b160a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..16e0a91baf31 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 000000000000..d766fc062ddc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..6d5b1ae5b15f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 000000000000..ffc1b23ea90d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..2758e48fc406 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 000000000000..fc31215cbae8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..6cb80f48329f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 000000000000..de9d0aba75a7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..2c49f359c22a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 000000000000..c7db6c0cbd3f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..7a07bbf41419 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 000000000000..3a3268cc17a8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..c27ca0a36594 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 000000000000..da477b1fb15e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} From 85c2fd66961189e39a87e5923bd0e45b22785671 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Tue, 18 Feb 2025 22:24:03 -0800 Subject: [PATCH 103/317] [Doc] Add clarification note regarding paligemma (#13511) --- docs/source/models/supported_models.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index a1a28986b8a9..5497b5dba76e 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -808,7 +808,7 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ -- * `PaliGemmaForConditionalGeneration` +- * `PaliGemmaForConditionalGeneration`\* * PaliGemma, PaliGemma 2 * T + IE * `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. @@ -885,6 +885,10 @@ The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (` For more details, please see: ::: +:::{note} +Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release. +::: + :::{note} `mistral-community/pixtral-12b` does not support V1 yet. ::: From aabd2638b02c006f027c1879e370cdcaaf1d4942 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Tue, 18 Feb 2025 23:34:59 -0800 Subject: [PATCH 104/317] [1/n][CI] Load models in CI from S3 instead of HF (#13205) Signed-off-by: <> Co-authored-by: EC2 Default User --- requirements-test.in | 2 ++ requirements-test.txt | 8 +++++ .../test_basic_correctness.py | 19 ++++++------ tests/basic_correctness/test_cumem.py | 13 ++++++-- tests/basic_correctness/test_preemption.py | 2 +- tests/conftest.py | 25 ++++++++++++++- tests/engine/test_computed_prefix_blocks.py | 6 +++- tests/engine/test_detokenization.py | 7 +++-- tests/engine/test_executor.py | 17 +++++++--- tests/engine/test_skip_tokenizer_init.py | 9 ++++-- tests/engine/test_stop_reason.py | 2 +- tests/entrypoints/llm/test_chat.py | 13 ++++++-- tests/entrypoints/llm/test_collective_rpc.py | 2 +- tests/entrypoints/llm/test_encode.py | 4 ++- tests/entrypoints/llm/test_generate.py | 4 ++- .../llm/test_generate_multiple_loras.py | 4 ++- tests/entrypoints/llm/test_guided_generate.py | 7 +++-- tests/entrypoints/llm/test_lazy_outlines.py | 31 +++++++++++++++++-- .../entrypoints/llm/test_prompt_validation.py | 9 ++++-- tests/entrypoints/openai/test_rerank.py | 2 +- tests/metrics/test_metrics.py | 21 +++++++++---- tests/models/registry.py | 3 +- tests/models/test_initialization.py | 6 +++- tests/mq_llm_engine/test_abort.py | 4 +-- tests/mq_llm_engine/test_error_handling.py | 6 ++-- tests/mq_llm_engine/test_load.py | 6 ++-- tests/multimodal/test_processing.py | 8 +++-- .../__init__.py | 0 .../test_runai_model_streamer_loader.py | 0 .../test_weight_utils.py | 0 tests/samplers/test_ignore_eos.py | 2 +- tests/samplers/test_logits_processor.py | 2 +- tests/samplers/test_logprobs.py | 2 +- tests/samplers/test_no_bad_words.py | 2 +- tests/samplers/test_ranks.py | 2 +- tests/test_config.py | 13 +++++--- tests/test_regression.py | 13 ++++++-- tests/worker/test_swap.py | 2 +- vllm/config.py | 3 +- vllm/model_executor/model_loader/loader.py | 2 +- .../model_loader/weight_utils.py | 4 +-- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/s3_utils.py | 11 +++++-- 43 files changed, 225 insertions(+), 76 deletions(-) rename tests/{runai_model_streamer => runai_model_streamer_test}/__init__.py (100%) rename tests/{runai_model_streamer => runai_model_streamer_test}/test_runai_model_streamer_loader.py (100%) rename tests/{runai_model_streamer => runai_model_streamer_test}/test_weight_utils.py (100%) diff --git a/requirements-test.in b/requirements-test.in index ecf874ecc50f..53c531360d87 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -37,3 +37,5 @@ genai_perf==0.0.8 tritonclient==2.51.0 numpy < 2.0.0 +runai-model-streamer==0.11.0 +runai-model-streamer-s3==0.11.0 \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index 648a2626c857..f91586419148 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -171,6 +171,8 @@ huggingface-hub==0.26.2 # tokenizers # transformers # vocos +humanize==4.11.0 + # via runai-model-streamer idna==3.10 # via # anyio @@ -290,6 +292,7 @@ numpy==1.26.4 # patsy # peft # rouge-score + # runai-model-streamer # sacrebleu # scikit-learn # scipy @@ -514,6 +517,10 @@ rpds-py==0.20.1 # referencing rsa==4.7.2 # via awscli +runai-model-streamer==0.11.0 + # via -r requirements-test.in +runai-model-streamer-s3==0.11.0 + # via -r requirements-test.in s3transfer==0.10.3 # via # awscli @@ -594,6 +601,7 @@ torch==2.5.1 # encodec # lm-eval # peft + # runai-model-streamer # sentence-transformers # tensorizer # timm diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index bd97dd945fed..cc25c8792aa9 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -9,6 +9,7 @@ import pytest from vllm import LLM +from vllm.config import LoadFormat from vllm.platforms import current_platform from ..conftest import VllmRunner @@ -33,7 +34,7 @@ def v1(run_with_both_engines): def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" - llm = LLM("facebook/opt-125m") + llm = LLM("distilbert/distilgpt2", load_format=LoadFormat.RUNAI_STREAMER) weak_llm = weakref.ref(llm) del llm # If there's any circular reference to vllm, this fails @@ -94,14 +95,14 @@ def test_models( @pytest.mark.parametrize( "model, distributed_executor_backend, attention_backend, " "test_suite", [ - ("facebook/opt-125m", "ray", "", "L4"), - ("facebook/opt-125m", "mp", "", "L4"), - ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4"), - ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4"), - ("facebook/opt-125m", "ray", "", "A100"), - ("facebook/opt-125m", "mp", "", "A100"), - ("facebook/opt-125m", "mp", "FLASHINFER", "A100"), - ("meta-llama/Llama-3.2-1B-Instruct", "ray", "FLASHINFER", "A100"), + ("distilbert/distilgpt2", "ray", "", "L4"), + ("distilbert/distilgpt2", "mp", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"), + ("distilbert/distilgpt2", "ray", "", "A100"), + ("distilbert/distilgpt2", "mp", "", "A100"), + ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), + ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), ]) def test_models_distributed( hf_runner, diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index f16b8007a742..24ed5d392839 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -4,9 +4,11 @@ import torch from vllm import LLM, SamplingParams +from vllm.config import LoadFormat from vllm.device_allocator.cumem import CuMemAllocator from vllm.utils import GiB_bytes +from ..conftest import MODEL_WEIGHTS_S3_BUCKET from ..utils import fork_new_process_for_each_test @@ -118,13 +120,18 @@ def model(x): @pytest.mark.parametrize( "model", [ - "meta-llama/Llama-3.2-1B-Instruct", # sleep mode with safetensors - "facebook/opt-125m" # sleep mode with pytorch checkpoint + # sleep mode with safetensors + f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B", + # sleep mode with pytorch checkpoint + "facebook/opt-125m" ]) def test_end_to_end(model): free, total = torch.cuda.mem_get_info() used_bytes_baseline = total - free # in case other process is running - llm = LLM(model, enable_sleep_mode=True) + load_format = LoadFormat.AUTO + if "Llama" in model: + load_format = LoadFormat.RUNAI_STREAMER + llm = LLM(model, load_format=load_format, enable_sleep_mode=True) prompt = "How are you?" sampling_params = SamplingParams(temperature=0, max_tokens=10) output = llm.generate(prompt, sampling_params) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 6aaec6eef9de..a32b7cac080b 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -17,7 +17,7 @@ from ..models.utils import check_outputs_equal MODELS = [ - "facebook/opt-125m", + "distilbert/distilgpt2", ] diff --git a/tests/conftest.py b/tests/conftest.py index 02105900f30d..74219e40026c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import TaskOption, TokenizerPoolConfig +from vllm.config import LoadFormat, TaskOption, TokenizerPoolConfig from vllm.connections import global_http_connection from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, @@ -46,6 +46,21 @@ _SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt") _M = TypeVar("_M") + +MODELS_ON_S3 = [ + "distilbert/distilgpt2", + "meta-llama/Llama-2-7b-hf", + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-1B-Instruct", + "openai-community/gpt2", + "ArthurZ/Ilama-3.2-1B", + "llava-hf/llava-1.5-7b-hf", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", +] + +MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights" + _PromptMultiModalInput = Union[List[_M], List[List[_M]]] PromptImageInput = _PromptMultiModalInput[Image.Image] @@ -677,8 +692,15 @@ def __init__( enable_chunked_prefill: bool = False, swap_space: int = 4, enforce_eager: Optional[bool] = False, + load_format: Optional[LoadFormat] = None, **kwargs, ) -> None: + if model_name in MODELS_ON_S3 and not load_format: + model_name = (f"s3://vllm-ci-model-weights/" + f"{model_name.split('/')[-1]}") + load_format = LoadFormat.RUNAI_STREAMER + if not load_format: + load_format = LoadFormat.AUTO self.model = LLM( model=model_name, task=task, @@ -693,6 +715,7 @@ def __init__( max_model_len=max_model_len, block_size=block_size, enable_chunked_prefill=enable_chunked_prefill, + load_format=load_format, **kwargs, ) diff --git a/tests/engine/test_computed_prefix_blocks.py b/tests/engine/test_computed_prefix_blocks.py index dca8fa6026ab..93907ecae554 100644 --- a/tests/engine/test_computed_prefix_blocks.py +++ b/tests/engine/test_computed_prefix_blocks.py @@ -2,12 +2,15 @@ import pytest +from vllm.config import LoadFormat from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.sampling_params import SamplingParams +from ..conftest import MODEL_WEIGHTS_S3_BUCKET -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) + +@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) @pytest.mark.parametrize("block_size", [16]) def test_computed_prefix_blocks(model: str, block_size: int): # This test checks if we are able to run the engine to completion @@ -24,6 +27,7 @@ def test_computed_prefix_blocks(model: str, block_size: int): "decoration.") engine_args = EngineArgs(model=model, + load_format=LoadFormat.RUNAI_STREAMER, block_size=block_size, enable_prefix_caching=True) diff --git a/tests/engine/test_detokenization.py b/tests/engine/test_detokenization.py index 742176ea8b60..ab594aeee40d 100644 --- a/tests/engine/test_detokenization.py +++ b/tests/engine/test_detokenization.py @@ -2,11 +2,14 @@ import pytest +from vllm.config import LoadFormat from vllm.entrypoints.llm import LLM from vllm.sampling_params import SamplingParams +from ..conftest import MODEL_WEIGHTS_S3_BUCKET -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) + +@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) def test_computed_prefix_blocks(model: str): # This test checks if the engine generates completions both with and # without optional detokenization, that detokenization includes text @@ -17,7 +20,7 @@ def test_computed_prefix_blocks(model: str): "paper clips? Is there an easy to follow video tutorial available " "online for free?") - llm = LLM(model=model) + llm = LLM(model=model, load_format=LoadFormat.RUNAI_STREAMER) sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False) diff --git a/tests/engine/test_executor.py b/tests/engine/test_executor.py index 84cc3ed63bb9..31c07e709bd9 100644 --- a/tests/engine/test_executor.py +++ b/tests/engine/test_executor.py @@ -6,12 +6,17 @@ import pytest +from vllm.config import LoadFormat from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine from vllm.executor.uniproc_executor import UniProcExecutor from vllm.sampling_params import SamplingParams +from ..conftest import MODEL_WEIGHTS_S3_BUCKET + +RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER + class Mock: ... @@ -33,10 +38,11 @@ def collective_rpc(self, CustomUniExecutorAsync = CustomUniExecutor -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) def test_custom_executor_type_checking(model): with pytest.raises(ValueError): engine_args = EngineArgs(model=model, + load_format=RUNAI_STREAMER_LOAD_FORMAT, distributed_executor_backend=Mock) LLMEngine.from_engine_args(engine_args) with pytest.raises(ValueError): @@ -45,7 +51,7 @@ def test_custom_executor_type_checking(model): AsyncLLMEngine.from_engine_args(engine_args) -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) def test_custom_executor(model, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -54,6 +60,7 @@ def test_custom_executor(model, tmp_path): engine_args = EngineArgs( model=model, + load_format=RUNAI_STREAMER_LOAD_FORMAT, distributed_executor_backend=CustomUniExecutor, enforce_eager=True, # reduce test time ) @@ -68,7 +75,7 @@ def test_custom_executor(model, tmp_path): os.chdir(cwd) -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) def test_custom_executor_async(model, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -77,6 +84,7 @@ def test_custom_executor_async(model, tmp_path): engine_args = AsyncEngineArgs( model=model, + load_format=RUNAI_STREAMER_LOAD_FORMAT, distributed_executor_backend=CustomUniExecutorAsync, enforce_eager=True, # reduce test time ) @@ -95,7 +103,7 @@ async def t(): os.chdir(cwd) -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) def test_respect_ray(model): # even for TP=1 and PP=1, # if users specify ray, we should use ray. @@ -104,6 +112,7 @@ def test_respect_ray(model): engine_args = EngineArgs( model=model, distributed_executor_backend="ray", + load_format=RUNAI_STREAMER_LOAD_FORMAT, enforce_eager=True, # reduce test time ) engine = LLMEngine.from_engine_args(engine_args) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index 655c8232ac77..fee7fd3f6aad 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -2,16 +2,21 @@ import pytest +from vllm.config import LoadFormat from vllm.entrypoints.llm import LLM from vllm.sampling_params import SamplingParams +from ..conftest import MODEL_WEIGHTS_S3_BUCKET -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) + +@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) def test_skip_tokenizer_initialization(model: str): # This test checks if the flag skip_tokenizer_init skips the initialization # of tokenizer and detokenizer. The generated output is expected to contain # token ids. - llm = LLM(model=model, skip_tokenizer_init=True) + llm = LLM(model=model, + skip_tokenizer_init=True, + load_format=LoadFormat.RUNAI_STREAMER) sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) with pytest.raises(ValueError, match="cannot pass text prompts when"): diff --git a/tests/engine/test_stop_reason.py b/tests/engine/test_stop_reason.py index a50b388048c9..4b1e4f5cf45e 100644 --- a/tests/engine/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -12,7 +12,7 @@ from vllm import SamplingParams -MODEL = "facebook/opt-350m" +MODEL = "distilbert/distilgpt2" STOP_STR = "." SEED = 42 MAX_TOKENS = 1024 diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 77c80b2f8944..f6fda5120d9e 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -5,12 +5,17 @@ import pytest from vllm import LLM +from vllm.config import LoadFormat +from ...conftest import MODEL_WEIGHTS_S3_BUCKET from ..openai.test_vision import TEST_IMAGE_URLS +RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER + def test_chat(): - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") + llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B-Instruct", + load_format=RUNAI_STREAMER_LOAD_FORMAT) prompt1 = "Explain the concept of entropy." messages = [ @@ -28,7 +33,8 @@ def test_chat(): def test_multi_chat(): - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") + llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B-Instruct", + load_format=RUNAI_STREAMER_LOAD_FORMAT) prompt1 = "Explain the concept of entropy." prompt2 = "Explain what among us is." @@ -65,7 +71,8 @@ def test_multi_chat(): [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) def test_chat_multi_image(image_urls: List[str]): llm = LLM( - model="microsoft/Phi-3.5-vision-instruct", + model=f"{MODEL_WEIGHTS_S3_BUCKET}/Phi-3.5-vision-instruct", + load_format=RUNAI_STREAMER_LOAD_FORMAT, dtype="bfloat16", max_model_len=4096, max_num_seqs=5, diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index 39d4810de9e7..69c60bbe6e8a 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -28,7 +28,7 @@ class MyWorker(Worker): def echo_rank(self): return self.rank - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + llm = LLM(model="s3://vllm-ci-model-weights/Llama-3.2-1B-Instruct", enforce_eager=True, load_format="dummy", tensor_parallel_size=tp_size, diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index ebec8baba38d..61085bf43d1b 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -6,9 +6,10 @@ import pytest from vllm import LLM, PoolingParams, PoolingRequestOutput +from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory -MODEL_NAME = "intfloat/e5-mistral-7b-instruct" +MODEL_NAME = "s3://vllm-ci-model-weights/e5-mistral-7b-instruct" PROMPTS = [ "Hello, my name is", @@ -32,6 +33,7 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, + load_format=LoadFormat.RUNAI_STREAMER, max_num_batched_tokens=32768, tensor_parallel_size=1, gpu_memory_utilization=0.75, diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 4c78c2c8ee2e..f1bad876be46 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -6,9 +6,10 @@ import pytest from vllm import LLM, RequestOutput, SamplingParams +from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory -MODEL_NAME = "facebook/opt-125m" +MODEL_NAME = "s3://vllm-ci-model-weights/distilgpt2" PROMPTS = [ "Hello, my name is", @@ -30,6 +31,7 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, + load_format=LoadFormat.RUNAI_STREAMER, max_num_batched_tokens=4096, tensor_parallel_size=1, gpu_memory_utilization=0.10, diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index 90e1d5814137..487c00460a63 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -7,10 +7,11 @@ from huggingface_hub import snapshot_download from vllm import LLM +from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +MODEL_NAME = "s3://vllm-ci-model-weights/zephyr-7b-beta" PROMPTS = [ "Hello, my name is", @@ -27,6 +28,7 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, + load_format=LoadFormat.RUNAI_STREAMER, tensor_parallel_size=1, max_model_len=8192, enable_lora=True, diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 01d2c1709b49..70252471cc24 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -7,12 +7,13 @@ import jsonschema import pytest +from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" +MODEL_NAME = "s3://vllm-ci-model-weights/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] @@ -20,7 +21,9 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, max_model_len=1024) + llm = LLM(model=MODEL_NAME, + load_format=LoadFormat.RUNAI_STREAMER, + max_model_len=1024) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py index b1f9ae14da07..07608e15fe92 100644 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -6,10 +6,11 @@ from vllm_test_utils import BlameResult, blame from vllm import LLM, SamplingParams +from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory -def run_normal(): +def run_normal_opt125m(): prompts = [ "Hello, my name is", "The president of the United States is", @@ -33,9 +34,35 @@ def run_normal(): cleanup_dist_env_and_memory() +def run_normal(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Create an LLM without guided decoding as a baseline. + llm = LLM(model="s3://vllm-ci-model-weights/distilgpt2", + load_format=LoadFormat.RUNAI_STREAMER, + enforce_eager=True, + gpu_memory_utilization=0.3) + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Destroy the LLM object and free up the GPU memory. + del llm + cleanup_dist_env_and_memory() + + def run_lmfe(sample_regex): # Create an LLM with guided decoding enabled. - llm = LLM(model="facebook/opt-125m", + llm = LLM(model="s3://vllm-ci-model-weights/distilgpt2", + load_format=LoadFormat.RUNAI_STREAMER, enforce_eager=True, guided_decoding_backend="lm-format-enforcer", gpu_memory_utilization=0.3) diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index f2c145fa3c2b..04848131dfc8 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -3,6 +3,7 @@ import pytest from vllm import LLM +from vllm.config import LoadFormat @pytest.fixture(autouse=True) @@ -14,13 +15,17 @@ def v1(run_with_both_engines): def test_empty_prompt(): - llm = LLM(model="gpt2", enforce_eager=True) + llm = LLM(model="s3://vllm-ci-model-weights/gpt2", + load_format=LoadFormat.RUNAI_STREAMER, + enforce_eager=True) with pytest.raises(ValueError, match='Prompt cannot be empty'): llm.generate([""]) @pytest.mark.skip_v1 def test_out_of_vocab_token(): - llm = LLM(model="gpt2", enforce_eager=True) + llm = LLM(model="s3://vllm-ci-model-weights/gpt2", + load_format=LoadFormat.RUNAI_STREAMER, + enforce_eager=True) with pytest.raises(ValueError, match='out of vocabulary'): llm.generate({"prompt_token_ids": [999999]}) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 4c9774a7397d..cf114f0641db 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -86,4 +86,4 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): assert rerank_response.status_code == 400 # Assert just a small fragments of the response assert "Please reduce the length of the input." in \ - rerank_response.text \ No newline at end of file + rerank_response.text diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 0942c8eed344..1a9063bc2dc3 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -8,16 +8,21 @@ from prometheus_client import REGISTRY from vllm import EngineArgs, LLMEngine +from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.metrics import RayPrometheusStatLogger from vllm.sampling_params import SamplingParams +from ..conftest import MODEL_WEIGHTS_S3_BUCKET + MODELS = [ - "facebook/opt-125m", + "distilbert/distilgpt2", ] +RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @@ -141,8 +146,9 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, metrics_tag_content = stat_logger.labels["model_name"] if served_model_name is None or served_model_name == []: - assert metrics_tag_content == model, ( - f"Metrics tag model_name is wrong! expect: {model!r}\n" + actual_model_name = f"{MODEL_WEIGHTS_S3_BUCKET}/{model.split('/')[-1]}" + assert metrics_tag_content == actual_model_name, ( + f"Metrics tag model_name is wrong! expect: {actual_model_name!r}\n" f"actual: {metrics_tag_content!r}") else: assert metrics_tag_content == served_model_name[0], ( @@ -170,7 +176,8 @@ async def test_async_engine_log_metrics_regression( """ engine_args = AsyncEngineArgs(model=model, dtype=dtype, - disable_log_stats=disable_log_stats) + disable_log_stats=disable_log_stats, + load_format=RUNAI_STREAMER_LOAD_FORMAT) async_engine = AsyncLLMEngine.from_engine_args(engine_args) for i, prompt in enumerate(example_prompts): results = async_engine.generate( @@ -199,7 +206,8 @@ def test_engine_log_metrics_regression( ) -> None: engine_args = EngineArgs(model=model, dtype=dtype, - disable_log_stats=disable_log_stats) + disable_log_stats=disable_log_stats, + load_format=RUNAI_STREAMER_LOAD_FORMAT) engine = LLMEngine.from_engine_args(engine_args) for i, prompt in enumerate(example_prompts): engine.add_request( @@ -283,7 +291,8 @@ def test_metric_spec_decode_interval( gpu_memory_utilization=0.4, speculative_model=model, num_speculative_tokens=k, - enforce_eager=True) + enforce_eager=True, + load_format=RUNAI_STREAMER_LOAD_FORMAT) engine = LLMEngine.from_engine_args(engine_args) diff --git a/tests/models/registry.py b/tests/models/registry.py index 17bfe1d21e4a..8b0ece161009 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -173,7 +173,8 @@ def check_available_online( trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", trust_remote_code=True), - "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"), + "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct", + extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501 "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b", is_available_online=False), diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index c58c63723168..e0d5e0032275 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -7,6 +7,7 @@ from vllm import LLM +from ..conftest import MODELS_ON_S3 from .registry import HF_EXAMPLE_MODELS @@ -42,8 +43,11 @@ def _initialize_kv_caches(self) -> None: with patch.object(LLM.get_engine_class(), "_initialize_kv_caches", _initialize_kv_caches): + model_name = model_info.default + if model_name in MODELS_ON_S3: + model_name = f"s3://vllm-ci-model-weights/{model_name.split('/')[-1]}" LLM( - model_info.default, + model_name, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, speculative_model=model_info.speculative_model, diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py index 808346b5e58d..b0ac0fb327f4 100644 --- a/tests/mq_llm_engine/test_abort.py +++ b/tests/mq_llm_engine/test_abort.py @@ -10,8 +10,8 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs -MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, load_format="runai_streamer") RAISED_ERROR = KeyError RAISED_VALUE = "foo" EXPECTED_TOKENS = 250 diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 35d001781110..4eac73417ad7 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -21,8 +21,10 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser -MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True) +MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, + load_format="runai_streamer", + enforce_eager=True) RAISED_ERROR = KeyError RAISED_VALUE = "foo" diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py index 2069ff987f2f..3162d56c6d4e 100644 --- a/tests/mq_llm_engine/test_load.py +++ b/tests/mq_llm_engine/test_load.py @@ -10,12 +10,14 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs -MODEL = "google/gemma-1.1-2b-it" +MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it" NUM_EXPECTED_TOKENS = 10 NUM_REQUESTS = 10000 # Scenarios to test for num generated token. -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, + load_format="runai_streamer", + disable_log_requests=True) @pytest.fixture(scope="function") diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 459c0d9d113f..7bbe5c53562d 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -553,7 +553,8 @@ def test_find_mm_placeholders( assert result == expected -@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize( + "model_id", ["s3://vllm-ci-model-weights/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("limit", "num_supported", "is_valid"), [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), @@ -592,7 +593,8 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): profiler.get_dummy_data(model_config.max_model_len) -@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize( + "model_id", ["s3://vllm-ci-model-weights/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("num_images", "limit", "is_valid"), [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), @@ -661,7 +663,7 @@ def __call__( return dict(exists=exists) -@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-7B-Instruct"]) # Dummy +@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy # yapf: disable @pytest.mark.parametrize( ("call_kwargs", "expected_kwargs"), diff --git a/tests/runai_model_streamer/__init__.py b/tests/runai_model_streamer_test/__init__.py similarity index 100% rename from tests/runai_model_streamer/__init__.py rename to tests/runai_model_streamer_test/__init__.py diff --git a/tests/runai_model_streamer/test_runai_model_streamer_loader.py b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py similarity index 100% rename from tests/runai_model_streamer/test_runai_model_streamer_loader.py rename to tests/runai_model_streamer_test/test_runai_model_streamer_loader.py diff --git a/tests/runai_model_streamer/test_weight_utils.py b/tests/runai_model_streamer_test/test_weight_utils.py similarity index 100% rename from tests/runai_model_streamer/test_weight_utils.py rename to tests/runai_model_streamer_test/test_weight_utils.py diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py index 9a92b08ff3ff..673d1b9a7ef6 100644 --- a/tests/samplers/test_ignore_eos.py +++ b/tests/samplers/test_ignore_eos.py @@ -10,7 +10,7 @@ # We also test with llama because it has generation_config to specify EOS # (past regression). -MODELS = ["facebook/opt-125m", "meta-llama/Llama-3.2-1B-Instruct"] +MODELS = ["distilbert/distilgpt2", "meta-llama/Llama-3.2-1B"] @pytest.mark.parametrize("model", MODELS) diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 3b95b038979f..f237b616077b 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -5,7 +5,7 @@ from vllm import SamplingParams -MODELS = ["facebook/opt-125m"] +MODELS = ["distilbert/distilgpt2"] @pytest.mark.parametrize("model", MODELS) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 59d36099c650..78bdd9b0b958 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -9,7 +9,7 @@ from ..conftest import VllmRunner -MODELS = ["facebook/opt-125m"] +MODELS = ["distilbert/distilgpt2"] @pytest.mark.parametrize("model", MODELS) diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index cc6557694c6c..143f52999415 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -76,7 +76,7 @@ def _encode(self, class TestTwoTokenBadWord: # Another model (with a different tokenizer behaviour) - MODEL = "openai-community/gpt2" + MODEL = "distilbert/distilgpt2" PROMPT = "How old are you? I am 10" TARGET_TOKEN1 = "years" diff --git a/tests/samplers/test_ranks.py b/tests/samplers/test_ranks.py index c74c1c02c247..66779d97a92c 100644 --- a/tests/samplers/test_ranks.py +++ b/tests/samplers/test_ranks.py @@ -4,7 +4,7 @@ from vllm import SamplingParams -MODELS = ["facebook/opt-125m"] +MODELS = ["distilbert/distilgpt2"] @pytest.mark.parametrize("model", MODELS) diff --git a/tests/test_config.py b/tests/test_config.py index 746ca7295a8e..4a1718613302 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,14 +8,19 @@ from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform +from .conftest import MODEL_WEIGHTS_S3_BUCKET + @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_task"), [ - ("facebook/opt-125m", "generate", "generate"), - ("intfloat/e5-mistral-7b-instruct", "pooling", "embed"), - ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), - ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), + (f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", "generate", "generate"), + (f"{MODEL_WEIGHTS_S3_BUCKET}/e5-mistral-7b-instruct", "pooling", + "embed"), + (f"{MODEL_WEIGHTS_S3_BUCKET}/Qwen2.5-1.5B-apeach", "pooling", + "classify"), + (f"{MODEL_WEIGHTS_S3_BUCKET}/ms-marco-MiniLM-L-6-v2", "pooling", + "score"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), ("openai/whisper-small", "transcription", "transcription"), ], diff --git a/tests/test_regression.py b/tests/test_regression.py index f781b3113b4c..e9b21e1a7232 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -10,6 +10,9 @@ import torch from vllm import LLM, SamplingParams +from vllm.config import LoadFormat + +from .conftest import MODEL_WEIGHTS_S3_BUCKET def test_duplicated_ignored_sequence_group(): @@ -18,7 +21,8 @@ def test_duplicated_ignored_sequence_group(): sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=256) - llm = LLM(model="facebook/opt-125m", + llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", + load_format=LoadFormat.RUNAI_STREAMER, max_num_batched_tokens=4096, tensor_parallel_size=1) prompts = ["This is a short prompt", "This is a very long prompt " * 1000] @@ -31,7 +35,8 @@ def test_max_tokens_none(): sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) - llm = LLM(model="facebook/opt-125m", + llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", + load_format=LoadFormat.RUNAI_STREAMER, max_num_batched_tokens=4096, tensor_parallel_size=1) prompts = ["Just say hello!"] @@ -41,7 +46,9 @@ def test_max_tokens_none(): def test_gc(): - llm = LLM("facebook/opt-125m", enforce_eager=True) + llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", + load_format=LoadFormat.RUNAI_STREAMER, + enforce_eager=True) del llm gc.collect() diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 7ae0f4bb8e80..2c337cc9fed2 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -10,7 +10,7 @@ def test_swap() -> None: # Configure the engine. - engine_args = EngineArgs(model="facebook/opt-125m", + engine_args = EngineArgs(model="s3://vllm-ci-model-weights/distilgpt2", dtype="half", load_format="dummy") engine_config = engine_args.create_engine_config() diff --git a/vllm/config.py b/vllm/config.py index 5c220ed13630..54227dda0441 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -409,7 +409,8 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, if is_s3(model) or is_s3(tokenizer): if is_s3(model): s3_model = S3Model() - s3_model.pull_files(model, allow_pattern=["*config.json"]) + s3_model.pull_files( + model, allow_pattern=["*.model", "*.py", "*.json"]) self.model_weights = self.model self.model = s3_model.dir diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 230484a36dec..df957cfca3c0 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1327,6 +1327,7 @@ def _prepare_weights(self, model_name_or_path: str, """Prepare weights for the model. If the model is not local, it will be downloaded.""" + is_s3_path = is_s3(model_name_or_path) is_local = os.path.isdir(model_name_or_path) safetensors_pattern = "*.safetensors" @@ -1340,7 +1341,6 @@ def _prepare_weights(self, model_name_or_path: str, revision, ignore_patterns=self.load_config.ignore_patterns, )) - if is_s3_path: hf_weights_files = s3_glob(path=hf_folder, allow_pattern=[safetensors_pattern]) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 18f6f40b32f0..ac1be383c15b 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -27,6 +27,8 @@ from vllm.platforms import current_platform from vllm.utils import PlaceholderModule +logger = init_logger(__name__) + try: from runai_model_streamer import SafetensorsStreamer except (ImportError, OSError): @@ -37,8 +39,6 @@ SafetensorsStreamer = runai_model_streamer.placeholder_attr( "SafetensorsStreamer") -logger = init_logger(__name__) - # use system-level temp directory for file locks, so that multiple users # can share the same lock without error. # lock files in the temp directory will be automatically deleted when the diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 2fed5d743e8e..4768226f9a03 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -144,7 +144,6 @@ def file_exists( revision: Optional[str] = None, token: Union[str, bool, None] = None, ) -> bool: - file_list = list_repo_files(repo_id, repo_type=repo_type, revision=revision, @@ -498,7 +497,7 @@ def get_sentence_transformer_tokenizer_config(model: str, if encoder_dict: break - if not encoder_dict: + if not encoder_dict and not model.startswith("/"): try: # If model is on HuggingfaceHub, get the repo files repo_files = list_repo_files(model, diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py index 4fe744d285d3..1c3520bcfb27 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/transformers_utils/s3_utils.py @@ -46,6 +46,8 @@ def glob(s3=None, """ if s3 is None: s3 = boto3.client("s3") + if not path.endswith("/"): + path = path + "/" bucket_name, _, paths = list_files(s3, path=path, allow_pattern=allow_pattern) @@ -109,6 +111,7 @@ def __init__(self) -> None: for sig in (signal.SIGINT, signal.SIGTERM): existing_handler = signal.getsignal(sig) signal.signal(sig, self._close_by_signal(existing_handler)) + self.dir = tempfile.mkdtemp() def __del__(self): @@ -140,6 +143,9 @@ def pull_files(self, ignore_pattern: A list of patterns of which files not to pull. """ + if not s3_model_path.endswith("/"): + s3_model_path = s3_model_path + "/" + bucket_name, base_dir, files = list_files(self.s3, s3_model_path, allow_pattern, ignore_pattern) @@ -147,8 +153,9 @@ def pull_files(self, return for file in files: - destination_file = os.path.join(self.dir, - file.removeprefix(base_dir)) + destination_file = os.path.join( + self.dir, + file.removeprefix(base_dir).lstrip("/")) local_dir = Path(destination_file).parent os.makedirs(local_dir, exist_ok=True) self.s3.download_file(bucket_name, file, destination_file) From 66a5a164abd63a3b4d1d6e97c57c68320123302b Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Tue, 18 Feb 2025 23:56:11 -0800 Subject: [PATCH 105/317] [perf-benchmark] Fix ECR path for premerge benchmark (#13512) Signed-off-by: <> Co-authored-by: EC2 Default User --- .../benchmark-pipeline.yaml | 100 ++++++++++++++++-- 1 file changed, 93 insertions(+), 7 deletions(-) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index d1c08de7c47c..4259514940d3 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -10,18 +10,24 @@ steps: - image: badouralix/curl-jq command: - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - + - label: "Cleanup H100" + agents: + queue: H100 + depends_on: ~ + command: docker system prune -a --volumes --force + - label: "A100" # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: A100 depends_on: wait-for-container-image + if: build.branch == "main" plugins: - kubernetes: podSpec: priorityClassName: perf-benchmark containers: - - image: public.ecr.aws/q9t5s3a7/${BUILDKITE_BRANCH:-main} == "main" && "vllm-ci-postmerge-repo" || "vllm-ci-test-repo"}:$BUILDKITE_COMMIT + - image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT command: - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh resources: @@ -50,9 +56,10 @@ steps: agents: queue: H200 depends_on: wait-for-container-image + if: build.branch == "main" plugins: - docker#v5.12.0: - image: public.ecr.aws/q9t5s3a7/${BUILDKITE_BRANCH:-main} == "main" && "vllm-ci-postmerge-repo" || "vllm-ci-test-repo"}:$BUILDKITE_COMMIT + image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT command: - bash - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -70,20 +77,99 @@ steps: #key: block-h100 #depends_on: ~ - - label: "Cleanup H100" + - label: "H100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H100 - depends_on: ~ - command: docker system prune -a --volumes --force + depends_on: wait-for-container-image + if: build.branch == "main" + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: all # see CUDA_VISIBLE_DEVICES for actual GPUs used + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN + + # Premerge benchmark + - label: "A100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: A100 + depends_on: wait-for-container-image + if: build.branch != "main" + plugins: + - kubernetes: + podSpec: + priorityClassName: perf-benchmark + containers: + - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + resources: + limits: + nvidia.com/gpu: 8 + volumeMounts: + - name: devshm + mountPath: /dev/shm + env: + - name: VLLM_USAGE_SOURCE + value: ci-test + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB + volumes: + - name: devshm + emptyDir: + medium: Memory + + - label: "H200" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: H200 + depends_on: wait-for-container-image + if: build.branch != "main" + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: 4,5,6,7 + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN + + #- block: "Run H100 Benchmark" + #key: block-h100 + #depends_on: ~ - label: "H100" # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H100 depends_on: wait-for-container-image + if: build.branch != "main" plugins: - docker#v5.12.0: - image: public.ecr.aws/q9t5s3a7/${BUILDKITE_BRANCH:-main} == "main" && "vllm-ci-postmerge-repo" || "vllm-ci-test-repo"}:$BUILDKITE_COMMIT + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT command: - bash - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh From ea489084c4dccd15a6610fea3f0353f0243227de Mon Sep 17 00:00:00 2001 From: Zhe Zhang <2631992879@qq.com> Date: Wed, 19 Feb 2025 16:05:02 +0800 Subject: [PATCH 106/317] use device param in load_model method (#13037) --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c7814f17375b..78cc352b1630 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1107,7 +1107,7 @@ def __init__( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: + with DeviceMemoryProfiler(self.device) as m: self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory From 761512863e2a77684479b9727de2501fd10707ee Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 19 Feb 2025 01:50:07 -0700 Subject: [PATCH 107/317] [Bugfix] Fix Positive Feature Layers in Llava Models (#13514) Signed-off-by: Alex-Brooks --- tests/models/test_vision.py | 34 +++++++++++++++++++++++++++ vllm/model_executor/models/clip.py | 2 +- vllm/model_executor/models/llava.py | 4 ++-- vllm/model_executor/models/pixtral.py | 2 +- vllm/model_executor/models/siglip.py | 2 +- vllm/model_executor/models/vision.py | 9 +++---- 6 files changed, 44 insertions(+), 9 deletions(-) create mode 100644 tests/models/test_vision.py diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py new file mode 100644 index 000000000000..d64c0e6d4e43 --- /dev/null +++ b/tests/models/test_vision.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from vllm.model_executor.models.vision import resolve_visual_encoder_outputs + + +@pytest.mark.parametrize( + ("feature_sample_layers", "num_layers_loaded", "max_possible_layers", + "expected_features"), + [ + # All layers loaded + ([1, 10], 10, 10, [1, 10]), + ([-10, -1], 10, 10, [1, 10]), + # Some layers not loaded + ([1, 10], 10, 20, [1, 10]), + ([-20, -11], 10, 20, [1, 10]), + ]) +def test_resolve_visual_encoder_outputs(feature_sample_layers, + num_layers_loaded, max_possible_layers, + expected_features): + """ + Test that offsets are correctly handled for vision feature layers. + """ + encoder_outputs = [ + torch.tensor([idx]) for idx in range(num_layers_loaded + 1) + ] + output_tensor = resolve_visual_encoder_outputs( + encoder_outputs=encoder_outputs, + feature_sample_layers=feature_sample_layers, + post_layer_norm=None, + max_possible_layers=max_possible_layers) + assert torch.equal(torch.tensor(expected_features), output_tensor) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 73c109a27ac7..dc3aa9cbe86b 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -251,7 +251,7 @@ def __init__( def forward( self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool ) -> Union[torch.Tensor, list[torch.Tensor]]: - hidden_states_pool = [] + hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds for encoder_layer in self.layers: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index dcd90474e936..6a4277adb6bf 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -428,7 +428,7 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: - """Given an signed vision feature layer, get the number of hidden layers + """Given a signed vision feature layer, get the number of hidden layers needed to leverage it. Args: @@ -438,7 +438,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: """ if feature_layer_index < 0: return num_hidden_layers + feature_layer_index + 1 - return feature_layer_index + 1 + return feature_layer_index def init_vision_tower_for_llava( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index e78e8d62cc47..44fca852805a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -969,7 +969,7 @@ def forward( position_embeddings: torch.Tensor, return_all_hidden_states: bool, ) -> torch.Tensor: - hidden_states_pool = [] + hidden_states_pool = [x] for layer in self.layers: x = layer(x, attention_mask, position_embeddings) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index ddae78d7739e..2892f696107b 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -378,7 +378,7 @@ def forward( inputs_embeds: torch.Tensor, return_all_hidden_states: bool, ) -> Union[torch.Tensor, list[torch.Tensor]]: - hidden_states_pool = [] + hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds for encoder_layer in self.layers: diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 0d67ee7bb5dd..9a6fac2eec56 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -132,10 +132,11 @@ def resolve_visual_encoder_outputs( # Get the hidden states corresponding to the layer indices. # Negative values are relative to the full visual encoder, # so offset them depending on how many layers were loaded. - # NOTE: this assumes that encoder_outputs contains a list - # of hidden states in the same order as the encoder layers - # that produced them. - offset = max_possible_layers - len(encoder_outputs) + # NOTE: this assumes that encoder_outputs is a list containing + # the inputs to the visual encoder, followed by the hidden states + # of each layer. + num_loaded_layers = len(encoder_outputs) - 1 + offset = max_possible_layers - num_loaded_layers hs_pool = [ encoder_outputs[layer_idx] if layer_idx >= 0 else encoder_outputs[layer_idx + offset] From 25f0d8742ecc261b6afe11369201a64e4838c9c7 Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Wed, 19 Feb 2025 01:06:23 -0800 Subject: [PATCH 108/317] [Model][Speculative Decoding] DeepSeek MTP spec decode (#12755) Signed-off-by: Lu Fang Co-authored-by: LiuXiaoxuanPKU --- .buildkite/test-pipeline.yaml | 22 +- tests/models/registry.py | 3 + tests/spec_decode/e2e/test_mtp_correctness.py | 318 ++++++++++++++++++ vllm/config.py | 43 ++- vllm/model_executor/models/deepseek_mtp.py | 284 ++++++++++++++++ vllm/model_executor/models/deepseek_v2.py | 22 +- vllm/model_executor/models/registry.py | 1 + vllm/sequence.py | 2 + vllm/spec_decode/draft_model_runner.py | 20 +- vllm/spec_decode/spec_decode_worker.py | 25 +- vllm/worker/model_runner.py | 20 +- vllm/worker/model_runner_base.py | 5 +- vllm/worker/worker.py | 6 +- vllm/worker/worker_base.py | 2 + 14 files changed, 727 insertions(+), 46 deletions(-) create mode 100644 tests/spec_decode/e2e/test_mtp_correctness.py create mode 100644 vllm/model_executor/models/deepseek_mtp.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9991060a3162..3918e3e86769 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -2,7 +2,7 @@ # adding a new command to an existing step. See different options here for examples. # This script will be feed into Jinja template in `test-template-aws.j2` at -# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 +# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 # to generate the final pipeline yaml file. # Documentation @@ -15,7 +15,7 @@ # mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] # gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 # num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. -# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, +# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, # in this case, commands must be specified. the first command runs on first host, the second # command runs on the second host. # working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests @@ -24,8 +24,8 @@ # When adding a test # - If the test belong to an existing group, add it there # - If the test is short, add to any existing step -# - If the test takes more than 10min, then it is okay to create a new step. -# Note that all steps execute in parallel. +# - If the test takes more than 10min, then it is okay to create a new step. +# Note that all steps execute in parallel. steps: ##### fast check tests ##### @@ -145,14 +145,14 @@ steps: - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py - label: Metrics, Tracing Test # 10min - num_gpus: 2 + num_gpus: 2 fast_check: true source_file_dependencies: - vllm/ - tests/metrics - tests/tracing commands: - - pytest -v -s metrics + - pytest -v -s metrics - "pip install \ 'opentelemetry-sdk>=1.26.0,<1.27.0' \ 'opentelemetry-api>=1.26.0,<1.27.0' \ @@ -254,7 +254,7 @@ steps: - vllm/model_executor/guided_decoding - tests/test_logits_processor - tests/model_executor/test_guided_processors - commands: + commands: - pytest -v -s test_logits_processor.py - pytest -v -s model_executor/test_guided_processors.py @@ -265,7 +265,7 @@ steps: - vllm/model_executor/models/eagle.py commands: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py - pytest -v -s spec_decode/e2e/test_eagle_correctness.py - label: LoRA Test %N # 15min each @@ -580,7 +580,7 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn # This test runs llama 13B, so it is required to run on 4 GPUs. - pytest -v -s -x lora/test_long_context.py - # There is some Tensor Parallelism related processing logic in LoRA that + # There is some Tensor Parallelism related processing logic in LoRA that # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py @@ -605,7 +605,7 @@ steps: - vllm/ - tests/weight_loading commands: - - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ##### multi gpus test ##### @@ -617,7 +617,7 @@ steps: num_gpus: 4 source_file_dependencies: - vllm/ - commands: + commands: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py diff --git a/tests/models/registry.py b/tests/models/registry.py index 8b0ece161009..d89a41dae3aa 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -296,6 +296,9 @@ def check_available_online( speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 + "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", + speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 + trust_remote_code=True), } _FALLBACK_MODEL = { diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py new file mode 100644 index 000000000000..0bad19f61d30 --- /dev/null +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, mtp would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_equality_correctness_test + +# main model +MAIN_MODEL = "luccafong/deepseek_mtp_main_random" + +# max. number of speculative tokens: this corresponds to +# num_nextn_predict_layers in the config.json of the speculator model. +MAX_SPEC_TOKENS = 1 + +# precision +PRECISION = "bfloat16" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int): + + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size: int, + output_len: int, seed: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_with_preemption( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_different_k(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_disable_queue(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/config.py b/vllm/config.py index 54227dda0441..59fa60fd8b0c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -763,7 +763,7 @@ def get_hidden_size(self) -> int: def is_deepseek_mla(self) -> bool: return (hasattr(self.hf_text_config, "model_type")) \ and (self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3'))\ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\ and (self.hf_text_config.kv_lora_rank is not None) def get_head_size(self) -> int: @@ -856,8 +856,12 @@ def get_num_attention_heads(self, def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> Tuple[int, int]: from vllm.distributed.utils import get_pp_indices - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) + if self.hf_text_config.model_type == "deepseek_mtp": + total_num_hidden_layers = getattr(self.hf_text_config, + "num_nextn_predict_layers", 0) + else: + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) @@ -1689,6 +1693,18 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str + @staticmethod + def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + return hf_config + @staticmethod def maybe_create_spec_config( target_model_config: ModelConfig, @@ -1771,12 +1787,18 @@ def maybe_create_spec_config( Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. """ - if speculative_model is None: if num_speculative_tokens is not None: - raise ValueError("num_speculative_tokens was provided without " - "speculative_model.") - return None + if target_model_config.hf_text_config.model_type \ + == "deepseek_v3": + # use the draft model from the same model: + speculative_model = target_model_config.model + else: + raise ValueError( + "num_speculative_tokens was provided without " + "speculative_model.") + else: + return None if (speculative_disable_by_batch_size is not None and speculative_disable_by_batch_size < 2): @@ -1830,6 +1852,7 @@ def maybe_create_spec_config( max_seq_len_to_capture=target_model_config. max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, ) draft_hf_config = draft_model_config.hf_config @@ -1846,7 +1869,6 @@ def maybe_create_spec_config( if (num_speculative_tokens is not None and hasattr(draft_hf_config, "num_lookahead_tokens")): draft_hf_config.num_lookahead_tokens = num_speculative_tokens - n_predict = getattr(draft_hf_config, "n_predict", None) if n_predict is not None: if num_speculative_tokens is None: @@ -1960,8 +1982,9 @@ def _verify_and_get_draft_model_tensor_parallel_size( speculative_draft_tensor_parallel_size = 1 if target_parallel_config.tensor_parallel_size > 1: logger.warning( - "MLPSpeculator cannot currently be run with tp>1; " - "setting speculative_draft_tensor_parallel_size=1") + "%s cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1", + draft_hf_config.model_type) else: speculative_draft_tensor_parallel_size = \ target_parallel_config.tensor_parallel_size diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py new file mode 100644 index 000000000000..1a051992a306 --- /dev/null +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .deepseek_v2 import (DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name) +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, + cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=None) + hidden_states = residual + hidden_states + return self.shared_head(hidden_states) + + +class DeepSeekMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + DeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + input_ids, + positions, + kv_caches[spec_step_idx], + attn_metadata, + previous_hidden_states, + inputs_embeds, + spec_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + hidden_states, sampling_metadata) + return logits + + +class DeepSeekMTP(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, previous_hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + return name diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fd0e58fa1458..a4d52c613b3e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -732,13 +732,9 @@ def load_weights(self, weights: Iterable[Tuple[str, if "rotary_emb.inv_freq" in name: continue - # TODO(simon): support nextn predict layers - if hasattr(self.config, "num_nextn_predict_layers" - ) and self.config.num_nextn_predict_layers > 0: - assert self.config.num_nextn_predict_layers == 1 - layer_idx = self.config.num_hidden_layers - if name.startswith(f"model.layers.{layer_idx}"): - continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -805,3 +801,15 @@ def load_weights(self, weights: Iterable[Tuple[str, class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 775398e003cd..81623defd337 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -187,6 +187,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), + "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } diff --git a/vllm/sequence.py b/vllm/sequence.py index 45d0e5bc7680..c0425ba33c9a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1307,6 +1307,8 @@ class ExecuteModelRequest( previous_hidden_states: Optional[HiddenStates] = None # The number of forward steps to run. num_steps: int = 1 + # The step index for spec model input. + spec_step_idx: Optional[int] = None # Finished request ids since last step. finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3948298db40c..7353d3c53ae9 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -153,7 +153,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add support for other attn backends - if self.attn_backend.get_name() != "FLASH_ATTN": + if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): return False # TODO: Add support for LORA @@ -175,6 +175,7 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[List[SamplerOutput]]: """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. @@ -271,10 +272,17 @@ def execute_model( for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} - kwargs = {"previous_hidden_states": hidden_states} \ + model_execute_kwargs = {"previous_hidden_states": hidden_states} \ if previous_hidden_states is not None else {} + compute_logits_kwargs = {} # Run model + if hasattr(self.model.config, "num_nextn_predict_layers"): + # for DeepSeek MTP only to use the corresponding layer for + # each step + spec_step_idx = kwargs.get("spec_step_idx", step) + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_states = model_executable( @@ -285,13 +293,15 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **kwargs, + **model_execute_kwargs, ) # Compute the logits. logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - + model_input.sampling_metadata, + **compute_logits_kwargs) + if not self.is_driver_worker: + return [] # Sample the next token. output = self.model.sample( logits=logits, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 33b1be54c8b3..fce06a81ff04 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -108,6 +108,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_alpha, disable_logprobs=speculative_config.disable_logprobs, disable_log_stats=speculative_config.disable_log_stats, + num_speculative_tokens=speculative_config.num_speculative_tokens, ) return spec_decode_worker @@ -153,10 +154,12 @@ def create_worker( typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, disable_log_stats: bool, + num_speculative_tokens: int, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True enable_lm_head_weight_load = False + num_spec_prefill_steps = 1 ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -179,14 +182,16 @@ def create_worker( elif draft_model_config.hf_config.model_type == "medusa": proposer_worker = MedusaWorker(**draft_worker_kwargs) else: - if draft_tp == 1: + if draft_tp == 1 or draft_model_config.hf_config.model_type ==\ + "deepseek_mtp": if current_platform.is_cuda_alike(): draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( - "EAGLE does not support TP > 1 yet") + f"{draft_model_config.hf_config.model_type} " + "does not support TP > 1 yet") allow_zero_draft_token_step = False @@ -195,6 +200,8 @@ def create_worker( enable_lm_head_weight_load = True proposer_worker = MultiStepWorker(**draft_worker_kwargs) + if draft_model_config.hf_config.model_type == "deepseek_mtp": + num_spec_prefill_steps = num_speculative_tokens proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) @@ -247,7 +254,8 @@ def create_worker( disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step, - enable_lm_head_weight_load=enable_lm_head_weight_load) + enable_lm_head_weight_load=enable_lm_head_weight_load, + num_spec_prefill_steps=num_spec_prefill_steps) def __init__( self, @@ -261,6 +269,7 @@ def __init__( disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, enable_lm_head_weight_load: Optional[bool] = False, + num_spec_prefill_steps: int = 1, ): """ Create a SpecDecodeWorker. @@ -293,6 +302,10 @@ def __init__( draft model is larger than 1 (TODO: #5814) enable_lm_head_weight_load: whether to load lm_head weight for draft models like eagle. + num_spec_prefill_steps: number of speculative prefill steps to run + before the speculative decoding starts. This is only used when + the draft model is a deepseek_mtp model that requires prefill + kv cache separately for each MTP layer. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker @@ -326,6 +339,7 @@ def __init__( self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs self._disable_log_stats = disable_log_stats + self._num_spec_prefill_steps = num_spec_prefill_steps def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -685,8 +699,9 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) - - self.proposer_worker.execute_model(execute_model_req) + for i in range(self._num_spec_prefill_steps): + execute_model_req.spec_step_idx = i + self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 78cc352b1630..67d175c373d8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -99,6 +99,7 @@ class ModelInputForGPU(ModelRunnerInputBase): virtual_engine: int = 0 async_callback: Optional[Callable] = None scheduler_outputs: Optional[SchedulerOutputs] = None + previous_hidden_states: Optional[torch.Tensor] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1649,6 +1650,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1706,6 +1708,10 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} + previous_hidden_states = kwargs.get("previous_hidden_states") + model_kwargs = {} + if previous_hidden_states is not None: + model_kwargs["previous_hidden_states"] = previous_hidden_states if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch.cuda.Event(enable_timing=True) @@ -1723,7 +1729,9 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **seqlen_agnostic_kwargs) + **seqlen_agnostic_kwargs, + **model_kwargs, + ) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): @@ -1815,7 +1823,7 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: 1. current vLLM instance is KV cache consumer/decode vLLM instance 2. this batch is not a profiling run 3. this batch is a prefill run - + Args: model_input: input to the model executable kv_caches: vLLM's paged memory @@ -1840,7 +1848,7 @@ def need_send_kv(self, model_input, kv_caches) -> bool: 1. current vLLM instance is KV cache producer/prefill vLLM instance 2. this batch is not a profiling run 3. this batch is a prefill run - + Args: model_input: input to the model executable kv_caches: vLLM's paged memory @@ -1976,7 +1984,11 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) if positions is not None: - self.input_buffers["positions"].copy_(positions, non_blocking=True) + # in some case like MLA, it will reuse positions in metadata + # but truncate them to the original size + # so the shape is not padded, we need to copy partial only + self.input_buffers["positions"][:positions.shape[0]].copy_( + positions, non_blocking=True) if self.backend_name != "NO_ATTENTION": self.input_buffers["slot_mapping"].copy_( diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 38d2b712eff5..bae37cb7155f 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -46,7 +46,10 @@ def _init_attn_metadata_from_tensor_dict( valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): if field.name in tensor_dict: - valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) + if field.name == "input_positions": + valid_attn_kwargs[field.name] = tensor_dict[field.name] + else: + valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) tensor_dict["attn_metadata"] = attn_metadata diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 582aa460eb4f..ff38e3bfc207 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -68,10 +68,10 @@ def __init__( speculative_config = self.speculative_config model_config = self.model_config speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type == + model_config.hf_config.model_type) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ + not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 83fcf0865ae1..190429074d56 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -397,6 +397,8 @@ def execute_model( model_input, worker_input, kwargs = inputs num_steps = worker_input.num_steps + if (execute_model_req is not None and execute_model_req.spec_step_idx): + kwargs["spec_step_idx"] = execute_model_req.spec_step_idx self.execute_worker(worker_input) From 9288c5868f9b57906c3d285762e731f91deb27bf Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 19 Feb 2025 01:09:22 -0800 Subject: [PATCH 109/317] [V1][Core] Generic mechanism for handling engine utility (#13060) Signed-off-by: Nick Hill --- tests/lora/test_add_lora.py | 2 +- tests/v1/engine/test_engine_core_client.py | 57 ++++++++-- vllm/v1/engine/__init__.py | 24 +++- vllm/v1/engine/core.py | 49 ++++++--- vllm/v1/engine/core_client.py | 121 ++++++++++++++++----- 5 files changed, 197 insertions(+), 56 deletions(-) diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index df8031cba687..2b421bfd9eb8 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -41,7 +41,7 @@ def download_and_prepare_lora_module(): ] for tokenizer_file in tokenizer_files: del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file - del_path.unlink() + del_path.unlink(missing_ok=True) @pytest.fixture(autouse=True) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 45080be8e8ce..828d7eed309f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -3,7 +3,8 @@ import asyncio import time import uuid -from typing import Dict, List +from contextlib import ExitStack +from typing import Dict, List, Optional import pytest from transformers import AutoTokenizer @@ -14,7 +15,9 @@ from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.core import EngineCore +from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, + SyncMPClient) from vllm.v1.executor.abstract import Executor if not current_platform.is_cuda(): @@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict): async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): while True: - engine_core_outputs = await client.get_output_async().outputs + engine_core_outputs = (await client.get_output_async()).outputs if len(engine_core_outputs) == 0: break @@ -78,6 +81,14 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): break +# Dummy utility function to monkey-patch into engine core. +def echo(self, msg: str, err_msg: Optional[str] = None) -> str: + print(f"echo util function called: {msg}, {err_msg}") + if err_msg is not None: + raise ValueError(err_msg) + return msg + + @fork_new_process_for_each_test @pytest.mark.parametrize("multiprocessing_mode", [True, False]) def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): @@ -85,7 +96,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) + # Monkey-patch core engine utility function to test. + m.setattr(EngineCore, "echo", echo, raising=False) + + engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) @@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): client.abort_requests([request.request_id]) + if multiprocessing_mode: + """Utility method invocation""" -@fork_new_process_for_each_test -@pytest.mark.asyncio + core_client: SyncMPClient = client + + result = core_client._call_utility("echo", "testarg") + assert result == "testarg" + + with pytest.raises(Exception) as e_info: + core_client._call_utility("echo", None, "help!") + + assert str(e_info.value) == "Call to echo method failed: help!" + + +@pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_asyncio(monkeypatch): - with monkeypatch.context() as m: + with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") - engine_args = EngineArgs(model=MODEL_NAME) + # Monkey-patch core engine utility function to test. + m.setattr(EngineCore, "echo", echo, raising=False) + + engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) @@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch): executor_class=executor_class, log_stats=True, ) + after.callback(client.shutdown) MAX_TOKENS = 20 params = SamplingParams(max_tokens=MAX_TOKENS) @@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch): else: assert len(outputs[req_id]) == MAX_TOKENS, ( f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + """Utility method invocation""" + + core_client: AsyncMPClient = client + + result = await core_client._call_utility_async("echo", "testarg") + assert result == "testarg" + + with pytest.raises(Exception) as e_info: + await core_client._call_utility_async("echo", None, "help!") + + assert str(e_info.value) == "Call to echo method failed: help!" diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index dee7102bb47b..7420dde1f7e4 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -2,7 +2,7 @@ import enum import time -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import msgspec @@ -106,6 +106,18 @@ def finished(self) -> bool: return self.finish_reason is not None +class UtilityOutput( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + + call_id: int + + # Non-None implies the call failed, result should be None. + failure_message: Optional[str] = None + result: Any = None + + class EngineCoreOutputs( msgspec.Struct, array_like=True, # type: ignore[call-arg] @@ -116,10 +128,12 @@ class EngineCoreOutputs( # e.g. columnwise layout # [num_reqs] - outputs: List[EngineCoreOutput] - scheduler_stats: Optional[SchedulerStats] + outputs: List[EngineCoreOutput] = [] + scheduler_stats: Optional[SchedulerStats] = None timestamp: float = 0.0 + utility_output: Optional[UtilityOutput] = None + def __post_init__(self): if self.timestamp == 0.0: self.timestamp = time.monotonic() @@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum): """ ADD = b'\x00' ABORT = b'\x01' - PROFILE = b'\x02' - RESET_PREFIX_CACHE = b'\x03' - ADD_LORA = b'\x04' + UTILITY = b'\x02' diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 6718a5f7b02d..66e252b7ccb0 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -5,9 +5,11 @@ import threading import time from concurrent.futures import Future +from inspect import isclass, signature from multiprocessing.connection import Connection from typing import Any, List, Optional, Tuple, Type +import msgspec import psutil import zmq import zmq.asyncio @@ -21,7 +23,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType) + EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput @@ -330,19 +332,39 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, self.add_request(request) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) - elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE: - self.reset_prefix_cache() - elif request_type == EngineCoreRequestType.PROFILE: - self.model_executor.profile(request) - elif request_type == EngineCoreRequestType.ADD_LORA: - self.model_executor.add_lora(request) + elif request_type == EngineCoreRequestType.UTILITY: + call_id, method_name, args = request + output = UtilityOutput(call_id) + try: + method = getattr(self, method_name) + output.result = method( + *self._convert_msgspec_args(method, args)) + except BaseException as e: + logger.exception("Invocation of %s method failed", method_name) + output.failure_message = (f"Call to {method_name} method" + f" failed: {str(e)}") + self.output_queue.put_nowait( + EngineCoreOutputs(utility_output=output)) + + @staticmethod + def _convert_msgspec_args(method, args): + """If a provided arg type doesn't match corresponding target method + arg type, try converting to msgspec object.""" + if not args: + return args + arg_types = signature(method).parameters.values() + assert len(args) <= len(arg_types) + return tuple( + msgspec.convert(v, type=p.annotation) if isclass(p.annotation) + and issubclass(p.annotation, msgspec.Struct) + and not isinstance(v, p.annotation) else v + for v, p in zip(args, arg_types)) def process_input_socket(self, input_path: str): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) - add_lora_decoder = MsgpackDecoder(LoRARequest) generic_decoder = MsgpackDecoder() with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: @@ -352,14 +374,9 @@ def process_input_socket(self, input_path: str): request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. - decoder = None - if request_type == EngineCoreRequestType.ADD: - decoder = add_request_decoder - elif request_type == EngineCoreRequestType.ADD_LORA: - decoder = add_lora_decoder - else: - decoder = generic_decoder - + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder request = decoder.decode(data_frame.buffer) # Push to input queue for core busy loop. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 07176629e949..8641833e438b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,10 +2,14 @@ import asyncio import os +import queue import signal +import uuid import weakref from abc import ABC, abstractmethod -from typing import Any, List, Optional, Type +from concurrent.futures import Future +from threading import Thread +from typing import Any, Dict, List, Optional, Type, Union import zmq import zmq.asyncio @@ -16,7 +20,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType) + EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -24,6 +28,8 @@ logger = init_logger(__name__) +AnyFuture = Union[asyncio.Future[Any], Future[Any]] + class EngineCoreClient(ABC): """ @@ -204,6 +210,8 @@ def sigusr1_handler(signum, frame): "log_stats": log_stats, }) + self.utility_results: Dict[int, AnyFuture] = {} + def shutdown(self): """Clean up background resources.""" if hasattr(self, "proc_handle"): @@ -212,6 +220,16 @@ def shutdown(self): self._finalizer() +def _process_utility_output(output: UtilityOutput, + utility_results: Dict[int, AnyFuture]): + """Set the result from a utility method in the waiting future""" + future = utility_results.pop(output.call_id) + if output.failure_message is not None: + future.set_exception(Exception(output.failure_message)) + else: + future.set_result(output.result) + + class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" @@ -224,10 +242,30 @@ def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor], log_stats=log_stats, ) - def get_output(self) -> EngineCoreOutputs: + self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() - (frame, ) = self.output_socket.recv_multipart(copy=False) - return self.decoder.decode(frame.buffer) + # Ensure that the outputs socket processing thread does not have + # a ref to the client which prevents gc. + output_socket = self.output_socket + decoder = self.decoder + utility_results = self.utility_results + outputs_queue = self.outputs_queue + + def process_outputs_socket(): + while True: + (frame, ) = output_socket.recv_multipart(copy=False) + outputs = decoder.decode(frame.buffer) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + else: + outputs_queue.put_nowait(outputs) + + # Process outputs from engine in separate thread. + Thread(target=process_outputs_socket, daemon=True).start() + + def get_output(self) -> EngineCoreOutputs: + return self.outputs_queue.get() def _send_input(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -236,6 +274,16 @@ def _send_input(self, request_type: EngineCoreRequestType, msg = (request_type.value, self.encoder.encode(request)) self.input_socket.send_multipart(msg, copy=False) + def _call_utility(self, method: str, *args) -> Any: + call_id = uuid.uuid1().int >> 64 + future: Future[Any] = Future() + self.utility_results[call_id] = future + + self._send_input(EngineCoreRequestType.UTILITY, + (call_id, method, args)) + + return future.result() + def add_request(self, request: EngineCoreRequest) -> None: # NOTE: text prompt is not needed in the core engine as it has been # tokenized. @@ -247,13 +295,13 @@ def abort_requests(self, request_ids: List[str]) -> None: self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: - self._send_input(EngineCoreRequestType.PROFILE, is_start) + self._call_utility("profile", is_start) def reset_prefix_cache(self) -> None: - self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) + self._call_utility("reset_prefix_cache") def add_lora(self, lora_request: LoRARequest) -> None: - self._send_input(EngineCoreRequestType.ADD_LORA, lora_request) + self._call_utility("add_lora", lora_request) class AsyncMPClient(MPClient): @@ -268,24 +316,35 @@ def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor], log_stats=log_stats, ) - self.outputs_queue: Optional[asyncio.Queue[bytes]] = None + self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None self.queue_task: Optional[asyncio.Task] = None + async def _start_output_queue_task(self): + # Perform IO in separate task to parallelize as much as possible. + # Avoid task having direct reference back to the client. + self.outputs_queue = asyncio.Queue() + output_socket = self.output_socket + decoder = self.decoder + utility_results = self.utility_results + outputs_queue = self.outputs_queue + + async def process_outputs_socket(): + while True: + (frame, ) = await output_socket.recv_multipart(copy=False) + outputs: EngineCoreOutputs = decoder.decode(frame.buffer) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + else: + outputs_queue.put_nowait(outputs) + + self.queue_task = asyncio.create_task(process_outputs_socket()) + async def get_output_async(self) -> EngineCoreOutputs: if self.outputs_queue is None: - # Perform IO in separate task to parallelize as much as possible - self.outputs_queue = asyncio.Queue() - - async def process_outputs_socket(): - assert self.outputs_queue is not None - while True: - (frame, ) = await self.output_socket.recv_multipart( - copy=False) - self.outputs_queue.put_nowait(frame.buffer) - - self.queue_task = asyncio.create_task(process_outputs_socket()) - - return self.decoder.decode(await self.outputs_queue.get()) + await self._start_output_queue_task() + assert self.outputs_queue is not None + return await self.outputs_queue.get() async def _send_input(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -293,6 +352,18 @@ async def _send_input(self, request_type: EngineCoreRequestType, msg = (request_type.value, self.encoder.encode(request)) await self.input_socket.send_multipart(msg, copy=False) + if self.outputs_queue is None: + await self._start_output_queue_task() + + async def _call_utility_async(self, method: str, *args) -> Any: + call_id = uuid.uuid1().int >> 64 + future = asyncio.get_running_loop().create_future() + self.utility_results[call_id] = future + await self._send_input(EngineCoreRequestType.UTILITY, + (call_id, method, args)) + + return await future + async def add_request_async(self, request: EngineCoreRequest) -> None: # NOTE: text prompt is not needed in the core engine as it has been # tokenized. @@ -304,10 +375,10 @@ async def abort_requests_async(self, request_ids: List[str]) -> None: await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: - await self._send_input(EngineCoreRequestType.PROFILE, is_start) + await self._call_utility_async("profile", is_start) async def reset_prefix_cache_async(self) -> None: - await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) + await self._call_utility_async("reset_prefix_cache") async def add_lora_async(self, lora_request: LoRARequest) -> None: - await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request) + await self._call_utility_async("add_lora", lora_request) From b5c9857b711c8528435ffe759bc98f9cfe83775b Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Wed, 19 Feb 2025 10:16:38 +0100 Subject: [PATCH 110/317] [Feature] Pluggable platform-specific scheduler (#13161) Signed-off-by: Yannick Schnider Signed-off-by: Yannick Schnider --- .buildkite/test-pipeline.yaml | 1 + tests/plugins_tests/test_scheduler_plugins.py | 33 +++++++++++++++++++ vllm/config.py | 4 +++ vllm/engine/arg_utils.py | 10 ++++++ vllm/engine/llm_engine.py | 11 +++++-- 5 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 tests/plugins_tests/test_scheduler_plugins.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3918e3e86769..9d05ff4c2cfd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -531,6 +531,7 @@ steps: - pip uninstall vllm_add_dummy_platform -y # end platform plugin tests # other tests continue here: + - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process diff --git a/tests/plugins_tests/test_scheduler_plugins.py b/tests/plugins_tests/test_scheduler_plugins.py new file mode 100644 index 000000000000..84688cee9660 --- /dev/null +++ b/tests/plugins_tests/test_scheduler_plugins.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.core.scheduler import Scheduler + + +class DummyScheduler(Scheduler): + + def schedule(self): + raise Exception("Exception raised by DummyScheduler") + + +def test_scheduler_plugins(): + import pytest + + from vllm.engine.arg_utils import EngineArgs + from vllm.engine.llm_engine import LLMEngine + from vllm.sampling_params import SamplingParams + + with pytest.raises(Exception) as exception_info: + + engine_args = EngineArgs( + model="facebook/opt-125m", + enforce_eager=True, # reduce test time + scheduler_cls=DummyScheduler, + ) + + engine = LLMEngine.from_engine_args(engine_args=engine_args) + + sampling_params = SamplingParams(max_tokens=1) + engine.add_request("0", "foo", sampling_params) + engine.step() + + assert str(exception_info.value) == "Exception raised by DummyScheduler" diff --git a/vllm/config.py b/vllm/config.py index 59fa60fd8b0c..56315aacbe51 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1495,6 +1495,10 @@ class SchedulerConfig: chunked_prefill_enabled: bool = field(init=False) + # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) + # or "mod.custom_class". + scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5f076f05d046..78681008b62e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -192,6 +192,7 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False scheduling_policy: Literal["fcfs", "priority"] = "fcfs" + scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None @@ -938,6 +939,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'priority (lower value means earlier handling) and time of ' 'arrival deciding any ties).') + parser.add_argument( + '--scheduler-cls', + default=EngineArgs.scheduler_cls, + help='The scheduler class to use. "vllm.core.scheduler.Scheduler" ' + 'is the default scheduler. Can be a class directly or the path to ' + 'a class of form "mod.custom_class".') + parser.add_argument( '--override-neuron-config', type=json.loads, @@ -1273,10 +1281,12 @@ def create_engine_config(self, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, + scheduler_cls=self.scheduler_cls, max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, ) + lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2e5bc75c6db3..3ce9a0461368 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -19,8 +19,7 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, - SchedulerOutputs) +from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.output_processor.interfaces import ( @@ -58,7 +57,8 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind +from vllm.utils import (Counter, Device, deprecate_kwargs, + resolve_obj_by_qualname, weak_bind) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -346,6 +346,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. + if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str): + Scheduler = resolve_obj_by_qualname( + self.vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = self.vllm_config.scheduler_config.scheduler_cls self.scheduler = [ Scheduler( self.scheduler_config, self.cache_config, self.lora_config, From 7862a7c559aca20ecc8eb8fdb15859799c3b5bce Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Wed, 19 Feb 2025 11:48:03 +0100 Subject: [PATCH 111/317] [CI/Build] force writing version file (#13544) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- pyproject.toml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ac155116ccde..1c03e9e17be5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ Slack="http://slack.vllm.ai/" vllm = "vllm.entrypoints.cli.main:main" [tool.setuptools_scm] -version_file = "vllm/_version.py" +# no extra settings needed, presence enables setuptools-scm [tool.setuptools.packages.find] where = ["."] diff --git a/setup.py b/setup.py index d09ae4b3810d..d8a336c2d426 100755 --- a/setup.py +++ b/setup.py @@ -499,7 +499,7 @@ def get_gaudi_sw_version(): def get_vllm_version() -> str: - version = get_version() + version = get_version(write_to="vllm/_version.py") sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): From 81c6d1482784b5bf8fcbd3628a54525098e3fe00 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 19 Feb 2025 20:55:58 +0800 Subject: [PATCH 112/317] [doc] clarify profiling is only for developers (#13554) Signed-off-by: youkaichao --- docs/source/contributing/profiling/profiling_index.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/contributing/profiling/profiling_index.md b/docs/source/contributing/profiling/profiling_index.md index 79aeb292a9b7..3d044f890382 100644 --- a/docs/source/contributing/profiling/profiling_index.md +++ b/docs/source/contributing/profiling/profiling_index.md @@ -1,15 +1,15 @@ # Profiling vLLM +:::{warning} +Profiling is only intended for vLLM developers and maintainers to understand the proportion of time spent in different parts of the codebase. **vLLM end-users should never turn on profiling** as it will significantly slow down the inference. +::: + We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/` The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set. When using `benchmarks/benchmark_serving.py`, you can enable profiling by passing the `--profile` flag. -:::{warning} -Only enable profiling in a development environment. -::: - Traces can be visualized using . :::{tip} From b6e30576deb9b405181dfef61d99f21f87f46cef Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 19 Feb 2025 21:13:50 +0800 Subject: [PATCH 113/317] [VLM][Bugfix] Pass processor kwargs properly on init (#13516) Signed-off-by: DarkLight1337 --- .../vision_language_multi_image.py | 1 + .../multimodal/processing/test_common.py | 7 +- .../multimodal/processing/test_h2ovl.py | 225 ++++++++++-------- .../multimodal/processing/test_idefics3.py | 24 +- .../multimodal/processing/test_internvl.py | 142 ++++++++--- .../multimodal/processing/test_llava_next.py | 17 +- .../processing/test_llava_onevision.py | 17 +- .../multimodal/processing/test_phi3v.py | 13 +- .../multimodal/processing/test_qwen2_vl.py | 16 +- tests/models/utils.py | 18 +- tests/multimodal/test_processing.py | 10 +- vllm/inputs/registry.py | 77 +++--- vllm/model_executor/models/aria.py | 4 +- vllm/model_executor/models/chameleon.py | 4 +- vllm/model_executor/models/deepseek_vl2.py | 15 +- vllm/model_executor/models/fuyu.py | 4 +- vllm/model_executor/models/glm4v.py | 15 +- vllm/model_executor/models/gritlm.py | 9 +- vllm/model_executor/models/h2ovl.py | 41 +++- vllm/model_executor/models/idefics3.py | 12 +- vllm/model_executor/models/internvl.py | 45 +++- vllm/model_executor/models/llava.py | 27 ++- vllm/model_executor/models/llava_next.py | 4 +- .../model_executor/models/llava_next_video.py | 4 +- vllm/model_executor/models/llava_onevision.py | 4 +- vllm/model_executor/models/minicpmv.py | 7 +- vllm/model_executor/models/mllama.py | 4 +- vllm/model_executor/models/molmo.py | 4 +- vllm/model_executor/models/nvlm_d.py | 19 +- vllm/model_executor/models/paligemma.py | 4 +- vllm/model_executor/models/phi3v.py | 5 +- vllm/model_executor/models/pixtral.py | 20 +- vllm/model_executor/models/qwen2_5_vl.py | 47 +--- vllm/model_executor/models/qwen2_audio.py | 3 +- vllm/model_executor/models/qwen2_vl.py | 94 +++++--- vllm/model_executor/models/qwen_vl.py | 9 +- vllm/model_executor/models/ultravox.py | 3 +- vllm/model_executor/models/whisper.py | 8 +- vllm/multimodal/image.py | 5 +- vllm/multimodal/registry.py | 14 +- vllm/multimodal/utils.py | 5 +- vllm/multimodal/video.py | 8 +- vllm/transformers_utils/processor.py | 92 ++++++- vllm/transformers_utils/tokenizer.py | 22 +- 44 files changed, 675 insertions(+), 453 deletions(-) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index b2821966cf12..5dc6a936d1c1 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -85,6 +85,7 @@ def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData: trust_remote_code=True, max_model_len=8192, limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"max_dynamic_patch": 4}, ) placeholders = "\n".join(f"Image-{i}: \n" diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 88dcc32f44f5..331ffe82ec85 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -10,7 +10,7 @@ from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.processing import ProcessingCache -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from ....multimodal.utils import random_audio, random_image, random_video from ...registry import HF_EXAMPLE_MODELS @@ -42,10 +42,7 @@ def _test_processing_correctness( factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] ctx = InputProcessingContext( model_config, - tokenizer=cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_info.trust_remote_code, - ), + tokenizer=cached_tokenizer_from_config(model_config), ) # Ensure that it can fit all of the data cache = ProcessingCache(capacity=1 << 30) diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 767ac5eb9ef9..5c43e4eed787 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -1,17 +1,118 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for H2OVL's multimodal preprocessing kwargs.""" -from typing import Optional +from typing import Mapping, Optional import pytest +from PIL import Image +from transformers import PretrainedConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import rescale_image_size -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from ....conftest import _ImageAssets from ...utils import build_model_context +def _get_expected_num_patches( + config: PretrainedConfig, + image: Image.Image, + num_imgs: int, + min_num: int, + max_num: int, +): + from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets, + get_h2ovl_target_ratios) + + width, height = image.size + + # Calculate the expected number of blocks + if num_imgs == 1 and config.use_msac: + # First pass + blocks1, _, _, aspect_ratio = calculate_h2ovl_targets( + orig_width=width, + orig_height=height, + target_ratios=get_h2ovl_target_ratios( + min_num=1, + max_num=max_num, + prior_aspect_ratio=None, + ), + image_size=config.vision_config.image_size, + use_thumbnail=False, # Thumbnail is handled separately + ) + + # Second pass + blocks2, _, _, _ = calculate_h2ovl_targets( + orig_width=width, + orig_height=height, + target_ratios=get_h2ovl_target_ratios( + min_num=3, + max_num=max_num, + prior_aspect_ratio=aspect_ratio, + ), + image_size=config.vision_config.image_size, + use_thumbnail=False, + ) + + # Add thumbnail if use_thumbnail is True and total_blocks > 1 + if config.use_thumbnail: + blocks1 += 1 if blocks1 > 1 else 0 + blocks2 += 1 if blocks2 > 1 else 0 + + # Total blocks is the sum of blocks from both passes minus + # overlapping + total_blocks = blocks1 + blocks2 - 1 + + return total_blocks + + blocks, _, _, _ = calculate_h2ovl_targets( + orig_width=width, + orig_height=height, + target_ratios=get_h2ovl_target_ratios( + min_num, + max_num, + prior_aspect_ratio=None, + ), + image_size=config.vision_config.image_size, + use_thumbnail=False, + ) + expected_num_patches = blocks + + if config.use_thumbnail and expected_num_patches > 1: + expected_num_patches += 1 + + return expected_num_patches + + +def _run_check( + processor: BaseMultiModalProcessor, + images: list[Image.Image], + min_num: int, + max_num: int, + mm_processor_kwargs: Mapping[str, object], +): + tokenizer = processor.info.get_tokenizer() + config = processor.info.get_hf_config() + + mm_data = {"image": images} + + total_expected_num_patches = sum( + _get_expected_num_patches(config, image, len(images), min_num, max_num) + for image in images) + + processed_inputs = processor.apply("" * len(images), mm_data, + mm_processor_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + image_token_id = tokenizer.convert_tokens_to_ids("") + img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) + pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + + assert img_tok_count == 256 * total_expected_num_patches + assert pixel_shape[0] == total_expected_num_patches + + @pytest.mark.parametrize("model_id", [ "h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-2b", @@ -25,118 +126,54 @@ [1.0, 1.0, 1.0], # Multi-scale [0.25, 0.5, 1.0], + [4.0, 2.0, 1.0], ], ) -@pytest.mark.parametrize("max_dynamic_patch", [1, 2, 4, 8]) +@pytest.mark.parametrize( + ("min_dynamic_patch", "max_dynamic_patch"), + [(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)], +) @pytest.mark.parametrize("dynamic_image_size", [True, False]) -@pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( model_id: str, image_assets: _ImageAssets, size_factors: list[int], + min_dynamic_patch: int, max_dynamic_patch: int, dynamic_image_size: Optional[bool], - num_imgs: int, + kwargs_on_init: bool, ): - from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets, - get_h2ovl_target_ratios) + mm_processor_kwargs = { + "min_dynamic_patch": min_dynamic_patch, + "max_dynamic_patch": max_dynamic_patch, + "dynamic_image_size": dynamic_image_size, + } ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, trust_remote_code=True, - mm_processor_kwargs=None, - limit_mm_per_prompt={"image": num_imgs}, - ) - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, + limit_mm_per_prompt={"image": len(size_factors)}, ) + tokenizer = cached_tokenizer_from_config(ctx.model_config) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, tokenizer=tokenizer, ) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs - config = processor.info.get_hf_config() - use_msac = config.use_msac - - mm_processor_kwargs = { - "max_dynamic_patch": max_dynamic_patch, - } - if dynamic_image_size is not None: - mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size - - min_num = config.min_dynamic_patch + min_num = min_dynamic_patch if dynamic_image_size else 1 max_num = max_dynamic_patch if dynamic_image_size else 1 - # Build the image str / prompt based on the number of images we pass - prompt = "" * num_imgs - - for asset in image_assets: - for factor in size_factors: - image = rescale_image_size(asset.pil_image, factor) - mm_data = {"image": [image] * num_imgs} - - width, height = image.size - - # Calculate the expected number of blocks - if num_imgs == 1 and use_msac: - # First pass - blocks1, _, _, aspect_ratio = calculate_h2ovl_targets( - orig_width=width, - orig_height=height, - target_ratios=get_h2ovl_target_ratios( - min_num, - max_num, - prior_aspect_ratio=None, - ), - image_size=config.vision_config.image_size, - use_thumbnail=False, # Thumbnail is handled separately - ) - - # Second pass - blocks2, _, _, _ = calculate_h2ovl_targets( - orig_width=width, - orig_height=height, - target_ratios=get_h2ovl_target_ratios( - min_num, - max_num, - prior_aspect_ratio=aspect_ratio, - ), - image_size=config.vision_config.image_size, - use_thumbnail=False, - ) - - # Add thumbnail if use_thumbnail is True and total_blocks > 1 - if config.use_thumbnail: - blocks1 += 1 if blocks1 > 1 else 0 - blocks2 += 1 if blocks2 > 1 else 0 - - # Total blocks is the sum of blocks from both passes minus - # overlapping - total_blocks = blocks1 + blocks2 - 1 - - expected_num_patches = total_blocks - else: - blocks, _, _, _ = calculate_h2ovl_targets( - orig_width=width, - orig_height=height, - target_ratios=get_h2ovl_target_ratios( - min_num, - max_num, - prior_aspect_ratio=None, - ), - image_size=config.vision_config.image_size, - use_thumbnail=False, - ) - expected_num_patches = blocks - - if config.use_thumbnail and expected_num_patches != 1: - expected_num_patches += 1 - - processed_inputs = processor.apply(prompt, mm_data, - mm_processor_kwargs) - pixel_shape = ( - processed_inputs["mm_kwargs"]["pixel_values_flat"].shape) - - assert pixel_shape[0] == expected_num_patches * num_imgs + _run_check( + processor, + [ + rescale_image_size(image_assets[0].pil_image, f) + for f in size_factors + ], + min_num, + max_num, + hf_processor_mm_kwargs, + ) diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index 07ab1bbd4b5e..0a0f1cb38938 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -4,7 +4,7 @@ from transformers import Idefics3Config from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from ....conftest import _ImageAssets from ...utils import build_model_context @@ -22,9 +22,15 @@ ]) # yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) -def test_processor_override(image_assets: _ImageAssets, model: str, - mm_processor_kwargs: dict[str, object], - expected_toks_per_img: int, num_imgs: int): +@pytest.mark.parametrize("kwargs_on_init", [True, False]) +def test_processor_override( + image_assets: _ImageAssets, + model: str, + mm_processor_kwargs: dict[str, object], + expected_toks_per_img: int, + num_imgs: int, + kwargs_on_init: bool, +): """Ensure input_processor_for_idefics3 handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by @@ -33,15 +39,15 @@ def test_processor_override(image_assets: _ImageAssets, model: str, model_name=model, tokenizer_name=model, trust_remote_code=True, - mm_processor_kwargs=None, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + tokenizer = cached_tokenizer_from_config(ctx.model_config) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, tokenizer=tokenizer, ) - hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs # Build the image str / prompt based on the number of images we pass placeholders = "" if num_imgs == 1 else "\n".join( @@ -54,8 +60,10 @@ def test_processor_override(image_assets: _ImageAssets, model: str, dummy_image = image_assets[0].pil_image.resize(dummy_image_size) mm_data = {"image": [dummy_image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + # Ensure the placeholders format are correct + hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[ "input_ids"][0] diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index ede961225be7..cc777fdf57b3 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -1,64 +1,136 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for InternVL's multimodal preprocessing kwargs.""" -from typing import Optional +from typing import Mapping, Optional import pytest +from PIL import Image +from transformers import PretrainedConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.image import rescale_image_size +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from ....conftest import _ImageAssets from ...utils import build_model_context +def _get_expected_num_patches( + config: PretrainedConfig, + image: Image.Image, + num_imgs: int, + min_num: int, + max_num: int, +): + from vllm.model_executor.models.internvl import ( + calculate_internvl_targets, get_internvl_target_ratios) + + width, height = image.size + + blocks, _, _ = calculate_internvl_targets( + orig_width=width, + orig_height=height, + target_ratios=get_internvl_target_ratios( + min_num, + max_num, + ), + image_size=config.vision_config.image_size, + use_thumbnail=False, + ) + expected_num_patches = blocks + + if config.use_thumbnail and expected_num_patches > 1: + expected_num_patches += 1 + + return expected_num_patches + + +def _run_check( + processor: BaseMultiModalProcessor, + images: list[Image.Image], + min_num: int, + max_num: int, + mm_processor_kwargs: Mapping[str, object], +): + tokenizer = processor.info.get_tokenizer() + config = processor.info.get_hf_config() + + mm_data = {"image": images} + + total_expected_num_patches = sum( + _get_expected_num_patches(config, image, len(images), min_num, max_num) + for image in images) + + processed_inputs = processor.apply("" * len(images), mm_data, + mm_processor_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + image_token_id = tokenizer.convert_tokens_to_ids("") + img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) + pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + + assert img_tok_count == 256 * total_expected_num_patches + assert pixel_shape[0] == total_expected_num_patches + + @pytest.mark.parametrize("model_id", ["OpenGVLab/InternVL2-2B"]) -@pytest.mark.parametrize("max_dynamic_patch", [1, 4]) -@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) -@pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize( + "size_factors", + [ + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + [4.0, 2.0, 1.0], + ], +) +@pytest.mark.parametrize( + ("min_dynamic_patch", "max_dynamic_patch"), + [(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)], +) +@pytest.mark.parametrize("dynamic_image_size", [True, False]) +@pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( model_id: str, image_assets: _ImageAssets, + size_factors: list[int], + min_dynamic_patch: int, max_dynamic_patch: int, dynamic_image_size: Optional[bool], - num_imgs: int, + kwargs_on_init: bool, ): + mm_processor_kwargs = { + "min_dynamic_patch": min_dynamic_patch, + "max_dynamic_patch": max_dynamic_patch, + "dynamic_image_size": dynamic_image_size, + } + ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, trust_remote_code=True, - mm_processor_kwargs=None, - limit_mm_per_prompt={"image": num_imgs}, - ) - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, + limit_mm_per_prompt={"image": len(size_factors)}, ) + tokenizer = cached_tokenizer_from_config(ctx.model_config) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, tokenizer=tokenizer, ) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs - mm_processor_kwargs = { - "max_dynamic_patch": max_dynamic_patch, - } - if dynamic_image_size is not None: - mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size + min_num = min_dynamic_patch if dynamic_image_size else 1 + max_num = max_dynamic_patch if dynamic_image_size else 1 - # Build the image str / prompt based on the number of images we pass - prompt = "" * num_imgs - image = image_assets[0].pil_image.resize((448 * 2, 448 * 2)) - mm_data = {"image": [image] * num_imgs} - - expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 - if dynamic_image_size is False: - expected_num_patches = 1 - - processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) - - # Ensure we have the right number of placeholders per num_crops size - image_token_id = tokenizer.convert_tokens_to_ids("") - img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape - - assert img_tok_count == 256 * expected_num_patches * num_imgs - assert pixel_shape[0] == expected_num_patches * num_imgs + _run_check( + processor, + [ + rescale_image_size(image_assets[0].pil_image, f) + for f in size_factors + ], + min_num, + max_num, + hf_processor_mm_kwargs, + ) diff --git a/tests/models/multimodal/processing/test_llava_next.py b/tests/models/multimodal/processing/test_llava_next.py index fe4754c2ef6f..dca25e5d4c4c 100644 --- a/tests/models/multimodal/processing/test_llava_next.py +++ b/tests/models/multimodal/processing/test_llava_next.py @@ -10,7 +10,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize from vllm.multimodal.processing import BaseMultiModalProcessor -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from ...utils import build_model_context @@ -43,10 +43,7 @@ def test_processor_max_tokens(model_id): ) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, - tokenizer=cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, - ), + tokenizer=cached_tokenizer_from_config(ctx.model_config), ) info = processor.info @@ -146,10 +143,7 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, - tokenizer=cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, - ), + tokenizer=cached_tokenizer_from_config(ctx.model_config), ) image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), @@ -179,10 +173,7 @@ def test_processor_prompt_replacements_all(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, - tokenizer=cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, - ), + tokenizer=cached_tokenizer_from_config(ctx.model_config), ) seen_aspect_ratios = set[float]() diff --git a/tests/models/multimodal/processing/test_llava_onevision.py b/tests/models/multimodal/processing/test_llava_onevision.py index fb650d9e0995..96abc840f052 100644 --- a/tests/models/multimodal/processing/test_llava_onevision.py +++ b/tests/models/multimodal/processing/test_llava_onevision.py @@ -10,7 +10,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize from vllm.multimodal.processing import BaseMultiModalProcessor -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from ...utils import build_model_context @@ -44,10 +44,7 @@ def test_processor_max_tokens(model_id): ) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, - tokenizer=cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, - ), + tokenizer=cached_tokenizer_from_config(ctx.model_config), ) info = processor.info @@ -146,10 +143,7 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, - tokenizer=cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, - ), + tokenizer=cached_tokenizer_from_config(ctx.model_config), ) image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), @@ -180,10 +174,7 @@ def test_processor_prompt_replacements_all(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, - tokenizer=cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, - ), + tokenizer=cached_tokenizer_from_config(ctx.model_config), ) seen_aspect_ratios = set[float]() diff --git a/tests/models/multimodal/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py index dde8904f2ef6..420644f70842 100644 --- a/tests/models/multimodal/processing/test_phi3v.py +++ b/tests/models/multimodal/processing/test_phi3v.py @@ -3,7 +3,7 @@ import pytest from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from ....conftest import _ImageAssets from ...utils import build_model_context @@ -21,12 +21,14 @@ ]) # yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( image_assets: _ImageAssets, model_id: str, mm_processor_kwargs: dict[str, int], expected_toks_per_img: int, num_imgs: int, + kwargs_on_init: bool, ): """Ensure input_processor_for_phi3v handles num_crops properly.""" # Avoid initializing CUDA early @@ -36,23 +38,22 @@ def test_processor_override( model_name=model_id, tokenizer_name=model_id, trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, - ) + tokenizer = cached_tokenizer_from_config(ctx.model_config) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, tokenizer=tokenizer, ) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs # Build the image str / prompt based on the number of images we pass img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" mm_data = {"image": [image_assets[0].pil_image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index ef8e97f82d0b..b882528aafb9 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -3,7 +3,7 @@ import pytest from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from ....conftest import _ImageAssets from ...utils import build_model_context @@ -18,6 +18,7 @@ ]) # yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( image_assets: _ImageAssets, model_id: str, @@ -25,31 +26,30 @@ def test_processor_override( expected_toks_per_img: int, expected_pixels_shape: tuple[int, int], num_imgs: int, + kwargs_on_init: bool, ): """Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly.""" ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, - mm_processor_kwargs=None, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - trust_remote_code=ctx.model_config.trust_remote_code, - ) + tokenizer = cached_tokenizer_from_config(ctx.model_config) processor = MULTIMODAL_REGISTRY.create_processor( ctx.model_config, tokenizer=tokenizer, ) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs # Build the image str / prompt based on the number of images we pass prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs mm_data = {"image": [image_assets[0].pil_image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) # Ensure we have the right number of placeholders per num_crops size - hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) + hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape diff --git a/tests/models/utils.py b/tests/models/utils.py index e2be43c12667..a90efb176722 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -248,13 +248,16 @@ def check_logprobs_close( warnings.warn(fail_msg, stacklevel=2) -def build_model_context(model_name: str, - task: TaskOption = "auto", - tokenizer_name: Optional[str] = None, - trust_remote_code: bool = False, - dtype: Optional[Union[str, torch.dtype]] = None, - mm_processor_kwargs: Optional[Dict] = None, - limit_mm_per_prompt: Optional[Dict] = None): +def build_model_context( + model_name: str, + task: TaskOption = "auto", + tokenizer_name: Optional[str] = None, + trust_remote_code: bool = False, + dtype: Optional[Union[str, torch.dtype]] = None, + mm_processor_kwargs: Optional[Dict] = None, + limit_mm_per_prompt: Optional[Dict] = None, + disable_mm_preprocessor_cache: bool = True, +): """Creates an InputContext for a given model. Args: @@ -283,5 +286,6 @@ def build_model_context(model_name: str, seed=0, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt=limit_mm_per_prompt, + disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, ) return InputContext(model_config) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 7bbe5c53562d..b247321ebb2f 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -22,8 +22,8 @@ replace_token_matches) # yapf: enable from vllm.multimodal.profiling import MultiModalProfiler -from vllm.multimodal.utils import cached_get_tokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + cached_tokenizer_from_config) from vllm.utils import full_groupby from .utils import random_image @@ -577,7 +577,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): processor = MULTIMODAL_REGISTRY.create_processor( model_config, - tokenizer=cached_get_tokenizer(model_config.tokenizer), + tokenizer=cached_tokenizer_from_config(model_config), ) profiler = MultiModalProfiler(processor) @@ -617,7 +617,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): processor = MULTIMODAL_REGISTRY.create_processor( model_config, - tokenizer=cached_get_tokenizer(model_config.tokenizer), + tokenizer=cached_tokenizer_from_config(model_config), ) rng = np.random.RandomState(0) @@ -689,7 +689,7 @@ def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs): processor = MULTIMODAL_REGISTRY.create_processor( model_config, - tokenizer=cached_get_tokenizer(model_config.tokenizer), + tokenizer=cached_tokenizer_from_config(model_config), ) orig_get_hf_processor = processor.info.get_hf_processor diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 87b7a7631e42..691fcd7dc53f 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -11,8 +11,9 @@ from typing_extensions import TypeVar, assert_never from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_get_processor -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + cached_tokenizer_from_config) from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, resolve_mm_processor_kwargs) @@ -27,19 +28,9 @@ logger = init_logger(__name__) -C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) -P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin) - - -class HashableDict(dict): - """ - A dictionary that can be hashed by lru_cache. - """ - - # NOTE: pythonic dict is not hashable, - # we override on it directly for simplicity - def __hash__(self) -> int: # type: ignore[override] - return hash(frozenset(self.items())) +_T = TypeVar("_T") +_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) +_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) @dataclass(frozen=True) @@ -54,9 +45,9 @@ class InputContext: def get_hf_config( self, - typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig, + typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig, /, - ) -> C: + ) -> _C: """ Get the HuggingFace configuration (:class:`transformers.PretrainedConfig`) of the model, @@ -94,10 +85,10 @@ def get_mm_config(self): def get_hf_processor( self, - typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin, + typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, /, **kwargs: object, - ) -> P: + ) -> _P: """ Get the HuggingFace processor (:class:`transformers.ProcessorMixin`) of the model, @@ -106,33 +97,29 @@ def get_hf_processor( Raises: TypeError: If the processor is not of the specified type. """ + return cached_processor_from_config( + self.model_config, + processor_cls=typ, + **kwargs, + ) + + def init_processor( + self, + typ: type[_T], + /, + **kwargs: object, + ) -> _T: + """ + Initialize a HuggingFace-like processor class, merging the + keyword arguments with those in the model's configuration. + """ base_kwargs = self.model_config.mm_processor_kwargs if base_kwargs is None: base_kwargs = {} merged_kwargs = {**base_kwargs, **kwargs} - if isinstance(typ, type): - merged_kwargs["processor_cls"] = typ - - # NOTE: Pythonic dict is not hashable and will raise unhashable type - # error when calling `cached_get_processor`, therefore we need to - # wrap it to a hashable dict. - for key, value in merged_kwargs.items(): - if isinstance(value, dict): - merged_kwargs[key] = HashableDict(value) - - hf_processor = cached_get_processor( - self.model_config.model, - trust_remote_code=self.model_config.trust_remote_code, - **merged_kwargs, - ) - if not isinstance(hf_processor, typ): - raise TypeError("Invalid type of HuggingFace processor. " - f"Expected type: {typ}, but " - f"found type: {type(hf_processor)}") - - return hf_processor + return typ(**merged_kwargs) @dataclass(frozen=True) @@ -142,10 +129,10 @@ class InputProcessingContext(InputContext): def get_hf_processor( self, - typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin, + typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, /, **kwargs: object, - ) -> P: + ) -> _P: return super().get_hf_processor( typ, tokenizer=self.tokenizer, @@ -341,13 +328,9 @@ def dummy_data_for_profiling( from vllm.model_executor.model_loader import get_model_architecture from vllm.multimodal import MultiModalKwargs from vllm.multimodal.profiling import MultiModalProfiler - from vllm.multimodal.utils import cached_get_tokenizer if mm_registry.has_processor(model_config): - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code, - ) + tokenizer = cached_tokenizer_from_config(model_config) processor = mm_registry.create_processor(model_config, tokenizer) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_dummy_data( diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index df73a3b76b1f..bff4100a1dee 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -400,8 +400,8 @@ def get_hf_config(self): def get_vision_config(self): return self.get_hf_config().vision_config - def get_hf_processor(self): - return self.ctx.get_hf_processor(AriaProcessor) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(AriaProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index b29dd65a8e35..2d4dfab60730 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -58,8 +58,8 @@ class ChameleonProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(ChameleonConfig) - def get_hf_processor(self): - return self.ctx.get_hf_processor(ChameleonProcessor) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 0eaf3a6201f6..5f684fa295ad 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -28,13 +28,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs -from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, MlpProjectorConfig, VisionEncoderConfig) from vllm.transformers_utils.processors.deepseek_vl2 import ( DeepseekVLV2Processor) +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP @@ -133,8 +133,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(DeepseekVLV2Config) - def get_hf_processor(self) -> DeepseekVLV2Processor: - return self.ctx.get_hf_processor(DeepseekVLV2Processor) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(DeepseekVLV2Processor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} @@ -308,13 +308,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.text_config = config.text_config model_config = vllm_config.model_config - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - tokenizer_revision=model_config.tokenizer_revision, - trust_remote_code=model_config.trust_remote_code, - ) - self.image_token_id = tokenizer.vocab.get(_IMAGE_TOKEN) + tokenizer = cached_tokenizer_from_config(model_config) + self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] self.vision = self._init_vision_module(self.vision_config, quant_config, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 4e0ee6364f86..42a6aa979427 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -71,8 +71,8 @@ class FuyuProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(FuyuConfig) - def get_hf_processor(self): - return self.ctx.get_hf_processor(FuyuProcessor) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(FuyuProcessor, **kwargs) def get_image_processor(self) -> FuyuImageProcessor: return self.get_hf_processor().image_processor diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 450421302a19..40010ec55906 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -416,18 +416,15 @@ def __call__( class GLM4VProcessingInfo(BaseProcessingInfo): - def get_tokenizer(self): - tokenizer = self.ctx.tokenizer - assert isinstance(tokenizer, PreTrainedTokenizer) - return tokenizer - def get_hf_config(self): return self.ctx.get_hf_config(ChatGLMConfig) - def get_hf_processor(self) -> GLM4VProcessor: - return GLM4VProcessor( - self.get_hf_config(), - self.get_tokenizer(), + def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor: + return self.ctx.init_processor( + GLM4VProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, ) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 7bda54ea7689..0f3a2ffe9a13 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -15,9 +15,9 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) -from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput) +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config logger = init_logger(__name__) @@ -29,12 +29,7 @@ def __init__(self, model_config: ModelConfig): self.model_config = model_config - tokenizer = cached_get_tokenizer( - self.model_config.tokenizer, - tokenizer_mode=self.model_config.tokenizer_mode, - tokenizer_revision=self.model_config.tokenizer_revision, - trust_remote_code=self.model_config.trust_remote_code, - ) + tokenizer = cached_tokenizer_from_config(self.model_config) # Collect the tokens needed for pattern matching. # "▁<" is different from "_<". The former uses "▁" to indicate that diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index cf3e777a2027..01b721fa79e1 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -41,6 +41,7 @@ def resolve_h2ovl_min_max_num( dynamic_image_size: bool, use_thumbnail: bool, ) -> tuple[int, int]: + min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if use_thumbnail and max_dynamic_patch != 1: @@ -190,7 +191,7 @@ def image_to_pixel_values_h2ovl( pixel_values1, aspect_ratio1 = _preprocess_image( image, input_size=input_size, - min_num=min_num, + min_num=1, max_num=max_num, use_thumbnail=True, prior_aspect_ratio=None, @@ -199,7 +200,7 @@ def image_to_pixel_values_h2ovl( pixel_values2, _ = _preprocess_image( image, input_size=input_size, - min_num=3, # Hardcoded value + min_num=3, max_num=max_num, use_thumbnail=True, prior_aspect_ratio=aspect_ratio1, @@ -228,6 +229,7 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, use_msac: Optional[bool] = None, @@ -235,6 +237,7 @@ def __init__( super().__init__( config, tokenizer, + min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) @@ -267,11 +270,13 @@ def get_image_repl_full( def resolve_min_max_num( self, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: - min_dynamic_patch = self.min_dynamic_patch + min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch + is None else min_dynamic_patch) max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch) dynamic_image_size = (self.dynamic_image_size if dynamic_image_size @@ -289,18 +294,21 @@ def resolve_min_max_num( def resolve_target_ratios( self, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, prior_aspect_ratio: Optional[tuple[int, int]] = None, + override_min_num: Optional[int] = None, ) -> list[tuple[int, int]]: min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, use_thumbnail=use_thumbnail, ) - if prior_aspect_ratio: # hardcoded value for second pass of use_msac - min_num = 3 + if override_min_num is not None: + min_num = override_min_num return get_h2ovl_target_ratios( min_num, @@ -322,6 +330,7 @@ def get_num_image_tokens( if use_msac: target_ratios_1 = self.resolve_target_ratios( use_thumbnail=False, # Applied in calculate_targets + override_min_num=1, ) num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets( orig_width=image_width, @@ -334,6 +343,7 @@ def get_num_image_tokens( target_ratios_2 = self.resolve_target_ratios( use_thumbnail=False, # Applied in calculate_targets prior_aspect_ratio=aspect_ratio_1, + override_min_num=3, ) num_patches_2, _, _, _ = calculate_h2ovl_targets( orig_width=image_width, @@ -361,12 +371,14 @@ def get_num_image_tokens( def _images_to_pixel_values_lst( self, images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, ) -> list[torch.Tensor]: use_msac = self.use_msac if len(images) == 1 else False min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, use_thumbnail=False, # Applied in image_to_pixel_values @@ -389,14 +401,23 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): def get_hf_processor( self, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, + **kwargs: object, ) -> H2OVLProcessor: - return H2OVLProcessor( - self.get_hf_config(), - self.get_tokenizer(), - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + H2OVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, ) def get_mm_max_tokens_per_item( diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index fdfabbaafce3..579253632c81 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -83,13 +83,15 @@ class Idefics3ImageEmbeddingInputs(TypedDict): class Idefics3ProcessingInfo(BaseProcessingInfo): def get_hf_processor( - self, - *, - size: Optional[Dict[str, int]] = None) -> Idefics3Processor: + self, + *, + size: Optional[Dict[str, int]] = None, + **kwargs: object, + ) -> Idefics3Processor: if size is not None: - return self.ctx.get_hf_processor(Idefics3Processor, size=size) + kwargs["size"] = size - return self.ctx.get_hf_processor(Idefics3Processor) + return self.ctx.get_hf_processor(Idefics3Processor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 380eb40d9eb2..4a6007876776 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -120,6 +120,7 @@ def resolve_internvl_min_max_num( dynamic_image_size: bool, use_thumbnail: bool, ) -> tuple[int, int]: + min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if use_thumbnail and max_dynamic_patch != 1: @@ -247,6 +248,7 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, ) -> None: @@ -258,18 +260,22 @@ def __init__( image_size: int = config.vision_config.image_size patch_size: int = config.vision_config.patch_size - if dynamic_image_size is None: - dynamic_image_size = config.dynamic_image_size - assert isinstance(dynamic_image_size, bool) + if min_dynamic_patch is None: + min_dynamic_patch = config.min_dynamic_patch + assert isinstance(min_dynamic_patch, int) if max_dynamic_patch is None: max_dynamic_patch = config.max_dynamic_patch assert isinstance(max_dynamic_patch, int) + if dynamic_image_size is None: + dynamic_image_size = config.dynamic_image_size + assert isinstance(dynamic_image_size, bool) + self.num_image_token = int( (image_size // patch_size)**2 * (config.downsample_ratio**2)) self.image_size = image_size - self.min_dynamic_patch: int = config.min_dynamic_patch + self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch self.dynamic_image_size = dynamic_image_size self.use_thumbnail: bool = config.use_thumbnail @@ -298,11 +304,13 @@ def get_image_repl_full( def resolve_min_max_num( self, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: - min_dynamic_patch = self.min_dynamic_patch + min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch + is None else min_dynamic_patch) max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch) dynamic_image_size = (self.dynamic_image_size if dynamic_image_size @@ -320,11 +328,13 @@ def resolve_min_max_num( def resolve_target_ratios( self, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> list[tuple[int, int]]: min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, use_thumbnail=use_thumbnail, @@ -355,10 +365,12 @@ def get_num_image_tokens( def _images_to_pixel_values_lst( self, images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, ) -> list[torch.Tensor]: min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, use_thumbnail=False, # Applied in image_to_pixel_values @@ -378,6 +390,7 @@ def __call__( self, text: Optional[Union[str, list[str]]] = None, images: Optional[Union[Image.Image, list[Image.Image]]] = None, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, @@ -396,6 +409,7 @@ def __call__( else: pixel_values_lst = self._images_to_pixel_values_lst( images, + min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) @@ -451,8 +465,10 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): def get_hf_processor( self, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, + **kwargs: object, ) -> BaseInternVLProcessor: raise NotImplementedError @@ -642,14 +658,23 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo): def get_hf_processor( self, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, + **kwargs: object, ) -> InternVLProcessor: - return InternVLProcessor( - self.get_hf_config(), - self.get_tokenizer(), - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + InternVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 6a4277adb6bf..19752ba703f4 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -119,7 +119,7 @@ def get_vision_encoder_info(self): return get_vision_encoder_info(self.get_hf_config()) @abstractmethod - def get_hf_processor(self) -> LlavaLikeProcessor: + def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor: raise NotImplementedError def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: @@ -208,8 +208,8 @@ def get_dummy_processor_inputs( class LlavaProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_processor(self): - return self.ctx.get_hf_processor(LlavaProcessor) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(LlavaProcessor, **kwargs) class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): @@ -272,8 +272,8 @@ def _get_mm_fields_config( class PixtralHFProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_processor(self): - return self.ctx.get_hf_processor(PixtralProcessor) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(PixtralProcessor, **kwargs) class PixtralHFMultiModalProcessor( @@ -742,23 +742,24 @@ def load_weights(self, weights: Iterable[Tuple[str, class MantisProcessingInfo(LlavaProcessingInfo): - def get_hf_processor(self): + def get_hf_processor(self, **kwargs: object): hf_config = self.get_hf_config() vision_info = self.get_vision_encoder_info() + kwargs.setdefault("patch_size", vision_info.get_patch_size()) + if Version(TRANSFORMERS_VERSION) < Version("4.48"): # BUG: num_additional_image_tokens = 0 but treated as 1, # so we set vision_feature_select_strategy to None to offset this - vision_feature_select_strategy = None + kwargs.setdefault("vision_feature_select_strategy", None) else: # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150 - vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501 + kwargs.setdefault( + "vision_feature_select_strategy", + hf_config.vision_feature_select_strategy, + ) - return self.ctx.get_hf_processor( - LlavaProcessor, - patch_size=vision_info.get_patch_size(), - vision_feature_select_strategy=vision_feature_select_strategy, - ) + return self.ctx.get_hf_processor(LlavaProcessor, **kwargs) class MantisMultiModalProcessor(LlavaMultiModalProcessor): diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 719916642f25..c39daec709fc 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -72,8 +72,8 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): def get_hf_config(self) -> LlavaNextLikeConfig: return self.ctx.get_hf_config(LlavaNextConfig) - def get_hf_processor(self): - hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor) + def get_hf_processor(self, **kwargs: object): + hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs) # In case patch_size is omitted from `processor_config.json` # e.g. for E5-V: https://huggingface.co/royokong/e5-v diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 817edcef4ba1..2af3cc05080a 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -56,8 +56,8 @@ def get_hf_config(self): def get_vision_encoder_info(self): return get_vision_encoder_info(self.get_hf_config()) - def get_hf_processor(self): - return self.ctx.get_hf_processor(LlavaNextVideoProcessor) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(LlavaNextVideoProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"video": 1} diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 084d4d51ad23..8eb8071e6577 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -97,8 +97,8 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): def get_hf_config(self) -> LlavaOnevisionLikeConfig: return self.ctx.get_hf_config(LlavaOnevisionConfig) - def get_hf_processor(self): - return self.ctx.get_hf_processor(LlavaOnevisionProcessor) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2083e7dc0b83..97596f9e82c6 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -331,11 +331,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config() - def get_hf_processor( - self, - **kwargs: object, - ): - hf_processor = self.ctx.get_hf_processor() + def get_hf_processor(self, **kwargs: object): + hf_processor = self.ctx.get_hf_processor(**kwargs) # NumPy arrays are considered as Iterable but not Sequence in # https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428 diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 3ca22d346b79..1f8f5b2eb136 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -94,8 +94,8 @@ class MllamaProcessingInfo(BaseProcessingInfo): def get_hf_config(self) -> MllamaConfig: return self.ctx.get_hf_config(MllamaConfig) - def get_hf_processor(self) -> MllamaProcessor: - return self.ctx.get_hf_processor(MllamaProcessor) + def get_hf_processor(self, **kwargs: object) -> MllamaProcessor: + return self.ctx.get_hf_processor(MllamaProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index b2154ef54af3..1d84d25c96ac 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1200,8 +1200,8 @@ def __call__( class MolmoProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self) -> MolmoProcessorWrapper: - processor = self.ctx.get_hf_processor() + def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper: + processor = self.ctx.get_hf_processor(**kwargs) return MolmoProcessorWrapper(processor) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 9c674ab46446..5de8eeb3fffe 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -69,14 +69,23 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo): def get_hf_processor( self, *, + min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, + **kwargs: object, ) -> NVLMProcessor: - return NVLMProcessor( - self.get_hf_config(), - self.get_tokenizer(), - max_dynamic_patch=max_dynamic_patch, - dynamic_image_size=dynamic_image_size, + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + NVLMProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, ) def get_max_image_tokens(self) -> int: diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 65d810dc23bc..955a59953eb4 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -16,8 +16,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, @@ -88,7 +88,7 @@ def input_processor_for_paligemma(ctx: InputContext, model_config = ctx.model_config hf_config = ctx.get_hf_config(PaliGemmaConfig) - tokenizer = cached_get_tokenizer(model_config.tokenizer) + tokenizer = cached_tokenizer_from_config(model_config) image_feature_size = hf_config.text_config.num_image_tokens image_token_str = tokenizer.decode(hf_config.image_token_index) bos_token = tokenizer.decode(hf_config.bos_token_id) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6bbfa40beed1..207204df2055 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -313,11 +313,12 @@ def get_hf_processor( self, *, num_crops: Optional[int] = None, + **kwargs: object, ) -> ProcessorMixin: if num_crops is not None: - return self.ctx.get_hf_processor(num_crops=num_crops) + kwargs["num_crops"] = num_crops - return self.ctx.get_hf_processor() + return self.ctx.get_hf_processor(**kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 44fca852805a..273dc3b1cf75 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -32,9 +32,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import NestedTensors, PlaceholderRange -from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges) +from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsMultiModal, SupportsPP from .utils import (init_vllm_registered_model, maybe_prefix, @@ -49,9 +49,7 @@ def get_max_pixtral_image_tokens(ctx: InputContext): - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - tokenizer_mode=ctx.model_config.tokenizer_mode) + tokenizer = cached_tokenizer_from_config(ctx.model_config) mm_encoder = tokenizer.instruct.mm_encoder image_config = mm_encoder.mm_config if hasattr( @@ -65,9 +63,7 @@ def get_max_pixtral_image_tokens(ctx: InputContext): def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - tokenizer_mode=ctx.model_config.tokenizer_mode) + tokenizer = cached_tokenizer_from_config(ctx.model_config) mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder image_token_id = mm_encoder.special_ids.img @@ -109,9 +105,7 @@ def input_mapper_for_pixtral(ctx: InputContext, MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ - model_config = ctx.model_config - tokenizer = cached_get_tokenizer( - model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) + tokenizer = cached_tokenizer_from_config(ctx.model_config) data_list = data if isinstance(data, list) else [data] @@ -138,9 +132,7 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): prompt_token_ids = inputs.get("prompt_token_ids") prompt = inputs.get("prompt") - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - tokenizer_mode=ctx.model_config.tokenizer_mode) + tokenizer = cached_tokenizer_from_config(ctx.model_config) mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder image_token_id = mm_encoder.special_ids.img diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 632ecaf65f2f..29187eb2ef9c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -36,8 +36,6 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) -from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, - Qwen2VLImageProcessorFast) from vllm.attention import AttentionMetadata from vllm.config import VllmConfig @@ -690,41 +688,20 @@ def get_hf_processor( *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, - fps: Optional[float] = 2.0, + size: Optional[dict[str, int]] = None, + fps: Optional[float] = None, + **kwargs: object, ) -> Qwen2_5_VLProcessor: - hf_processor = self.ctx.get_hf_processor(Qwen2_5_VLProcessor) - image_processor = hf_processor.image_processor # type: ignore - assert isinstance(image_processor, - (Qwen2VLImageProcessor, Qwen2VLImageProcessorFast)) - - if min_pixels: - image_processor.min_pixels = min_pixels - if max_pixels: - image_processor.max_pixels = max_pixels - if max_pixels or min_pixels: - image_processor.size = { - "min_pixels": image_processor.min_pixels, - "max_pixels": image_processor.max_pixels, - } - - return hf_processor - - def get_image_processor( - self, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, - fps: Optional[float] = 2.0, - ) -> Union[Qwen2VLImageProcessor, Qwen2VLImageProcessorFast]: - hf_processor = self.get_hf_processor( - min_pixels=min_pixels, - max_pixels=max_pixels, - fps=fps, + if fps is not None: + kwargs["fps"] = fps + + return self.ctx.get_hf_processor( + Qwen2_5_VLProcessor, + image_processor=self.get_image_processor(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size), + **kwargs, ) - image_processor = hf_processor.image_processor # type: ignore - assert isinstance(image_processor, - (Qwen2VLImageProcessor, Qwen2VLImageProcessorFast)) - return image_processor class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index cf79544e60e8..3df5dd2bdd41 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -93,8 +93,9 @@ def get_hf_processor( *, # Ignored in initialization sampling_rate: Optional[int] = None, + **kwargs: object, ) -> Qwen2AudioProcessor: - return self.ctx.get_hf_processor(Qwen2AudioProcessor) + return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs) def get_feature_extractor( self, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 68340ace18dd..919445267f4a 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -31,9 +31,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from packaging.version import Version from transformers import BatchFeature -from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, Qwen2VLProcessor) from transformers.models.qwen2_vl.configuration_qwen2_vl import ( @@ -69,6 +67,8 @@ from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from vllm.transformers_utils.processor import ( + cached_image_processor_from_config) from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, @@ -722,40 +722,64 @@ def get_hf_processor( *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + **kwargs: object, ) -> Qwen2VLProcessor: - hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) - image_processor = hf_processor.image_processor # type: ignore - assert isinstance(image_processor, Qwen2VLImageProcessor) - - if min_pixels: - image_processor.min_pixels = min_pixels - if max_pixels: - image_processor.max_pixels = max_pixels - if max_pixels or min_pixels: - image_processor.size = { - "min_pixels": image_processor.min_pixels, - "max_pixels": image_processor.max_pixels, - } - - return hf_processor + return self.ctx.get_hf_processor( + Qwen2VLProcessor, + image_processor=self.get_image_processor(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size), + **kwargs, + ) + + def _get_image_processor_kwargs( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + **kwargs: object, + ): + if self.ctx.model_config.mm_processor_kwargs: + kwargs.update(self.ctx.model_config.mm_processor_kwargs) + + if min_pixels is not None: + kwargs["min_pixels"] = min_pixels + + if size is None: + size = {"shortest_edge": min_pixels} + else: + size["shortest_edge"] = min_pixels + + if max_pixels is not None: + kwargs["max_pixels"] = max_pixels + + if size is None: + size = {"longest_edge": max_pixels} + else: + size["longest_edge"] = max_pixels + + if size is not None: + kwargs["size"] = size + + return kwargs def get_image_processor( self, *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + **kwargs: object, ): - hf_processor = self.get_hf_processor(min_pixels=min_pixels, - max_pixels=max_pixels) - image_processor = hf_processor.image_processor # type: ignore - if Version(TRANSFORMERS_VERSION) >= Version("4.49"): - from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast - assert isinstance( - image_processor, - (Qwen2VLImageProcessor, Qwen2VLImageProcessorFast)) - else: - assert isinstance(image_processor, Qwen2VLImageProcessor) - return image_processor + return cached_image_processor_from_config( + self.ctx.model_config, + **self._get_image_processor_kwargs(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size, + **kwargs), + ) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} @@ -952,6 +976,18 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] def _get_data_parser(self) -> MultiModalDataParser: return Qwen2VLMultiModalDataParser() + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + return self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + self.info._get_image_processor_kwargs(**mm_kwargs), + ) + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, @@ -964,8 +1000,6 @@ def _get_prompt_replacements( tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() - # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has - # image_token and video_token registered placeholder = { "image": vocab[hf_processor.image_token], "video": vocab[hf_processor.video_token], diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 0f4f5072fb2b..61a4584abf85 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -519,8 +519,13 @@ def get_tokenizer(self) -> PreTrainedTokenizer: return _get_tokenizer_without_image_pad(tokenizer) - def get_hf_processor(self) -> QwenVLProcessor: - return QwenVLProcessor(self.get_hf_config(), self.get_tokenizer()) + def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor: + return self.ctx.init_processor( + QwenVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 063997a14a66..e24b4aeb8ae8 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -68,8 +68,9 @@ def get_hf_processor( *, # Ignored in initialization sampling_rate: Optional[int] = None, + **kwargs: object, ) -> ProcessorMixin: - hf_processor = self.ctx.get_hf_processor() + hf_processor = self.ctx.get_hf_processor(**kwargs) # NOTE: Ultravox processing definition uses '<|eot_id|>' as the # placeholder that will cause confusion with the actual end of turn diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0b506072094e..073a30d25e23 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -29,7 +29,7 @@ NestedTensors) from vllm.multimodal.audio import resample_audio from vllm.sequence import SequenceData -from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.processor import cached_processor_from_config from .interfaces import SupportsMultiModal, SupportsTranscription from .utils import AutoWeightsLoader, WeightsMapper, make_layers @@ -579,7 +579,7 @@ def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): assert mm_counts["audio"] == 1 num_tokens = get_max_whisper_audio_tokens(ctx) - processor = cached_get_processor(ctx.model_config.model) + processor = cached_processor_from_config(ctx.model_config) chunk_length = processor.feature_extractor.chunk_length sampling_rate = processor.feature_extractor.sampling_rate num_samples = chunk_length * sampling_rate @@ -596,7 +596,7 @@ def input_processor_for_whisper(ctx: InputContext, inputs): multi_modal_data["audio"] = multi_modal_data["audio"][0] # Resample and process audio audio, orig_sr = multi_modal_data["audio"] - processor = cached_get_processor(ctx.model_config.model) + processor = cached_processor_from_config(ctx.model_config) target_sr = processor.feature_extractor.sampling_rate audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) multi_modal_data["audio"] = (audio, target_sr) @@ -618,7 +618,7 @@ def input_mapper_for_whisper( if len(multi_modal_data) == 0: return MultiModalKwargs() - processor = cached_get_processor(ctx.model_config.model) + processor = cached_processor_from_config(ctx.model_config) sampling_rate = processor.feature_extractor.sampling_rate audios = [audio for audio, _ in multi_modal_data] diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 98ac8057e8f1..98ece8f806f1 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import base64 -from functools import lru_cache from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional @@ -11,7 +10,7 @@ from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.processor import get_image_processor +from vllm.transformers_utils.processor import cached_get_image_processor from vllm.utils import is_list_of from .base import MediaIO, MultiModalPlugin @@ -22,8 +21,6 @@ logger = init_logger(__name__) -cached_get_image_processor = lru_cache(get_image_processor) - class ImagePlugin(MultiModalPlugin): """Plugin for image data.""" diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 613d1db41672..1882ffe9bf69 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -11,7 +11,8 @@ from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE from vllm.inputs import InputProcessingContext from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + cached_tokenizer_from_config) from vllm.utils import ClassRegistry from .audio import AudioPlugin @@ -21,7 +22,6 @@ from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache) from .profiling import BaseDummyInputsBuilder, MultiModalProfiler -from .utils import cached_get_tokenizer from .video import VideoPlugin if TYPE_CHECKING: @@ -256,10 +256,7 @@ def get_max_tokens_per_item_by_modality( on underlying model configuration. """ if self.has_processor(model_config): - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code, - ) + tokenizer = cached_tokenizer_from_config(model_config) processor = self.create_processor(model_config, tokenizer) seq_len = model_config.max_model_len mm_limits = self.get_mm_limits_per_prompt(model_config) @@ -374,10 +371,7 @@ def get_mm_limits_per_prompt( This should be called after :meth:`init_mm_limits_per_prompt`. """ if self.has_processor(model_config): - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code, - ) + tokenizer = cached_tokenizer_from_config(model_config) processor = self.create_processor(model_config, tokenizer) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 583f53655124..6e6c10b34a25 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from functools import lru_cache from itertools import groupby from pathlib import Path from typing import TYPE_CHECKING, Optional, TypeVar, Union @@ -13,7 +12,7 @@ import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from .audio import AudioMediaIO from .base import MediaIO @@ -23,8 +22,6 @@ logger = init_logger(__name__) -cached_get_tokenizer = lru_cache(get_tokenizer) - _M = TypeVar("_M") if TYPE_CHECKING: diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 78a2918e3ed3..8004377191b3 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import base64 -from functools import lru_cache, partial +from functools import partial from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional @@ -12,8 +12,7 @@ from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.processor import get_video_processor -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.processor import cached_get_video_processor from vllm.utils import PlaceholderModule, is_list_of from .base import MediaIO, ModalityData @@ -30,9 +29,6 @@ logger = init_logger(__name__) -cached_get_video_processor = lru_cache(get_video_processor) -cached_get_tokenizer = lru_cache(get_tokenizer) - class VideoPlugin(ImagePlugin): """Plugin for video data.""" diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 3197b07d8a46..29fab16c25c1 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,25 +1,59 @@ # SPDX-License-Identifier: Apache-2.0 from functools import lru_cache -from typing import Any, cast +from typing import TYPE_CHECKING, Any, Union, cast from transformers.processing_utils import ProcessorMixin +from typing_extensions import TypeVar + +if TYPE_CHECKING: + from vllm.config import ModelConfig + +_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) + + +class HashableDict(dict): + """ + A dictionary that can be hashed by lru_cache. + """ + + # NOTE: pythonic dict is not hashable, + # we override on it directly for simplicity + def __hash__(self) -> int: # type: ignore[override] + return hash(frozenset(self.items())) + + +def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs): + base_kwargs = model_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + + # NOTE: Pythonic dict is not hashable and will raise unhashable type + # error when calling `cached_get_processor`, therefore we need to + # wrap it to a hashable dict. + for key, value in merged_kwargs.items(): + if isinstance(value, dict): + merged_kwargs[key] = HashableDict(value) + + return merged_kwargs def get_processor( processor_name: str, *args: Any, trust_remote_code: bool = False, - processor_cls: type[ProcessorMixin] = ProcessorMixin, + processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, **kwargs: Any, -): +) -> _P: """Load a processor for the given model name via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor - processor_factory = (AutoProcessor - if processor_cls == ProcessorMixin else processor_cls) + processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or + isinstance(processor_cls, tuple) else processor_cls) try: processor = processor_factory.from_pretrained( @@ -43,12 +77,30 @@ def get_processor( else: raise e - return cast(ProcessorMixin, processor) + if not isinstance(processor, processor_cls): + raise TypeError("Invalid type of HuggingFace processor. " + f"Expected type: {processor_cls}, but " + f"found type: {type(processor)}") + + return processor cached_get_processor = lru_cache(get_processor) +def cached_processor_from_config( + model_config: "ModelConfig", + processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + **kwargs: Any, +) -> _P: + return cached_get_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + processor_cls=processor_cls, # type: ignore[arg-type] + **_merge_mm_kwargs(model_config, **kwargs), + ) + + def get_image_processor( processor_name: str, *args: Any, @@ -85,6 +137,20 @@ def get_image_processor( return cast(BaseImageProcessor, processor) +cached_get_image_processor = lru_cache(get_image_processor) + + +def cached_image_processor_from_config( + model_config: "ModelConfig", + **kwargs: Any, +): + return cached_get_image_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **_merge_mm_kwargs(model_config, **kwargs), + ) + + def get_video_processor( processor_name: str, *args: Any, @@ -104,3 +170,17 @@ def get_video_processor( ) return cast(BaseImageProcessor, processor.video_processor) + + +cached_get_video_processor = lru_cache(get_video_processor) + + +def cached_video_processor_from_config( + model_config: "ModelConfig", + **kwargs: Any, +): + return cached_get_video_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **_merge_mm_kwargs(model_config, **kwargs), + ) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 0c0f68ac123e..f0aa5fdcaa61 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -3,9 +3,10 @@ import contextlib import os import warnings +from functools import lru_cache from pathlib import Path from types import MethodType -from typing import Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import huggingface_hub from transformers import (AutoTokenizer, PreTrainedTokenizer, @@ -20,6 +21,9 @@ from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import make_async +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, @@ -232,6 +236,22 @@ def get_tokenizer( return tokenizer +cached_get_tokenizer = lru_cache(get_tokenizer) + + +def cached_tokenizer_from_config( + model_config: "ModelConfig", + **kwargs: Any, +): + return cached_get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + tokenizer_revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + **kwargs, + ) + + def get_lora_tokenizer(lora_request: LoRARequest, *args, **kwargs) -> Optional[AnyTokenizer]: if lora_request is None: From 4b2dd1447c055e426a0c0b6ac6ac31ea70d9a3d8 Mon Sep 17 00:00:00 2001 From: shangmingc Date: Wed, 19 Feb 2025 22:13:15 +0800 Subject: [PATCH 114/317] [Bugfix] Fix device ordinal for multi-node spec decode (#13269) Signed-off-by: Shangming Cai --- vllm/spec_decode/spec_decode_worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index fce06a81ff04..3f381d5199d7 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -10,6 +10,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import (broadcast_tensor_dict, + get_tp_group, tensor_model_parallel_gather) from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler @@ -365,7 +366,7 @@ def init_device(self) -> None: target_lm_head_weight) self._metrics.init_tensors(self.rank, device_type=self.device) - self.spec_decode_sampler.init_tensors(self.rank, + self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, device_type=self.device) scorer_cls: Type[SpeculativeScorer] From c1e40f8624cb2c762818ef3bd80ea9eba8c12af0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 19 Feb 2025 22:32:17 +0800 Subject: [PATCH 115/317] [doc] clarify multi-node serving doc (#13558) Signed-off-by: youkaichao --- docs/source/serving/distributed_serving.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/serving/distributed_serving.md b/docs/source/serving/distributed_serving.md index 6d136147c8dd..54c7ded20421 100644 --- a/docs/source/serving/distributed_serving.md +++ b/docs/source/serving/distributed_serving.md @@ -75,11 +75,15 @@ bash run_cluster.sh \ -e VLLM_HOST_IP=ip_of_this_node ``` -Then you get a ray cluster of containers. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster. In addition, please note that the argument `ip_of_head_node` should be the IP address of the head node, which is accessible by all the worker nodes. The IP addresses of each worker node should be specified in the `VLLM_HOST_IP` environment variable, and should be different for each worker node. Please check the network configuration of your cluster to make sure the nodes can communicate with each other through the specified IP addresses. +Then you get a ray cluster of **containers**. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster. In addition, please note that the argument `ip_of_head_node` should be the IP address of the head node, which is accessible by all the worker nodes. The IP addresses of each worker node should be specified in the `VLLM_HOST_IP` environment variable, and should be different for each worker node. Please check the network configuration of your cluster to make sure the nodes can communicate with each other through the specified IP addresses. + +:::{warning} +Since this is a ray cluster of **containers**, all the following commands should be executed in the **containers**, otherwise you are executing the commands on the host machine, which is not connected to the ray cluster. To enter the container, you can use `docker exec -it node /bin/bash`. +::: Then, on any node, use `docker exec -it node /bin/bash` to enter the container, execute `ray status` to check the status of the Ray cluster. You should see the right number of nodes and GPUs. -After that, on any node, you can use vLLM as usual, just as you have all the GPUs on one node. The common practice is to set the tensor parallel size to the number of GPUs in each node, and the pipeline parallel size to the number of nodes. For example, if you have 16 GPUs in 2 nodes (8 GPUs per node), you can set the tensor parallel size to 8 and the pipeline parallel size to 2: +After that, on any node, use `docker exec -it node /bin/bash` to enter the container again. **In the container**, you can use vLLM as usual, just as you have all the GPUs on one node. The common practice is to set the tensor parallel size to the number of GPUs in each node, and the pipeline parallel size to the number of nodes. For example, if you have 16 GPUs in 2 nodes (8 GPUs per node), you can set the tensor parallel size to 8 and the pipeline parallel size to 2: ```console vllm serve /path/to/the/model/in/the/container \ From 10020cbc44eed4b963b1417673668c8bd1d77cff Mon Sep 17 00:00:00 2001 From: Wilson Wu Date: Thu, 20 Feb 2025 00:55:34 +0800 Subject: [PATCH 116/317] Fix copyright year to auto get current year (#13561) --- docs/source/conf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 84c9a27be3bf..97bec81b1eee 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,6 +12,7 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. +import datetime import inspect import logging import os @@ -27,7 +28,7 @@ # -- Project information ----------------------------------------------------- project = 'vLLM' -copyright = '2024, vLLM Team' +copyright = f'{datetime.datetime.now().year}, vLLM Team' author = 'the vLLM Team' # -- General configuration --------------------------------------------------- From f1ead2cf3bbb1f172d1ee99c6715ac0b359bfaff Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 19 Feb 2025 09:40:50 -0800 Subject: [PATCH 117/317] [MISC] Logging the message about Ray teardown (#13502) Signed-off-by: Cody Yu Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> --- vllm/executor/ray_distributed_executor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 6a25a4d50fb9..79ca45d55d96 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -101,6 +101,10 @@ def _init_executor(self) -> None: self.driver_worker.execute_method) def shutdown(self) -> None: + logger.info( + "Shutting down Ray distributed executor. If you see error log " + "from logging.cc regarding SIGTERM received, please ignore because " + "this is the expected termination process in Ray.") if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() import ray From 9e112ca214c598e706174d90dc22961a0a53191f Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 20 Feb 2025 02:57:48 +0800 Subject: [PATCH 118/317] [Misc] Avoid calling unnecessary `hf_list_repo_files` for local model path (#13348) Signed-off-by: isotr0py <2037008807@qq.com> --- vllm/transformers_utils/config.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4768226f9a03..dd6ee9a34adb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -115,7 +115,14 @@ def list_repo_files( token: Union[str, bool, None] = None, ) -> list[str]: - def lookup_files(): + def lookup_files() -> list[str]: + # directly list files if model is local + if (local_path := Path(repo_id)).exists(): + return [ + str(file.relative_to(local_path)) + for file in local_path.rglob('*') if file.is_file() + ] + # if model is remote, use hf_hub api to list files try: if VLLM_USE_MODELSCOPE: from vllm.transformers_utils.utils import ( @@ -154,8 +161,8 @@ def file_exists( # In offline mode the result can be a false negative def file_or_path_exists(model: Union[str, Path], config_name: str, revision: Optional[str]) -> bool: - if Path(model).exists(): - return (Path(model) / config_name).is_file() + if (local_path := Path(model)).exists(): + return (local_path / config_name).is_file() # Offline mode support: Check if config file is cached already cached_filepath = try_to_load_from_cache(repo_id=model, From 4b3a90d144065827c74a8cdd8bba7142427f4899 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 19 Feb 2025 16:49:01 -0800 Subject: [PATCH 119/317] [BugFix] Avoid error traceback in logs when V1 `LLM` terminates (#13565) Signed-off-by: Nick Hill --- vllm/v1/engine/core_client.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 8641833e438b..77df9ed54095 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -252,14 +252,18 @@ def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor], outputs_queue = self.outputs_queue def process_outputs_socket(): - while True: - (frame, ) = output_socket.recv_multipart(copy=False) - outputs = decoder.decode(frame.buffer) - if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) - else: - outputs_queue.put_nowait(outputs) + try: + while True: + (frame, ) = output_socket.recv_multipart(copy=False) + outputs = decoder.decode(frame.buffer) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + else: + outputs_queue.put_nowait(outputs) + except zmq.error.ContextTerminated: + # Expected when the class is GC'd / during process termination. + pass # Process outputs from engine in separate thread. Thread(target=process_outputs_socket, daemon=True).start() From ad6b051dba1a061354d08bb6e3c13b2343c34ddd Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Wed, 19 Feb 2025 18:12:30 -0800 Subject: [PATCH 120/317] [3/n][CI] Load Quantization test models with S3 (#13570) Signed-off-by: <> Co-authored-by: EC2 Default User --- tests/conftest.py | 51 +++++++++++++++++++ .../model_loader/weight_utils.py | 4 +- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 74219e40026c..46b8dd1e1df1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,6 +57,57 @@ "ArthurZ/Ilama-3.2-1B", "llava-hf/llava-1.5-7b-hf", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "JackFram/llama-160m", + "ai21labs/Jamba-tiny-random", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + "nm-testing/Phi-3-mini-128k-instruct-FP8", + "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", + "AMead10/Llama-3.2-1B-Instruct-AWQ", + "shuyuej/Llama-3.2-1B-Instruct-GPTQ", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", + "ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8", + "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test", + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", + "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", + "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", + "nm-testing/tinyllama-oneshot-w4a16-channel-v2", + "nm-testing/tinyllama-oneshot-w4a16-group128-v2", + "nm-testing/tinyllama-oneshot-w8a16-per-channel", + "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", + "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test", + "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor", + "nm-testing/llama2.c-stories42M-pruned2.4-compressed", ] MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights" diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index ac1be383c15b..18f6f40b32f0 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -27,8 +27,6 @@ from vllm.platforms import current_platform from vllm.utils import PlaceholderModule -logger = init_logger(__name__) - try: from runai_model_streamer import SafetensorsStreamer except (ImportError, OSError): @@ -39,6 +37,8 @@ SafetensorsStreamer = runai_model_streamer.placeholder_attr( "SafetensorsStreamer") +logger = init_logger(__name__) + # use system-level temp directory for file locks, so that multiple users # can share the same lock without error. # lock files in the temp directory will be automatically deleted when the From 3f8fdc1dd6f0d79cdcf62171dae33dd7370aebae Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 20 Feb 2025 10:37:55 +0800 Subject: [PATCH 121/317] [Misc] Qwen2.5 VL support LoRA (#13261) --- docs/source/models/supported_models.md | 2 +- tests/lora/conftest.py | 5 + tests/lora/test_qwen2vl.py | 176 +++++++++++++++-------- vllm/model_executor/models/qwen2_5_vl.py | 10 +- 4 files changed, 130 insertions(+), 63 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 5497b5dba76e..ae851c35e626 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -854,7 +854,7 @@ See [this page](#generative-models) for more information on how to use generativ * Qwen2.5-VL * T + IE+ + VE+ * `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. - * + * ✅︎ * ✅︎ * ✅︎ - * `UltravoxModel` diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 7baa632f5bff..47c89d5fd344 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -237,6 +237,11 @@ def qwen2vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") +@pytest.fixture(scope="session") +def qwen25vl_lora_files(): + return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon") + + @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index a988f06ab25f..1cf1534e4036 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -1,83 +1,143 @@ # SPDX-License-Identifier: Apache-2.0 - -from typing import List +from dataclasses import dataclass +from typing import Dict, List, Optional import pytest +from packaging.version import Version +from transformers import __version__ as TRANSFORMERS_VERSION import vllm from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest from vllm.platforms import current_platform -MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct" -PROMPT_TEMPLATE = ( - "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" - "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" - "What is in the image?<|im_end|>\n" - "<|im_start|>assistant\n") +@dataclass +class TestConfig: + model_path: str + lora_path: str + max_num_seqs: int = 2 + max_loras: int = 2 + max_lora_rank: int = 16 + max_model_len: int = 4096 + mm_processor_kwargs: Optional[Dict[str, int]] = None + + def __post_init__(self): + if self.mm_processor_kwargs is None: + self.mm_processor_kwargs = { + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + } + + +class Qwen2VLTester: + """Test helper for Qwen2 VL models with LoRA""" + + PROMPT_TEMPLATE = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" + "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + "What is in the image?<|im_end|>\n" + "<|im_start|>assistant\n") + + def __init__(self, config: TestConfig): + self.config = config + self.llm = self._initialize_llm() + + def _initialize_llm(self) -> vllm.LLM: + """Initialize the LLM with given configuration""" + return vllm.LLM( + model=self.config.model_path, + max_num_seqs=self.config.max_num_seqs, + enable_lora=True, + max_loras=self.config.max_loras, + max_lora_rank=self.config.max_lora_rank, + trust_remote_code=True, + mm_processor_kwargs=self.config.mm_processor_kwargs, + max_model_len=self.config.max_model_len, + ) + + def run_test(self, + images: List[ImageAsset], + expected_outputs: List[str], + lora_id: Optional[int] = None, + temperature: float = 0, + max_tokens: int = 5) -> List[str]: + + sampling_params = vllm.SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + ) + inputs = [{ + "prompt": self.PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in images] + + lora_request = LoRARequest(str(lora_id), lora_id, + self.config.lora_path) + outputs = self.llm.generate(inputs, + sampling_params, + lora_request=lora_request) + generated_texts = [ + output.outputs[0].text.strip() for output in outputs + ] -IMAGE_ASSETS = [ + # Validate outputs + for generated, expected in zip(generated_texts, expected_outputs): + assert expected.startswith( + generated), f"Generated text {generated} doesn't " + f"match expected pattern {expected}" + + return generated_texts + + +TEST_IMAGES = [ ImageAsset("stop_sign"), ImageAsset("cherry_blossom"), ] -# After fine-tuning with LoRA, all generated content should start begin `A`. -EXPECTED_OUTPUT = [ +EXPECTED_OUTPUTS = [ "A red stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501 "A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501 ] - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: - sampling_params = vllm.SamplingParams( - temperature=0, - max_tokens=5, - ) - - inputs = [{ - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in IMAGE_ASSETS] - - outputs = llm.generate( - inputs, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: List[str] = [] - for output in outputs: - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Generated text: {generated_text!r}") - return generated_texts +QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct" +QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct" @pytest.mark.xfail( current_platform.is_rocm(), reason="Qwen2-VL dependency xformers incompatible with ROCm") def test_qwen2vl_lora(qwen2vl_lora_files): - llm = vllm.LLM( - MODEL_PATH, - max_num_seqs=2, - enable_lora=True, - max_loras=2, - max_lora_rank=16, - trust_remote_code=True, - mm_processor_kwargs={ - "min_pixels": 28 * 28, - "max_pixels": 1280 * 28 * 28, - }, - max_model_len=4096, - ) - output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1) - for i in range(len(EXPECTED_OUTPUT)): - assert EXPECTED_OUTPUT[i].startswith(output1[i]) - - output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2) - for i in range(len(EXPECTED_OUTPUT)): - assert EXPECTED_OUTPUT[i].startswith(output2[i]) + """Test Qwen 2.0 VL model with LoRA""" + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, + lora_path=qwen2vl_lora_files) + tester = Qwen2VLTester(config) + + # Test with different LoRA IDs + for lora_id in [1, 2]: + tester.run_test(TEST_IMAGES, + expected_outputs=EXPECTED_OUTPUTS, + lora_id=lora_id) + + +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="Qwen2.5-VL dependency xformers incompatible with ROCm", +) +@pytest.mark.skipif( + Version(TRANSFORMERS_VERSION) < Version("4.49.0"), + reason="Qwen2.5-VL require transformers version no lower than 4.49.0", +) +def test_qwen25vl_lora(qwen25vl_lora_files): + """Test Qwen 2.5 VL model with LoRA""" + config = TestConfig(model_path=QWEN25VL_MODEL_PATH, + lora_path=qwen25vl_lora_files) + tester = Qwen2VLTester(config) + + # Test with different LoRA IDs + for lora_id in [1, 2]: + tester.run_test(TEST_IMAGES, + expected_outputs=EXPECTED_OUTPUTS, + lora_id=lora_id) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 29187eb2ef9c..f16fa536791e 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -734,16 +734,17 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "up_proj", ], } - # LoRA specific attributes, TODO: double check + # LoRA specific attributes supported_lora_modules = [ + # language model "qkv_proj", "o_proj", "gate_up_proj", - "down_proj", - "gate_proj" - "up_proj", + "down_proj", # Same name with vision encoder # vision tower "qkv", + "gate_proj", + "up_proj", "attn.proj", # Distinguish patch_embed.proj "fc1", "fc2", @@ -751,6 +752,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "mlp.0", "mlp.2" ] + embedding_modules = {} embedding_padding_modules = [] From 861c978a465746e904d2e4306706bd1128128204 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Wed, 19 Feb 2025 19:56:06 -0800 Subject: [PATCH 122/317] [ci] Add AWS creds for AMD (#13572) --- .buildkite/run-amd-test.sh | 4 ++++ requirements-rocm.txt | 2 ++ 2 files changed, 6 insertions(+) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 3515ccd65667..f8bf1c87603f 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -121,6 +121,8 @@ if [[ $commands == *"--shard-id="* ]]; then --rm \ -e HIP_VISIBLE_DEVICES="${GPU}" \ -e HF_TOKEN \ + -e AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ --name "${container_name}_${GPU}" \ @@ -148,6 +150,8 @@ else --rm \ -e HIP_VISIBLE_DEVICES=0 \ -e HF_TOKEN \ + -e AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ --name "${container_name}" \ diff --git a/requirements-rocm.txt b/requirements-rocm.txt index ccc906234177..d86e039c2326 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -10,3 +10,5 @@ ray >= 2.10.0 peft pytest-asyncio tensorizer>=2.9.0 +runai-model-streamer==0.11.0 +runai-model-streamer-s3==0.11.0 From f6840386901f6f2beb88bacfc1851d92ff2bfb58 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Wed, 19 Feb 2025 22:01:02 -0600 Subject: [PATCH 123/317] [ROCm][MoE] mi300 mixtral8x7B perf for specific BS (#13577) Signed-off-by: Divakar Verma --- .../configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json | 4 ++-- .../configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json | 4 ++-- .../configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json index 66f9106bd1be..4bf775347ecc 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json @@ -45,8 +45,8 @@ }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json index ed5b655d8993..5a3f415d5414 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json @@ -45,8 +45,8 @@ }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json index 822f04e33e87..8d7b78027185 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json @@ -128,7 +128,7 @@ "num_warps": 8, "num_stages": 2, "waves_per_eu": 0, - "matrix_instr_nonkdim": 32, + "matrix_instr_nonkdim": 16, "kpack": 2 }, "512": { From 5cde704e5b106f0a84f4072587e090da8479b1a9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 20 Feb 2025 12:41:17 +0800 Subject: [PATCH 124/317] [core] add sleep and wake up endpoint and v1 support (#12987) Signed-off-by: youkaichao Signed-off-by: cennn <2523403608@qq.com> Co-authored-by: cennn <2523403608@qq.com> --- tests/basic_correctness/test_cumem.py | 12 ++++--- tests/entrypoints/openai/test_sleep.py | 32 +++++++++++++++++++ vllm/engine/async_llm_engine.py | 6 ++++ vllm/engine/multiprocessing/__init__.py | 12 ++++++- vllm/engine/multiprocessing/client.py | 15 +++++++-- vllm/engine/multiprocessing/engine.py | 15 +++++++-- vllm/engine/protocol.py | 10 ++++++ vllm/entrypoints/openai/api_server.py | 18 +++++++++++ .../openai/serving_transcription.py | 1 + vllm/v1/engine/async_llm.py | 6 ++++ vllm/v1/engine/core.py | 6 ++++ vllm/v1/engine/core_client.py | 30 +++++++++++++++++ vllm/v1/engine/llm_engine.py | 6 ++++ 13 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 tests/entrypoints/openai/test_sleep.py diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 24ed5d392839..7ebccdb5caed 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -118,14 +118,16 @@ def model(x): @fork_new_process_for_each_test @pytest.mark.parametrize( - "model", + "model, use_v1", [ # sleep mode with safetensors - f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B", + (f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B", True), # sleep mode with pytorch checkpoint - "facebook/opt-125m" + ("facebook/opt-125m", False), ]) -def test_end_to_end(model): +def test_end_to_end(model: str, use_v1: bool): + import os + os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" free, total = torch.cuda.mem_get_info() used_bytes_baseline = total - free # in case other process is running load_format = LoadFormat.AUTO @@ -152,3 +154,5 @@ def test_end_to_end(model): # cmp output assert output[0].outputs[0].text == output2[0].outputs[0].text + + del os.environ["VLLM_USE_V1"] diff --git a/tests/entrypoints/openai/test_sleep.py b/tests/entrypoints/openai/test_sleep.py new file mode 100644 index 000000000000..1caa743c4018 --- /dev/null +++ b/tests/entrypoints/openai/test_sleep.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 + +import requests + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "meta-llama/Llama-3.2-1B" + + +def test_sleep_mode(): + # dtype, max-len etc set so that this can run in CI + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enable-sleep-mode", + ] + + with RemoteOpenAIServer(MODEL_NAME, + args, + env_dict={ + "VLLM_SERVER_DEV_MODE": "1", + "CUDA_VISIBLE_DEVICES": "0" + }) as remote_server: + response = requests.post(remote_server.url_for("/sleep"), + data={"level": "1"}) + assert response.status_code == 200 + response = requests.post(remote_server.url_for("/wake_up")) + assert response.status_code == 200 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 053635a28638..93d9b74d8e1e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1187,6 +1187,12 @@ async def stop_profile(self) -> None: async def reset_prefix_cache(self) -> None: self.engine.reset_prefix_cache() + async def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + async def wake_up(self) -> None: + self.engine.wake_up() + async def add_lora(self, lora_request: LoRARequest) -> None: self.engine.add_lora(lora_request) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 3cf1850ee65a..26dfb63c3dbf 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -127,6 +127,15 @@ class RPCResetPrefixCacheRequest(Enum): RESET_PREFIX_CACHE = 1 +class RPCSleepRequest(Enum): + SLEEP_LEVEL_1 = 1 + SLEEP_LEVEL_2 = 2 + + +class RPCWakeUpRequest(Enum): + WAKE_UP = 1 + + @dataclass class RPCLoadAdapterRequest: lora_request: LoRARequest @@ -141,7 +150,8 @@ class RPCAdapterLoadedResponse: RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPCUProfileRequest, RPCLoadAdapterRequest, - RPCResetPrefixCacheRequest] + RPCResetPrefixCacheRequest, RPCSleepRequest, + RPCWakeUpRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 85b5f31e3a4a..c12fe242082b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -31,8 +31,9 @@ RPCLoadAdapterRequest, RPCProcessRequest, RPCResetPrefixCacheRequest, - RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest) + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) from vllm.engine.protocol import EngineClient # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -685,6 +686,16 @@ async def reset_prefix_cache(self) -> None: request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE, socket=self.input_socket) + async def sleep(self, level: int = 1) -> None: + """Sleep the engine for a given level""" + return await self._send_one_way_rpc_request( + request=RPCSleepRequest(level), socket=self.input_socket) + + async def wake_up(self) -> None: + """Wake up the engine""" + return await self._send_one_way_rpc_request( + request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket) + async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" # Uses the same I/O as generate requests diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a0dd79586588..ce24aa21514d 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -20,8 +20,9 @@ RPCLoadAdapterRequest, RPCProcessRequest, RPCResetPrefixCacheRequest, - RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest) + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) # yapf: enable from vllm.logger import init_logger from vllm.outputs import RequestOutput @@ -242,6 +243,10 @@ def handle_new_input(self): self._handle_load_adapter_request(request) elif isinstance(request, RPCResetPrefixCacheRequest): self.reset_prefix_cache() + elif isinstance(request, RPCSleepRequest): + self.sleep(request.value) + elif isinstance(request, RPCWakeUpRequest): + self.wake_up() else: raise ValueError("Unknown RPCRequest Type: " f"{type(request)}") @@ -369,6 +374,12 @@ def stop_profile(self) -> None: def reset_prefix_cache(self) -> bool: return self.engine.reset_prefix_cache() + def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + def wake_up(self) -> None: + self.engine.wake_up() + def signal_handler(*_) -> None: raise KeyboardInterrupt("MQLLMEngine terminated") diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d1112558666f..ee9accd32f21 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -278,6 +278,16 @@ async def reset_prefix_cache(self) -> None: """Reset the prefix cache""" ... + @abstractmethod + async def sleep(self, level: int = 1) -> None: + """Sleep the engine""" + ... + + @abstractmethod + async def wake_up(self) -> None: + """Wake up the engine""" + ... + @abstractmethod async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0de7e2392691..f7162fadbce8 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -625,6 +625,24 @@ async def reset_prefix_cache(raw_request: Request): await engine_client(raw_request).reset_prefix_cache() return Response(status_code=200) + @router.post("/sleep") + async def sleep(raw_request: Request): + # get POST params + level = raw_request.query_params.get("level", "1") + logger.info("sleep the engine with level %s", level) + await engine_client(raw_request).sleep(int(level)) + # FIXME: in v0 with frontend multiprocessing, the sleep command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + @router.post("/wake_up") + async def wake_up(raw_request: Request): + logger.info("wake up the engine") + await engine_client(raw_request).wake_up() + # FIXME: in v0 with frontend multiprocessing, the wake-up command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + @router.post("/invocations", dependencies=[Depends(validate_json_request)]) async def invocations(raw_request: Request): diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index da4930e0e2d8..0bedb5718a4b 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -295,6 +295,7 @@ async def create_transcription( # TODO(rob): figure out a way to pipe streaming in. # Non-streaming response. try: + assert result_generator is not None async for op in result_generator: result = op return TranscriptionResponse(text=result.outputs[0].text) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1920dbf7a7dc..670454c283da 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -361,6 +361,12 @@ async def stop_profile(self) -> None: async def reset_prefix_cache(self) -> None: await self.engine_core.reset_prefix_cache_async() + async def sleep(self, level: int = 1) -> None: + await self.engine_core.sleep_async(level) + + async def wake_up(self) -> None: + await self.engine_core.wake_up_async() + async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" await self.engine_core.add_lora_async(lora_request) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 66e252b7ccb0..03825d6ea430 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -213,6 +213,12 @@ def profile(self, is_start: bool = True): def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() + def sleep(self, level: int = 1): + self.model_executor.sleep(level) + + def wake_up(self): + self.model_executor.wake_up() + def add_lora(self, lora_request: LoRARequest) -> None: self.model_executor.add_lora(lora_request) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 77df9ed54095..43ba7583c662 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -81,6 +81,12 @@ def profile(self, is_start: bool = True) -> None: def reset_prefix_cache(self) -> None: raise NotImplementedError + def sleep(self, level: int = 1) -> None: + raise NotImplementedError + + def wake_up(self) -> None: + raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -99,6 +105,12 @@ async def profile_async(self, is_start: bool = True) -> None: async def reset_prefix_cache_async(self) -> None: raise NotImplementedError + async def sleep_async(self, level: int = 1) -> None: + raise NotImplementedError + + async def wake_up_async(self) -> None: + raise NotImplementedError + async def abort_requests_async(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -138,6 +150,12 @@ def profile(self, is_start: bool = True) -> None: def reset_prefix_cache(self) -> None: self.engine_core.reset_prefix_cache() + def sleep(self, level: int = 1) -> None: + self.engine_core.sleep(level) + + def wake_up(self) -> None: + self.engine_core.wake_up() + def add_lora(self, lora_request: LoRARequest) -> None: self.engine_core.add_lora(lora_request) @@ -307,6 +325,12 @@ def reset_prefix_cache(self) -> None: def add_lora(self, lora_request: LoRARequest) -> None: self._call_utility("add_lora", lora_request) + def sleep(self, level: int = 1) -> None: + self._call_utility("sleep", level) + + def wake_up(self) -> None: + self._call_utility("wake_up") + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -384,5 +408,11 @@ async def profile_async(self, is_start: bool = True) -> None: async def reset_prefix_cache_async(self) -> None: await self._call_utility_async("reset_prefix_cache") + async def sleep_async(self, level: int = 1) -> None: + await self._call_utility_async("sleep", level) + + async def wake_up_async(self) -> None: + await self._call_utility_async("wake_up") + async def add_lora_async(self, lora_request: LoRARequest) -> None: await self._call_utility_async("add_lora", lora_request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c9a4c5369dfd..6b7de4deed39 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -169,6 +169,12 @@ def stop_profile(self): def reset_prefix_cache(self): self.engine_core.reset_prefix_cache() + def sleep(self, level: int = 1): + self.engine_core.sleep(level) + + def wake_up(self): + self.engine_core.wake_up() + def get_tokenizer_group( self, group_type: Type[_G] = BaseTokenizerGroup, From 82a666b0aa9b1cbb54d377634cd0b47ecd5f7649 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 19 Feb 2025 20:46:28 -0800 Subject: [PATCH 125/317] [bugfix] spec decode worker get tp group only when initialized (#13578) --- vllm/spec_decode/spec_decode_worker.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3f381d5199d7..8af71842224b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -12,6 +12,7 @@ from vllm.distributed.communication_op import (broadcast_tensor_dict, get_tp_group, tensor_model_parallel_gather) +from vllm.distributed.parallel_state import model_parallel_is_initialized from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.sampler import SamplerOutput @@ -366,8 +367,12 @@ def init_device(self) -> None: target_lm_head_weight) self._metrics.init_tensors(self.rank, device_type=self.device) - self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, - device_type=self.device) + if model_parallel_is_initialized(): + self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, + device_type=self.device) + else: + self.spec_decode_sampler.init_tensors(self.rank, + device_type=self.device) scorer_cls: Type[SpeculativeScorer] if self.disable_mqa_scorer: From 26232b71856715e94fba9d6001f03b606c459439 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 19 Feb 2025 23:24:48 -0700 Subject: [PATCH 126/317] [Misc] Warn if the vLLM version can't be retrieved (#13501) Signed-off-by: Alex-Brooks --- vllm/platforms/__init__.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 724c4357ff74..48cf8f7a323a 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -2,6 +2,7 @@ import logging import traceback +from contextlib import suppress from itertools import chain from typing import TYPE_CHECKING, Optional @@ -14,6 +15,21 @@ logger = logging.getLogger(__name__) +def vllm_version_matches_substr(substr: str) -> bool: + """ + Check to see if the vLLM version matches a substring. + """ + from importlib.metadata import PackageNotFoundError, version + try: + vllm_version = version("vllm") + except PackageNotFoundError as e: + logger.warning( + "The vLLM package was not found, so its version could not be " + "inspected. This may cause platform detection to fail.") + raise e + return substr in vllm_version + + def tpu_platform_plugin() -> Optional[str]: is_tpu = False try: @@ -33,8 +49,6 @@ def cuda_platform_plugin() -> Optional[str]: is_cuda = False try: - from importlib.metadata import version - from vllm.utils import import_pynvml pynvml = import_pynvml() pynvml.nvmlInit() @@ -45,7 +59,7 @@ def cuda_platform_plugin() -> Optional[str]: # Otherwise, vllm will always activate cuda plugin # on a GPU machine, even if in a cpu build. is_cuda = (pynvml.nvmlDeviceGetCount() > 0 - and "cpu" not in version("vllm")) + and not vllm_version_matches_substr("cpu")) finally: pynvml.nvmlShutdown() except Exception as e: @@ -113,8 +127,7 @@ def xpu_platform_plugin() -> Optional[str]: def cpu_platform_plugin() -> Optional[str]: is_cpu = False try: - from importlib.metadata import version - is_cpu = "cpu" in version("vllm") + is_cpu = vllm_version_matches_substr("cpu") if not is_cpu: import platform is_cpu = platform.machine().lower().startswith("arm") @@ -138,11 +151,8 @@ def neuron_platform_plugin() -> Optional[str]: def openvino_platform_plugin() -> Optional[str]: is_openvino = False - try: - from importlib.metadata import version - is_openvino = "openvino" in version("vllm") - except Exception: - pass + with suppress(Exception): + is_openvino = vllm_version_matches_substr("openvino") return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None From 4c3177b63af569d1a0834f455885c5dfda633ba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=83?= Date: Thu, 20 Feb 2025 15:04:30 +0800 Subject: [PATCH 127/317] [Misc] add mm_processor_kwargs to extra_body for Qwen2.5-VL (#13533) --- vllm/entrypoints/openai/protocol.py | 4 ++++ vllm/entrypoints/openai/serving_engine.py | 2 ++ vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/transformers_utils/processor.py | 12 +++++++++++- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2bcfdc235776..98ea6a46133f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -312,6 +312,10 @@ class ChatCompletionRequest(OpenAIBaseModel): description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), ) + mm_processor_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, description=("If specified, the output will follow the JSON schema."), diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 785117ca1d45..dfc3328677c7 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -451,6 +451,8 @@ async def _preprocess_chat( prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data + if request.mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs return conversation, [request_prompt], [engine_prompt] diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index f16fa536791e..ff10fcb4315c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -689,7 +689,7 @@ def get_hf_processor( min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, size: Optional[dict[str, int]] = None, - fps: Optional[float] = None, + fps: Optional[Union[float, List[float]]] = None, **kwargs: object, ) -> Qwen2_5_VLProcessor: if fps is not None: diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 29fab16c25c1..1d09b99d50c0 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -23,6 +23,15 @@ def __hash__(self) -> int: # type: ignore[override] return hash(frozenset(self.items())) +class HashableList(list): + """ + A list that can be hashed by lru_cache. + """ + + def __hash__(self) -> int: # type: ignore[override] + return hash(tuple(self)) + + def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs): base_kwargs = model_config.mm_processor_kwargs if base_kwargs is None: @@ -36,7 +45,8 @@ def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs): for key, value in merged_kwargs.items(): if isinstance(value, dict): merged_kwargs[key] = HashableDict(value) - + if isinstance(value, list): + merged_kwargs[key] = HashableList(value) return merged_kwargs From 1d457c76859c3038458c009bebe1e59fa4bb25e2 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 20 Feb 2025 02:05:00 -0500 Subject: [PATCH 128/317] [ROCm] MI300A compile targets deprecation (#13560) --- CMakeLists.txt | 2 +- csrc/quantization/fp8/amd/hip_float8_impl.h | 3 +-- csrc/rocm/attention.cu | 3 +-- vllm/attention/backends/rocm_flash_attn.py | 3 +-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e8f7adf6ea9..cd1c2c9015da 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,7 +34,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101") # # Supported/expected torch versions for CUDA/ROCm. diff --git a/csrc/quantization/fp8/amd/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h index 90251c353953..8b9cd26f2f76 100644 --- a/csrc/quantization/fp8/amd/hip_float8_impl.h +++ b/csrc/quantization/fp8/amd/hip_float8_impl.h @@ -1,7 +1,6 @@ #pragma once -#if defined(__HIPCC__) && \ - (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if defined(__HIPCC__) && defined(__gfx942__) #define __HIP__MI300__ #endif diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 366b3cdc23aa..82f7104a9e5a 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -24,8 +24,7 @@ #include "../attention/dtype_fp8.cuh" #include "../quantization/fp8/amd/quant_utils.cuh" -#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ - defined(__gfx941__) || defined(__gfx942__)) +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) #define __HIP__MI300_MI250__ #endif diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 02bff57a62b7..f49b37842d9b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -25,8 +25,7 @@ _PARTITION_SIZE_ROCM = 512 _GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName _ON_NAVI = "gfx1" in _GPU_ARCH -_ON_MI250_MI300 = any(arch in _GPU_ARCH - for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]) +_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"]) class ROCmFlashAttentionBackend(AttentionBackend): From 3e32a6a4720325d18fc6e21d25c3a4f6c5724f6b Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 20 Feb 2025 02:05:13 -0500 Subject: [PATCH 129/317] [API Server] Add port number range validation (#13506) Signed-off-by: Yuan Tang --- vllm/entrypoints/api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 00793d4b9677..4294a8aad9a5 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -145,7 +145,7 @@ async def run_server(args: Namespace, if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--port", type=int, default=8000, ge=1024, le=65535) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument("--ssl-ca-certs", From 4ba521b463f5620483a230aa52a2be35b69d9b19 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 20 Feb 2025 02:05:44 -0500 Subject: [PATCH 130/317] [CI/Build] Use uv in the Dockerfile (#13566) --- Dockerfile | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/Dockerfile b/Dockerfile index 26da8c0f2690..310e003d427d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,9 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version +# Install uv for faster pip installs +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install uv # Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 # as it was causing spam when compiling the CUTLASS kernels @@ -52,13 +55,13 @@ WORKDIR /workspace # after this step RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \ + uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \ fi COPY requirements-common.txt requirements-common.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-cuda.txt + uv pip install --system -r requirements-cuda.txt # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -79,7 +82,7 @@ ARG TARGETPLATFORM COPY requirements-build.txt requirements-build.txt RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-build.txt + uv pip install --system -r requirements-build.txt COPY . . ARG GIT_REPO_CHECK=0 @@ -144,7 +147,7 @@ COPY requirements-lint.txt requirements-lint.txt COPY requirements-test.txt requirements-test.txt COPY requirements-dev.txt requirements-dev.txt RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-dev.txt + uv pip install --system -r requirements-dev.txt #################### DEV IMAGE #################### #################### vLLM installation IMAGE #################### @@ -174,6 +177,9 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version +# Install uv for faster pip installs +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install uv # Workaround for https://github.com/openai/triton/issues/2507 and # https://github.com/pytorch/pytorch/issues/107960 -- hopefully @@ -187,13 +193,13 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ # after this step RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \ + uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \ fi # Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install dist/*.whl --verbose + uv pip install --system dist/*.whl --verbose # If we need to build FlashInfer wheel before its release: # $ export FLASHINFER_ENABLE_AOT=1 @@ -210,7 +216,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=cache,target=/root/.cache/pip \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post1/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl ; \ + uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post1/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl ; \ fi COPY examples examples @@ -220,7 +226,7 @@ COPY examples examples # TODO: Remove this once FlashInfer AOT wheel is fixed COPY requirements-build.txt requirements-build.txt RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-build.txt + uv pip install --system -r requirements-build.txt #################### vLLM installation IMAGE #################### @@ -233,15 +239,15 @@ ADD . /vllm-workspace/ # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-dev.txt + uv pip install --system -r requirements-dev.txt # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -e tests/vllm_test_utils + uv pip install --system -e tests/vllm_test_utils # enable fast downloads from hf (for testing) RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install hf_transfer + uv pip install --system hf_transfer ENV HF_HUB_ENABLE_HF_TRANSFER 1 # Copy in the v1 package for testing (it isn't distributed yet) @@ -262,9 +268,9 @@ FROM vllm-base AS vllm-openai-base # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ + uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ else \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ + uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ fi ENV VLLM_USAGE_SOURCE production-docker-image From a6333533c20ed912997e926931f92743bf046c27 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Thu, 20 Feb 2025 00:56:00 -0800 Subject: [PATCH 131/317] [ci] Fix spec decode test (#13600) --- tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 46b8dd1e1df1..ca268dd6657c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,6 @@ "ArthurZ/Ilama-3.2-1B", "llava-hf/llava-1.5-7b-hf", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "JackFram/llama-160m", "ai21labs/Jamba-tiny-random", "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", "nm-testing/Phi-3-mini-128k-instruct-FP8", From 66b0fd7b06825a0471fa3fac5dd22a2dfdd553ae Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Thu, 20 Feb 2025 01:20:15 -0800 Subject: [PATCH 132/317] [2/n][ci] S3: Use full model path (#13564) Signed-off-by: <> --- tests/basic_correctness/test_cumem.py | 2 +- tests/conftest.py | 3 +-- tests/engine/test_computed_prefix_blocks.py | 3 ++- tests/engine/test_detokenization.py | 3 ++- tests/engine/test_executor.py | 12 ++++++++---- tests/engine/test_skip_tokenizer_init.py | 3 ++- tests/test_config.py | 13 +++++++------ tests/test_regression.py | 6 +++--- 8 files changed, 26 insertions(+), 19 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 7ebccdb5caed..f1148fc8e3f4 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -121,7 +121,7 @@ def model(x): "model, use_v1", [ # sleep mode with safetensors - (f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B", True), + (f"{MODEL_WEIGHTS_S3_BUCKET}/meta-llama/Llama-3.2-1B", True), # sleep mode with pytorch checkpoint ("facebook/opt-125m", False), ]) diff --git a/tests/conftest.py b/tests/conftest.py index ca268dd6657c..9304b8f17dca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -746,8 +746,7 @@ def __init__( **kwargs, ) -> None: if model_name in MODELS_ON_S3 and not load_format: - model_name = (f"s3://vllm-ci-model-weights/" - f"{model_name.split('/')[-1]}") + model_name = (f"{MODEL_WEIGHTS_S3_BUCKET}/{model_name}") load_format = LoadFormat.RUNAI_STREAMER if not load_format: load_format = LoadFormat.AUTO diff --git a/tests/engine/test_computed_prefix_blocks.py b/tests/engine/test_computed_prefix_blocks.py index 93907ecae554..51e7c8e7739d 100644 --- a/tests/engine/test_computed_prefix_blocks.py +++ b/tests/engine/test_computed_prefix_blocks.py @@ -10,7 +10,8 @@ from ..conftest import MODEL_WEIGHTS_S3_BUCKET -@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) +@pytest.mark.parametrize("model", + [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) @pytest.mark.parametrize("block_size", [16]) def test_computed_prefix_blocks(model: str, block_size: int): # This test checks if we are able to run the engine to completion diff --git a/tests/engine/test_detokenization.py b/tests/engine/test_detokenization.py index ab594aeee40d..6ae4be2e4786 100644 --- a/tests/engine/test_detokenization.py +++ b/tests/engine/test_detokenization.py @@ -9,7 +9,8 @@ from ..conftest import MODEL_WEIGHTS_S3_BUCKET -@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) +@pytest.mark.parametrize("model", + [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) def test_computed_prefix_blocks(model: str): # This test checks if the engine generates completions both with and # without optional detokenization, that detokenization includes text diff --git a/tests/engine/test_executor.py b/tests/engine/test_executor.py index 31c07e709bd9..6a86401ce5db 100644 --- a/tests/engine/test_executor.py +++ b/tests/engine/test_executor.py @@ -38,7 +38,8 @@ def collective_rpc(self, CustomUniExecutorAsync = CustomUniExecutor -@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) +@pytest.mark.parametrize("model", + [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) def test_custom_executor_type_checking(model): with pytest.raises(ValueError): engine_args = EngineArgs(model=model, @@ -51,7 +52,8 @@ def test_custom_executor_type_checking(model): AsyncLLMEngine.from_engine_args(engine_args) -@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) +@pytest.mark.parametrize("model", + [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) def test_custom_executor(model, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -75,7 +77,8 @@ def test_custom_executor(model, tmp_path): os.chdir(cwd) -@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) +@pytest.mark.parametrize("model", + [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) def test_custom_executor_async(model, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -103,7 +106,8 @@ async def t(): os.chdir(cwd) -@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) +@pytest.mark.parametrize("model", + [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) def test_respect_ray(model): # even for TP=1 and PP=1, # if users specify ray, we should use ray. diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index fee7fd3f6aad..b0930eaac17b 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -9,7 +9,8 @@ from ..conftest import MODEL_WEIGHTS_S3_BUCKET -@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"]) +@pytest.mark.parametrize("model", + [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) def test_skip_tokenizer_initialization(model: str): # This test checks if the flag skip_tokenizer_init skips the initialization # of tokenizer and detokenizer. The generated output is expected to contain diff --git a/tests/test_config.py b/tests/test_config.py index 4a1718613302..bc87e6ccdfcc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,13 +14,14 @@ @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_task"), [ - (f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", "generate", "generate"), - (f"{MODEL_WEIGHTS_S3_BUCKET}/e5-mistral-7b-instruct", "pooling", - "embed"), - (f"{MODEL_WEIGHTS_S3_BUCKET}/Qwen2.5-1.5B-apeach", "pooling", + (f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", "generate", + "generate"), + (f"{MODEL_WEIGHTS_S3_BUCKET}/intfloat/e5-mistral-7b-instruct", + "pooling", "embed"), + (f"{MODEL_WEIGHTS_S3_BUCKET}/jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), - (f"{MODEL_WEIGHTS_S3_BUCKET}/ms-marco-MiniLM-L-6-v2", "pooling", - "score"), + (f"{MODEL_WEIGHTS_S3_BUCKET}/cross-encoder/ms-marco-MiniLM-L-6-v2", + "pooling", "score"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), ("openai/whisper-small", "transcription", "transcription"), ], diff --git a/tests/test_regression.py b/tests/test_regression.py index e9b21e1a7232..8cecc2892b6e 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -21,7 +21,7 @@ def test_duplicated_ignored_sequence_group(): sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=256) - llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", + llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", load_format=LoadFormat.RUNAI_STREAMER, max_num_batched_tokens=4096, tensor_parallel_size=1) @@ -35,7 +35,7 @@ def test_max_tokens_none(): sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) - llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", + llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", load_format=LoadFormat.RUNAI_STREAMER, max_num_batched_tokens=4096, tensor_parallel_size=1) @@ -46,7 +46,7 @@ def test_max_tokens_none(): def test_gc(): - llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", + llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", load_format=LoadFormat.RUNAI_STREAMER, enforce_eager=True) del llm From a1ddfdd9bd26f6326c0e2a023ae48defaf0d7d4c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 20 Feb 2025 17:58:06 +0530 Subject: [PATCH 133/317] [Kernel] LoRA - Refactor sgmv kernels (#13110) --- vllm/lora/ops/triton_ops/kernel_utils.py | 243 +++++++++++++++++++++++ vllm/lora/ops/triton_ops/sgmv_expand.py | 117 ++++------- vllm/lora/ops/triton_ops/sgmv_shrink.py | 96 ++++----- 3 files changed, 327 insertions(+), 129 deletions(-) create mode 100644 vllm/lora/ops/triton_ops/kernel_utils.py diff --git a/vllm/lora/ops/triton_ops/kernel_utils.py b/vllm/lora/ops/triton_ops/kernel_utils.py new file mode 100644 index 000000000000..3572d3018622 --- /dev/null +++ b/vllm/lora/ops/triton_ops/kernel_utils.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Utilities for Punica kernel construction. +""" +import triton +import triton.language as tl + + +@triton.jit +def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr, + b_dtype: tl.constexpr): + """ + Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of + B (k x n), iterate, through the K dimension to compute the partial/complete + matrix block product. + If SPLIT_K == 1, the output m x n product is complete. + If SPLIT_K > 1, the thread block computes partial outputs. The partial + outputs are then atomically summed in the caller code. + Args: + a_ptr: Array of pointers, identifying rows of A + b_ptr: Array of pointers, identifying columns of B + ak_stride: K dimension stride of the A matrix + bk_stride: K dimension stride of the B matrix + K: Length of the K dimension + BLOCK_M: M dimension of the output block m x n + BLOCK_N: N dimension of the output block m x n + BLOCK_K: K dimension atom + EVEN_K: True if the blocks of A and B can be loaded without any + masking. + SPLIT_K: Parameter signifying parallelism in the K dimension. + CAST_TYPE: if True, cast the values from the A matrix to the B + matrix dtype. + b_dtype: datatype of the B matrix + """ + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * SPLIT_K * ak_stride + b_ptr += BLOCK_K * SPLIT_K * bk_stride + return accumulator + + +@triton.jit +def do_expand_kernel( + pid_n, + lora_index, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SAME_STRIDE: tl.constexpr, + SLICE_NUM: tl.constexpr, + EVEN_K: tl.constexpr, + CAST_TYPE: tl.constexpr, + ADD_INPUTS: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, + compute the matrix product and store in the appropriate output location. + Given that this is an expand kernel, we don't perform any split-K reduction + as the K dimension is assumed to be small. + """ + + # ls_d*_ptr can be either an integer or a pointer + if SAME_STRIDE: + # integer + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + # pointer + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + + # Identify the input_ptr and lora_ptr from slice_id. + if SLICE_NUM == 1: + cur_input_ptr = input_ptr + cur_lora_ptr = lora_ptr + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = tl.arange(0, BLOCK_K) + a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) + + # Compute the block matrix product. + SPLIT_K = 1 + accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, + offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, + CAST_TYPE, cur_lora_ptr.dtype.element_ty) + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + offset_cm = tl.arange(0, BLOCK_M) + c_ptr = (out_ptr + ram[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] + < (cur_slice_start + N)) + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@triton.jit +def do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_index, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, compute the + matrix product and store in the appropriate output location. + """ + + # Identify the lora_ptr from slice_id. + if SLICE_NUM == 1: + # current lora ptr + cur_lora_ptr = lora_ptr + else: + # current lora ptr + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + a_ptr = (input_ptr + ram[:, None] * input_d0_stride + + offset_k[None, :] * input_d1_stride) + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride) + + # Compute partial/complete block matrix product. + accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k, + K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False, + cur_lora_ptr.dtype.element_ty) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_cm = tl.arange(0, BLOCK_M) + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + + slice_id * output_d0_stride) + c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[ + None, :] * output_d2_stride + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) + + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) diff --git a/vllm/lora/ops/triton_ops/sgmv_expand.py b/vllm/lora/ops/triton_ops/sgmv_expand.py index a8e71cacfe5a..6aa3eafaba4c 100644 --- a/vllm/lora/ops/triton_ops/sgmv_expand.py +++ b/vllm/lora/ops/triton_ops/sgmv_expand.py @@ -14,6 +14,7 @@ from vllm.utils import direct_register_custom_op +from .kernel_utils import do_expand_kernel from .utils import _get_lora_b_ptr @@ -63,86 +64,56 @@ def _sgmv_expand_kernel( curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) pid_m = pid // cta_n_num pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M > M: + if pid_m * BLOCK_M >= M: return - if pid_n * BLOCK_N > curr_N: + if pid_n * BLOCK_N >= curr_N: return lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return - cur_seq_start = tl.load(b_seq_start_loc + cur_batch) - offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - offset_k = tl.arange(0, BLOCK_K) - ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N), - BLOCK_N) - # ls_d*_ptr can be either an integer or a pointer - if SAME_STRIDE: - # integer - cur_lora_d0_stride = ls_d0_ptr - cur_lora_d1_stride = ls_d1_ptr - cur_lora_d2_stride = ls_d2_ptr - else: - # pointer - cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) - cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) - cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) - if SLICE_NUM == 1: - cur_input_ptr = input_ptr - cur_lora_ptr = lora_ptr - - else: - cur_input_ptr = input_ptr + slice_id * input_d0_stride - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(out_ptr.dtype.element_ty)) - - a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + - ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) - b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - offset_k[:, None] * cur_lora_d2_stride + - rbn[None, :] * cur_lora_d1_stride) - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(tl.cdiv(K, BLOCK_K)): - if EVEN_K: - tiled_a = tl.load(a_ptr) - tiled_b = tl.load(b_ptr) - else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, - other=0) - if CAST_TYPE: - tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) - accumulator += tl.dot( - tiled_a, - tiled_b, - ) - a_ptr += BLOCK_K * input_d2_stride - b_ptr += BLOCK_K * cur_lora_d2_stride - - tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) - if SLICE_NUM == 1: - cur_slice_start = slice_start_loc - else: - cur_slice_start = tl.load(slice_start_loc + slice_id) - - offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start - c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride + - offset_cn[None, :] * output_d1_stride) - M = tl.load(seq_lens + cur_batch) - c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & ( - offset_cn[None, :] < (cur_slice_start + curr_N)) - if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) - tiled_c += tiled_out - tl.store(c_ptr, tiled_c, mask=c_mask) + m_offset = tl.load(b_seq_start_loc + cur_batch) + + cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M)) + cta_m_offset = m_offset + (pid_m * BLOCK_M) + offset_m = tl.arange(0, BLOCK_M) + ram = cta_m_offset + tl.max_contiguous( + tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M) + do_expand_kernel( + pid_n, + lora_index, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + curr_N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M, + BLOCK_N, + BLOCK_K, + SAME_STRIDE, + SLICE_NUM, + EVEN_K, + CAST_TYPE, + ADD_INPUTS, + ) @torch.inference_mode() diff --git a/vllm/lora/ops/triton_ops/sgmv_shrink.py b/vllm/lora/ops/triton_ops/sgmv_shrink.py index 8b26583c11c1..b8ed0b020f9a 100644 --- a/vllm/lora/ops/triton_ops/sgmv_shrink.py +++ b/vllm/lora/ops/triton_ops/sgmv_shrink.py @@ -14,6 +14,7 @@ from vllm.utils import direct_register_custom_op +from .kernel_utils import do_shrink_kernel from .utils import _get_lora_a_ptr @@ -62,67 +63,50 @@ def _sgmv_shrink_kernel( pid_sk = pid_mix % SPLIT_K M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M > M: + if pid_m * BLOCK_M >= M: return lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return - cur_seq_start = tl.load(b_seq_start_loc + cur_batch) - offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) - - ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - # input ptr - a_ptr = (input_ptr + cur_seq_start * input_d0_stride + - ram[:, None] * input_d0_stride + - offset_k[None, :] * input_d1_stride) - if SLICE_NUM == 1: - # current lora ptr - cur_lora_ptr = lora_ptr - else: - # current lora ptr - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(input_ptr.dtype.element_ty)) - - b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + - rbn[None, :] * lora_d1_stride + - offset_k[:, None] * lora_d2_stride) - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - tiled_a = tl.load(a_ptr) - tiled_b = tl.load(b_ptr) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < k_remaining, - other=0.0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < k_remaining, - other=0.0) - accumulator += tl.dot(tiled_a, tiled_b) - - a_ptr += BLOCK_K * SPLIT_K * input_d1_stride - b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride - offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + - slice_id * output_d0_stride) - c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[ - None, :] * output_d2_stride - c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] - < N) - accumulator *= scaling - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(c_ptr, accumulator, mask=c_mask) - else: - tl.atomic_add(c_ptr, accumulator, mask=c_mask) + m_offset = tl.load(b_seq_start_loc + cur_batch) + + cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M)) + cta_m_offset = m_offset + (pid_m * BLOCK_M) + offset_m = tl.arange(0, BLOCK_M) + ram = cta_m_offset + tl.max_contiguous( + tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M) + + do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_index, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + cta_m_len, + ram, + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + SLICE_NUM) @torch.inference_mode() From ae773112cdcbda26abcd6cb8a97b79f6f7ed8fb4 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 20 Feb 2025 12:53:51 +0000 Subject: [PATCH 134/317] Merge similar examples in `offline_inference` into single `basic` example (#12737) --- .buildkite/run-cpu-test.sh | 2 +- .buildkite/run-gh200-test.sh | 2 +- .buildkite/run-hpu-test.sh | 2 +- .buildkite/run-openvino-test.sh | 2 +- .buildkite/run-xpu-test.sh | 4 +- .buildkite/test-pipeline.yaml | 12 +-- docs/source/generate_examples.py | 4 +- .../getting_started/installation/cpu/index.md | 4 +- docs/source/getting_started/quickstart.md | 2 +- docs/source/models/generative_models.md | 4 +- docs/source/models/pooling_models.md | 6 +- examples/offline_inference/aqlm_example.py | 47 --------- examples/offline_inference/arctic.py | 28 ------ examples/offline_inference/basic/README.md | 94 ++++++++++++++++++ .../offline_inference/{ => basic}/basic.py | 0 examples/offline_inference/basic/chat.py | 98 +++++++++++++++++++ examples/offline_inference/basic/classify.py | 42 ++++++++ examples/offline_inference/basic/embed.py | 42 ++++++++ examples/offline_inference/basic/generate.py | 57 +++++++++++ examples/offline_inference/basic/score.py | 38 +++++++ .../basic_with_model_default_sampling.py | 32 ------ examples/offline_inference/chat.py | 82 ---------------- examples/offline_inference/classification.py | 30 ------ examples/offline_inference/cli.py | 82 ---------------- examples/offline_inference/cpu_offload.py | 24 ----- examples/offline_inference/embedding.py | 30 ------ examples/offline_inference/gguf_inference.py | 34 ------- examples/offline_inference/scoring.py | 25 ----- tests/plugins_tests/test_platform_plugins.py | 2 +- 29 files changed, 394 insertions(+), 437 deletions(-) delete mode 100644 examples/offline_inference/aqlm_example.py delete mode 100644 examples/offline_inference/arctic.py create mode 100644 examples/offline_inference/basic/README.md rename examples/offline_inference/{ => basic}/basic.py (100%) create mode 100644 examples/offline_inference/basic/chat.py create mode 100644 examples/offline_inference/basic/classify.py create mode 100644 examples/offline_inference/basic/embed.py create mode 100644 examples/offline_inference/basic/generate.py create mode 100644 examples/offline_inference/basic/score.py delete mode 100644 examples/offline_inference/basic_with_model_default_sampling.py delete mode 100644 examples/offline_inference/chat.py delete mode 100644 examples/offline_inference/classification.py delete mode 100644 examples/offline_inference/cli.py delete mode 100644 examples/offline_inference/cpu_offload.py delete mode 100644 examples/offline_inference/embedding.py delete mode 100644 examples/offline_inference/gguf_inference.py delete mode 100644 examples/offline_inference/scoring.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index e19ace782feb..2ead1f51ed81 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -30,7 +30,7 @@ function cpu_tests() { # offline inference docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c " set -e - python3 examples/offline_inference/basic.py" + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" # Run basic model test docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh index 99972afa21d1..20aca328ba13 100644 --- a/.buildkite/run-gh200-test.sh +++ b/.buildkite/run-gh200-test.sh @@ -24,5 +24,5 @@ remove_docker_container # Run the image and test offline inference docker run -e HF_TOKEN -v /root/.cache/huggingface:/root/.cache/huggingface --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' - python3 examples/offline_inference/cli.py --model meta-llama/Llama-3.2-1B + python3 examples/offline_inference/basic/generate.py --model meta-llama/Llama-3.2-1B ' diff --git a/.buildkite/run-hpu-test.sh b/.buildkite/run-hpu-test.sh index 1edcb1d2669e..f83eb927aae4 100644 --- a/.buildkite/run-hpu-test.sh +++ b/.buildkite/run-hpu-test.sh @@ -20,5 +20,5 @@ trap remove_docker_container_and_exit EXIT remove_docker_container # Run the image and launch offline inference -docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py +docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m EXITCODE=$? diff --git a/.buildkite/run-openvino-test.sh b/.buildkite/run-openvino-test.sh index 6159b21ff820..a1103bed66ec 100755 --- a/.buildkite/run-openvino-test.sh +++ b/.buildkite/run-openvino-test.sh @@ -13,4 +13,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic.py +docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic/generate.py --model facebook/opt-125m diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh index 4d344e58db8a..d48639e5720c 100644 --- a/.buildkite/run-xpu-test.sh +++ b/.buildkite/run-xpu-test.sh @@ -14,6 +14,6 @@ remove_docker_container # Run the image and test offline inference/tensor parallel docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c ' - python3 examples/offline_inference/basic.py - python3 examples/offline_inference/cli.py -tp 2 + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 ' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9d05ff4c2cfd..66efe3ed3298 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -215,18 +215,18 @@ steps: - examples/ commands: - pip install tensorizer # for tensorizer test - - python3 offline_inference/basic.py - - python3 offline_inference/cpu_offload.py - - python3 offline_inference/chat.py + - python3 offline_inference/basic/generate.py --model facebook/opt-125m + - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 + - python3 offline_inference/basic/chat.py - python3 offline_inference/prefix_caching.py - python3 offline_inference/llm_engine_example.py - python3 offline_inference/vision_language.py - python3 offline_inference/vision_language_multi_image.py - python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/encoder_decoder.py - - python3 offline_inference/classification.py - - python3 offline_inference/embedding.py - - python3 offline_inference/scoring.py + - python3 offline_inference/basic/classify.py + - python3 offline_inference/basic/embed.py + - python3 offline_inference/basic/score.py - python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min diff --git a/docs/source/generate_examples.py b/docs/source/generate_examples.py index 9d4de18a3b79..c5f75953aaf2 100644 --- a/docs/source/generate_examples.py +++ b/docs/source/generate_examples.py @@ -147,7 +147,7 @@ def generate(self) -> str: return content content += "## Example materials\n\n" - for file in self.other_files: + for file in sorted(self.other_files): include = "include" if file.suffix == ".md" else "literalinclude" content += f":::{{admonition}} {file.relative_to(self.path)}\n" content += ":class: dropdown\n\n" @@ -194,7 +194,7 @@ def generate_examples(): path=EXAMPLE_DOC_DIR / "examples_offline_inference_index.md", title="Offline Inference", description= - "Offline inference examples demonstrate how to use vLLM in an offline setting, where the model is queried for predictions in batches.", # noqa: E501 + "Offline inference examples demonstrate how to use vLLM in an offline setting, where the model is queried for predictions in batches. We recommend starting with .", # noqa: E501 caption="Examples", ), } diff --git a/docs/source/getting_started/installation/cpu/index.md b/docs/source/getting_started/installation/cpu/index.md index d53430403583..9c5977939cc5 100644 --- a/docs/source/getting_started/installation/cpu/index.md +++ b/docs/source/getting_started/installation/cpu/index.md @@ -170,7 +170,7 @@ vLLM CPU backend supports the following vLLM features: sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library find / -name *libtcmalloc* # find the dynamic link library path export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD -python examples/offline_inference/basic.py # run vLLM +python examples/offline_inference/basic/basic.py # run vLLM ``` - When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP: @@ -207,7 +207,7 @@ CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ # On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15 $ export VLLM_CPU_OMP_THREADS_BIND=0-7 -$ python examples/offline_inference/basic.py +$ python examples/offline_inference/basic/basic.py ``` - If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using `VLLM_CPU_OMP_THREADS_BIND` to avoid cross NUMA node memory access. diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index f4682ee45a48..f3a4773f0fc6 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -40,7 +40,7 @@ For non-CUDA platforms, please refer [here](#installation-index) for specific in ## Offline Batched Inference -With vLLM installed, you can start generating texts for list of input prompts (i.e. offline batch inferencing). See the example script: +With vLLM installed, you can start generating texts for list of input prompts (i.e. offline batch inferencing). See the example script: The first line of this example imports the classes {class}`~vllm.LLM` and {class}`~vllm.SamplingParams`: diff --git a/docs/source/models/generative_models.md b/docs/source/models/generative_models.md index 4abe6b776eea..f31e5715d175 100644 --- a/docs/source/models/generative_models.md +++ b/docs/source/models/generative_models.md @@ -46,7 +46,7 @@ for output in outputs: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -A code example can be found here: +A code example can be found here: ### `LLM.beam_search` @@ -103,7 +103,7 @@ for output in outputs: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -A code example can be found here: +A code example can be found here: If the model doesn't have a chat template or you want to specify another one, you can explicitly pass a chat template: diff --git a/docs/source/models/pooling_models.md b/docs/source/models/pooling_models.md index 764b67241999..8612935432b8 100644 --- a/docs/source/models/pooling_models.md +++ b/docs/source/models/pooling_models.md @@ -88,7 +88,7 @@ embeds = output.outputs.embedding print(f"Embeddings: {embeds!r} (size={len(embeds)})") ``` -A code example can be found here: +A code example can be found here: ### `LLM.classify` @@ -103,7 +103,7 @@ probs = output.outputs.probs print(f"Class Probabilities: {probs!r} (size={len(probs)})") ``` -A code example can be found here: +A code example can be found here: ### `LLM.score` @@ -125,7 +125,7 @@ score = output.outputs.score print(f"Score: {score}") ``` -A code example can be found here: +A code example can be found here: ## Online Serving diff --git a/examples/offline_inference/aqlm_example.py b/examples/offline_inference/aqlm_example.py deleted file mode 100644 index e8db3811ff17..000000000000 --- a/examples/offline_inference/aqlm_example.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM, SamplingParams -from vllm.utils import FlexibleArgumentParser - - -def main(): - - parser = FlexibleArgumentParser(description='AQLM examples') - - parser.add_argument('--model', - '-m', - type=str, - default=None, - help='model path, as for HF') - parser.add_argument('--choice', - '-c', - type=int, - default=0, - help='known good models by index, [0-4]') - parser.add_argument('--tensor-parallel-size', - '-t', - type=int, - default=1, - help='tensor parallel size') - - args = parser.parse_args() - - models = [ - "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", - "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf", - "ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf", - "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf", - "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf", - ] - - model = LLM(args.model if args.model is not None else models[args.choice], - tensor_parallel_size=args.tensor_parallel_size) - - sampling_params = SamplingParams(max_tokens=100, temperature=0) - outputs = model.generate("Hello my name is", - sampling_params=sampling_params) - print(outputs[0].outputs[0].text) - - -if __name__ == '__main__': - main() diff --git a/examples/offline_inference/arctic.py b/examples/offline_inference/arctic.py deleted file mode 100644 index 90c88446c514..000000000000 --- a/examples/offline_inference/arctic.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. -llm = LLM(model="snowflake/snowflake-arctic-instruct", - quantization="deepspeedfp", - tensor_parallel_size=8, - trust_remote_code=True) -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. - -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference/basic/README.md b/examples/offline_inference/basic/README.md new file mode 100644 index 000000000000..5cb0177b355d --- /dev/null +++ b/examples/offline_inference/basic/README.md @@ -0,0 +1,94 @@ +# Basic + +The `LLM` class provides the primary Python interface for doing offline inference, which is interacting with a model without using a separate model inference server. + +## Usage + +The first script in this example shows the most basic usage of vLLM. If you are new to Python and vLLM, you should start here. + +```bash +python examples/offline_inference/basic/basic.py +``` + +The rest of the scripts include an [argument parser](https://docs.python.org/3/library/argparse.html), which you can use to pass any arguments that are compatible with [`LLM`](https://docs.vllm.ai/en/latest/api/offline_inference/llm.html). Try running the script with `--help` for a list of all available arguments. + +```bash +python examples/offline_inference/basic/classify.py +``` + +```bash +python examples/offline_inference/basic/embed.py +``` + +```bash +python examples/offline_inference/basic/score.py +``` + +The chat and generate scripts also accept the [sampling parameters](https://docs.vllm.ai/en/latest/api/inference_params.html#sampling-parameters): `max_tokens`, `temperature`, `top_p` and `top_k`. + +```bash +python examples/offline_inference/basic/chat.py +``` + +```bash +python examples/offline_inference/basic/generate.py +``` + +## Features + +In the scripts that support passing arguments, you can experiment with the following features. + +### Default generation config + +The `--generation-config` argument specifies where the generation config will be loaded from when calling `LLM.get_default_sampling_params()`. If set to ‘auto’, the generation config will be loaded from model path. If set to a folder path, the generation config will be loaded from the specified folder path. If it is not provided, vLLM defaults will be used. + +> If max_new_tokens is specified in generation config, then it sets a server-wide limit on the number of output tokens for all requests. + +Try it yourself with the following argument: + +```bash +--generation-config auto +``` + +### Quantization + +#### AQLM + +vLLM supports models that are quantized using AQLM. + +Try one yourself by passing one of the following models to the `--model` argument: + +- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf` +- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf` +- `ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf` +- `ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf` +- `BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf` + +> Some of these models are likely to be too large for a single GPU. You can split them across multiple GPUs by setting `--tensor-parallel-size` to the number of required GPUs. + +#### GGUF + +vLLM supports models that are quantized using GGUF. + +Try one yourself by downloading a GUFF quantised model and using the following arguments: + +```python +from huggingface_hub import hf_hub_download +repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF" +filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf" +print(hf_hub_download(repo_id, filename=filename)) +``` + +```bash +--model {local-path-printed-above} --tokenizer microsoft/Phi-3-medium-4k-instruct +``` + +### CPU offload + +The `--cpu-offload-gb` argument can be seen as a virtual way to increase the GPU memory size. For example, if you have one 24 GB GPU and set this to 10, virtually you can think of it as a 34 GB GPU. Then you can load a 13B model with BF16 weight, which requires at least 26GB GPU memory. Note that this requires fast CPU-GPU interconnect, as part of the model is loaded from CPU memory to GPU memory on the fly in each model forward pass. + +Try it yourself with the following arguments: + +```bash +--model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 +``` diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic/basic.py similarity index 100% rename from examples/offline_inference/basic.py rename to examples/offline_inference/basic/basic.py diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py new file mode 100644 index 000000000000..b2523e533a40 --- /dev/null +++ b/examples/offline_inference/basic/chat.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def main(args: dict): + # Pop arguments not used by LLM + max_tokens = args.pop("max_tokens") + temperature = args.pop("temperature") + top_p = args.pop("top_p") + top_k = args.pop("top_k") + chat_template_path = args.pop("chat_template_path") + + # Create an LLM + llm = LLM(**args) + + # Create sampling params object + sampling_params = llm.get_default_sampling_params() + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if temperature is not None: + sampling_params.temperature = temperature + if top_p is not None: + sampling_params.top_p = top_p + if top_k is not None: + sampling_params.top_k = top_k + + def print_outputs(outputs): + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}") + print("-" * 80) + + print("=" * 80) + + # In this script, we demonstrate how to pass input to the chat method: + conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": + "Write an essay about the importance of higher education.", + }, + ] + outputs = llm.chat(conversation, sampling_params, use_tqdm=False) + print_outputs(outputs) + + # You can run batch inference with llm.chat API + conversations = [conversation for _ in range(10)] + + # We turn on tqdm progress bar to verify it's indeed running batch inference + outputs = llm.chat(conversations, sampling_params, use_tqdm=True) + print_outputs(outputs) + + # A chat template can be optionally supplied. + # If not, the model will use its default chat template. + if chat_template_path is not None: + with open(chat_template_path) as f: + chat_template = f.read() + + outputs = llm.chat( + conversations, + sampling_params, + use_tqdm=False, + chat_template=chat_template, + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + # Add engine args + engine_group = parser.add_argument_group("Engine arguments") + EngineArgs.add_cli_args(engine_group) + engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + # Add example params + parser.add_argument("--chat-template-path", type=str) + args: dict = vars(parser.parse_args()) + main(args) diff --git a/examples/offline_inference/basic/classify.py b/examples/offline_inference/basic/classify.py new file mode 100644 index 000000000000..4ef949b4784d --- /dev/null +++ b/examples/offline_inference/basic/classify.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create an LLM. + # You should pass task="classify" for classification models + model = LLM(**vars(args)) + + # Generate logits. The output is a list of ClassificationRequestOutputs. + outputs = model.classify(prompts) + + # Print the outputs. + for prompt, output in zip(prompts, outputs): + probs = output.outputs.probs + probs_trimmed = ((str(probs[:16])[:-1] + + ", ...]") if len(probs) > 16 else probs) + print(f"Prompt: {prompt!r} | " + f"Class Probabilities: {probs_trimmed} (size={len(probs)})") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach", + task="classify", + enforce_eager=True) + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py new file mode 100644 index 000000000000..f1655b6dbe11 --- /dev/null +++ b/examples/offline_inference/basic/embed.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create an LLM. + # You should pass task="embed" for embedding models + model = LLM(**vars(args)) + + # Generate embedding. The output is a list of EmbeddingRequestOutputs. + outputs = model.embed(prompts) + + # Print the outputs. + for prompt, output in zip(prompts, outputs): + embeds = output.outputs.embedding + embeds_trimmed = ((str(embeds[:16])[:-1] + + ", ...]") if len(embeds) > 16 else embeds) + print(f"Prompt: {prompt!r} | " + f"Embeddings: {embeds_trimmed} (size={len(embeds)})") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="intfloat/e5-mistral-7b-instruct", + task="embed", + enforce_eager=True) + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference/basic/generate.py b/examples/offline_inference/basic/generate.py new file mode 100644 index 000000000000..93f4f2a36fac --- /dev/null +++ b/examples/offline_inference/basic/generate.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def main(args: dict): + # Pop arguments not used by LLM + max_tokens = args.pop("max_tokens") + temperature = args.pop("temperature") + top_p = args.pop("top_p") + top_k = args.pop("top_k") + + # Create an LLM + llm = LLM(**args) + + # Create a sampling params object + sampling_params = llm.get_default_sampling_params() + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if temperature is not None: + sampling_params.temperature = temperature + if top_p is not None: + sampling_params.top_p = top_p + if top_k is not None: + sampling_params.top_k = top_k + + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + # Add engine args + engine_group = parser.add_argument_group("Engine arguments") + EngineArgs.add_cli_args(engine_group) + engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + args: dict = vars(parser.parse_args()) + main(args) diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py new file mode 100644 index 000000000000..2d21f1f0e397 --- /dev/null +++ b/examples/offline_inference/basic/score.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def main(args: Namespace): + # Sample prompts. + text_1 = "What is the capital of France?" + texts_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", + ] + + # Create an LLM. + # You should pass task="score" for cross-encoder models + model = LLM(**vars(args)) + + # Generate scores. The output is a list of ScoringRequestOutputs. + outputs = model.score(text_1, texts_2) + + # Print the outputs. + for text_2, output in zip(texts_2, outputs): + score = output.outputs.score + print(f"Pair: {[text_1, text_2]!r} | Score: {score}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="BAAI/bge-reranker-v2-m3", + task="score", + enforce_eager=True) + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference/basic_with_model_default_sampling.py b/examples/offline_inference/basic_with_model_default_sampling.py deleted file mode 100644 index 80de9428f6a9..000000000000 --- a/examples/offline_inference/basic_with_model_default_sampling.py +++ /dev/null @@ -1,32 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -# Create an LLM with built-in default generation config. -# The generation config is set to None by default to keep -# the behavior consistent with the previous version. -# If you want to use the default generation config from the model, -# you should set the generation_config to "auto". -llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto") - -# Load the default sampling parameters from the model. -sampling_params = llm.get_default_sampling_params() -# Modify the sampling parameters if needed. -sampling_params.temperature = 0.5 - -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference/chat.py b/examples/offline_inference/chat.py deleted file mode 100644 index dbc710cc8a0b..000000000000 --- a/examples/offline_inference/chat.py +++ /dev/null @@ -1,82 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM, SamplingParams - -llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") -sampling_params = SamplingParams(temperature=0.5) - - -def print_outputs(outputs): - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - print("-" * 80) - - -print("=" * 80) - -# In this script, we demonstrate how to pass input to the chat method: - -conversation = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "Write an essay about the importance of higher education.", - }, -] -outputs = llm.chat(conversation, - sampling_params=sampling_params, - use_tqdm=False) -print_outputs(outputs) - -# You can run batch inference with llm.chat API -conversation = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "Write an essay about the importance of higher education.", - }, -] -conversations = [conversation for _ in range(10)] - -# We turn on tqdm progress bar to verify it's indeed running batch inference -outputs = llm.chat(messages=conversations, - sampling_params=sampling_params, - use_tqdm=True) -print_outputs(outputs) - -# A chat template can be optionally supplied. -# If not, the model will use its default chat template. - -# with open('template_falcon_180b.jinja', "r") as f: -# chat_template = f.read() - -# outputs = llm.chat( -# conversations, -# sampling_params=sampling_params, -# use_tqdm=False, -# chat_template=chat_template, -# ) diff --git a/examples/offline_inference/classification.py b/examples/offline_inference/classification.py deleted file mode 100644 index 4a364aeb8c47..000000000000 --- a/examples/offline_inference/classification.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -# Create an LLM. -# You should pass task="classify" for classification models -model = LLM( - model="jason9693/Qwen2.5-1.5B-apeach", - task="classify", - enforce_eager=True, -) - -# Generate logits. The output is a list of ClassificationRequestOutputs. -outputs = model.classify(prompts) - -# Print the outputs. -for prompt, output in zip(prompts, outputs): - probs = output.outputs.probs - probs_trimmed = ((str(probs[:16])[:-1] + - ", ...]") if len(probs) > 16 else probs) - print(f"Prompt: {prompt!r} | " - f"Class Probabilities: {probs_trimmed} (size={len(probs)})") diff --git a/examples/offline_inference/cli.py b/examples/offline_inference/cli.py deleted file mode 100644 index bc6833b3f39c..000000000000 --- a/examples/offline_inference/cli.py +++ /dev/null @@ -1,82 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from dataclasses import asdict - -from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.utils import FlexibleArgumentParser - - -def get_prompts(num_prompts: int): - # The default sample prompts. - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - if num_prompts != len(prompts): - prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] - - return prompts - - -def main(args): - # Create prompts - prompts = get_prompts(args.num_prompts) - - # Create a sampling params object. - sampling_params = SamplingParams(n=args.n, - temperature=args.temperature, - top_p=args.top_p, - top_k=args.top_k, - max_tokens=args.max_tokens) - - # Create an LLM. - # The default model is 'facebook/opt-125m' - engine_args = EngineArgs.from_cli_args(args) - llm = LLM(**asdict(engine_args)) - - # Generate texts from the prompts. - # The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -if __name__ == '__main__': - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - group = parser.add_argument_group("SamplingParams options") - group.add_argument("--num-prompts", - type=int, - default=4, - help="Number of prompts used for inference") - group.add_argument("--max-tokens", - type=int, - default=16, - help="Generated output length for sampling") - group.add_argument('--n', - type=int, - default=1, - help='Number of generated sequences per prompt') - group.add_argument('--temperature', - type=float, - default=0.8, - help='Temperature for text generation') - group.add_argument('--top-p', - type=float, - default=0.95, - help='top_p for text generation') - group.add_argument('--top-k', - type=int, - default=-1, - help='top_k for text generation') - - args = parser.parse_args() - main(args) diff --git a/examples/offline_inference/cpu_offload.py b/examples/offline_inference/cpu_offload.py deleted file mode 100644 index 5511eb738778..000000000000 --- a/examples/offline_inference/cpu_offload.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. -llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10) -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference/embedding.py b/examples/offline_inference/embedding.py deleted file mode 100644 index f9399329d24f..000000000000 --- a/examples/offline_inference/embedding.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -# Create an LLM. -# You should pass task="embed" for embedding models -model = LLM( - model="intfloat/e5-mistral-7b-instruct", - task="embed", - enforce_eager=True, -) - -# Generate embedding. The output is a list of EmbeddingRequestOutputs. -outputs = model.embed(prompts) - -# Print the outputs. -for prompt, output in zip(prompts, outputs): - embeds = output.outputs.embedding - embeds_trimmed = ((str(embeds[:16])[:-1] + - ", ...]") if len(embeds) > 16 else embeds) - print(f"Prompt: {prompt!r} | " - f"Embeddings: {embeds_trimmed} (size={len(embeds)})") diff --git a/examples/offline_inference/gguf_inference.py b/examples/offline_inference/gguf_inference.py deleted file mode 100644 index 0447e74e0d6f..000000000000 --- a/examples/offline_inference/gguf_inference.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from huggingface_hub import hf_hub_download - -from vllm import LLM, SamplingParams - - -def run_gguf_inference(model_path, tokenizer): - # Sample prompts. - prompts = [ - "How many helicopters can a human eat in one sitting?", - "What's the future of AI?", - ] - prompts = [[{"role": "user", "content": prompt}] for prompt in prompts] - # Create a sampling params object. - sampling_params = SamplingParams(temperature=0, max_tokens=128) - - # Create an LLM. - llm = LLM(model=model_path, tokenizer=tokenizer) - - outputs = llm.chat(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -if __name__ == "__main__": - repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF" - filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf" - tokenizer = "microsoft/Phi-3-medium-4k-instruct" - model = hf_hub_download(repo_id, filename=filename) - run_gguf_inference(model, tokenizer) diff --git a/examples/offline_inference/scoring.py b/examples/offline_inference/scoring.py deleted file mode 100644 index 7daa82b82772..000000000000 --- a/examples/offline_inference/scoring.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM - -# Sample prompts. -text_1 = "What is the capital of France?" -texts_2 = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." -] - -# Create an LLM. -# You should pass task="score" for cross-encoder models -model = LLM( - model="BAAI/bge-reranker-v2-m3", - task="score", - enforce_eager=True, -) - -# Generate scores. The output is a list of ScoringRequestOutputs. -outputs = model.score(text_1, texts_2) - -# Print the outputs. -for text_2, output in zip(texts_2, outputs): - score = output.outputs.score - print(f"Pair: {[text_1, text_2]!r} | Score: {score}") diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index ed50fe535014..3be248f5aca4 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -14,7 +14,7 @@ def test_platform_plugins(): import os example_file = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(current_file))), - "examples", "offline_inference/basic.py") + "examples", "offline_inference/basic/basic.py") runpy.run_path(example_file) # check if the plugin is loaded correctly From ecfd03e58e5f1504112bb2fa3e69366181179832 Mon Sep 17 00:00:00 2001 From: chenxiaobing <22113491+Chen-XiaoBing@users.noreply.github.com> Date: Thu, 20 Feb 2025 22:47:01 +0800 Subject: [PATCH 135/317] [Bugfix] Fix deepseekv3 grouped topk error (#13474) Signed-off-by: Chen-XiaoBing --- vllm/model_executor/layers/fused_moe/fused_moe.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d0b6249e1c33..543c8ced165a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -939,15 +939,17 @@ def grouped_topk(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported scoring function: {scoring_func}") + num_token = scores.shape[0] if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) - - num_token = scores.shape[0] - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] @@ -955,7 +957,8 @@ def grouped_topk(hidden_states: torch.Tensor, score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] From 55d7ec54e54eed07b8069a07062e2817fbd0b135 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 20 Feb 2025 16:00:14 +0000 Subject: [PATCH 136/317] Update `pre-commit`'s `isort` version to remove warnings (#13614) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1967065c09b..5c4cb767c9ec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: additional_dependencies: ['tomli'] args: ['--toml', 'pyproject.toml'] - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0 hooks: - id: isort exclude: 'vllm/third_party/.*' From 63645d2c7305f8f257dacf8e0afb7b158597b06c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 20 Feb 2025 09:24:31 -0800 Subject: [PATCH 137/317] [V1][Minor] Print KV cache size in token counts (#13596) Signed-off-by: Woosuk Kwon --- vllm/v1/core/kv_cache_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6dec87d4dd20..e3eb6b24c195 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -519,11 +519,13 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) num_blocks = num_gpu_blocks_override - logger.info("# GPU blocks: %d", num_blocks) - max_concurrency = (num_blocks * vllm_config.cache_config.block_size / - vllm_config.model_config.max_model_len) + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = num_tokens / vllm_config.model_config.max_model_len logger.info("Maximum concurrency for %s tokens per request: %.2fx", - vllm_config.model_config.max_model_len, max_concurrency) + max_model_len_str, max_concurrency) per_layer_size = page_size * num_blocks From 6b813011963413bbc99ac87f790f9014d79d4274 Mon Sep 17 00:00:00 2001 From: ajayvohra2005 Date: Thu, 20 Feb 2025 13:59:36 -0500 Subject: [PATCH 138/317] fix neuron performance issue (#13589) --- vllm/worker/neuron_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 5f0eb0019eee..95e7acd025f0 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -76,7 +76,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Set the number of GPU blocks to be the same as the maximum number of # sequences that can be processed in a single batch. This is equivalent # to schedule without PagedAttention. - num_gpu_blocks = self.scheduler_config.max_num_seqs + num_gpu_blocks = self.scheduler_config.max_num_seqs + 1 # Swap not yet supported with Neuron backend. num_cpu_blocks = 0 @@ -90,7 +90,7 @@ def initialize_cache(self, num_gpu_blocks: int, # Different values are not tested. assert num_cpu_blocks == 0 - assert num_gpu_blocks == self.scheduler_config.max_num_seqs + assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1 self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks From 07f3c1edc2424c0d42ef5c7a848b764017a58774 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 20 Feb 2025 13:07:58 -0700 Subject: [PATCH 139/317] [Frontend] Add backend-specific options for guided decoding (#13505) Signed-off-by: Joe Runde --- docs/source/features/structured_outputs.md | 2 +- ...enai_chat_completion_structured_outputs.py | 25 +++++- tests/entrypoints/llm/test_guided_generate.py | 16 ++++ .../model_executor/test_guided_processors.py | 10 +++ vllm/config.py | 5 +- vllm/engine/arg_utils.py | 7 +- .../guided_decoding/__init__.py | 81 ++++++++++--------- vllm/sampling_params.py | 19 +++++ 8 files changed, 123 insertions(+), 42 deletions(-) diff --git a/docs/source/features/structured_outputs.md b/docs/source/features/structured_outputs.md index 90c880e8cfa4..1d5aa07ab177 100644 --- a/docs/source/features/structured_outputs.md +++ b/docs/source/features/structured_outputs.md @@ -16,7 +16,7 @@ The following parameters are supported, which must be added as extra parameters: - `guided_json`: the output will follow the JSON schema. - `guided_grammar`: the output will follow the context free grammar. - `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding. -- `guided_decoding_backend`: used to select the guided decoding backend to use. +- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error. You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page. diff --git a/examples/online_serving/openai_chat_completion_structured_outputs.py b/examples/online_serving/openai_chat_completion_structured_outputs.py index cddd9318000b..986ff500e586 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs.py @@ -2,7 +2,7 @@ from enum import Enum -from openai import OpenAI +from openai import BadRequestError, OpenAI from pydantic import BaseModel client = OpenAI( @@ -94,3 +94,26 @@ class CarDescription(BaseModel): extra_body={"guided_grammar": simplified_sql_grammar}, ) print(completion.choices[0].message.content) + +# Extra backend options +prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + +try: + # The no-fallback option forces vLLM to use xgrammar, so when it fails + # you get a 400 with the reason why + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": "\w+@\w+\.com\n", + "stop": ["\n"], + "guided_decoding_backend": "xgrammar:no-fallback" + }, + ) +except BadRequestError as e: + print("This error is expected:", e) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 70252471cc24..252eb3fb334a 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -280,6 +280,22 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm): guided_options_request=dict(guided_regex=sample_regex)) +@pytest.mark.skip_global_cleanup +def test_disable_guided_decoding_fallback(sample_regex, llm): + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + regex=sample_regex, + backend="xgrammar:no-fallback")) + + with pytest.raises( + ValueError, + match="xgrammar does not support regex guided decoding"): + llm.generate(prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) def test_guided_json_object(llm, guided_decoding_backend: str): diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 64d0928f828f..be544698fa03 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -109,6 +109,16 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") +def test_guided_decoding_backend_options(): + """Test backend-specific options""" + params = GuidedDecodingParams( + backend="xgrammar:option-1,option-2,option-3") + assert params.backend_options() == ["option-1", "option-2", "option-3"] + + no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback") + assert no_fallback.no_fallback() + + def test_pickle_xgrammar_tokenizer_data(): # TODO: move to another test file for xgrammar diff --git a/vllm/config.py b/vllm/config.py index 56315aacbe51..6764694f8059 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -25,6 +25,7 @@ get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import CpuArchEnum +from vllm.sampling_params import GuidedDecodingParams from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -2631,7 +2632,9 @@ def compute_hash(self) -> str: def __post_init__(self): valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] - backend = self.guided_decoding_backend + + backend = GuidedDecodingParams( + backend=self.guided_decoding_backend).backend_name if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}," f"must be one of {valid_guided_backends}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 78681008b62e..5aa77a138a3e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -372,14 +372,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '--guided-decoding-backend', type=str, default='xgrammar', - choices=['outlines', 'lm-format-enforcer', 'xgrammar'], help='Which engine will be used for guided decoding' ' (JSON schema / regex etc) by default. Currently support ' 'https://github.com/outlines-dev/outlines, ' 'https://github.com/mlc-ai/xgrammar, and ' 'https://github.com/noamgat/lm-format-enforcer.' ' Can be overridden per request via guided_decoding_backend' - ' parameter.') + ' parameter.\n' + 'Backend-sepcific options can be supplied in a comma-separated ' + 'list following a colon after the backend name. Valid backends and ' + 'all available options are: [xgrammar:no-fallback, ' + 'outlines:no-fallback, lm-format-enforcer:no-fallback]') parser.add_argument( '--logits-processor-pattern', type=nullable_str, diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 77212a1d8cf1..1522e3404182 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -22,47 +22,56 @@ def maybe_backend_fallback( guided_params: GuidedDecodingParams) -> GuidedDecodingParams: + + def fallback_or_error(guided_params: GuidedDecodingParams, message: str, + fallback: str) -> None: + """Change the backend to the specified fallback with a warning log, + or raise a ValueError if the `no-fallback` option is specified.""" + if guided_params.no_fallback(): + raise ValueError(message) + + logger.warning("%s Falling back to use %s instead.", message, fallback) + guided_params.backend = fallback + # lm-format-enforce doesn't support grammar, fallback to xgrammar - if guided_params.backend == "lm-format-enforcer": + if guided_params.backend_name == "lm-format-enforcer": if guided_params.grammar is not None: - logger.warning( - "lm-format-enforcer does not support grammar guided decoding. " - "Falling back to use xgrammar instead.") - guided_params.backend = "xgrammar" + fallback_or_error( + guided_params, + "lm-format-enforcer does not support grammar guided decoding.", + "xgrammar") # lm-format-enforcer doesn't support some JSON schema features elif (guided_params.json is not None and has_lmf_unsupported_json_features(guided_params.json)): - logger.warning( + fallback_or_error( + guided_params, "lm-format-enforcer does not support advanced JSON schema " - "features like patterns or numeric ranges. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" + "features like patterns or numeric ranges.", "outlines") - if guided_params.backend == "xgrammar": + if guided_params.backend_name == "xgrammar": from vllm.model_executor.guided_decoding.xgrammar_decoding import ( xgr_installed) # xgrammar only has x86 wheels for linux, fallback to outlines from vllm.platforms import current_platform if current_platform.get_cpu_architecture() is not CpuArchEnum.X86: - logger.warning("xgrammar is only supported on x86 CPUs. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" + fallback_or_error(guided_params, + "xgrammar is only supported on x86 CPUs.", + "outlines") # xgrammar doesn't support regex, fallback to outlines if guided_params.regex is not None: - logger.warning("xgrammar does not support regex guided decoding. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" + fallback_or_error( + guided_params, + "xgrammar does not support regex guided decoding.", "outlines") # xgrammar doesn't support some JSON schema features elif (guided_params.json is not None and has_xgrammar_unsupported_json_features(guided_params.json)): - logger.warning( + fallback_or_error( + guided_params, "xgrammar does not support advanced JSON schema features like " - "patterns or numeric ranges. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" + "enums, patterns or numeric ranges.", "outlines") # xgrammar only supports GBNF grammars, so we must convert Lark. # We must check if the grammar is likely Lark and if that @@ -72,25 +81,23 @@ def maybe_backend_fallback( try: convert_lark_to_gbnf(guided_params.grammar) except Exception: - logger.warning( + fallback_or_error( + guided_params, "xgrammar does not support Lark grammars and the " - "grammar failed to convert to GBNF. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" + "grammar failed to convert to GBNF.", "outlines") # If the xgrammar module cannot be imported successfully, # we should still allow users to use guided decoding with a fallback. elif not xgr_installed: - logger.warning("xgrammar module cannot be imported successfully. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" + fallback_or_error( + guided_params, + "xgrammar module cannot be imported successfully.", "outlines") - if (guided_params.backend == "outlines" + if (guided_params.backend_name == "outlines" and guided_params.json_object is not None): # outlines doesn't support json_object, fallback to xgrammar - logger.warning("outlines does not support json_object. " - "Falling back to use xgrammar instead.") - guided_params.backend = "xgrammar" + fallback_or_error(guided_params, + "outlines does not support json_object.", "xgrammar") return guided_params @@ -100,18 +107,18 @@ async def get_guided_decoding_logits_processor( model_config: ModelConfig) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines': + if guided_params.backend_name == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( guided_params, tokenizer) - if guided_params.backend == 'lm-format-enforcer': + if guided_params.backend_name == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) - if guided_params.backend == 'xgrammar': + if guided_params.backend_name == 'xgrammar': from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( @@ -127,18 +134,18 @@ def get_local_guided_decoding_logits_processor( model_config: ModelConfig) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines': + if guided_params.backend_name == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_params, tokenizer) - if guided_params.backend == 'lm-format-enforcer': + if guided_params.backend_name == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) - if guided_params.backend == 'xgrammar': + if guided_params.backend_name == 'xgrammar': from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 04ddcd73fa95..2ce87283df75 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -64,6 +64,25 @@ def from_optional( whitespace_pattern=whitespace_pattern, ) + @property + def backend_name(self) -> str: + """Return the backend name without any options. + + For example if the backend is "xgrammar:no-fallback", returns "xgrammar" + """ + return (self.backend or "").split(":")[0] + + def backend_options(self) -> List[str]: + """Return the backend options as a list of strings.""" + if not self.backend or ":" not in self.backend: + return [] + return self.backend.split(":")[1].split(",") + + def no_fallback(self) -> bool: + """Returns True if the "no-fallback" option is supplied for the guided + decoding backend""" + return "no-fallback" in self.backend_options() + def __post_init__(self): """Validate that some fields are mutually exclusive.""" guide_count = sum([ From cfb63a63ca01f1e68d5ad0107be8b933510b9bf0 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 20 Feb 2025 20:45:20 -0500 Subject: [PATCH 140/317] [Bugfix] Fix max_num_batched_tokens for MLA (#13620) Signed-off-by: mgoin --- vllm/config.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 6764694f8059..f118004b2f2f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -51,6 +51,9 @@ logger = init_logger(__name__) +# This value is chosen to have a balance between ITL and TTFT. Note it is +# not optimized for throughput. +_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 @@ -1526,15 +1529,17 @@ def __post_init__(self) -> None: # for now. Have max_num_batched_tokens set to max_model_len # so we don't reject sequences on account of a short # max_num_batched_tokens. - self.max_num_batched_tokens = max(self.max_model_len, 2048) + self.max_num_batched_tokens = max( + self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) else: - # This value is chosen to have a balance between ITL - # and TTFT. Note it is not optimized for throughput. - self.max_num_batched_tokens = 2048 + self.max_num_batched_tokens = ( + _DEFAULT_MAX_NUM_BATCHED_TOKENS) else: - # If max_model_len is too short, use 2048 as the default value + # If max_model_len is too short, use + # _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value # for higher throughput. - self.max_num_batched_tokens = max(self.max_model_len, 2048) + self.max_num_batched_tokens = max( + self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.runner_type == "pooling": # Choose specific value for higher throughput @@ -3333,6 +3338,9 @@ def __post_init__(self): "caching to be disabled.") self.scheduler_config.enable_chunked_prefill = False self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.max_num_batched_tokens = max( + self.scheduler_config.max_model_len, + _DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.cache_config is not None: self.cache_config.enable_prefix_caching = False From 12965381989856335ccd35a540736cf12f3e5c3f Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Thu, 20 Feb 2025 17:45:45 -0800 Subject: [PATCH 141/317] [Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth (#13245) Signed-off-by: Lingfan Yu --- tests/neuron/test_block_table.py | 153 +++++++ tests/neuron/test_prefix_prefill.py | 332 +++++++------- vllm/attention/ops/nki_flash_attn.py | 627 ++++++++++++++++++--------- 3 files changed, 764 insertions(+), 348 deletions(-) create mode 100644 tests/neuron/test_block_table.py diff --git a/tests/neuron/test_block_table.py b/tests/neuron/test_block_table.py new file mode 100644 index 000000000000..30dcdd573edf --- /dev/null +++ b/tests/neuron/test_block_table.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import neuronxcc.nki.language as nl +import pytest +import torch +import torch.nn.functional as F +from neuronxcc import nki + +from vllm.attention.ops.nki_flash_attn import ( + load_block_tables, transform_block_tables_for_indirect_load) + + +def is_power_of_2(n): + return n > 0 and (n & (n - 1) == 0) + + +def nki_load_and_transform_block_tables( + block_tables, + num_tiles, + num_blocks_per_tile, + num_head, + head_id, + block_size_tiling_factor, +): + assert is_power_of_2( + num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2" + block_tables_sbuf = load_block_tables(block_tables, num_tiles, + num_blocks_per_tile) + + # we need to pass an Index as head_id + head_id = nl.arange(1)[None, :] + head_id + + block_tables_transposed = transform_block_tables_for_indirect_load( + block_tables_sbuf, block_size_tiling_factor, num_head, head_id) + B_P_SIZE = 128 + assert block_tables_transposed.shape[1] == B_P_SIZE + + out = nl.ndarray( + block_tables_transposed.shape, + dtype=nl.int32, + buffer=nl.shared_hbm, + ) + for i in nl.affine_range(block_tables_transposed.shape[0]): + nl.store(dst=out[i], value=block_tables_transposed[i]) + return out + + +def ref_block_tables_transform( + block_tables, + num_tiles, + num_blocks_per_tile, + num_head, + head_id, + block_size_tiling_factor, +): + assert block_tables.numel() == num_tiles * num_blocks_per_tile + block_tables = block_tables.view(num_tiles, num_blocks_per_tile) + B_F_SIZE = 128 + num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE + block_tables = F.pad( + block_tables, + (0, 0, 0, num_tiles_padded - num_tiles), + "constant", + 0, + ) + + block_tables = block_tables * num_head + head_id + block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1) + offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1) + block_tables = block_tables * block_size_tiling_factor + offset + block_tables_transposed = block_tables.view(num_tiles_padded, -1).t() + + num_blocks_per_tile = block_tables_transposed.shape[0] + assert num_blocks_per_tile % B_F_SIZE == 0 + return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE, + B_F_SIZE, num_tiles_padded) + + +@pytest.mark.parametrize( + "q_head_per_kv_head,head_id", + [ + (1, 0), + (3, 1), + ], +) +@pytest.mark.parametrize( + "num_tiles,num_blocks_per_tile", + [ + (1, 1), + (13, 16), + (17, 128), + (35, 512), + (128, 128), + (130, 64), + (280, 256), + (315, 1), + ], +) +@torch.inference_mode() +def test_load_and_transform_block_tables( + num_tiles, + num_blocks_per_tile, + q_head_per_kv_head, + head_id, +) -> None: + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + + compiler_flags = [ + "-O1", + "--retry_failed_compilation", + ] + compiler_flags_str = " ".join(compiler_flags) + os.environ["NEURON_CC_FLAGS"] = compiler_flags_str + + torch.manual_seed(10000) + torch.set_printoptions(sci_mode=False) + + # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient + B_P_SIZE = 128 + if num_blocks_per_tile < B_P_SIZE: + assert B_P_SIZE % num_blocks_per_tile == 0 + block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile + else: + block_size_tiling_factor = 1 + max_num_blocks = 100000 + block_tables = torch.randint( + 0, + max_num_blocks, + (num_tiles * num_blocks_per_tile, ), + dtype=torch.int32, + ) + nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( + block_tables.to(device=device), + num_tiles, + num_blocks_per_tile, + q_head_per_kv_head, + head_id, + block_size_tiling_factor, + ).cpu() + ref_out = ref_block_tables_transform( + block_tables, + num_tiles, + num_blocks_per_tile, + q_head_per_kv_head, + head_id, + block_size_tiling_factor, + ) + assert (nki_out.shape == ref_out.shape + ), f"{nki_out.shape=} != {ref_out.shape=}" + assert torch.all(nki_out == ref_out) diff --git a/tests/neuron/test_prefix_prefill.py b/tests/neuron/test_prefix_prefill.py index 04d1bd3f0eb0..347a139f39b4 100644 --- a/tests/neuron/test_prefix_prefill.py +++ b/tests/neuron/test_prefix_prefill.py @@ -107,7 +107,7 @@ def ref_masked_attention( masked_score, dim=-1, return_max_reduce=True) else: norm_score = ref_softmax(masked_score, dim=-1) - out = torch.einsum("hqk,khd->qhd", norm_score, value) + out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value) if return_max_reduce: return ( out, @@ -118,7 +118,7 @@ def ref_masked_attention( scaled_qk, ) else: - return out + return (out, ) def ref_context_attention( @@ -128,8 +128,6 @@ def ref_context_attention( query_lens, seq_lens, head_size, - num_kv_heads, - num_heads, num_queries_per_kv, return_max_reduce=False, ): @@ -146,18 +144,19 @@ def ref_context_attention( attn_mask = torch.logical_not(attn_mask) attn_mask = attn_mask.float() * -30000 - output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = ( - ref_masked_attention( - query, - key, - value, - scale, - attn_mask, - return_max_reduce=return_max_reduce, - )) + output, *debug_tensors = ref_masked_attention( + query, + key, + value, + scale, + attn_mask, + return_max_reduce=return_max_reduce, + ) output = output.unsqueeze(1) if return_max_reduce: + cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = ( + debug_tensors) return ( output, cached_max, @@ -170,65 +169,22 @@ def ref_context_attention( return output -@pytest.mark.parametrize( - "block_size, large_tile_size", - [ - (32, 2048), # 64 blocks - (32, 4096), # 128 blocks - (32, 8192), # 256 blocks - (64, 8192), # 128 blocks - ], -) -@pytest.mark.parametrize( - "num_heads,num_queries_per_kv,head_size,mixed_precision", - [ - (4, 2, 8, False), - (4, 2, 8, True), - (32, 8, 64, True), - (16, 2, 128, True), - ], -) -@torch.inference_mode() -def test_contexted_kv_attention( - num_heads: int, - num_queries_per_kv: int, - head_size: int, - block_size: int, - large_tile_size, - mixed_precision: bool, -) -> None: - import os - - import torch_xla.core.xla_model as xm - - from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc - - assert large_tile_size % block_size == 0 - - device = xm.xla_device() - - compiler_flags = [ - "--model-type=transformer -O1", - "--internal-hlo2tensorizer-options='--verify-hlo'", - "--retry_failed_compilation", - ] - compiler_flags_str = " ".join(compiler_flags) - os.environ["NEURON_CC_FLAGS"] = compiler_flags_str - - torch.manual_seed(0) - torch.set_printoptions(sci_mode=False) - - min_ctx_len = 32 - max_ctx_len = 1024 - min_query_len = 16 - max_query_len = 512 - prefill_batch_size = 4 - decode_batch_size = 12 +def sample_inputs( + prefill_batch_size, + decode_batch_size, + min_query_len, + max_query_len, + min_ctx_len, + max_ctx_len, + block_size, + num_heads, + num_kv_heads, + head_size, + dtype, +): batch_size = prefill_batch_size + decode_batch_size max_model_len = (max_query_len + max_ctx_len) * 4 - max_block_per_request = max_model_len // block_size - dtype = torch.float32 cache_size = (batch_size * max_block_per_request) + 2 prefill_ctx_lens = torch.randint(min_ctx_len, max_ctx_len + 1, (prefill_batch_size, ), @@ -244,7 +200,6 @@ def test_contexted_kv_attention( dtype=torch.long, ).tolist() + [1 for _ in range(decode_batch_size)] seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] - num_kv_heads = num_heads // num_queries_per_kv num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) @@ -304,47 +259,139 @@ def test_contexted_kv_attention( cur_ctx += block_size block_id += 1 + return ( + query, + k, + v, + k_cache, + v_cache, + block_table, + key, + value, + query_lens, + seq_lens, + ) + + +def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, + num_blocks): + context_lens = seq_lens - query_lens + blocks_per_seq = (context_lens + block_size - 1) // block_size + num_seqs = len(seq_lens) + active_blocks: list[int] = [] + for seq_id in range(num_seqs): + active_blocks = ( + active_blocks + + block_tables[seq_id, :blocks_per_seq[seq_id]].tolist()) + return F.pad( + torch.tensor(active_blocks, dtype=torch.int32), + (0, num_blocks - len(active_blocks)), + "constant", + 0, + ) + + +@pytest.mark.parametrize( + "prefill_batch_size,decode_batch_size,block_size,large_tile_size", + [ + (1, 199, 1, 512), # 512 blocks + (4, 12, 256, 2048), # 128 blocks + (4, 12, 16, 2048), # 128 blocks + (4, 12, 4, 1024), # 256 blocks + (4, 12, 32, 2048), # 64 blocks + (4, 12, 32, 4096), # 128 blocks + (4, 12, 32, 8192), # 256 blocks + (4, 12, 64, 8192), # 128 blocks + ], +) +@pytest.mark.parametrize( + "num_heads,num_queries_per_kv,head_size", + [ + (4, 2, 8), + (32, 8, 64), + (4, 4, 128), + (8, 1, 32), + ], +) +@pytest.mark.parametrize("mixed_precision", [True, False]) +@torch.inference_mode() +def test_contexted_kv_attention( + prefill_batch_size: int, + decode_batch_size: int, + num_heads: int, + num_queries_per_kv: int, + head_size: int, + block_size: int, + large_tile_size, + mixed_precision: bool, +) -> None: + import os + + import torch_xla.core.xla_model as xm + + from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc, + reorder_context_mask) + + assert large_tile_size % block_size == 0 + + device = xm.xla_device() + + compiler_flags = [ + "-O1", + "--retry_failed_compilation", + ] + compiler_flags_str = " ".join(compiler_flags) + os.environ["NEURON_CC_FLAGS"] = compiler_flags_str + + torch.manual_seed(0) + torch.set_printoptions(sci_mode=False) + dtype = torch.float32 + + min_ctx_len = 32 + max_ctx_len = 1024 + min_query_len = 16 + max_query_len = 512 + num_kv_heads = num_heads // num_queries_per_kv ( - output_ref, - cached_max, - cached_sum_reciprocal, - lse, - masked_score, - scaled_qk, - ) = ref_context_attention( + query, + k_active, + v_active, + k_cache, + v_cache, + block_table, + key, + value, + query_lens, + seq_lens, + ) = sample_inputs( + prefill_batch_size=prefill_batch_size, + decode_batch_size=decode_batch_size, + min_query_len=min_query_len, + max_query_len=max_query_len, + min_ctx_len=min_ctx_len, + max_ctx_len=max_ctx_len, + block_size=block_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + ) + + output_ref = ref_context_attention( query, key, value, query_lens, seq_lens, head_size, - num_kv_heads, - num_heads, num_queries_per_kv, - return_max_reduce=True, + return_max_reduce=False, ) # build neuron program - return_debug_tensors = False B_P_SIZE = 128 - LARGE_TILE_SZ = large_tile_size - - def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, - num_blocks): - context_lens = seq_lens - query_lens - blocks_per_seq = (context_lens + block_size - 1) // block_size - num_seqs = len(seq_lens) - active_blocks: list[int] = [] - for seq_id in range(num_seqs): - active_blocks = ( - active_blocks + - block_tables[seq_id, :blocks_per_seq[seq_id]].tolist()) - return F.pad( - torch.tensor(active_blocks), - (0, num_blocks - len(active_blocks)), - "constant", - 0, - ) + assert (large_tile_size >= B_P_SIZE + ), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}" def ceil_div(a, b): return (a + b - 1) // b @@ -357,32 +404,27 @@ def pad_to_next_power_of_2(a): return 2**int(a - 1).bit_length() # calculate input shapes - max_num_queries = pad_to_multiple(sum(query_lens), block_size) - max_num_queries = pad_to_next_power_of_2(max_num_queries) - head_size_padded = B_P_SIZE - assert head_size_padded >= head_size + max_num_queries = pad_to_next_power_of_2(sum(query_lens)) context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) num_active_blocks = ceil_div(context_lens, block_size).sum().item() num_active_blocks = pad_to_multiple(num_active_blocks, - LARGE_TILE_SZ // block_size) + large_tile_size // block_size) context_kv_len = num_active_blocks * block_size assert (context_kv_len % - LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}" + large_tile_size == 0), f"invalid context_kv_len={context_kv_len}" # pad QKV tensors pad_dims = ( 0, - head_size_padded - query.shape[2], + 0, 0, 0, 0, max_num_queries - query.shape[0], ) query = F.pad(query, pad_dims, "constant", 0) - k = F.pad(k, pad_dims, "constant", 0) - v = F.pad(v, pad_dims, "constant", 0) - k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0) - v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0) + k = F.pad(k_active, pad_dims, "constant", 0) + v = F.pad(v_active, pad_dims, "constant", 0) # permute QKV tensors # query: (1, n_heads, d, seq_q) @@ -391,6 +433,8 @@ def pad_to_next_power_of_2(a): query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous() k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous() v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous() + k_cache = k_cache.permute(0, 2, 1, 3).contiguous() + v_cache = v_cache.permute(0, 2, 1, 3).contiguous() # transform block table active_block_table = get_active_block_tables( @@ -405,33 +449,31 @@ def pad_to_next_power_of_2(a): prior_mask, active_mask = ( BlockDiagonalCausalFromBottomRightMask.from_seqlens( query_lens, seq_lens, block_size=block_size)) - attn_mask = torch.concat( - [ - F.pad( - prior_mask, - ( - 0, - context_kv_len - prior_mask.shape[1], - 0, - max_num_queries - prior_mask.shape[0], - ), - "constant", - 0, - ).bool(), - F.pad( - active_mask, - ( - 0, - max_num_queries - active_mask.shape[1], - 0, - max_num_queries - active_mask.shape[0], - ), - "constant", - 0, - ).bool(), - ], - dim=1, - ) + prior_mask_padded = F.pad( + prior_mask, + ( + 0, + context_kv_len - prior_mask.shape[1], + 0, + max_num_queries - prior_mask.shape[0], + ), + "constant", + 0, + ).bool() + active_mask_padded = F.pad( + active_mask, + ( + 0, + max_num_queries - active_mask.shape[1], + 0, + max_num_queries - active_mask.shape[0], + ), + "constant", + 0, + ).bool() + attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1) + + attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size) input_args = ( query.to(device=device), @@ -439,29 +481,21 @@ def pad_to_next_power_of_2(a): v.to(device=device), k_cache.to(device=device), v_cache.to(device=device), - active_block_table.to(torch.int32).to(device=device), + active_block_table.to(device=device), attn_mask.to(device=device), ) input_kwargs = dict( n_kv_head=num_kv_heads, head_size=head_size, mixed_precision=mixed_precision, - LARGE_TILE_SZ=LARGE_TILE_SZ, - return_debug_tensors=return_debug_tensors, + LARGE_TILE_SZ=large_tile_size, ) - if return_debug_tensors: - output_nki, *debug_tensors = flash_attn_varlen_nkifunc( - *input_args, **input_kwargs) - else: - output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) - debug_tensors = [] - - debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors] + output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) num_actual_tokens = sum(query_lens) # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d) - output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size] + output_nki = output_nki.cpu().permute(0, 2, 1, 3) output_nki = output_nki[0, :num_actual_tokens, :, :] output_ref_padded = F.pad( output_ref, diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py index 5e2a1f7e66d1..20f9dcd163fe 100644 --- a/vllm/attention/ops/nki_flash_attn.py +++ b/vllm/attention/ops/nki_flash_attn.py @@ -1,27 +1,203 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass - import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np +import torch from neuronxcc import nki from neuronxcc.nki.language import par_dim -@dataclass(frozen=True) -class FlashConfig: +def ceil_div(a, b): + return (a + b - 1) // b + + +def is_power_of_2(x): + return x > 0 and (x & (x - 1)) == 0 + + +@nki.jit +def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): + """ + Load block tables from HBM into SRAM + + `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. + In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. + """ + B_P_SIZE = 128 + + # reshape as `(num_tiles, num_blocks_per_tile)` + assert len(block_tables_hbm.shape) == 1 + (num_total_blocks, ) = block_tables_hbm.shape + assert num_blocks_per_tile * num_tiles == num_total_blocks + block_tables_hbm = block_tables_hbm.reshape( + (num_tiles, num_blocks_per_tile)) + + block_tables_sbuf = nl.zeros( + (ceil_div(num_tiles, + B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), + dtype=nl.int32, + ) + for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(num_blocks_per_tile)[None, :] + block_tables_sbuf[i, i_p, i_f] = nl.load( + block_tables_hbm[i_p + i * B_P_SIZE, i_f], + dtype=nl.int32, + mask=(i_p + i * B_P_SIZE < num_tiles), + ) + return block_tables_sbuf + + +@nki.jit +def transform_block_tables_for_indirect_load( + block_tables, + block_size_tiling_factor, + num_head, + head_id, +): """ - Config class for flash attention with default values + This function does two things: + 1. calculate new `block_tables` for a `head_id` after flattening + `num_block`, `num_head`, and `block_size_tiling_factor` dimensions + 2. transpose the result so that `block_table` for each tile is mapped to + SBUF Partition dimension for vectorized DMA + + Tiling trick to further improve DMA performance: + Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M + blocks of a given `head_id` from HBM, the load `cache[block_tables, + head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not + fully utilize hardware parallelization. The solution is to tile `block_size` + into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * + block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape + `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. + + Note: + We don't further tile D dimension as small DMA size also hurts performance. """ + B_P_SIZE = 128 + num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( + block_tables.shape) + assert num_tiles_per_partition == B_P_SIZE + assert is_power_of_2( + num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" + + num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE) + block_tables_transposed = nl.ndarray( + ( + num_loads, + par_dim(B_P_SIZE), + num_partitions * num_tiles_per_partition, + ), + dtype=nl.int32, + ) - seq_tile_size: int = 2048 - should_transpose_v: bool = False + # prepare iota ahead of time to avoid repeatedly using Gpsimd + if num_head > 1: + head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) + head_id = nl.transpose( + head_id.broadcast_to((1, num_tiles_per_partition))) + if num_blocks_per_tile > 1: + head_id = head_id.broadcast_to( + (num_tiles_per_partition, num_blocks_per_tile)) + + if block_size_tiling_factor > 1: + broadcast_shape = ( + num_tiles_per_partition, + num_blocks_per_tile, + block_size_tiling_factor, + ) + offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], + dtype=nl.int32).broadcast_to(broadcast_shape) + + for partition_id in nl.affine_range(num_partitions): + block_tables_partition = block_tables[partition_id] + if num_head > 1: + # fuse num_block and num_head dimension + block_tables_partition = block_tables_partition * num_head + head_id + + if block_size_tiling_factor > 1: + # need to apply block size tiling trick + assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE + block_tables_partition = ((block_tables_partition * + block_size_tiling_factor).reshape( + (num_tiles_per_partition, + num_blocks_per_tile, + 1)).broadcast_to(broadcast_shape)) + new_block_tables = block_tables_partition + offset + new_block_tables = new_block_tables.reshape( + (num_tiles_per_partition, B_P_SIZE)) + else: + new_block_tables = block_tables_partition - __annotations__ = { - "seq_tile_size": int, - "should_transpose_v": bool, - } + # transpose the block table so that it can be used by vector DGE + for i in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = (partition_id * num_tiles_per_partition + + nl.arange(num_tiles_per_partition)[None, :]) + block_tables_transposed[i, i_p, i_f] = nl.transpose( + new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) + return block_tables_transposed + + +@nki.jit +def load_kv_tile_from_cache( + cur_k_tile, + cur_v_tile, + key_cache, + value_cache, + block_tables, + large_k_tile_idx, + num_blocks_per_large_tile, + tiled_block_size, + B_P_SIZE, + B_D_SIZE, +): + """ + Load KV cache and transform Key and Value into layout required by Matmul + + Vectorized DMA Load layout: + Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + + Layout used by attention matmuls: + Key: (par_dim(B_D_SIZE), seqlen_kv) + Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) + equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + """ + # load key cache + num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + for load_idx in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + loaded = nl.load(key_cache[block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_k_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) + # Transpose SBUF tensor using PE + for tb_i in nl.affine_range(tiled_block_size): + cur_k_tile[ + :, + nl.ds( + load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, + B_P_SIZE, + ), + ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) + + # load value cache + for load_idx in nl.affine_range(num_loads): + loaded = nl.load(value_cache[block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + cur_v_tile[ + :, + nl.ds( + load_idx * tiled_block_size * B_D_SIZE, + tiled_block_size * B_D_SIZE, + ), + ] = loaded @nki.jit @@ -62,13 +238,13 @@ def _flash_attention_core( o_buffer, l_buffer, m_buffer, - q_tile_idx, kernel_dtype, acc_type, - flash_config: FlashConfig, - use_causal_mask, tile_mask, + use_causal_mask, + q_tile_idx=None, initialize=False, + LARGE_TILE_SZ=2048, B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128, @@ -77,19 +253,19 @@ def _flash_attention_core( """ The flash attention core function to calculate self attention between a tile of q and a block of K and V. - The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF - already. The block size of K and V - is defined in the seq_tile_size of the flash_config. The results are stored - in the following three buffers + The q_local_tile has (B_P_SIZE, B_D_SIZE) + The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will + be split into size B_F_SIZE tiles + + The results are stored in the following three buffers o_buffer: (B_P_SIZE, d) l_buffer: (B_P_SIZE, 1) m_buffer: (B_P_SIZE, 1) + + All IO buffers are in SBUF. """ - LARGE_TILE_SZ = flash_config.seq_tile_size num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE - # mask are used to only apply computation to the lower half of the matrix, - # which reduce the arithmetic intensity by half qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type) @@ -99,6 +275,8 @@ def _flash_attention_core( k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) if use_causal_mask: + # mask are used to only apply computation to the lower half of the + # matrix, which reduce the arithmetic intensity by up to 50% multiplication_required_selection = (q_tile_idx * B_P_SIZE >= k_i * B_F_SIZE) else: @@ -165,7 +343,9 @@ def _flash_attention_core( REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) p_partial_sum = nl.ndarray( - (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type) + (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), + dtype=acc_type, + ) for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) @@ -194,13 +374,15 @@ def _flash_attention_core( B_F_SIZE=B_F_SIZE, ) - pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), - dtype=np.float32, - buffer=nl.psum) + pv_psum = nl.zeros( + (par_dim(B_P_SIZE), B_D_SIZE), + dtype=np.float32, + buffer=nl.psum, + ) for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): pv_psum[:, :] += nl.matmul( p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], - v[k_i, :, :], + v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], transpose_x=True, ) # (128, 128) (p(Br), d) @@ -219,44 +401,16 @@ def _flash_attention_core( @nki.jit -def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config): - LARGE_TILE_SZ = config.seq_tile_size +def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): B_P_SIZE = 128 - - if not config.should_transpose_v: - cur_v_tile[v_i, :, :] = nl.load( - v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :], - dtype=cur_v_tile.dtype, - ) - return - - if nisa.get_nc_version() == nisa.nc_version.gen3: - cur_v_tile_transposed = nisa.dma_transpose( - v_hbm_tile[:, - nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)]) - cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed, - dtype=cur_v_tile.dtype) - return - - cur_v_tile[v_i, :, :] = nl.load_transpose2d( - v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)], - dtype=cur_v_tile.dtype, - ) - - -@nki.jit -def load_block_tables(block_tables_hbm, num_tiles): - (num_blocks, ) = block_tables_hbm.shape - assert num_blocks % num_tiles == 0 - num_blocks_per_tile = num_blocks // num_tiles - block_tables_hbm = block_tables_hbm.reshape( - (num_tiles, num_blocks_per_tile)) - block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32) - return block_tables_buffer - - -def is_power_of_2(x): - return x > 0 and (x & (x - 1)) == 0 + B_D_SIZE = v_hbm_tile.shape[-1] + loaded = nl.load(v_hbm_tile[ + nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), + :, + ]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded @nki.jit @@ -270,24 +424,21 @@ def flash_paged_attention( mask, softmax_scale=None, mixed_precision=True, - config=None, + LARGE_TILE_SZ=2048, return_debug_tensors=False, ): """ Flash PagedAttention Forward Kernel. - - PagedAttention Paper: https://arxiv.org/abs/2309.06180 - - Chunked Prefill Paper: https://arxiv.org/abs/2403.02310 IO tensor layouts: - query: shape (1, n_heads, d, seq_q) - key: shape (1, n_kv_heads, d, seq_k) - value: shape (1, n_kv_heads, seq_v, d) - - key_cache: (num_blocks, block_size, n_kv_heads, d) - - value_cache: (num_blocks, block_size, n_kv_heads, d) + - key_cache: (num_blocks, n_kv_heads, block_size, d) + - value_cache: (num_blocks, n_kv_heads, block_size, d) - block_tables: (num_active_blocks, ) - - mask: (seq_q, num_active_blocks * block_size) + - mask: (seq_q, num_active_blocks * block_size + seq_q) - o: shape (1, n_heads, seq_q, d) - - l_m: shape (1, n_heads, seq_q, 2) - This kernel requires seq_k == seq_v - We use continuous batching by default, so the batch dimension is @@ -306,11 +457,8 @@ def flash_paged_attention( - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types - - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` - with Performance config parameters for flash attention with default - values - seq_tile_size: `default=2048`, size of the kv tile size for attention - computation reduction + - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention + computation reduction GQA support Notes: the spmd kernel for launching kernel should be on kv_heads instead of @@ -322,31 +470,65 @@ def flash_paged_attention( GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] usage: `flash_fwd[b, kv_h](q, k, v, ...)` """ - config = config or FlashConfig() B_F_SIZE = 512 B_P_SIZE = 128 b, h, d, seqlen_q = query.shape B_D_SIZE = d - LARGE_TILE_SZ = config.seq_tile_size n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine - num_blocks, block_size, k_h, _ = key_cache.shape + num_blocks, k_h, block_size, _ = key_cache.shape q_h_per_k_h = h // k_h - assert tuple(key_cache.shape) == ( - num_blocks, - block_size, + assert b == 1, f"invalid batch size {b=}" + assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" + cache_shape = (num_blocks, k_h, block_size, d) + assert (tuple(key_cache.shape) == cache_shape + ), f"{key_cache.shape=} mismatch, expect {cache_shape}" + assert (tuple(value_cache.shape) == cache_shape + ), f"{value_cache.shape=} mismatch, expect {cache_shape}" + assert key is None or tuple(key.shape) == ( + 1, k_h, d, - ), "Input shape mismatch!" - assert tuple(value_cache.shape) == ( - num_blocks, - block_size, + seqlen_q, + ), f"key shape {key.shape} mismatch!" + assert value is None or tuple(value.shape) == ( + 1, k_h, + seqlen_q, d, - ), "Input shape mismatch!" - assert b == 1, f"invalid batch size {b=}" - assert d <= 128, f" we do not support head_dim > 128, got head dim {d}" + ), f"value shape {value.shape} mismatch!" + + assert ( + nl.program_ndim() == 2 + ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + (num_active_blocks, ) = block_tables.shape + context_kv_len = num_active_blocks * block_size + assert ( + LARGE_TILE_SZ % B_F_SIZE == 0 + ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" + assert (context_kv_len % LARGE_TILE_SZ == 0 + ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" + + num_blocks_per_large_tile = LARGE_TILE_SZ // block_size + assert is_power_of_2( + num_blocks_per_large_tile + ), f"{num_blocks_per_large_tile=} is expected of be power of 2" + if seqlen_q > B_F_SIZE: + MAX_REDUCTION_TILE = 2048 + if seqlen_q // 2 > MAX_REDUCTION_TILE: + assert ( + seqlen_q % MAX_REDUCTION_TILE == 0 + ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" + else: + assert (seqlen_q % B_F_SIZE == 0 + ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" + kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + softmax_scale = softmax_scale or (1.0 / (d**0.5)) + num_large_k_tile = context_kv_len // LARGE_TILE_SZ o = nl.ndarray((b, h, seqlen_q, d), dtype=query.dtype, @@ -373,35 +555,38 @@ def flash_paged_attention( buffer=nl.sbuf, lazy_initialization=True, ) + block_tables_sbuf = load_block_tables( + block_tables_hbm=block_tables, + num_tiles=num_large_k_tile, + num_blocks_per_tile=num_blocks_per_large_tile, + ) - assert ( - nl.program_ndim() == 2 - ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - - softmax_scale = softmax_scale or (1.0 / (d**0.5)) - - (num_active_blocks, ) = block_tables.shape - context_kv_len = num_active_blocks * block_size - assert (config.seq_tile_size >= 512 - ), f" seq tile_size {config.seq_tile_size} cannot be less than 512" - assert (context_kv_len % LARGE_TILE_SZ == 0 - ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" - assert ( - LARGE_TILE_SZ % B_P_SIZE == 0 - ), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}" - assert (B_P_SIZE % block_size == 0 - ), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}" - num_large_k_tile = context_kv_len // LARGE_TILE_SZ - num_blocks_per_large_tile = LARGE_TILE_SZ // block_size - assert block_size % 32 == 0, "block_size is expected to be a multiple of 32" - assert is_power_of_2( - num_blocks_per_large_tile - ), "The number of blocks in each large tile is expected of be power of 2" - assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2" + # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient + if num_blocks_per_large_tile < B_P_SIZE: + # we checked num_blocks_per_tile is a power of 2 + assert B_P_SIZE % num_blocks_per_large_tile == 0 + block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile + # We assume block_size >= block_size_tiling_factor + assert block_size % block_size_tiling_factor == 0 + else: + block_size_tiling_factor = 1 + tiled_block_size = block_size // block_size_tiling_factor + + # Indirect DMA load must be placed along Partition Dimension + block_tables_sbuf = transform_block_tables_for_indirect_load( + block_tables_sbuf, + block_size_tiling_factor=block_size_tiling_factor, + num_head=k_h, + head_id=head_id, + ) - block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile) + # Flatten KV cache to be 2D for loading into SBUF + new_cache_shape = ( + num_blocks * k_h * block_size_tiling_factor, + tiled_block_size * d, + ) + key_cache = key_cache.reshape(new_cache_shape) + value_cache = value_cache.reshape(new_cache_shape) # Global Flash Attention accumulators o_buffer = nl.zeros( @@ -411,7 +596,7 @@ def flash_paged_attention( lazy_initialization=True, ) l_buffer = nl.zeros( - (par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h), + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), dtype=acc_type, buffer=nl.sbuf, lazy_initialization=True, @@ -423,50 +608,42 @@ def flash_paged_attention( lazy_initialization=True, ) - for j in nl.sequential_range(0, num_large_k_tile): - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) + for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): + num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + cur_k_tile = nl.ndarray( + (par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype, + ) cur_v_tile = nl.ndarray( - (LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), + (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), dtype=kernel_dtype, ) - - for k_i in nl.affine_range(num_blocks_per_large_tile): - loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :, - head_id, :]) - cur_k_tile[:, nl.ds(k_i * - block_size, block_size)] = nl.transpose(loaded) - - load_tile_size = B_P_SIZE - num_blocks_per_partition = load_tile_size // block_size - for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - for block_in_partition in nl.affine_range( - num_blocks_per_partition): - v_i = (partition_idx * num_blocks_per_partition + - block_in_partition) - loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :, - head_id, :]) - cur_v_tile[ - partition_idx, - nl.ds(block_in_partition * block_size, block_size), - :, - ] = loaded_v + load_kv_tile_from_cache( + cur_k_tile=cur_k_tile, + cur_v_tile=cur_v_tile, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables_sbuf, + large_k_tile_idx=large_k_tile_idx, + num_blocks_per_large_tile=num_blocks_per_large_tile, + tiled_block_size=tiled_block_size, + B_P_SIZE=B_P_SIZE, + B_D_SIZE=B_D_SIZE, + ) for i in nl.affine_range(n_tile_q): - cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=mask.dtype) - for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): - cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE), - ]) + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), + ]) for i_q_h in nl.affine_range(q_h_per_k_h): q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load( - q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], - dtype=kernel_dtype, - ) # load (d, 128) tile in SBUF + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) q_tile[:, :] = q_sbuf_tile * softmax_scale _flash_attention_core( @@ -474,15 +651,15 @@ def flash_paged_attention( k=cur_k_tile, v=cur_v_tile, o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[:, i, i_q_h], + l_buffer=l_buffer[i, i_q_h], m_buffer=m_buffer[i, i_q_h], - q_tile_idx=i, kernel_dtype=kernel_dtype, acc_type=acc_type, - flash_config=config, - use_causal_mask=False, tile_mask=cur_mask, - initialize=j == 0, + use_causal_mask=False, + q_tile_idx=i, + initialize=large_k_tile_idx == 0, + LARGE_TILE_SZ=LARGE_TILE_SZ, B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, @@ -492,62 +669,58 @@ def flash_paged_attention( if key is not None and value is not None: B_F_SIZE = min(seqlen_q, B_F_SIZE) LARGE_TILE_SZ = seqlen_q - active_config = FlashConfig( - seq_tile_size=LARGE_TILE_SZ, - should_transpose_v=config.should_transpose_v, - ) cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) cur_v_tile = nl.ndarray( - (LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), + (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), dtype=kernel_dtype, ) - cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :]) + loaded = nl.load(key[batch_id, head_id, :, :]) + if loaded.dtype != kernel_dtype: + loaded = nl.copy(loaded, dtype=kernel_dtype) + cur_k_tile[:, :] = loaded - load_tile_size = B_P_SIZE v_hbm_tile = value[batch_id, head_id] - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): + for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): load_v_tile( v_hbm_tile=v_hbm_tile, cur_v_tile=cur_v_tile, - j=0, + large_tile_idx=0, v_i=v_i, - config=active_config, + LARGE_TILE_SZ=LARGE_TILE_SZ, ) for i in nl.affine_range(n_tile_q): - cur_mask = nl.load( - mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(context_kv_len, LARGE_TILE_SZ), - ], - dtype=mask.dtype, - ) + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(context_kv_len, LARGE_TILE_SZ), + ]) for i_q_h in nl.affine_range(q_h_per_k_h): q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load( - q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], - dtype=kernel_dtype, - ) # load (d, 128) tile in SBUF + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) q_tile[:, :] = q_sbuf_tile * softmax_scale _flash_attention_core( q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[:, i, i_q_h], + l_buffer=l_buffer[i, i_q_h], m_buffer=m_buffer[i, i_q_h], - q_tile_idx=i, kernel_dtype=kernel_dtype, acc_type=acc_type, - flash_config=active_config, - use_causal_mask=True, tile_mask=cur_mask, + use_causal_mask=True, + q_tile_idx=i, initialize=False, + LARGE_TILE_SZ=LARGE_TILE_SZ, B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, @@ -559,8 +732,8 @@ def flash_paged_attention( for i_q_h in nl.affine_range(q_h_per_k_h): for i in nl.affine_range(n_tile_q): out = nl.multiply( - o_buffer[i, i_q_h, :, :], - nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]), + o_buffer[i, i_q_h], + nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), dtype=kernel_dtype, ) @@ -589,7 +762,7 @@ def flash_paged_attention( head_id * q_h_per_k_h + i_q_h, nl.ds(i * B_P_SIZE, B_P_SIZE), ], - l_buffer[:, i, i_q_h], + l_buffer[i, i_q_h], ) nl.store( hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], @@ -601,6 +774,49 @@ def flash_paged_attention( return o +def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): + """ + Reorder the mask to make it compatible with the flash attention kernel. + + We vectorize KV cache read to improve DMA utilization. However, the layout + that maximizes DMA bandwidth changes the order tokens are consumed. + + The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, + tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And + each step the engine consumes a column (rather than a row) of B_P_SIZE + tokens. Therefore, the tokens are visited in a strided way. + + To make sure mask matches the order tokens are consumed, we need to properly + transpose mask. + """ + total_query_len, total_seq_len = mask.shape + context_kv_len = total_seq_len - total_query_len + + B_P_SIZE = 128 + assert (LARGE_TILE_SZ + >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" + num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) + tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks + if tiled_block_size > 1: + # Mask reordering is needed when tiled_block_size > 1 + device = mask.device + mask = mask.cpu() + context_mask = mask[:, :context_kv_len] + context_mask = context_mask.view( + total_query_len, + context_kv_len // LARGE_TILE_SZ, + num_tiled_blocks // B_P_SIZE, + B_P_SIZE, + tiled_block_size, + ) + context_mask = context_mask.transpose(3, 4).reshape( + total_query_len, context_kv_len) + new_mask = mask[:, context_kv_len:] + return torch.concat([context_mask, new_mask], dim=1).to(device) + else: + return mask + + def flash_attn_varlen_nkifunc( query, key, @@ -612,13 +828,32 @@ def flash_attn_varlen_nkifunc( n_kv_head=None, head_size=None, LARGE_TILE_SZ=2048, - return_debug_tensors=False, mixed_precision=True, ): - config = FlashConfig( - seq_tile_size=LARGE_TILE_SZ, - should_transpose_v=False, - ) + """ + Compute flash paged attention for variable length sequences. + + This function is a wrapper around the flash attention NKI kernel. It takes + in the following arguments: + - query: (1, n_heads, d, seq_q) + - key: (1, n_kv_heads, d, seq_k) + - value: (1, n_kv_heads, seq_v, d) + - key_cache: (n_blocks, n_kv_heads, block_size, d) + - value_cache: (n_blocks, n_kv_heads, block_size, d) + - block_tables: (n_active_blocks, ) + - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) + + Notes: + - attn_mask must be reordered outside using `reorder_context_mask` + - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) + for better DMA throughput + """ + if n_kv_head is None: + n_kv_head = key_cache.shape[1] + assert key_cache.shape[1] == n_kv_head + if head_size is None: + head_size = key_cache.shape[-1] + kwargs = dict( query=query, key=key, @@ -628,15 +863,9 @@ def flash_attn_varlen_nkifunc( block_tables=block_table, mask=attn_mask, softmax_scale=1.0 / (head_size**0.5), - config=config, mixed_precision=mixed_precision, - return_debug_tensors=return_debug_tensors, + LARGE_TILE_SZ=LARGE_TILE_SZ, ) - _, n_kv_head, _, _ = key.shape - if return_debug_tensors: - o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs) - return o, *debug_tensors - else: - o = flash_paged_attention[1, n_kv_head](**kwargs) - return o + o = flash_paged_attention[1, n_kv_head](**kwargs) + return o From 0fad7f44ad29c55b5e05a37606b9e71d4ba4e633 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Fri, 21 Feb 2025 11:52:40 +0800 Subject: [PATCH 142/317] Add llmaz as another integration (#13643) Signed-off-by: kerthcet --- docs/source/deployment/integrations/index.md | 1 + docs/source/deployment/integrations/llmaz.md | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 docs/source/deployment/integrations/llmaz.md diff --git a/docs/source/deployment/integrations/index.md b/docs/source/deployment/integrations/index.md index c286edb4d7bc..a557456c086d 100644 --- a/docs/source/deployment/integrations/index.md +++ b/docs/source/deployment/integrations/index.md @@ -6,4 +6,5 @@ kserve kubeai llamastack +llmaz ::: diff --git a/docs/source/deployment/integrations/llmaz.md b/docs/source/deployment/integrations/llmaz.md new file mode 100644 index 000000000000..cd4a76353d26 --- /dev/null +++ b/docs/source/deployment/integrations/llmaz.md @@ -0,0 +1,7 @@ +(deployment-llmaz)= + +# llmaz + +[llmaz](https://github.com/InftyAI/llmaz) is an easy-to-use and advanced inference platform for large language models on Kubernetes, aimed for production use. It uses vLLM as the default model serving backend. + +Please refer to the [Quick Start](https://github.com/InftyAI/llmaz?tab=readme-ov-file#quick-start) for more details. From b4c6249c6f8e21ff19b0560d21f5d26beb53ad2c Mon Sep 17 00:00:00 2001 From: Edwin Hernandez Date: Thu, 20 Feb 2025 21:16:40 -0800 Subject: [PATCH 143/317] [Misc] Adding script to setup ray for multi-node vllm deployments (#12913) --- examples/online_serving/multi-node-serving.sh | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 examples/online_serving/multi-node-serving.sh diff --git a/examples/online_serving/multi-node-serving.sh b/examples/online_serving/multi-node-serving.sh new file mode 100644 index 000000000000..067f20c69b88 --- /dev/null +++ b/examples/online_serving/multi-node-serving.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +subcommand=$1 +shift + +ray_port=6379 +ray_init_timeout=300 +declare -a start_params + +case "$subcommand" in + worker) + ray_address="" + while [ $# -gt 0 ]; do + case "$1" in + --ray_address=*) + ray_address="${1#*=}" + ;; + --ray_port=*) + ray_port="${1#*=}" + ;; + --ray_init_timeout=*) + ray_init_timeout="${1#*=}" + ;; + *) + start_params+=("$1") + esac + shift + done + + if [ -z "$ray_address" ]; then + echo "Error: Missing argument --ray_address" + exit 1 + fi + + for (( i=0; i < $ray_init_timeout; i+=5 )); do + ray start --address=$ray_address:$ray_port --block "${start_params[@]}" + if [ $? -eq 0 ]; then + echo "Worker: Ray runtime started with head address $ray_address:$ray_port" + exit 0 + fi + echo "Waiting until the ray worker is active..." + sleep 5s; + done + echo "Ray worker starts timeout, head address: $ray_address:$ray_port" + exit 1 + ;; + + leader) + ray_cluster_size="" + while [ $# -gt 0 ]; do + case "$1" in + --ray_port=*) + ray_port="${1#*=}" + ;; + --ray_cluster_size=*) + ray_cluster_size="${1#*=}" + ;; + --ray_init_timeout=*) + ray_init_timeout="${1#*=}" + ;; + *) + start_params+=("$1") + esac + shift + done + + if [ -z "$ray_cluster_size" ]; then + echo "Error: Missing argument --ray_cluster_size" + exit 1 + fi + + # start the ray daemon + ray start --head --port=$ray_port "${start_params[@]}" + + # wait until all workers are active + for (( i=0; i < $ray_init_timeout; i+=5 )); do + active_nodes=`python3 -c 'import ray; ray.init(); print(sum(node["Alive"] for node in ray.nodes()))'` + if [ $active_nodes -eq $ray_cluster_size ]; then + echo "All ray workers are active and the ray cluster is initialized successfully." + exit 0 + fi + echo "Wait for all ray workers to be active. $active_nodes/$ray_cluster_size is active" + sleep 5s; + done + + echo "Waiting for all ray workers to be active timed out." + exit 1 + ;; + + *) + echo "unknown subcommand: $subcommand" + exit 1 + ;; +esac From 13058c75efe2487893bc5b16dfb1e40848a29207 Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Thu, 20 Feb 2025 22:01:48 -0800 Subject: [PATCH 144/317] [NVIDIA] Fix an issue to use current stream for the nvfp4 quant (#13632) --- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index c3b8e9b3ec42..fef74111624f 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -348,10 +348,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); at::cuda::CUDAGuard device_guard{(char)input.get_device()}; - auto stream = at::cuda::getStreamFromPool(false, input.get_device()); - if (stream == nullptr) { - std::cerr << "Warning: Null CUDA stream" << std::endl; - } + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); // We don't support e8m0 scales at this moment. bool useUE8M0 = false; From 75c35cb6ea908e3554c469e03c07a2e8de2e883f Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 21 Feb 2025 06:03:27 +0000 Subject: [PATCH 145/317] Use pre-commit to update `requirements-test.txt` (#13617) --- .pre-commit-config.yaml | 7 +++++++ requirements-test.txt | 31 +++++++++++++------------------ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5c4cb767c9ec..6a66131cdb4d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,13 @@ repos: hooks: - id: actionlint exclude: 'vllm/third_party/.*' +repos: +- repo: https://github.com/astral-sh/uv-pre-commit + rev: 0.6.2 + hooks: + - id: pip-compile + args: [requirements-test.in, -o, requirements-test.txt] + files: ^requirements-test\.(in|txt)$ - repo: local hooks: - id: mypy-local diff --git a/requirements-test.txt b/requirements-test.txt index f91586419148..11f0e10969a6 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,9 +1,5 @@ -# -# This file is autogenerated by pip-compile with Python 3.12 -# by the following command: -# -# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt -# +# This file was autogenerated by uv via the following command: +# uv pip compile requirements-test.in -o requirements-test.txt absl-py==2.1.0 # via rouge-score accelerate==1.0.1 @@ -141,7 +137,7 @@ frozenlist==1.5.0 # aiohttp # aiosignal # ray -fsspec[http]==2024.9.0 +fsspec==2024.9.0 # via # datasets # evaluate @@ -221,7 +217,7 @@ librosa==0.10.2.post1 # via -r requirements-test.in llvmlite==0.43.0 # via numba -lm-eval[api]==0.4.4 +lm-eval==0.4.4 # via -r requirements-test.in lxml==5.3.0 # via sacrebleu @@ -238,10 +234,8 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common[opencv]==1.5.1 - # via - # -r requirements-test.in - # mistral-common +mistral-common==1.5.1 + # via -r requirements-test.in more-itertools==10.5.0 # via lm-eval mpmath==1.3.0 @@ -418,7 +412,7 @@ pybind11==2.13.6 # via lm-eval pycparser==2.22 # via cffi -pydantic[email]==2.9.2 +pydantic==2.9.2 # via # datamodel-code-generator # mistral-common @@ -478,7 +472,7 @@ pyyaml==6.0.2 # vocos rapidfuzz==3.12.1 # via jiwer -ray[adag]==2.40.0 +ray==2.40.0 # via -r requirements-test.in redis==5.2.0 # via tensorizer @@ -549,6 +543,10 @@ sentence-transformers==3.2.1 # via -r requirements-test.in sentencepiece==0.2.0 # via mistral-common +setuptools==75.8.0 + # via + # pytablewriter + # torch six==1.16.0 # via # python-dateutil @@ -646,7 +644,7 @@ tritonclient==2.51.0 # via # -r requirements-test.in # genai-perf -typepy[datetime]==1.3.2 +typepy==1.3.2 # via # dataproperty # pytablewriter @@ -683,6 +681,3 @@ yarl==1.17.1 # via aiohttp zstandard==0.23.0 # via lm-eval - -# The following packages are considered to be unsafe in a requirements file: -# setuptools From 761be24523f6a75802e136cba6bb477af4fd4447 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 20 Feb 2025 22:04:33 -0800 Subject: [PATCH 146/317] [Bugfix] Add `mm_processor_kwargs` to chat-related protocols (#13644) --- vllm/entrypoints/openai/protocol.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 98ea6a46133f..29f64d28bdf1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -974,6 +974,10 @@ class EmbeddingChatRequest(OpenAIBaseModel): description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), ) + mm_processor_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) priority: int = Field( default=0, description=( @@ -1394,6 +1398,10 @@ class TokenizeChatRequest(OpenAIBaseModel): description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), ) + mm_processor_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) @model_validator(mode="before") @classmethod From d10b6add7768cc88c2c651e6957e51d5cc14a51a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 20 Feb 2025 22:05:56 -0800 Subject: [PATCH 147/317] [V1][Sampler] Avoid an operation during temperature application (#13587) --- vllm/v1/sample/metadata.py | 2 +- vllm/v1/sample/sampler.py | 8 ++++---- vllm/v1/utils.py | 6 ++++-- vllm/v1/worker/gpu_input_batch.py | 12 +++++++++--- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 2184a1866ff5..6d82d3a79c8e 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -9,7 +9,7 @@ @dataclass class SamplingMetadata: - temperature: torch.Tensor + temperature: Optional[torch.Tensor] all_greedy: bool all_random: bool diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 8e2533eefab0..ff978b3b6c41 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -77,11 +77,8 @@ def apply_temperature( logits: torch.Tensor, temp: torch.Tensor, ) -> torch.Tensor: - # Avoid division by zero. - temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) # Use in-place division to avoid creating a new tensor. - logits.div_(temp.unsqueeze(dim=1)) - return logits + return logits.div_(temp.unsqueeze(dim=1)) def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: return logits.argmax(dim=-1).view(-1) @@ -100,6 +97,8 @@ def sample( if sampling_metadata.all_greedy: return greedy_sampled + assert sampling_metadata.temperature is not None + # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) @@ -122,6 +121,7 @@ def sample( sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, random_sampled, + out=greedy_sampled, # Reuse tensor ) return sampled diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5be465014242..62271255b0c0 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -191,11 +191,13 @@ def bind_kv_cache( def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, - length: int) -> None: + length: int) -> torch.Tensor: """ Copy the first length elements of a tensor into another tensor in a non-blocking manner. Used to copy pinned CPU tensor data to pre-allocated GPU tensors. + + Returns the sliced target tensor. """ - to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) + return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ccafc325b53f..bd1c369acb30 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -242,10 +242,12 @@ def add_request( self.block_table.add_row(req_index, request.block_ids) sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature if sampling_params.sampling_type == SamplingType.GREEDY: + # Avoid later division by zero. + self.temperature_cpu[req_index] = -1.0 self.greedy_reqs.add(req_id) else: + self.temperature_cpu[req_index] = sampling_params.temperature self.random_reqs.add(req_id) self.top_p_cpu[req_index] = sampling_params.top_p @@ -410,7 +412,11 @@ def refresh_sampling_metadata(self): def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs - copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs) + if not self.all_greedy: + temperature = copy_slice(self.temperature_cpu_tensor, + self.temperature, num_reqs) + else: + temperature = None if not self.no_top_p: copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) if not self.no_top_k: @@ -437,7 +443,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata: prompt_token_ids = None return SamplingMetadata( - temperature=self.temperature[:num_reqs], + temperature=temperature, all_greedy=self.all_greedy, all_random=self.all_random, top_p=None if self.no_top_p else self.top_p[:num_reqs], From 3fa24d28f70a899eb9573e8ab087499fc9a93c8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Fri, 21 Feb 2025 07:06:54 +0100 Subject: [PATCH 148/317] Missing comment explaining VDR variable in GGUF kernels (#13290) --- csrc/quantization/gguf/vecdotq.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index e00422637c65..d0d4c74ed379 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -37,6 +37,8 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 From 6cf9d59e098e9530febf855c95f174d3dcc66304 Mon Sep 17 00:00:00 2001 From: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com> Date: Fri, 21 Feb 2025 03:09:47 -0300 Subject: [PATCH 149/317] [FEATURE] Enables /score endpoint for embedding models (#12846) --- docs/source/models/pooling_models.md | 3 +- .../serving/openai_compatible_server.md | 10 +- tests/entrypoints/openai/test_rerank.py | 6 +- tests/entrypoints/openai/test_score.py | 284 ++++++----- vllm/entrypoints/llm.py | 46 +- vllm/entrypoints/openai/api_server.py | 17 +- vllm/entrypoints/openai/run_batch.py | 31 +- vllm/entrypoints/openai/serving_engine.py | 4 +- vllm/entrypoints/openai/serving_rerank.py | 208 -------- vllm/entrypoints/openai/serving_score.py | 463 +++++++++++++----- vllm/entrypoints/score_utils.py | 49 ++ 11 files changed, 599 insertions(+), 522 deletions(-) delete mode 100644 vllm/entrypoints/openai/serving_rerank.py create mode 100644 vllm/entrypoints/score_utils.py diff --git a/docs/source/models/pooling_models.md b/docs/source/models/pooling_models.md index 8612935432b8..f774f3d0fa0e 100644 --- a/docs/source/models/pooling_models.md +++ b/docs/source/models/pooling_models.md @@ -108,8 +108,7 @@ A code example can be found here: ### Score API -Our Score API applies a cross-encoder model to predict scores for sentence pairs. +Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair. Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1. -You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). +You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). Code example: @@ -496,11 +496,11 @@ The following extra parameters are supported: ### Re-rank API -Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and +Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1. -You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). +You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the `score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank` diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index cf114f0641db..ba11cd3a29a8 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -8,17 +8,17 @@ from ...utils import RemoteOpenAIServer MODEL_NAME = "BAAI/bge-reranker-base" +DTYPE = "bfloat16" @pytest.fixture(scope="module") def server(): - args = ["--enforce-eager", "--max-model-len", "100"] + args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server -@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" @@ -42,7 +42,6 @@ def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): assert rerank.results[1].relevance_score <= 0.01 -@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_top_n(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" @@ -68,7 +67,6 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): assert rerank.results[1].relevance_score <= 0.01 -@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index bcbcb5702c95..b756680ea9f2 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -1,123 +1,185 @@ # SPDX-License-Identifier: Apache-2.0 +import math +from typing import Any + import pytest import requests +import torch.nn.functional as F +from torch import tensor from vllm.entrypoints.openai.protocol import ScoreResponse from ...utils import RemoteOpenAIServer -MODEL_NAME = "BAAI/bge-reranker-v2-m3" - - -@pytest.fixture(scope="module") -def server(): - args = ["--enforce-eager", "--max-model-len", "100"] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: +MODELS = [ + { + "name": "BAAI/bge-reranker-v2-m3", + "is_cross_encoder": True + }, + { + "name": "BAAI/bge-base-en-v1.5", + "is_cross_encoder": False + }, +] +DTYPE = "half" + + +def run_transformers(hf_model, model, text_pairs): + if model["is_cross_encoder"]: + return hf_model.predict(text_pairs).tolist() + else: + hf_embeddings = [ + hf_model.encode(text_pair) for text_pair in text_pairs + ] + return [ + F.cosine_similarity(tensor(pair[0]), tensor(pair[1]), dim=0) + for pair in hf_embeddings + ] + + +@pytest.fixture(scope="class", params=MODELS) +def model(request): + yield request.param + + +@pytest.fixture(scope="class") +def server(model: dict[str, Any]): + args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] + + with RemoteOpenAIServer(model["name"], args) as remote_server: yield remote_server -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str): - text_1 = "What is the capital of France?" - text_2 = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." - ] - - score_response = requests.post(server.url_for("score"), - json={ - "model": model_name, - "text_1": text_1, - "text_2": text_2, - }) - score_response.raise_for_status() - score = ScoreResponse.model_validate(score_response.json()) - - assert score.id is not None - assert score.data is not None - assert len(score.data) == 2 - assert score.data[0].score <= 0.01 - assert score.data[1].score >= 0.9 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str): - text_1 = [ - "What is the capital of the United States?", - "What is the capital of France?" - ] - text_2 = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." - ] - - score_response = requests.post(server.url_for("score"), - json={ - "model": model_name, - "text_1": text_1, - "text_2": text_2, - }) - score_response.raise_for_status() - score = ScoreResponse.model_validate(score_response.json()) - - assert score.id is not None - assert score.data is not None - assert len(score.data) == 2 - assert score.data[0].score <= 0.01 - assert score.data[1].score >= 0.9 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str): - text_1 = "What is the capital of France?" - text_2 = "The capital of France is Paris." - - score_response = requests.post(server.url_for("score"), - json={ - "model": model_name, - "text_1": text_1, - "text_2": text_2, - }) - score_response.raise_for_status() - score = ScoreResponse.model_validate(score_response.json()) - - assert score.id is not None - assert score.data is not None - assert len(score.data) == 1 - assert score.data[0].score >= 0.9 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str): - - text_1 = "What is the capital of France?" * 20 - text_2 = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." - ] - - score_response = requests.post(server.url_for("score"), - json={ - "model": model_name, - "text_1": text_1, - "text_2": text_2, - }) - assert score_response.status_code == 400 - # Assert just a small fragments of the response - assert "Please reduce the length of the input." in \ - score_response.text - - # Test truncation - score_response = requests.post(server.url_for("score"), - json={ - "model": model_name, - "text_1": text_1, - "text_2": text_2, - "truncate_prompt_tokens": 101 - }) - assert score_response.status_code == 400 - assert "Please, select a smaller truncation size." in \ - score_response.text +@pytest.fixture(scope="class") +def runner(model: dict[str, Any], hf_runner): + kwargs = { + "dtype": DTYPE, + "is_cross_encoder" if model["is_cross_encoder"]\ + else "is_sentence_transformer": True + } + + with hf_runner(model["name"], **kwargs) as hf_model: + yield hf_model + + +class TestModel: + + def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer, + model: dict[str, Any], runner): + text_1 = "What is the capital of France?" + text_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + + vllm_outputs = [d.score for d in score.data] + + text_pairs = [[text_1, text_2[0]], [text_1, text_2[1]]] + hf_outputs = run_transformers(runner, model, text_pairs) + + for i in range(len(vllm_outputs)): + assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01) + + def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer, + model: dict[str, Any], runner): + text_1 = [ + "What is the capital of the United States?", + "What is the capital of France?" + ] + text_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + + vllm_outputs = [d.score for d in score.data] + + text_pairs = [[text_1[0], text_2[0]], [text_1[1], text_2[1]]] + hf_outputs = run_transformers(runner, model, text_pairs) + + for i in range(len(vllm_outputs)): + assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01) + + def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer, + model: dict[str, Any], runner): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + score_response = requests.post(server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 1 + + vllm_outputs = [d.score for d in score.data] + + text_pairs = [[text_1, text_2]] + hf_outputs = run_transformers(runner, model, text_pairs) + + for i in range(len(vllm_outputs)): + assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01) + + def test_score_max_model_len(self, server: RemoteOpenAIServer, + model: dict[str, Any]): + + text_1 = "What is the capital of France?" * 20 + text_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }) + assert score_response.status_code == 400 + # Assert just a small fragments of the response + assert "Please reduce the length of the input." in \ + score_response.text + + # Test truncation + score_response = requests.post(server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "truncate_prompt_tokens": 101 + }) + assert score_response.status_code == 400 + assert "Please, select a smaller truncation size." in \ + score_response.text diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 40b7a529ebfb..cefb9184b202 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -7,7 +7,6 @@ Tuple, Type, Union, cast, overload) import cloudpickle -import torch import torch.nn as nn from tqdm import tqdm from typing_extensions import TypeVar, deprecated @@ -25,6 +24,8 @@ apply_mistral_chat_template, parse_chat_messages, resolve_chat_template_content_format) +from vllm.entrypoints.score_utils import (_cosine_similarity, + _validate_score_input_lens) from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt from vllm.logger import init_logger @@ -1010,40 +1011,25 @@ def _embedding_score( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[ScoringRequestOutput]: - encoded_output = self.encode( + encoded_output: List[PoolingRequestOutput] = self.encode( text_1 + text_2, use_tqdm=use_tqdm, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - encoded_output_1 = encoded_output[0:len(text_1)] - encoded_output_2 = encoded_output[len(text_1):] + + encoded_output_1: List[PoolingRequestOutput] = encoded_output[ + 0:len(text_1)] + encoded_output_2: List[PoolingRequestOutput] = encoded_output[ + len(text_1):] if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) - output_pairs = [(t1, t2) - for t1, t2 in zip(encoded_output_1, encoded_output_2)] - - scores = [] - scorer = torch.nn.CosineSimilarity(0) - - for embed_1, embed_2 in output_pairs: - pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data) + scores: List[PoolingRequestOutput] = [] - if (pad_token_id := getattr(tokenizer, "pad_token_id", - None)) is not None: - tokens = embed_1.prompt_token_ids + [ - pad_token_id - ] + embed_2.prompt_token_ids - else: - tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids - - scores.append( - PoolingRequestOutput( - request_id=f"{embed_1.request_id}_{embed_2.request_id}", - outputs=pair_score, - prompt_token_ids=tokens, - finished=True)) + scores = _cosine_similarity(tokenizer=tokenizer, + embed_1=encoded_output_1, + embed_2=encoded_output_2) items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) @@ -1183,12 +1169,7 @@ def ensure_str(prompt: SingletonPrompt): text_2 = [text_2] input_text_2: List[str] = [ensure_str(t) for t in text_2] - if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2): - raise ValueError("Input lengths must be either 1:1, 1:N or N:N") - if len(input_text_1) == 0: - raise ValueError("At least one text element must be given") - if len(input_text_2) == 0: - raise ValueError("At least one text_pair element must be given") + _validate_score_input_lens(input_text_1, input_text_2) if self.llm_engine.model_config.is_cross_encoder: return self._cross_encoding_score(tokenizer, input_text_1, @@ -1197,7 +1178,6 @@ def ensure_str(prompt: SingletonPrompt): lora_request, prompt_adapter_request) else: - return self._embedding_score( tokenizer, input_text_1, # type: ignore[arg-type] diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f7162fadbce8..d037a4e63484 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -73,8 +73,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling -from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank -from vllm.entrypoints.openai.serving_score import OpenAIServingScores +from vllm.entrypoints.openai.serving_score import ServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.serving_transcription import ( @@ -320,12 +319,12 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: return request.app.state.openai_serving_embedding -def score(request: Request) -> Optional[OpenAIServingScores]: +def score(request: Request) -> Optional[ServingScores]: return request.app.state.openai_serving_scores -def rerank(request: Request) -> Optional[JinaAIServingRerank]: - return request.app.state.jinaai_serving_reranking +def rerank(request: Request) -> Optional[ServingScores]: + return request.app.state.openai_serving_scores def tokenization(request: Request) -> OpenAIServingTokenization: @@ -866,13 +865,13 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) if model_config.task == "embed" else None - state.openai_serving_scores = OpenAIServingScores( + state.openai_serving_scores = ServingScores( engine_client, model_config, state.openai_serving_models, - request_logger=request_logger - ) if model_config.task == "score" else None - state.jinaai_serving_reranking = JinaAIServingRerank( + request_logger=request_logger) if model_config.task in ( + "score", "embed", "pooling") else None + state.jinaai_serving_reranking = ServingScores( engine_client, model_config, state.openai_serving_models, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 81e7028ad774..e4496f61e607 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -26,7 +26,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) -from vllm.entrypoints.openai.serving_score import OpenAIServingScores +from vllm.entrypoints.openai.serving_score import ServingScores from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -342,7 +342,7 @@ async def main(args): chat_template=None, chat_template_content_format="auto", ) if model_config.task == "embed" else None - openai_serving_scores = (OpenAIServingScores( + openai_serving_scores = (ServingScores( engine, model_config, openai_serving_models, @@ -364,9 +364,9 @@ async def main(args): # Determine the type of request and run it. if request.url == "/v1/chat/completions": - handler_fn = (None if openai_serving_chat is None else - openai_serving_chat.create_chat_completion) - if handler_fn is None: + chat_handler_fn = (None if openai_serving_chat is None else + openai_serving_chat.create_chat_completion) + if chat_handler_fn is None: response_futures.append( make_async_error_request_output( request, @@ -375,12 +375,13 @@ async def main(args): )) continue - response_futures.append(run_request(handler_fn, request, tracker)) + response_futures.append( + run_request(chat_handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/embeddings": - handler_fn = (None if openai_serving_embedding is None else - openai_serving_embedding.create_embedding) - if handler_fn is None: + embed_handler_fn = (None if openai_serving_embedding is None else + openai_serving_embedding.create_embedding) + if embed_handler_fn is None: response_futures.append( make_async_error_request_output( request, @@ -388,12 +389,13 @@ async def main(args): )) continue - response_futures.append(run_request(handler_fn, request, tracker)) + response_futures.append( + run_request(embed_handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/score": - handler_fn = (None if openai_serving_scores is None else - openai_serving_scores.create_score) - if handler_fn is None: + score_handler_fn = (None if openai_serving_scores is None else + openai_serving_scores.create_score) + if score_handler_fn is None: response_futures.append( make_async_error_request_output( request, @@ -401,7 +403,8 @@ async def main(args): )) continue - response_futures.append(run_request(handler_fn, request, tracker)) + response_futures.append( + run_request(score_handler_fn, request, tracker)) tracker.submitted() else: response_futures.append( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index dfc3328677c7..5619e509c554 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -52,8 +52,8 @@ logger = init_logger(__name__) CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, - EmbeddingCompletionRequest, ScoreRequest, - TokenizeCompletionRequest] + EmbeddingCompletionRequest, RerankRequest, + ScoreRequest, TokenizeCompletionRequest] ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] diff --git a/vllm/entrypoints/openai/serving_rerank.py b/vllm/entrypoints/openai/serving_rerank.py deleted file mode 100644 index 366df71217e9..000000000000 --- a/vllm/entrypoints/openai/serving_rerank.py +++ /dev/null @@ -1,208 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast - -from fastapi import Request - -from vllm.config import ModelConfig -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, - RerankRequest, RerankResponse, - RerankResult, RerankUsage) -from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.inputs.data import TokensPrompt -from vllm.logger import init_logger -from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput -from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -from vllm.utils import make_async, merge_async_iterators - -logger = init_logger(__name__) - - -class JinaAIServingRerank(OpenAIServing): - - def __init__( - self, - engine_client: EngineClient, - model_config: ModelConfig, - models: OpenAIServingModels, - *, - request_logger: Optional[RequestLogger], - ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger) - - async def do_rerank( - self, - request: RerankRequest, - raw_request: Optional[Request] = None - ) -> Union[RerankResponse, ErrorResponse]: - """ - Rerank API based on JinaAI's rerank API; implements the same - API interface. Designed for compatibility with off-the-shelf - tooling, since this is a common standard for reranking APIs - - See example client implementations at - https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py - numerous clients use this standard. - """ - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - model_name = request.model - request_id = f"rerank-{self._base_request_id(raw_request)}" - truncate_prompt_tokens = request.truncate_prompt_tokens - query = request.query - documents = request.documents - request_prompts = [] - engine_prompts = [] - top_n = request.top_n if request.top_n > 0 else len(documents) - - try: - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) - - tokenizer = await self.engine_client.get_tokenizer(lora_request) - - if prompt_adapter_request is not None: - raise NotImplementedError("Prompt adapter is not supported " - "for scoring models") - - if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "MistralTokenizer not supported for cross-encoding") - - if not self.model_config.is_cross_encoder: - raise ValueError("Model is not cross encoder.") - - if truncate_prompt_tokens is not None and \ - truncate_prompt_tokens > self.max_model_len: - raise ValueError( - f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " - f"is greater than max_model_len ({self.max_model_len})." - f" Please, select a smaller truncation size.") - for doc in documents: - request_prompt = f"{query}{tokenizer.sep_token}{doc}" - tokenization_kwargs: Dict[str, Any] = {} - if truncate_prompt_tokens is not None: - tokenization_kwargs["truncation"] = True - tokenization_kwargs["max_length"] = truncate_prompt_tokens - - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) - prompt_inputs = await tokenize_async(text=query, - text_pair=doc, - **tokenization_kwargs) - - input_ids = prompt_inputs["input_ids"] - text_token_prompt = \ - self._validate_input(request, input_ids, request_prompt) - engine_prompt = TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - - request_prompts.append(request_prompt) - engine_prompts.append(engine_prompt) - - except ValueError as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) - - # Schedule the request and get the result generator. - generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] - - try: - pooling_params = request.to_pooling_params() - - for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}-{i}" - - self._log_inputs(request_id_item, - request_prompts[i], - params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) - - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) - - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - ) - - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - result_generator = merge_async_iterators(*generators) - - num_prompts = len(engine_prompts) - - # Non-streaming response - final_res_batch: List[Optional[PoolingRequestOutput]] - final_res_batch = [None] * num_prompts - - try: - async for i, res in result_generator: - final_res_batch[i] = res - - assert all(final_res is not None for final_res in final_res_batch) - - final_res_batch_checked = cast(List[PoolingRequestOutput], - final_res_batch) - - response = self.request_output_to_rerank_response( - final_res_batch_checked, request_id, model_name, documents, - top_n) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - return response - - def request_output_to_rerank_response( - self, final_res_batch: List[PoolingRequestOutput], request_id: str, - model_name: str, documents: List[str], - top_n: int) -> RerankResponse: - """ - Convert the output of do_rank to a RerankResponse - """ - results: List[RerankResult] = [] - num_prompt_tokens = 0 - for idx, final_res in enumerate(final_res_batch): - classify_res = ScoringRequestOutput.from_base(final_res) - - result = RerankResult( - index=idx, - document=RerankDocument(text=documents[idx]), - relevance_score=classify_res.outputs.score, - ) - results.append(result) - prompt_token_ids = final_res.prompt_token_ids - num_prompt_tokens += len(prompt_token_ids) - - # sort by relevance, then return the top n if set - results.sort(key=lambda x: x.relevance_score, reverse=True) - if top_n < len(documents): - results = results[:top_n] - - return RerankResponse( - id=request_id, - model=model_name, - results=results, - usage=RerankUsage(total_tokens=num_prompt_tokens)) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index c7597808f7fe..0e9b355ad4f9 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -1,53 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio import time -from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast +from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Union from fastapi import Request from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, - ScoreResponse, ScoreResponseData, - UsageInfo) +from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, + RerankRequest, RerankResponse, + RerankResult, RerankUsage, + ScoreRequest, ScoreResponse, + ScoreResponseData, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.score_utils import (_cosine_similarity, + _validate_score_input_lens) from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput -from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast) from vllm.utils import make_async, merge_async_iterators logger = init_logger(__name__) -def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str], - str]) -> List: - if isinstance(text_1, (str, dict)): - # Convert a single prompt to a list. - text_1 = [text_1] - text_1 = [t for t in text_1] - - if isinstance(text_2, (str, dict)): - # Convert a single prompt to a list. - text_2 = [text_2] - text_2 = [t for t in text_2] - if len(text_1) > 1 and len(text_1) != len(text_2): - raise ValueError("Input lengths must be either 1:1, 1:N or N:N") - if len(text_1) == 0: - raise ValueError("At least one text element must be given") - if len(text_2) == 0: - raise ValueError("At least one text_pair element must be given") - - if len(text_1) == 1: - text_1 = text_1 * len(text_2) - - return [(t1, t2) for t1, t2 in zip(text_1, text_2)] - - -class OpenAIServingScores(OpenAIServing): +class ServingScores(OpenAIServing): def __init__( self, @@ -62,137 +45,280 @@ def __init__( models=models, request_logger=request_logger) - async def create_score( + async def _embedding_score( self, - request: ScoreRequest, - raw_request: Optional[Request] = None, - ) -> Union[ScoreResponse, ErrorResponse]: - """ - Score API similar to Sentence Transformers cross encoder + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + texts_1: List[str], + texts_2: List[str], + request: Union[RerankRequest, ScoreRequest], + request_id=str, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[Union[LoRARequest, None]] = None, + prompt_adapter_request: Optional[Union[PromptAdapterRequest, + None]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> List[PoolingRequestOutput]: + + input_texts = texts_1 + texts_2 + + engine_prompts: List[TokensPrompt] = [] + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + + tokenization_kwargs = tokenization_kwargs or {} + tokenized_prompts = await asyncio.gather( + *(tokenize_async(t, **tokenization_kwargs) for t in input_texts)) + + for tok_result, input_text in zip(tokenized_prompts, input_texts): + + text_token_prompt = \ + self._validate_input( + request, + tok_result["input_ids"], + input_text) + + engine_prompts.append( + TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"])) - See https://sbert.net/docs/package_reference/cross_encoder - """ - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + pooling_params = request.to_pooling_params() - model_name = request.model - request_id = f"score-{self._base_request_id(raw_request)}" - created_time = int(time.time()) - truncate_prompt_tokens = request.truncate_prompt_tokens + for i, engine_prompt in enumerate(engine_prompts): - request_prompts = [] - engine_prompts = [] + request_id_item = f"{request_id}-{i}" - try: - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) + self._log_inputs(request_id_item, + input_texts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + generators.append( + self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + )) - if prompt_adapter_request is not None: - raise NotImplementedError("Prompt adapter is not supported " - "for scoring models") + result_generator = merge_async_iterators(*generators) - if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "MistralTokenizer not supported for cross-encoding") + # Non-streaming response + final_res_batch: List[PoolingRequestOutput] = [] - if not self.model_config.is_cross_encoder: - raise ValueError("Model is not cross encoder.") + embeddings: List[Optional[PoolingRequestOutput]] =\ + [None] * len(engine_prompts) - if truncate_prompt_tokens is not None and \ - truncate_prompt_tokens > self.max_model_len: - raise ValueError( - f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " - f"is greater than max_model_len ({self.max_model_len})." - f" Please, select a smaller truncation size.") - - input_pairs = make_pairs(request.text_1, request.text_2) - for q, t in input_pairs: - request_prompt = f"{q}{tokenizer.sep_token}{t}" - - tokenization_kwargs: Dict[str, Any] = {} - if truncate_prompt_tokens is not None: - tokenization_kwargs["truncation"] = True - tokenization_kwargs["max_length"] = truncate_prompt_tokens - - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) - prompt_inputs = await tokenize_async(q, - text_pair=t, - **tokenization_kwargs) - - input_ids = prompt_inputs["input_ids"] - text_token_prompt = \ - self._validate_input(request, input_ids, request_prompt) - engine_prompt = TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - - request_prompts.append(request_prompt) - engine_prompts.append(engine_prompt) + async for i, res in result_generator: + embeddings[i] = res - except ValueError as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + emb_texts_1: List[PoolingRequestOutput] = [] + emb_texts_2: List[PoolingRequestOutput] = [] + + for i in range(0, len(texts_1)): + assert (emb := embeddings[i]) is not None + emb_texts_1.append(emb) + + for i in range(len(texts_1), len(embeddings)): + assert (emb := embeddings[i]) is not None + emb_texts_2.append(emb) + + if len(emb_texts_1) == 1: + emb_texts_1 = emb_texts_1 * len(emb_texts_2) + + final_res_batch = _cosine_similarity(tokenizer=tokenizer, + embed_1=emb_texts_1, + embed_2=emb_texts_2) + + return final_res_batch + + async def _cross_encoding_score( + self, + tokenizer: Union[AnyTokenizer], + texts_1: List[str], + texts_2: List[str], + request: Union[RerankRequest, ScoreRequest], + request_id=str, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[Union[LoRARequest, None]] = None, + prompt_adapter_request: Optional[Union[PromptAdapterRequest, + None]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> List[PoolingRequestOutput]: + + request_prompts: List[str] = [] + engine_prompts: List[TokensPrompt] = [] + + if len(texts_1) == 1: + texts_1 = texts_1 * len(texts_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)] + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + + tokenization_kwargs = tokenization_kwargs or {} + tokenized_prompts = await asyncio.gather( + *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) + for t1, t2 in input_pairs)) + + for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): + + request_prompt = f"{t1}{tokenizer.sep_token}{t2}" + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) # Schedule the request and get the result generator. generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] - try: - pooling_params = request.to_pooling_params() - - for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}-{i}" + pooling_params = request.to_pooling_params() - self._log_inputs(request_id_item, - request_prompts[i], - params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - ) + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + generators.append(generator) result_generator = merge_async_iterators(*generators) - num_prompts = len(engine_prompts) - # Non-streaming response - final_res_batch: List[Optional[PoolingRequestOutput]] - final_res_batch = [None] * num_prompts + final_res_batch: List[ + Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) - try: - async for i, res in result_generator: - final_res_batch[i] = res + async for i, res in result_generator: + final_res_batch[i] = res + + return [out for out in final_res_batch if out is not None] + + async def _run_scoring( + self, + texts_1: Union[str, list[str]], + texts_2: Union[str, list[str]], + request: Union[ScoreRequest, RerankRequest], + request_id: str, + raw_request: Optional[Request] = None, + truncate_prompt_tokens: Optional[int] = None, + ) -> List[PoolingRequestOutput]: + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for scoring models") + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if truncate_prompt_tokens is not None and \ + truncate_prompt_tokens > self.max_model_len: + raise ValueError( + f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " + f"is greater than max_model_len ({self.max_model_len})." + f" Please, select a smaller truncation size.") + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(texts_1, str): + texts_1 = [texts_1] + if isinstance(texts_2, str): + texts_2 = [texts_2] + + _validate_score_input_lens(texts_1, texts_2) + + if self.model_config.is_cross_encoder: + return await self._cross_encoding_score( + tokenizer=tokenizer, + texts_1=texts_1, + texts_2=texts_2, + request=request, + request_id=request_id, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers) + + else: + return await self._embedding_score( + tokenizer=tokenizer, + texts_1=texts_1, + texts_2=texts_2, + request=request, + request_id=request_id, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers) + + async def create_score( + self, + request: ScoreRequest, + raw_request: Optional[Request] = None, + ) -> Union[ScoreResponse, ErrorResponse]: + """ + Score API similar to Sentence Transformers cross encoder + + See https://sbert.net/docs/package_reference/cross_encoder + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret - assert all(final_res is not None for final_res in final_res_batch) + request_id = f"score-{self._base_request_id(raw_request)}" + created_time = int(time.time()) - final_res_batch_checked = cast(List[PoolingRequestOutput], - final_res_batch) + try: + final_res_batch = await self._run_scoring( + request.text_1, + request.text_2, + request, + request_id, + raw_request, + request.truncate_prompt_tokens, + ) - response = self.request_output_to_score_response( - final_res_batch_checked, + return self.request_output_to_score_response( + final_res_batch, request_id, created_time, - model_name, + request.model, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -200,7 +326,44 @@ async def create_score( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - return response + async def do_rerank( + self, + request: RerankRequest, + raw_request: Optional[Request] = None + ) -> Union[RerankResponse, ErrorResponse]: + """ + Rerank API based on JinaAI's rerank API; implements the same + API interface. Designed for compatibility with off-the-shelf + tooling, since this is a common standard for reranking APIs + + See example client implementations at + https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py + numerous clients use this standard. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"rerank-{self._base_request_id(raw_request)}" + documents = request.documents + top_n = request.top_n if request.top_n > 0 else len(documents) + + try: + final_res_batch = await self._run_scoring( + request.query, + documents, + request, + request_id, + raw_request, + request.truncate_prompt_tokens, + ) + return self.request_output_to_rerank_response( + final_res_batch, request_id, request.model, documents, top_n) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) def request_output_to_score_response( self, @@ -236,3 +399,35 @@ def request_output_to_score_response( data=items, usage=usage, ) + + def request_output_to_rerank_response( + self, final_res_batch: List[PoolingRequestOutput], request_id: str, + model_name: str, documents: List[str], + top_n: int) -> RerankResponse: + """ + Convert the output of do_rank to a RerankResponse + """ + results: List[RerankResult] = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + result = RerankResult( + index=idx, + document=RerankDocument(text=documents[idx]), + relevance_score=classify_res.outputs.score, + ) + results.append(result) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + # sort by relevance, then return the top n if set + results.sort(key=lambda x: x.relevance_score, reverse=True) + if top_n < len(documents): + results = results[:top_n] + + return RerankResponse( + id=request_id, + model=model_name, + results=results, + usage=RerankUsage(total_tokens=num_prompt_tokens)) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py new file mode 100644 index 000000000000..6ec0b5fb024a --- /dev/null +++ b/vllm/entrypoints/score_utils.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Union + +from torch.nn import CosineSimilarity + +from vllm.outputs import PoolingRequestOutput +from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer, + PreTrainedTokenizerFast) + + +def _cosine_similarity( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + embed_1: List[PoolingRequestOutput], + embed_2: List[PoolingRequestOutput], +) -> List[PoolingRequestOutput]: + + scorer = CosineSimilarity(0) + scores: Union[List[PoolingRequestOutput]] = [] + + for emb_1, emb_2 in zip(embed_1, embed_2): + pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) + + padding = [] + if (pad_token_id := getattr(tokenizer, "pad_token_id", + None)) is not None: + padding = [pad_token_id] + + tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids + + scores.append( + PoolingRequestOutput( + request_id=f"{emb_1.request_id}_{emb_2.request_id}", + outputs=pair_score, + prompt_token_ids=tokens, + finished=True)) + + return scores + + +def _validate_score_input_lens( + texts_1: Union[List[str], List[dict]], + texts_2: Union[List[str], List[dict]], +): + if len(texts_1) > 1 and len(texts_1) != len(texts_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(texts_1) == 0: + raise ValueError("At least one text element must be given") + if len(texts_2) == 0: + raise ValueError("At least one text_pair element must be given") From fe51a3df8de261b886857620bc3e830d8ffd5edf Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Thu, 20 Feb 2025 22:12:10 -0800 Subject: [PATCH 150/317] [ci] Fix metrics test model path (#13635) --- tests/metrics/test_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 1a9063bc2dc3..45a13488f07e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -146,7 +146,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, metrics_tag_content = stat_logger.labels["model_name"] if served_model_name is None or served_model_name == []: - actual_model_name = f"{MODEL_WEIGHTS_S3_BUCKET}/{model.split('/')[-1]}" + actual_model_name = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" assert metrics_tag_content == actual_model_name, ( f"Metrics tag model_name is wrong! expect: {actual_model_name!r}\n" f"actual: {metrics_tag_content!r}") From 016178118440caf42e04363f35f399fc1d9b5267 Mon Sep 17 00:00:00 2001 From: leoneo Date: Fri, 21 Feb 2025 14:14:24 +0800 Subject: [PATCH 151/317] [Kernel]Add streamK for block-quantized CUTLASS kernels (#12978) --- .../cutlass_w8a8/c3x/cutlass_gemm_caller.cuh | 16 +++++--- .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 40 +++++++++++++++---- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index 9ac7eee7204e..69a3f64cb0b0 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -30,12 +30,18 @@ static inline cute::Shape get_problem_shape( } template -void cutlass_gemm_caller(torch::Device device, - cute::Shape prob_shape, - typename GemmKernel::MainloopArguments mainloop_args, - typename GemmKernel::EpilogueArguments epilogue_args) { +void cutlass_gemm_caller( + torch::Device device, cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args, + typename GemmKernel::TileSchedulerArguments scheduler = {}) { + cutlass::KernelHardwareInfo hw_info; typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; + prob_shape, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index fb7a82b80ee6..e089c3d4be2c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -22,8 +22,9 @@ namespace vllm { using namespace cute; -template > +template > struct cutlass_3x_gemm_fp8_blockwise { using GroupSizeM = Int; using GroupSizeN = Int; @@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise { using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; + SchedulerType>>; struct GemmKernel : public KernelType {}; @@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; + typename GemmKernel::TileSchedulerArguments scheduler; + + static constexpr bool UsesStreamKScheduler = + cute::is_same_v; + + if constexpr (UsesStreamKScheduler) { + using DecompositionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::ReductionMode; + + scheduler.decomposition_mode = DecompositionMode::StreamK; + scheduler.reduction_mode = ReductionMode::Nondeterministic; + } + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, - epilogue_args); + epilogue_args, scheduler); } template @@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - cutlass_gemm_caller_blockwise< - cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, - b_scales); + auto k = a.size(1); + auto n = b.size(1); + + if (k > 3 * n) { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); + } } } // namespace vllm \ No newline at end of file From a9c0ffbbab98ad58a9b27363156eb76d1478bb30 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 21 Feb 2025 14:24:17 +0800 Subject: [PATCH 152/317] [Bugfix][CPU] Fix cpu all-reduce using native pytorch implementation (#13586) --- vllm/distributed/device_communicators/cpu_communicator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index 4e86396e7135..b920cd7e1acf 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -30,4 +30,5 @@ def __init__(self, pass def all_reduce(self, input_): - return self.dist_module.all_reduce(input_, group=self.device_group) + self.dist_module.all_reduce(input_, group=self.device_group) + return input_ From feaf88e39ca0fc583e99be193ffc3aa0d849aee4 Mon Sep 17 00:00:00 2001 From: John Zheng Date: Sat, 22 Feb 2025 02:21:05 +0800 Subject: [PATCH 153/317] fix typo of grafana dashboard, with correct datasource (#13668) Signed-off-by: John Zheng --- examples/online_serving/prometheus_grafana/grafana.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/online_serving/prometheus_grafana/grafana.json b/examples/online_serving/prometheus_grafana/grafana.json index f76a61bb5eec..fbe96b48e799 100644 --- a/examples/online_serving/prometheus_grafana/grafana.json +++ b/examples/online_serving/prometheus_grafana/grafana.json @@ -1260,7 +1260,7 @@ { "datasource": { "type": "prometheus", - "uid": "edx8memhpd9tsa" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "code", @@ -1360,7 +1360,7 @@ { "datasource": { "type": "prometheus", - "uid": "edx8memhpd9tsa" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "code", @@ -1473,7 +1473,7 @@ { "datasource": { "type": "prometheus", - "uid": "edx8memhpd9tsa" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "code", @@ -1523,7 +1523,7 @@ }, "datasource": { "type": "prometheus", - "uid": "edx8memhpd9tsa" + "uid": "${DS_PROMETHEUS}" }, "definition": "label_values(model_name)", "hide": 0, From 855ecfbd39d7bdf41d9c2a03f731c77a9632b5b1 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 21 Feb 2025 18:30:12 -0500 Subject: [PATCH 154/317] [Attention] MLA with chunked prefill (#12639) Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson Co-authored-by: Patrick Horn Co-authored-by: simon-mo Co-authored-by: Tyler Michael Smith --- csrc/cache.h | 7 + csrc/cache_kernels.cu | 159 ++ csrc/core/math.hpp | 5 - csrc/cuda_utils.h | 22 +- .../cutlass_w8a8/scaled_mm_c3x.cu | 5 +- csrc/torch_bindings.cpp | 6 + tests/kernels/test_cache.py | 75 +- vllm/_custom_ops.py | 10 + vllm/attention/__init__.py | 12 +- vllm/attention/backends/mla/common.py | 1503 +++++++++++++++++ vllm/attention/backends/mla/utils.py | 515 ------ vllm/attention/backends/triton_mla.py | 664 +------- .../attention/ops/triton_merge_attn_states.py | 84 + vllm/config.py | 13 - vllm/engine/arg_utils.py | 7 +- .../layers/quantization/utils/fp8_utils.py | 23 +- vllm/utils.py | 4 + vllm/v1/attention/backends/flash_attn.py | 71 +- 18 files changed, 1910 insertions(+), 1275 deletions(-) create mode 100644 vllm/attention/backends/mla/common.py delete mode 100644 vllm/attention/backends/mla/utils.py create mode 100644 vllm/attention/ops/triton_merge_attn_states.py diff --git a/csrc/cache.h b/csrc/cache.h index cf4a65c29055..0970b704be3a 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -39,3 +39,10 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); + +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0960888d1f75..a6f8602a0588 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include #include +#include "cuda_utils.h" #include "cuda_compat.h" #include "dispatch_utils.h" @@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } + +namespace vllm { + +// grid is launched with dimensions (batch, num_splits) +template +__global__ void gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRIES...] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per + // batch + + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size); + const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits); + + const int32_t split_start = split * split_blocks; + const int32_t split_end = min((split + 1) * split_blocks, tot_blocks); + + const bool is_active_split = (split_start < tot_blocks); + const bool is_last_split = (split_end == tot_blocks); + + if (!is_active_split) return; + + int32_t full_blocks_end = split_end; + int32_t partial_block_size = 0; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on (seq_starts[bid] / + // page_size) + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = 0; + if (seq_starts != nullptr) { + offset = seq_starts[bid] / block_size; + } + const int32_t* batch_block_table = block_table + batch_offset + offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + if (is_last_split) { + partial_block_size = seq_len % block_size; + if (partial_block_size) full_blocks_end -= 1; + } + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < full_blocks_end; ++pid) { + auto block_id = batch_block_table[pid]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; + for (int eid = 0; eid < block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } + + if (partial_block_size) { + auto block_id = batch_block_table[full_blocks_end]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride; + for (int eid = 0; eid < partial_block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } +} + +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_GATHER_CACHE(CPY_DTYPE) \ + vllm::gather_cache<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting block index by +// (seq_starts[bid] / page_size) +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + TORCH_CHECK(src_cache.dtype() == dst.dtype(), + "src_cache and dst must have the same dtype"); + + const int dtype_bits = src_cache.element_size() * 8; + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + if (dtype_bits == 32) { + CALL_GATHER_CACHE(uint32_t); + } else if (dtype_bits == 16) { + CALL_GATHER_CACHE(uint16_t); + } else if (dtype_bits == 8) { + CALL_GATHER_CACHE(uint8_t); + } else { + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); + } +} diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ddfaca27147b..b8171133f6aa 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } - -template -inline constexpr std::enable_if_t, T> ceil_div(T a, T b) { - return (a + b - 1) / b; -} \ No newline at end of file diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 6f79d2b74452..6e62ea208db8 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -2,10 +2,14 @@ #include -#if defined(__CUDACC__) || defined(_NVHPC_CUDA) - #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ - #define DEVICE_INLINE __forceinline__ __device__ - #define HOST_INLINE __forceinline__ __host__ +#if defined(__HIPCC__) + #define HOST_DEVICE_INLINE __host__ __device__ + #define DEVICE_INLINE __device__ + #define HOST_INLINE __host__ +#elif defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__ + #define DEVICE_INLINE __device__ __forceinline__ + #define HOST_INLINE __host__ __forceinline__ #else #define HOST_DEVICE_INLINE inline #define DEVICE_INLINE inline @@ -25,3 +29,13 @@ int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); + +namespace cuda_utils { + +template +HOST_DEVICE_INLINE constexpr std::enable_if_t, T> +ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +}; // namespace cuda_utils \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index e40f28229968..53921abc951c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,7 +1,7 @@ #include #include "c3x/scaled_mm_kernels.hpp" -#include "core/math.hpp" +#include "cuda_utils.h" /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for @@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, auto make_group_shape = [](torch::Tensor const& x, torch::Tensor const& s) -> GroupShape { TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))}; + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; }; GroupShape a_scale_group_shape = make_group_shape(a, a_scales); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ef81db14bf84..d2aecba442b4 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -493,6 +493,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); + + // Gather cache blocks from src_cache to dst. + cache_ops.def( + "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " + "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); + cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 21c02c5de35c..b8b5e2045457 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -682,8 +682,6 @@ def test_swap_blocks_mla( torch.ops._C_cache_ops.swap_blocks, (src_cache, dst_cache, block_mapping_tensor), test_utils=DEFAULT_OPCHECK_TEST_UTILS, - cond=(kv_lora_rank == KV_LORA_RANKS[0] - and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]), ) ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor) @@ -694,3 +692,76 @@ def test_swap_blocks_mla( dst_cache[dst].cpu(), msg=f"Block {src} from src should have been swapped to block " f"{dst} in dst_cache.") + + +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("qk_rope_head_dim", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_blocks", [1024]) +@pytest.mark.parametrize("max_seq_len", [512]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", + ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize("align_cache", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, + num_blocks, max_seq_len, batch_size, dtype, + kv_cache_dtype, align_cache, device): + entry_size = kv_lora_rank + qk_rope_head_dim + src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device, align_cache) + _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) + + seq_len_tensor = torch.randint(0, + max_seq_len + 1, (batch_size, ), + device=device) + + total_tokens = seq_len_tensor.sum() + cu_seq_lens = torch.empty((batch_size + 1), + dtype=torch.int32, + device=device) + cu_seq_lens[0] = 0 + cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + print("seq_len_tensor", seq_len_tensor) + + tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size + block_table = torch.empty((batch_size, num_blocks), + dtype=torch.int32, + device=device) + + for b in range(batch_size): + perm = torch.randperm(num_blocks, device=device) + block_table[b, :] = perm + + dst = torch.zeros((total_tokens, entry_size), + dtype=src_cache.dtype, + device=device) + + expected_batches = [] + for b in range(batch_size): + s = seq_len_tensor[b] + if s == 0: + continue + tot = tot_blocks_tensor[b] + blocks = block_table[b, :tot].tolist() + + gathered_rows = [] + for i in range(tot - 1): + gathered_rows.append(src_cache[blocks[i]]) + remaining = s - (tot - 1) * block_size + gathered_rows.append(src_cache[blocks[-1], :remaining, :]) + + batch_expected = torch.cat(gathered_rows, dim=0) + expected_batches.append(batch_expected) + expected = torch.cat(expected_batches, dim=0) + + opcheck( + torch.ops._C_cache_ops.gather_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, None), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + torch.testing.assert_close(dst, expected) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e3e3c644fbdd..2112af1201f3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1099,6 +1099,16 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) +def gather_cache(src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, seq_starts) + + def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 85c5715faba7..89229e7b87a0 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -4,16 +4,12 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) +from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend __all__ = [ - "Attention", - "AttentionBackend", - "AttentionMetadata", - "AttentionType", - "AttentionMetadataBuilder", - "Attention", - "AttentionState", - "get_attn_backend", + "Attention", "AttentionBackend", "AttentionMetadata", "AttentionType", + "AttentionMetadataBuilder", "Attention", "AttentionState", + "get_attn_backend", "get_flash_attn_version" ] diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py new file mode 100644 index 000000000000..c3dbbdb86823 --- /dev/null +++ b/vllm/attention/backends/mla/common.py @@ -0,0 +1,1503 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N * P] +W_KR project h_t to k_pe shape [H, N * R] +W_UV project kv_c to v shape [Lkv, N * V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK).view(Skv, N, P) +v = (kv_c @ W_UV).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatnated per head + `q_b_proj` is [W_UQ; W_QR] concatnated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Ahead of time, compute: + +% this projects from q_c to [Sq, N * Lkv] +W_UQ_UK = einsum("qnp,knp -> qnk" + W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P) + ).view(Lkv, N * Lkv) +% this projects from attn output [Sq, N * Lkv] to [Sq, H] +W_UV_O = einsum("knv,nvh -> nkh" + W_UV.view(Lkv, N, V), W_O.view(N, V, H) + ).view(N * Lkv, H) + +Runtime +q_c = h_t @ W_DQ +q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([q_latent, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) +return spda_o.reshape(-1, N * Lkv) @ W_UV_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) +new_v = (new_kv_c @ W_UV).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type, TypeVar) + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, MLAAttentionImpl) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + get_flash_attn_version, + is_block_tables_empty) +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsW8A8Fp8) +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_quantize) +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +except ImportError: + # For rocm use upstream flash attention + from flash_attn import flash_attn_varlen_func + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class MLACommonBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +class MLACommonState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + + scheduler_config = runner.scheduler_config + self.model_config = runner.model_config + cache_config = runner.cache_config + + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max( + 8 * self.model_config.max_model_len, 4 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + self._positions = torch.zeros((max_batch_size, ), + dtype=torch.long, + device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._positions + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + use_cuda_graph=True, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + input_positions=self._positions[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + "input_positions": attn_metadata.decode_metadata.input_positions, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_positions = attn_metadata.input_positions + num_positions = input_positions.shape[0] + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + # CUDA graph buffer is padded so only perform a partial copy based on + # num_positions + input_buffers["input_positions"][:num_positions].copy_( + input_positions, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + def begin_forward(self, model_input): + if self.chunked_prefill_enabled: + if not hasattr(self, "chunked_prefill_workspace"): + # not self.runner.device does not return the correct device + # for this process, (init_device sets the correct device but + # only on the Worker). The only way Ive figured out to get the + # correct device is to allocate the workspace on the first call + # to begin_forward and use the device of the input tokens + assert model_input.input_tokens is not None + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=model_input.input_tokens.device, + ) + + model_input.attn_metadata.chunked_prefill_workspace = \ + self.chunked_prefill_workspace + + +@dataclass +class MLACommonMetadata(AttentionMetadata): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # New for MLA (compared to FlashAttention) + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["MLACommonMetadata"] = None + _cached_decode_metadata: Optional["MLACommonMetadata"] = None + + num_prefill_tokens: int + + # The dimension of the attention heads + head_dim: Optional[int] = None + + # Used when chunked prefill is enabled to simulate worst case workspace + # allocations, hopefully to avoid going OOM + is_profile_run: bool = False + + # New for MLA (compared to FlashAttention) + # For chunked prefill + context_chunk_cu_seq_lens: Optional[torch.Tensor] = None + context_chunk_starts: Optional[torch.Tensor] = None + context_chunk_seq_tot: Optional[List[int]] = None + context_chunk_max_seq_lens: Optional[List[int]] = None + # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted + chunked_prefill_workspace: Optional[torch.Tensor] = None + + def __post_init__(self): + supported_head_sizes = MLACommonBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + @property + def prefill_metadata(self) -> Optional["MLACommonMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + input_positions = (None if self.input_positions is None else + self.input_positions[:self.num_prefill_tokens]) + + self._cached_prefill_metadata = MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=False, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, + context_chunk_starts=self.context_chunk_starts, + context_chunk_seq_tot=self.context_chunk_seq_tot, + context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["MLACommonMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + input_positions = (None if self.input_positions is None else + self.input_positions[self.num_prefill_tokens:]) + + self._cached_decode_metadata = MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=self.use_cuda_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + input_positions=input_positions, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +T = TypeVar("T", bound=MLACommonMetadata) + + +class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.chunked_prefill_enabled = \ + self.runner.scheduler_config.chunked_prefill_enabled + + if self.chunked_prefill_enabled: + attn_state = self.input_builder.runner.attn_state + self.chunked_prefill_workspace_size = \ + attn_state.chunked_prefill_workspace_size + self.page_size = self.runner.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.input_positions: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + input_positions = async_tensor_h2d(self.input_positions, torch.long, + device, self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + context_chunk_cu_seq_lens = None + context_chunk_starts = None + context_chunk_seq_tot = None + context_chunk_max_seq_lens = None + + if self.chunked_prefill_enabled and self.num_prefills > 0 \ + and context_lens_tensor is not None \ + and context_lens_tensor[:self.num_prefills].max() > 0: + + # NOTE: it is recommend you read the `Chunked Prefill` section in + # the comment at the top of the file before trying to understand + # the following code + + num_prefills_with_context = \ + (context_lens_tensor[:self.num_prefills] > 0).sum().item() + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + self.chunked_prefill_workspace_size // num_prefills_with_context + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + context_chunk_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32)\ + .unsqueeze(1).expand(-1, self.num_prefills)\ + * max_context_chunk + chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ + .unsqueeze(0), context_chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) + _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) + zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ + .unsqueeze(-1) + context_chunk_cu_seq_lens = \ + torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) + context_chunk_max_seq_lens = \ + chunk_seq_lens.max(dim=1).values.tolist() + context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() + assert max(context_chunk_seq_tot) <= \ + self.chunked_prefill_workspace_size + + return MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=use_captured_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, # Not Attention Related + enable_kv_scales_calculation=False, + # MLACommonMetadata + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.runner.model_config.get_head_size(), + is_profile_run=self.runner.in_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, + context_chunk_starts=context_chunk_starts, + context_chunk_seq_tot=context_chunk_seq_tot, + context_chunk_max_seq_lens=context_chunk_max_seq_lens, + ) + + +class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + rotary_emb: RotaryEmbedding, + # q_proj should be q_b_proj if q_lora_rank is not None, but from an + # attention backend perspective we rely on the layer to pass in the + # correct matrix + q_proj: ColumnParallelLinear, + kv_b_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + self.rotary_emb = rotary_emb + self.use_yarn_rope = isinstance(rotary_emb, + DeepseekScalingRotaryEmbedding) + self.q_proj = q_proj + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + self.vllm_flash_attn_version = get_flash_attn_version() + + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + def _v_up_proj_and_o_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_UV_O): + output_parallel = apply_fp8_linear_generic( + x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape) + else: + output_parallel = torch.matmul(x.flatten(start_dim=1), + self.W_UV_O) + if self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + return output + else: + x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) + return self.o_proj(x.reshape(-1, + self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_Q_UK): + return apply_fp8_linear_generic( + x, self.W_Q_UK, self.W_Q_UK_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape).view( + -1, self.num_heads, self.kv_lora_rank) + return torch.matmul(x, self.W_Q_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + else: + x = torch.matmul(x, self.W_Q)\ + .view(-1, self.num_heads, self.qk_nope_head_dim) + return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + # TODO(lucas) This is very gross, we need a more wide scale refactor of + # all the FP8 code with a more standard way of + # defining schemes/group-shapes, we should also potentially force + # quant_methods to support a decompress function + # + # returns input_group_shape, weight_group_shape + def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ + Tuple[Tuple[int, int], Tuple[int, int]]: + if isinstance(layer.quant_method, Fp8LinearMethod): + if layer.quant_method.block_quant: + weight_block_size = \ + layer.quant_method.quant_config.weight_block_size + # per-token-group (1, X), block-quantized (X, Y) + return (1, weight_block_size[-1]), weight_block_size + else: + return (-1, -1), (-1, -1) # per-tensor, per-tensor + elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ + and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): + # this is hacky but we always assume the for + # CompressedTensorsW8A8Fp8 the input is dynamic per-token + # we ignore if it is static-per-tensor since we are going to + # requantize after later anyways + strategy = layer.scheme.strategy + if strategy == QuantizationStrategy.TENSOR: + return (1, -1), (-1, -1) # per-token, per-tensor + elif strategy == QuantizationStrategy.CHANNEL: + return (1, -1), (-1, 1) # per-token, per-channel + else: + raise NotImplementedError( + f"QuantizationStrategy.{strategy} is not supported for " + "fp8 MLA, please run with VLLM_MLA_DISABLE=1") + else: + raise NotImplementedError( + "Can't determine scale group shapes for " + f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" + ) + + def get_layer_weight(layer): + if hasattr(layer, "weight"): + return layer.weight + elif hasattr(layer, "qweight"): + return layer.qweight + else: + raise AttributeError( + f"Layer '{layer}' has neither weight nor qweight") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + weight_dtype = get_layer_weight(self.kv_b_proj).dtype + assert get_layer_weight(self.o_proj).dtype == weight_dtype + assert get_layer_weight(self.q_proj).dtype == weight_dtype + + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ + .view(-1, self.num_heads, self.qk_head_dim) + + # can be W_Q or W_UQ depending q_lora_rank, the former if + # q_lora_rank is None, the latter otherwise. From the Attention backend + # perspective though we call these both W_Q and rely on the layer + # to pass in the correct matrix + W_Q = q_proj_weight[..., :self.qk_nope_head_dim] + self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ + .flatten(start_dim=1).contiguous() + + # W_QR is small so for simplicity we dont bother requantizing it + self.W_QR = self.W_QR.to(act_dtype) + + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION + if is_fp8(weight_dtype) and requantization_enabled: + # This assumes it wise to requantize using the same group shapes + # (i.e. strategy, per-tensor, per-channel, block etc.) that the + # weights were originally quantized + requant_input_group_shape, requant_weight_group_shape = \ + get_scale_group_shapes_for_fp8(self.q_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.kv_b_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.o_proj) + self.reqaunt_input_group_shape = requant_input_group_shape + self.reqaunt_weight_group_shape = requant_weight_group_shape + + # + # Perform matrix-absorption following + # https://github.com/flashinfer-ai/flashinfer/pull/551 + # for decode, as a result we end up with absorbed weights for decode + # and another copy of raw weights for prefill. + # + self.W_UK, self.W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK + # depending q_lora_rank, the former if q_lora_rank is None, the + # latter otherwise + # basically if q_lora_rank is none we are absorbing into q_proj + # instead of UQ + W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ + .flatten(start_dim=1).contiguous() + + if is_fp8(weight_dtype) and requantization_enabled: + W_Q_UK, W_Q_UK_scales = scaled_quantize( + W_Q_UK, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_Q_UK = W_Q_UK.T.contiguous() + self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() + else: + self.W_Q_UK = W_Q_UK.to(act_dtype) + + W_O = get_and_maybe_dequant_weights(self.o_proj)\ + .view(-1, self.num_heads, self.v_head_dim) + W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ + .flatten(start_dim=0, end_dim=1).contiguous() + + if is_fp8(weight_dtype) and requantization_enabled: + W_UV_O, W_UV_O_scales = scaled_quantize( + W_UV_O, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_UV_O = W_UV_O.T.contiguous() + self.W_UV_O_scales = W_UV_O_scales.T.contiguous() + else: + self.W_UV_O = W_UV_O.to(act_dtype) + + self.tp_size = get_tensor_model_parallel_world_size() + else: + if is_fp8(weight_dtype): + raise NotImplementedError( + "Currently fp8 requires matrix absorption") + + self.W_UV = W_UV + self.W_UK = W_UK + self.W_Q = W_Q.flatten(start_dim=1) + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ): + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + assert prefill_metadata.context_chunk_seq_tot is not None + assert prefill_metadata.context_chunk_cu_seq_lens is not None + assert prefill_metadata.context_chunk_starts is not None + assert prefill_metadata.context_chunk_max_seq_lens is not None + assert prefill_metadata.context_lens_tensor is not None + + output = None + iters = len(prefill_metadata.context_chunk_seq_tot) + + # Fetch from attn_metadata directly, since it late bound by + # MLAAttentionState, grabbing it directly `attn_metadata` can avoid + # any weirdness around prefill_metadata caching + assert attn_metadata.chunked_prefill_workspace is not None + workspace = attn_metadata.chunked_prefill_workspace + + for i in range(iters): + toks = prefill_metadata.context_chunk_seq_tot[i] + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], + batch_size=prefill_metadata.num_prefills, + seq_starts=prefill_metadata.context_chunk_starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank].unsqueeze(1) + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, + [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) + + if has_context: + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # slice by `:v.shape[-1]` in order to remove v headdim padding + output = output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + return self.o_proj(output)[0] + + @abstractmethod + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for MLAImplBase") + + if attn_metadata.is_profile_run and \ + attn_metadata.chunked_prefill_workspace is not None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (attn_metadata.chunked_prefill_workspace.shape[0], + self.num_heads, self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + has_decode = attn_metadata.decode_metadata is not None + has_prefill = attn_metadata.prefill_metadata is not None + + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) + assert hasattr(attn_metadata, "input_positions") + + num_prefill_tokens: int = attn_metadata.num_prefill_tokens + + decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:] + decode_k_pe = k_pe[num_prefill_tokens:] + decode_input_positions = \ + attn_metadata.input_positions[num_prefill_tokens:] + + prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens] + prefill_k_pe = k_pe[:num_prefill_tokens] + prefill_input_positions = \ + attn_metadata.input_positions[:num_prefill_tokens] + prefill_k_c_normed = k_c_normed[:num_prefill_tokens] + + if has_decode: + decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ + .view(-1, self.num_heads, self.qk_rope_head_dim) + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + decode_input_positions, decode_q_pe, decode_k_pe) + + if has_prefill: + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + prefill_input_positions, prefill_q_pe, prefill_k_pe) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + output = torch.empty(attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens, + self.o_proj.output_size, + device=hidden_states_or_q_c.device, + dtype=hidden_states_or_q_c.dtype) + if has_prefill: + output[:num_prefill_tokens] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) + + if has_decode: + output[num_prefill_tokens:] = self._forward_decode( + decode_q_nope, decode_q_pe, kv_cache, attn_metadata) + + return output diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py deleted file mode 100644 index df3fb2aeefc4..000000000000 --- a/vllm/attention/backends/mla/utils.py +++ /dev/null @@ -1,515 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import functools -from abc import abstractmethod -from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, Tuple - -import torch -from compressed_tensors.quantization import QuantizationStrategy - -from vllm import _custom_ops as ops -from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, - AttentionMetadata, - MLAAttentionImpl, T) -from vllm.attention.backends.utils import get_flash_attn_version -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, RowParallelLinear, - UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsW8A8Fp8) -from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - scaled_quantize) -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) - -try: - from vllm.vllm_flash_attn import flash_attn_varlen_func -except ImportError: - from flash_attn import flash_attn_varlen_func - - -@dataclass -class MLACommonMetadata(AttentionMetadata): - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - - -class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): - """ - Common class for implementing repeated parts - - Main reference: DeepseekV2 paper, and FlashInfer Implementation - (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - - Deepseek's MLA attention works the following way: - * Use a single latent vector to represent the entire KV cache. - * The attention "simulates" a multi-head attention, while the compute is - similar to multi-query attention. - * The dataflow is as follows, - - * B: batch/sequence length - * H: hidden size - * N: number of attention heads - * Lq: latent dimension for Q - * Lkv: latent dimension for K/V - * P: nope dimension, P+R is the actual head_dim in common attention. - * R: rope dimension, this slide of the head_dim goes through rope. - * V: V head dim. - * kv_c: latent/compressed KV - * q_c: latent/compressed Q - - # - # Outside the MLA attention backend - # - - 1. The hidden states (B, H) are projected down into cq (B, Lq) and - kv_c_k_pe (B, Lkv+R). - 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq - and kv_c are normalized. - - # - # Inside the MLA attention backend - # - - * if prefill: - - 3. The q_c is then projected up into the multi-head version. - * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope - (B, N, P) and q_pe (B, N, R). - 4. q_pe, k_pe are then passed through rotary embeddings. - 5. kv_c and k_pe are concatenated and inserted into the cache - 6. The kv_c is then projected up into the multi-head version. - * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope - dimensions for K and V, which is split into k_nope (B, N, P) - and v (B, N, V). - 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from - q_nope, q_pe, k_nope, k_pe. - 8. Attention is computued with q, k, v. - 9. The attention computation returns (B, N, V), which is projected back - to (B, H) using out projection. - - * if decode: - - 3. Here's the change, we do not perform up the full up projection for - q_c, and there is no up projection at all for kv_c. This is - achieved by the technique of "weight absorption". The paper says - "Fortunately, due to the associative law of matrix multiplication, - we can absorb WUK into WUQ, and WUV into WO" - * The q up projection turns (B, Lq) into (B, N, (P+R)), we split it - into W_UQ (Lq, N, P) and W_QR (Lq, N, R). - * The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split - it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). - * The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H). - * We can precompute the product of W_UQ and W_UK into - W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in - attention. - * We can precompute the product of W_UV and W_O into - W_UV_O (N, Lkv, H), which is possible due to V@O as the - "epilogue" of attention - 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. - 5. q_pe, k_pe are then passed through rotary embeddings. - 6. kv_c and k_pe are concatenated and inserted into the cache - 7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape - (B, N, Lkv). - 8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, - kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. - 9. The attention is computed with q, k, v. Note that we just performed - a MQA attention with (LKv+R) as our head dim. - 10. The KV cache is updated using the new entries k (B, N, (Lkv+R)), - which included the v and rope values. - 11. The attention computation returns (B, N, Lkv), which is projected - back to (B, H) using W_UV_O. - - From @tsu-bin's calculation, we only want to use the absorption technique - for decode. The prefill algorithm should still use the up-projected MHA - for less flops and memory usage. - - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], - logits_soft_cap: Optional[float], - attn_type: str, - # MLA Specific Arguments - q_lora_rank: Optional[int], - kv_lora_rank: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - qk_head_dim: int, - v_head_dim: int, - rotary_emb: RotaryEmbedding, - # q_proj should be q_b_proj if q_lora_rank is not None, but from an - # attention backend perspective we rely on the layer to pass in the - # correct matrix - q_proj: ColumnParallelLinear, - kv_b_proj: ColumnParallelLinear, - o_proj: RowParallelLinear, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_head_dim - self.v_head_dim = v_head_dim - - self.rotary_emb = rotary_emb - self.use_yarn_rope = isinstance(rotary_emb, - DeepseekScalingRotaryEmbedding) - self.q_proj = q_proj - self.kv_b_proj = kv_b_proj - self.o_proj = o_proj - self.vllm_flash_attn_version = get_flash_attn_version() - - # Handle the differences between the flash_attn_varlen from flash_attn - # and the one from vllm_flash_attn. The former is used on RoCM and the - # latter has an additional parameter to control FA2 vs FA3 - self.flash_attn_varlen_func = flash_attn_varlen_func - if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) - - def _v_up_proj_and_o_proj(self, x): - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - if is_fp8(self.W_UV_O): - output_parallel = apply_fp8_linear_generic( - x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, - self.reqaunt_input_group_shape, - self.reqaunt_weight_group_shape) - else: - output_parallel = torch.matmul(x.flatten(start_dim=1), - self.W_UV_O) - if self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - return output - else: - x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) - return self.o_proj(x.reshape(-1, - self.num_heads * self.v_head_dim))[0] - - def _q_proj_and_k_up_proj(self, x): - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - if is_fp8(self.W_Q_UK): - return apply_fp8_linear_generic( - x, self.W_Q_UK, self.W_Q_UK_scales, - self.reqaunt_input_group_shape, - self.reqaunt_weight_group_shape).view( - -1, self.num_heads, self.kv_lora_rank) - return torch.matmul(x, self.W_Q_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) - else: - x = torch.matmul(x, self.W_Q)\ - .view(-1, self.num_heads, self.qk_nope_head_dim) - return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) - - def process_weights_after_loading(self, act_dtype: torch.dtype): - # TODO(lucas) This is very gross, we need a more wide scale refactor of - # all the FP8 code with a more standard way of - # defining schemes/group-shapes, we should also potentially force - # quant_methods to support a decompress function - # - # returns input_group_shape, weight_group_shape - def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ - Tuple[Tuple[int, int], Tuple[int, int]]: - if isinstance(layer.quant_method, Fp8LinearMethod): - if layer.quant_method.block_quant: - weight_block_size = \ - layer.quant_method.quant_config.weight_block_size - # per-token-group (1, X), block-quantized (X, Y) - return (1, weight_block_size[-1]), weight_block_size - else: - return (-1, -1), (-1, -1) # per-tensor, per-tensor - elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ - and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): - # this is hacky but we always assume the for - # CompressedTensorsW8A8Fp8 the input is dynamic per-token - # we ignore if it is static-per-tensor since we are going to - # requantize after later anyways - strategy = layer.scheme.strategy - if strategy == QuantizationStrategy.TENSOR: - return (1, -1), (-1, -1) # per-token, per-tensor - elif strategy == QuantizationStrategy.CHANNEL: - return (1, -1), (-1, 1) # per-token, per-channel - else: - raise NotImplementedError( - f"QuantizationStrategy.{strategy} is not supported for " - "fp8 MLA, please run with VLLM_MLA_DISABLE=1") - else: - raise NotImplementedError( - "Can't determine scale group shapes for " - f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" - ) - - def get_layer_weight(layer): - if hasattr(layer, "weight"): - return layer.weight - elif hasattr(layer, "qweight"): - return layer.qweight - else: - raise AttributeError( - f"Layer '{layer}' has neither weight nor qweight") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - weight_dtype = get_layer_weight(self.kv_b_proj).dtype - assert get_layer_weight(self.o_proj).dtype == weight_dtype - assert get_layer_weight(self.q_proj).dtype == weight_dtype - - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ - .view(-1, self.num_heads, self.qk_head_dim) - - # can be W_Q or W_UQ depending q_lora_rank, the former if - # q_lora_rank is None, the latter otherwise. From the Attention backend - # perspective though we call these both W_Q and rely on the layer - # to pass in the correct matrix - W_Q = q_proj_weight[..., :self.qk_nope_head_dim] - self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ - .flatten(start_dim=1).contiguous() - - # W_QR is small so for simplicity we dont bother requantizing it - self.W_QR = self.W_QR.to(act_dtype) - - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION - if is_fp8(weight_dtype) and requantization_enabled: - # This assumes it wise to requantize using the same group shapes - # (i.e. strategy, per-tensor, per-channel, block etc.) that the - # weights were originally quantized - requant_input_group_shape, requant_weight_group_shape = \ - get_scale_group_shapes_for_fp8(self.q_proj) - assert (requant_input_group_shape, requant_weight_group_shape)\ - == get_scale_group_shapes_for_fp8(self.kv_b_proj) - assert (requant_input_group_shape, requant_weight_group_shape)\ - == get_scale_group_shapes_for_fp8(self.o_proj) - self.reqaunt_input_group_shape = requant_input_group_shape - self.reqaunt_weight_group_shape = requant_weight_group_shape - - # - # Perform matrix-absorption following - # https://github.com/flashinfer-ai/flashinfer/pull/551 - # for decode, as a result we end up with absorbed weights for decode - # and another copy of raw weights for prefill. - # - self.W_UK, self.W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK - # depending q_lora_rank, the former if q_lora_rank is None, the - # latter otherwise - # basically if q_lora_rank is none we are absorbing into q_proj - # instead of UQ - W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ - .flatten(start_dim=1).contiguous() - - if is_fp8(weight_dtype) and requantization_enabled: - W_Q_UK, W_Q_UK_scales = scaled_quantize( - W_Q_UK, - self.reqaunt_weight_group_shape, - quant_dtype=current_platform_fp8_dtype) - # For FP8 save the transpose so we can use - # `apply_w8a8_block_fp8_linear` directly - self.W_Q_UK = W_Q_UK.T.contiguous() - self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() - else: - self.W_Q_UK = W_Q_UK.to(act_dtype) - - W_O = get_and_maybe_dequant_weights(self.o_proj)\ - .view(-1, self.num_heads, self.v_head_dim) - W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ - .flatten(start_dim=0, end_dim=1).contiguous() - - if is_fp8(weight_dtype) and requantization_enabled: - W_UV_O, W_UV_O_scales = scaled_quantize( - W_UV_O, - self.reqaunt_weight_group_shape, - quant_dtype=current_platform_fp8_dtype) - # For FP8 save the transpose so we can use - # `apply_w8a8_block_fp8_linear` directly - self.W_UV_O = W_UV_O.T.contiguous() - self.W_UV_O_scales = W_UV_O_scales.T.contiguous() - else: - self.W_UV_O = W_UV_O.to(act_dtype) - - self.tp_size = get_tensor_model_parallel_world_size() - else: - if is_fp8(weight_dtype): - raise NotImplementedError( - "Currently fp8 requires matrix absorption") - - self.W_UV = W_UV - self.W_UK = W_UK - self.W_Q = W_Q.flatten(start_dim=1) - - @abstractmethod - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - attn_metadata: T, - ) -> torch.Tensor: - raise NotImplementedError - - @abstractmethod - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: T, - ) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, - layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, - attn_metadata: T, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if output is not None: - raise NotImplementedError( - "output is not yet supported for MLAImplBase") - - is_decode = attn_metadata.decode_metadata is not None - is_prefill = attn_metadata.prefill_metadata is not None - - if (is_decode and is_prefill): - raise NotImplementedError( - "chunked prefill is not supported for MLAImplBase") - - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert hasattr(attn_metadata, "input_positions") - - if is_decode: - q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) - q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ - .view(-1, self.num_heads, self.qk_rope_head_dim) - q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, - k_pe) - else: - assert is_prefill - q = self.q_proj(hidden_states_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) - - # TODO(lucas): there must be a nicer way to write this line - q[..., self.qk_nope_head_dim:], k_pe = \ - self.rotary_emb( - attn_metadata.input_positions, - q[..., self.qk_nope_head_dim:], k_pe) - - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - - if attn_metadata.prefill_metadata is not None: - return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata) - - if attn_metadata.decode_metadata is not None: - return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata) - - # Optional common flash-attn based prefill - def _forward_prefill_flash( - self, - q: torch.Tensor, - k_c_normed: torch.Tensor, - k_pe: torch.Tensor, - seq_start_loc: torch.Tensor, - max_prefill_seq_len: int, - ) -> torch.Tensor: - - kv_nope = self.kv_b_proj(k_c_normed)[0]\ - .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=seq_start_loc, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max_prefill_seq_len, - max_seqlen_k=max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - ) - attn_output = attn_output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(attn_output)[0] diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 9a1984a931b5..08e8226ab04c 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -1,40 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass -from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type - -from vllm.multimodal import MultiModalPlaceholderMap - -try: - from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeMlaWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +from typing import Any, Dict, List, Optional, Type import torch -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, AttentionType) -from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd -from vllm.utils import async_tensor_h2d, make_tensor_with_pad - -if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) -class TritonMLABackend(AttentionBackend): +class TritonMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -44,610 +21,8 @@ def get_name() -> str: def get_impl_cls() -> Type["TritonMLAImpl"]: return TritonMLAImpl - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return TritonMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]: - return TritonMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["TritonMLAState"]: - return TritonMLAState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - ops.copy_blocks_mla(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [576] - - -class TritonMLAState(AttentionState): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - - @contextmanager - def graph_capture(self, max_batch_size: int): - self._is_graph_capturing = True - - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - - self._positions = torch.zeros((max_batch_size, ), - dtype=torch.long, - device=self.runner.device) - - yield - - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - del self._positions - - def graph_clone(self, batch_size: int): - assert self._is_graph_capturing - return self.__class__(self.runner) - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - assert self._is_graph_capturing - - attn_metadata = self.runner.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=self._graph_slot_mapping[:batch_size], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=1, - max_decode_query_len=1, - max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self._graph_block_tables[:batch_size], - use_cuda_graph=True, - input_positions=self._positions[:batch_size], - head_dim=self.runner.model_config.get_head_size()) - - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - return attn_metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - "input_positions": attn_metadata.decode_metadata.input_positions, - } - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_positions = attn_metadata.input_positions - num_positions = input_positions.shape[0] - input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - # CUDA graph buffer is padded so only perform a partial copy based on - # num_positions - input_buffers["input_positions"][:num_positions].copy_( - input_positions, non_blocking=True) - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - def begin_forward(self, model_input): - return - - -@dataclass -class TritonMLAMetadata(MLACommonMetadata): - """Metadata for TritonMLAMetadata. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["TritonMLAMetadata"] = None - _cached_decode_metadata: Optional["TritonMLAMetadata"] = None - num_prefill_tokens: int - - num_kv_splits: int = 4 # TODO(lucas) add heuristic - attn_logits: Optional[torch.Tensor] = None - req_idx: Optional[torch.Tensor] = None - - # The dimension of the attention heads - head_dim: Optional[int] = None - - def __post_init__(self): - supported_head_sizes = TritonMLABackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f"received {self.head_dim}.") - - @property - def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - input_positions = (None if self.input_positions is None else - self.input_positions[:self.num_prefill_tokens]) - - self._cached_prefill_metadata = TritonMLAMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - input_positions=input_positions, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - head_dim=self.head_dim) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["TritonMLAMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - input_positions = (None if self.input_positions is None else - self.input_positions[self.num_prefill_tokens:]) - - self._cached_decode_metadata = TritonMLAMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - input_positions=input_positions, - head_dim=self.head_dim) - return self._cached_decode_metadata - - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - - -class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.input_positions: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) - self.context_lens.append(context_len) - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - input_positions = async_tensor_h2d(self.input_positions, torch.long, - device, self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return TritonMLAMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - input_positions=input_positions, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - num_kv_splits=4, # TODO(lucas) add heuristic - head_dim=self.runner.model_config.get_head_size(), - ) - - -class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): def __init__( self, @@ -662,11 +37,11 @@ def __init__( logits_soft_cap: Optional[float], attn_type: str, # MLA Specific Arguments - **kwargs) -> None: + **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **kwargs) + **mla_args) unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap @@ -683,24 +58,12 @@ def __init__( "are not implemented for " "TritonMLAImpl") - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - attn_metadata: TritonMLAMetadata, - ) -> torch.Tensor: - assert isinstance(attn_metadata, TritonMLAMetadata) - return self._forward_prefill_flash(q, kv_c_normed, k_pe, - attn_metadata.seq_start_loc, - attn_metadata.max_prefill_seq_len) - def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: TritonMLAMetadata, + attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 if self.kv_cache_dtype.startswith("fp8"): @@ -717,12 +80,14 @@ def _forward_decode( dtype=q.dtype, device=q.device) + num_kv_splits = 4 # TODO: heuristic + # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( ( B, self.num_heads, - attn_metadata.num_kv_splits, + num_kv_splits, # NOTE(lucas) idk why the +1 is here but sglang has it so we # just mirror that self.kv_lora_rank + 1, @@ -740,7 +105,6 @@ def _forward_decode( decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_meta.block_tables, decode_meta.seq_lens_tensor, attn_logits, - attn_metadata.num_kv_splits, self.scale, - PAGE_SIZE) + num_kv_splits, self.scale, PAGE_SIZE) return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py new file mode 100644 index 000000000000..31545b607fec --- /dev/null +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import triton +import triton.language as tl + + +# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + + +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse, # [NUM_HEADS, NUM_TOKENS] + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse, # [NUM_HEADS, NUM_TOKENS] + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = (tl.exp(p_lse) + tl.exp(s_lse)) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + + # NOTE(woosuk): Be careful with the numerical stability. + # We should compute the scale first, and then multiply it with the output. + # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store(output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask) diff --git a/vllm/config.py b/vllm/config.py index f118004b2f2f..d6e197fe988a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3332,19 +3332,6 @@ def __post_init__(self): current_platform.check_and_update_config(self) - # If MLA is enabled, force disable chunked prefill and prefix caching - if self.model_config and self.model_config.use_mla: - logger.info("MLA is enabled; forcing chunked prefill and prefix " - "caching to be disabled.") - self.scheduler_config.enable_chunked_prefill = False - self.scheduler_config.chunked_prefill_enabled = False - self.scheduler_config.max_num_batched_tokens = max( - self.scheduler_config.max_model_len, - _DEFAULT_MAX_NUM_BATCHED_TOKENS) - - if self.cache_config is not None: - self.cache_config.enable_prefix_caching = False - if not self.instance_id: self.instance_id = random_uuid()[:5] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5aa77a138a3e..8b460b33e235 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1170,9 +1170,9 @@ def create_engine_config(self, # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - # For multimodal models, chunked prefill is disabled by default in - # V0, but enabled by design in V1 - if model_config.is_multimodal_model: + # For multimodal models and models with MLA, chunked prefill is + # disabled by default in V0, but enabled by design in V1 + if model_config.is_multimodal_model or model_config.use_mla: self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) elif use_long_context: @@ -1207,7 +1207,6 @@ def create_engine_config(self, msg = "Chunked prefill is not supported for pooling models" raise ValueError(msg) - speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9895537c219a..891edf23010c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -162,6 +162,9 @@ def _per_token_group_quant_fp8( y_q_ptr, y_s_ptr, group_size, + # Num columns of y + y_num_columns, + y_row_stride, # Avoid to divide zero eps, # Information for float8 @@ -174,9 +177,14 @@ def _per_token_group_quant_fp8( quantization on a tensor. This function converts the tensor values into float8 values. """ + groups_per_row = y_num_columns // group_size + # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * group_size + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) y_q_ptr += g_id * group_size y_s_ptr += g_id @@ -202,6 +210,7 @@ def _per_token_group_quant_fp8_colmajor( group_size, # Num columns of y y_num_columns, + y_row_stride, # Stride from one column to the next of y_s y_s_col_stride, # Avoid to divide zero @@ -216,9 +225,14 @@ def _per_token_group_quant_fp8_colmajor( quantization on a tensor. This function converts the tensor values into float8 values. """ + groups_per_row = y_num_columns // group_size + # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * group_size + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) y_q_ptr += g_id * group_size # Convert g_id the flattened block coordinate to 2D so we can index @@ -267,7 +281,7 @@ def per_token_group_quant_fp8( assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}") - assert x.is_contiguous(), "`x` must be contiguous" + assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) fp8_min = finfo.min @@ -295,6 +309,7 @@ def per_token_group_quant_fp8( x_s, group_size, x.shape[1], + x.stride(0), x_s.stride(1), eps, fp8_min=fp8_min, @@ -309,6 +324,8 @@ def per_token_group_quant_fp8( x_q, x_s, group_size, + x.shape[1], + x.stride(0), eps, fp8_min=fp8_min, fp8_max=fp8_max, diff --git a/vllm/utils.py b/vllm/utils.py index b1bac649c972..4d3f90c95a7d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -565,6 +565,10 @@ def round_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y +def round_down(x: int, y: int) -> int: + return (x // y) * y + + def _generate_random_fp8( tensor: torch.Tensor, low: float, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b1b5cc359251..1922a3bf2724 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -5,12 +5,11 @@ import numpy as np import torch -import triton -import triton.language as tl from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import get_flash_attn_version +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv @@ -372,70 +371,4 @@ def cascade_attention( # Merge prefix and suffix outputs, and store the result in output. merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) - - -def merge_attn_states( - output: torch.Tensor, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - suffix_output: torch.Tensor, - suffix_lse: torch.Tensor, -) -> None: - num_tokens = output.shape[0] - num_query_heads = output.shape[1] - head_size = output.shape[2] - padded_head_size = triton.next_power_of_2(head_size) - - # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. - merge_attn_states_kernel[(num_tokens, num_query_heads)]( - output, - prefix_output, - prefix_lse, - suffix_output, - suffix_lse, - head_size, - padded_head_size, - ) - - -@triton.jit -def merge_attn_states_kernel( - output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_lse, # [NUM_HEADS, NUM_TOKENS] - suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - suffix_lse, # [NUM_HEADS, NUM_TOKENS] - HEAD_SIZE: tl.constexpr, - PADDED_HEAD_SIZE: tl.constexpr, -): - token_idx = tl.program_id(0) - num_tokens = tl.num_programs(0) - head_idx = tl.program_id(1) - num_heads = tl.num_programs(1) - - p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) - s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) - max_lse = tl.maximum(p_lse, s_lse) - p_lse = p_lse - max_lse - s_lse = s_lse - max_lse - - head_arange = tl.arange(0, PADDED_HEAD_SIZE) - head_mask = head_arange < HEAD_SIZE - p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - - # NOTE(woosuk): Be careful with the numerical stability. - # We should compute the scale first, and then multiply it with the output. - # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse)) - s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse)) - out = p_out * p_scale + s_out * s_scale - tl.store(output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - out, - mask=head_mask) + suffix_lse) \ No newline at end of file From fe90015016554449fdb00e1a5cd84ebc01cc7257 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 22 Feb 2025 13:10:43 +0800 Subject: [PATCH 155/317] [Misc] Fix yapf linting tools etc not running on pre-commit (#13695) Signed-off-by: Isotr0py <2037008807@qq.com> --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6a66131cdb4d..20d1981c9a05 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,6 @@ repos: hooks: - id: actionlint exclude: 'vllm/third_party/.*' -repos: - repo: https://github.com/astral-sh/uv-pre-commit rev: 0.6.2 hooks: From 025d6e6875532a01fc5adfb21bed4b1dcd48fd04 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Sat, 22 Feb 2025 00:53:59 -0500 Subject: [PATCH 156/317] docs: Add a note on full CI run in contributing guide (#13646) --- docs/source/contributing/overview.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/contributing/overview.md b/docs/source/contributing/overview.md index af09bfecc649..5f8f5525e52a 100644 --- a/docs/source/contributing/overview.md +++ b/docs/source/contributing/overview.md @@ -145,6 +145,9 @@ review process: - Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion. +- Note that not all CI checks will be executed due to limited computational + resources. The reviewer will add `ready` label to the PR when the PR is + ready to merge or a full CI run is needed. ## Thank You From 3752fb34b20a66ae31c1aef83dceea7d5cbfb28c Mon Sep 17 00:00:00 2001 From: Keyun Tong Date: Fri, 21 Feb 2025 21:55:50 -0800 Subject: [PATCH 157/317] [HTTP Server] Make model param optional in request (#13568) --- tests/entrypoints/openai/test_chat.py | 32 +++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 20 ++++++------ vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/entrypoints/openai/serving_embedding.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 13 +++++++- vllm/entrypoints/openai/serving_models.py | 2 +- vllm/entrypoints/openai/serving_pooling.py | 2 +- vllm/entrypoints/openai/serving_score.py | 4 +-- 9 files changed, 61 insertions(+), 18 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 4b5ad55c5eda..d7ed4afa2861 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -9,6 +9,7 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio +import requests import torch from openai import BadRequestError @@ -996,3 +997,34 @@ async def test_long_seed(client: openai.AsyncOpenAI): assert ("greater_than_equal" in exc_info.value.message or "less_than_equal" in exc_info.value.message) + + +@pytest.mark.asyncio +async def test_http_chat_wo_model_name(server: RemoteOpenAIServer): + url = f"http://localhost:{server.port}/v1/chat/completions" + headers = { + "Content-Type": "application/json", + } + data = { + # model_name is avoided here. + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "what is 1+1?" + }], + "max_tokens": + 5 + } + + response = requests.post(url, headers=headers, json=data) + response_data = response.json() + print(response_data) + + choice = response_data.get("choices")[0] + message = choice.get("message") + assert message is not None + content = message.get("content") + assert content is not None + assert len(content) > 0 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 29f64d28bdf1..45b98a032bda 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -213,7 +213,7 @@ class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create messages: List[ChatCompletionMessageParam] - model: str + model: Optional[str] = None frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[bool] = False @@ -642,7 +642,7 @@ def check_generation_prompt(cls, data): class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create - model: str + model: Optional[str] = None prompt: Union[List[int], List[List[int]], str, List[str]] best_of: Optional[int] = None echo: Optional[bool] = False @@ -907,7 +907,7 @@ def validate_stream_options(cls, data): class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings - model: str + model: Optional[str] = None input: Union[List[int], List[List[int]], str, List[str]] encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None @@ -939,7 +939,7 @@ def to_pooling_params(self): class EmbeddingChatRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None messages: List[ChatCompletionMessageParam] encoding_format: Literal["float", "base64"] = "float" @@ -1007,7 +1007,7 @@ def to_pooling_params(self): class ScoreRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None text_1: Union[List[str], str] text_2: Union[List[str], str] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None @@ -1031,7 +1031,7 @@ def to_pooling_params(self): class RerankRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None query: str documents: List[str] top_n: int = Field(default_factory=lambda: 0) @@ -1345,7 +1345,7 @@ class BatchRequestOutput(OpenAIBaseModel): class TokenizeCompletionRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None prompt: str add_special_tokens: bool = Field( @@ -1357,7 +1357,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel): class TokenizeChatRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None messages: List[ChatCompletionMessageParam] add_generation_prompt: bool = Field( @@ -1423,7 +1423,7 @@ class TokenizeResponse(OpenAIBaseModel): class DetokenizeRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None tokens: List[int] @@ -1456,7 +1456,7 @@ class TranscriptionRequest(OpenAIBaseModel): formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. """ - model: str + model: Optional[str] = None """ID of the model to use. """ diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 934bd2a95063..02dd2c4881c6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -141,7 +141,7 @@ async def create_chat_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - model_name = self.models.model_name(lora_request) + model_name = self._get_model_name(request.model, lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e7ad263e7fbe..840f0f9b8448 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -166,7 +166,7 @@ async def create_completion( result_generator = merge_async_iterators(*generators) - model_name = self.models.model_name(lora_request) + model_name = self._get_model_name(request.model, lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 45f8ad90ddcb..607dbd96b194 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -83,7 +83,7 @@ async def create_embedding( return self.create_error_response( "dimensions is currently not supported") - model_name = request.model + model_name = self._get_model_name(request.model) request_id = f"embd-{self._base_request_id(raw_request)}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5619e509c554..05b5f95a5e59 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -523,5 +523,16 @@ def _get_decoded_token(logprob: Logprob, return logprob.decoded_token return tokenizer.decode(token_id) - def _is_model_supported(self, model_name): + def _is_model_supported(self, model_name) -> bool: + if not model_name: + return True return self.models.is_base_model(model_name) + + def _get_model_name(self, + model_name: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> str: + if lora_request: + return lora_request.lora_name + if model_name is None: + return self.models.base_model_paths[0].name + return model_name diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index f917a4851901..6ade4ece6d03 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -95,7 +95,7 @@ async def init_static_loras(self): if isinstance(load_result, ErrorResponse): raise ValueError(load_result.message) - def is_base_model(self, model_name): + def is_base_model(self, model_name) -> bool: return any(model.name == model_name for model in self.base_model_paths) def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 01a3d211f6ba..bbf5aed1a33c 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -79,7 +79,7 @@ async def create_pooling( return self.create_error_response( "dimensions is currently not supported") - model_name = request.model + model_name = self._get_model_name(request.model) request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 0e9b355ad4f9..01e2d3043610 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -318,7 +318,7 @@ async def create_score( final_res_batch, request_id, created_time, - request.model, + self._get_model_name(request.model), ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -358,7 +358,7 @@ async def do_rerank( request.truncate_prompt_tokens, ) return self.request_output_to_rerank_response( - final_res_batch, request_id, request.model, documents, top_n) + final_res_batch, request_id, self._get_model_name(request.model), documents, top_n) except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: From 74b39374988d42656f56f0cc952beb6065c88ab3 Mon Sep 17 00:00:00 2001 From: Robin <863579016@qq.com> Date: Sat, 22 Feb 2025 14:05:28 +0800 Subject: [PATCH 158/317] =?UTF-8?q?[Bugfix][API=20Server]=20Fix=20invalid?= =?UTF-8?q?=20usage=20of=20'ge'=20and=20'le'=20in=20port=20valid=E2=80=A6?= =?UTF-8?q?=20(#13672)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm/entrypoints/api_server.py | 2 +- vllm/utils.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 4294a8aad9a5..11ffc4f67cea 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -145,7 +145,7 @@ async def run_server(args: Namespace, if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=int, default=8000, ge=1024, le=65535) + parser.add_argument("--port", type=parser.check_port, default=8000) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument("--ssl-ca-certs", diff --git a/vllm/utils.py b/vllm/utils.py index 4d3f90c95a7d..dcafd5411bbb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1194,6 +1194,17 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) + def check_port(self, value): + try: + value = int(value) + except ValueError: + raise argparse.ArgumentTypeError("Port must be an integer") + + if not (1024 <= value <= 65535): + raise argparse.ArgumentTypeError("Port must be between 1024 and 65535") + + return value + def _pull_args_from_config(self, args: List[str]) -> List[str]: """Method to pull arguments specified in the config file into the command-line args variable. From 3cb4af2447dfafc3da774c1e846ef8af855ed4cf Mon Sep 17 00:00:00 2001 From: Jun Duan Date: Sat, 22 Feb 2025 01:06:34 -0500 Subject: [PATCH 159/317] [Misc] Capture and log the time of loading weights (#13666) --- vllm/v1/worker/gpu_model_runner.py | 8 +++++--- vllm/worker/model_runner.py | 7 +++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 31fe095a91bc..d2e9c2650c7b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1048,6 +1048,7 @@ def generate_draft_token_ids( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 + time_before_load = time.perf_counter() self.model = get_model(vllm_config=self.vllm_config) if self.lora_config: self.model = self.load_lora_model(self.model, @@ -1055,10 +1056,11 @@ def load_model(self) -> None: self.scheduler_config, self.lora_config, self.device) - + time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", - self.model_memory_usage / float(2**30)) + logger.info("Loading model weights took %.4f GB and %.6f seconds", + self.model_memory_usage / float(2**30), + time_after_load - time_before_load) def _get_prompt_logprobs_dict( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 67d175c373d8..1a78498ad124 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1109,11 +1109,14 @@ def __init__( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler(self.device) as m: + time_before_load = time.perf_counter() self.model = get_model(vllm_config=self.vllm_config) + time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", - self.model_memory_usage / float(2**30)) + logger.info("Loading model weights took %.4f GB and %.6f seconds", + self.model_memory_usage / float(2**30), + time_after_load - time_before_load) if self.lora_config: assert supports_lora( From ecfe822246bbdfa7e6c573a47c97c445bc8196f0 Mon Sep 17 00:00:00 2001 From: Gordon Wong Date: Sat, 22 Feb 2025 14:07:04 +0800 Subject: [PATCH 160/317] [ROCM] fix native attention function call (#13650) --- vllm/attention/backends/rocm_flash_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index f49b37842d9b..e1a8d3d33613 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -717,7 +717,6 @@ def forward( self.num_heads, self.head_size, self.scale, - causal_mask, attn_masks, ) else: From 836ec6b64831544582abe77ae7cd2b7416c98c43 Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 21 Feb 2025 22:07:45 -0800 Subject: [PATCH 161/317] [Bugfix][Model] OLMo 2: split qkv correctly for GQA and MQA (#13687) --- vllm/model_executor/models/olmo2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 4b0455098eed..d06f894123ac 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -157,7 +157,7 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) From ce71cab204d143b62b451253711edcf208322f96 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Sat, 22 Feb 2025 01:09:04 -0500 Subject: [PATCH 162/317] [Misc] Bump compressed-tensors (#13619) --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index c52980bc7df7..f72aa40fccec 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -34,6 +34,6 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.9.1 # required for compressed-tensors +compressed-tensors == 0.9.2 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py From 225c16dbc192312b45a515e3b20e1184bc0e1fda Mon Sep 17 00:00:00 2001 From: Robin <863579016@qq.com> Date: Sat, 22 Feb 2025 14:10:38 +0800 Subject: [PATCH 163/317] [Bugfix] Fix benchmark script bug: inaccurate stats for vllm backend when max_model_len < input_len + output_len (#13691) Signed-off-by: WangErXiao <863579016@qq.com> --- benchmarks/benchmark_guided.py | 13 +++++++++++++ benchmarks/benchmark_latency.py | 4 ++++ benchmarks/benchmark_prioritization.py | 18 ++++++++++++++---- benchmarks/benchmark_throughput.py | 13 ++++++++++++- 4 files changed, 43 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_guided.py b/benchmarks/benchmark_guided.py index 2b41834baf4d..dc2bf0e79cbc 100644 --- a/benchmarks/benchmark_guided.py +++ b/benchmarks/benchmark_guided.py @@ -46,6 +46,12 @@ def run_vllm(requests: List[SampleRequest], warmup: bool = False) -> float: from vllm import LLM, SamplingParams llm = LLM(**vars(engine_args)) + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. prompts: List[str] = [] @@ -115,6 +121,13 @@ async def run_vllm_async( async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: + assert all( + llm.model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + # Add the requests to the engine. prompts: List[str] = [] sampling_params: List[SamplingParams] = [] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index b041626550b5..b1d68ea24694 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -42,6 +42,10 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) + assert llm.llm_engine.model_config.max_model_len >= ( + args.input_len + args.output_len), ( + "Please ensure that max_model_len is greater than" + " the sum of input_len and output_len.") sampling_params = SamplingParams( n=args.n, diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index a32065e4e7c0..24014e5b6c37 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -13,6 +13,11 @@ from vllm.utils import FlexibleArgumentParser +#Select a equi-probable random priority +def get_random_flag(): + return 0 if random.random() < 0.5 else 1 + + def sample_requests( dataset_path: str, num_requests: int, @@ -55,8 +60,7 @@ def sample_requests( # Prune too long sequences. continue - #Select a equi-probable random priority - priority = 0 if random.random() < 0.5 else 1 + priority = get_random_flag() filtered_dataset.append((prompt, prompt_len, output_len, priority)) @@ -71,6 +75,12 @@ def run_vllm( from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) + assert all( + llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " input_len and output_len for all requests.") + # Add the requests to the engine. prompts = [] sampling_params = [] @@ -103,8 +113,8 @@ def main(args: argparse.Namespace): if args.dataset is None: # Synthesize a prompt with the given input length. prompt = "hi" * (args.input_len - 1) - requests = [(prompt, args.input_len, args.output_len) - for _ in range(args.num_prompts)] + requests = [(prompt, args.input_len, args.output_len, + get_random_flag()) for _ in range(args.num_prompts)] else: requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.output_len) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index f7d87f1b336f..ca54213c0646 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -171,7 +171,12 @@ def run_vllm( ) -> float: from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) - + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] @@ -229,6 +234,12 @@ async def run_vllm_async( async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: + assert all( + llm.model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. prompts: List[TextPrompt] = [] From b8466ec563df480a318a4db1f908c796765e9ca3 Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Fri, 21 Feb 2025 22:13:05 -0800 Subject: [PATCH 164/317] [v1] Support allowed_token_ids in v1 Sampler (#13210) Signed-off-by: Lu Fang --- tests/v1/sample/test_rejection_sampler.py | 1 + tests/v1/sample/test_sampler.py | 94 +++++++++++++++++++---- tests/v1/worker/test_gpu_input_batch.py | 13 ++++ vllm/v1/engine/processor.py | 14 ++++ vllm/v1/sample/metadata.py | 4 + vllm/v1/sample/sampler.py | 18 ++++- vllm/v1/worker/gpu_input_batch.py | 43 ++++++++++- 7 files changed, 168 insertions(+), 19 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 3e810e525e1c..956d91c6daf7 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -43,6 +43,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: output_token_ids=[], min_tokens={}, logit_bias=[None] * batch_size, + allowed_token_ids_mask=None, ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 3f6301c54267..34fba5a9f6d7 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -57,6 +57,26 @@ def _create_logit_bias( return res +def _create_allowed_token_ids( + batch_size: int, + vocab_size: int, + num_allowed_token_ids: int, + device: torch.device, +) -> Optional[torch.Tensor]: + mask: Optional[torch.Tensor] = None + for i in range(batch_size): + if i % 2 == 1: + continue + if mask is None: + mask = torch.zeros((batch_size, vocab_size), + dtype=torch.bool, + device=device) + start = min(i, vocab_size - 1) + end = min(i + num_allowed_token_ids, vocab_size - 1) + mask[i, start:end] = True + return mask + + def _create_default_sampling_metadata( num_output_tokens: int, batch_size: int, @@ -92,6 +112,7 @@ def _create_default_sampling_metadata( no_penalties=True, min_tokens={}, logit_bias=[None] * batch_size, + allowed_token_ids_mask=None, ) return fake_sampling_metadata @@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, sampling_metadata.frequency_penalties = _create_penalty_tensor( batch_size, frequency_penalty, torch.device(device)) output_token_ids, sorted_token_ids_in_output = \ - _create_weighted_output_token_list(batch_size, VOCAB_SIZE) + _create_weighted_output_token_list( + batch_size, + VOCAB_SIZE, + ) sampling_metadata.output_token_ids = output_token_ids sampling_metadata.no_penalties = False sampler = Sampler() @@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() - distinct_sorted_token_ids_in_output = \ - sorted_token_ids_in_output[batch_idx] + distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[ + batch_idx] most_frequent_token_id = distinct_sorted_token_ids_in_output[ len(distinct_sorted_token_ids_in_output) - 1] if frequency_penalty > 0: @@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, # non-penalized token ID is not present in the output, while the # most penalized token is the one that occurs most frequently in # the output. - assert non_penalized_token_id \ - not in distinct_sorted_token_ids_in_output + assert (non_penalized_token_id + not in distinct_sorted_token_ids_in_output) assert penalized_token_id == most_frequent_token_id elif frequency_penalty < 0: # If `frequency_penalty` is set to < 0, it indicates @@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, # in the output, while the penalized token ID is one that has not # yet appeared. assert non_penalized_token_id == most_frequent_token_id - assert penalized_token_id \ - not in distinct_sorted_token_ids_in_output + assert penalized_token_id not in distinct_sorted_token_ids_in_output @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, # If `repetition_penalty` > 1.0, verify that the non-penalized # token ID has not been seen before, while the penalized token ID # exists either in the prompt or the output. - assert (non_penalized_token_id not in prompt_tokens and \ - non_penalized_token_id not in output_tokens) - assert (penalized_token_id in prompt_tokens or \ - penalized_token_id in output_tokens) + assert (non_penalized_token_id not in prompt_tokens + and non_penalized_token_id not in output_tokens) + assert (penalized_token_id in prompt_tokens + or penalized_token_id in output_tokens) elif repetition_penalty < 1.0: # If `repetition_penalty` < 1.0, verify that the penalized # token ID has not been seen before, while the non-penalized # token ID exists either in the prompt or the output. - assert (penalized_token_id not in prompt_tokens and \ - penalized_token_id not in output_tokens) - assert (non_penalized_token_id in prompt_tokens or \ - non_penalized_token_id in output_tokens) + assert (penalized_token_id not in prompt_tokens + and penalized_token_id not in output_tokens) + assert (non_penalized_token_id in prompt_tokens + or non_penalized_token_id in output_tokens) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float): 1e-2) else: assert logits_for_req[token_id] == pytest.approx(1e-2) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) +def test_sampler_allowed_token_ids(device: str, batch_size: int, + num_allowed_token_ids: int): + """ + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + mask = _create_allowed_token_ids( + batch_size=batch_size, + vocab_size=VOCAB_SIZE, + num_allowed_token_ids=num_allowed_token_ids, + device=device, + ) + sampling_metadata.allowed_token_ids_mask = mask + sampler = Sampler() + logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata) + logits = logits.cpu() + for batch_idx in range(batch_size): + logits_for_req = logits[batch_idx] + if batch_idx % 2 == 1: + assert torch.all(logits_for_req != -float("inf")) + continue + for token_id in range(VOCAB_SIZE): + start = min(batch_idx, VOCAB_SIZE - 1) + end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1) + if token_id >= start and token_id < end: + assert logits_for_req[token_id] == -float( + "inf"), f"{batch_idx}, {token_id}" + else: + assert logits_for_req[token_id] != -float("inf") diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index cb3b3d21fbb3..0aee266264ac 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata( temperature = [0.0 for _ in range(num_reqs)] min_tokens = {} logit_bias = [None] * num_reqs + allowed_token_ids_mask = torch.zeros(num_reqs, + VOCAB_SIZE, + dtype=torch.bool, + device=device) for req in reqs: if req.req_id not in req_ids_retained: continue @@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata( req.sampling_params.min_tokens, req.sampling_params.all_stop_token_ids) logit_bias[index_in_input_batch] = req.sampling_params.logit_bias + if req.sampling_params.allowed_token_ids: + allowed_token_ids_mask[index_in_input_batch][ + req.sampling_params.allowed_token_ids] = True + return SamplingMetadata( temperature=torch.tensor(temperature, dtype=torch.float, device=device), @@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata( and all(x == 0 for x in frequency_penalties) and all(x == 1 for x in repetition_penalties)), logit_bias=logit_bias, + allowed_token_ids_mask=allowed_token_ids_mask, ) @@ -242,3 +251,7 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: assert expected_sampling_metadata.no_penalties == \ sampling_metadata.no_penalties assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias + if sampling_metadata.allowed_token_ids_mask: + assert torch.allclose( + expected_sampling_metadata.allowed_token_ids_mask, + sampling_metadata.allowed_token_ids_mask) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index b7eee5a39972..2547cebaede7 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -83,6 +83,19 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + def _validate_allowed_token_ids( + self, + params: Union[SamplingParams, PoolingParams], + ) -> None: + if not isinstance(params, SamplingParams): + return + if params.allowed_token_ids is None: + return + if not all(0 <= tid < self.model_config.vocab_size + for tid in params.allowed_token_ids): + raise ValueError( + "allowed_token_ids contains out-of-vocab token id") + def process_inputs( self, request_id: str, @@ -100,6 +113,7 @@ def process_inputs( self._validate_logprobs(params) self._validate_lora(lora_request) + self._validate_allowed_token_ids(params) if arrival_time is None: arrival_time = time.time() diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 6d82d3a79c8e..9f7770bbd078 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -37,3 +37,7 @@ class SamplingMetadata: min_tokens: Dict[int, Tuple[int, Set[int]]] logit_bias: List[Optional[Dict[int, float]]] + + # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, + # vocab size). + allowed_token_ids_mask: Optional[torch.Tensor] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index ff978b3b6c41..47ec26d42024 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -47,6 +47,8 @@ def forward( # Use float32 for the logits. logits = logits.to(torch.float32) + # Apply allowed token ids. + logits = self.apply_allowed_token_ids(logits, sampling_metadata) # Apply logits bias. logits = self.apply_logits_bias(logits, sampling_metadata) # Apply penalties (e.g., min_tokens, freq_penalties). @@ -184,11 +186,13 @@ def apply_penalties( if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None logits = apply_all_penalties( - logits, sampling_metadata.prompt_token_ids, + logits, + sampling_metadata.prompt_token_ids, sampling_metadata.presence_penalties, sampling_metadata.frequency_penalties, sampling_metadata.repetition_penalties, - sampling_metadata.output_token_ids) + sampling_metadata.output_token_ids, + ) return logits def apply_min_p( @@ -226,3 +230,13 @@ def apply_logits_bias( for token_id, bias in logit_bias.items(): logits[i, token_id] += bias return logits + + def apply_allowed_token_ids( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if sampling_metadata.allowed_token_ids_mask is not None: + logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, + float("-inf")) + return logits diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index bd1c369acb30..d9fc53490c07 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -143,7 +143,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: Set[str] = set() # Presence penalty related data structures @@ -168,7 +168,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() # req_index -> (min_tokens, stop_token_ids) @@ -192,6 +192,9 @@ def __init__( self.logit_bias: List[Optional[Dict[int, float]]] = [None] * max_num_reqs + self.has_allowed_token_ids: Set[str] = set() + self.allowed_token_ids_mask: Optional[torch.Tensor] = None + self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None self.req_output_token_ids: List[Optional[List[int]]] = [] @@ -287,6 +290,22 @@ def add_request( if sampling_params.logit_bias is not None: self.logit_bias[req_index] = sampling_params.logit_bias + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu_tensor is None: + # Lazy allocation for this tensor, which can be large. + self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device=self.device) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device="cpu") + self.allowed_token_ids_mask_cpu_tensor[req_index][ + sampling_params.allowed_token_ids] = True + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -332,6 +351,9 @@ def remove_request(self, req_id: str) -> Optional[int]: self.request_lora_mapping[req_index] = 0 self.logit_bias[req_index] = None + self.has_allowed_token_ids.discard(req_id) + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) return req_index def condense(self, empty_req_indices: List[int]) -> None: @@ -400,6 +422,11 @@ def condense(self, empty_req_indices: List[int]) -> None: self.logit_bias[empty_index] = self.logit_bias[last_req_index] + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[ + empty_index] = self.allowed_token_ids_mask_cpu_tensor[ + last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -442,6 +469,13 @@ def _make_sampling_metadata(self) -> SamplingMetadata: else: prompt_token_ids = None + allowed_token_ids_mask: Optional[torch.Tensor] = None + if not self.no_allowed_token_ids: + assert self.allowed_token_ids_mask is not None + copy_slice(self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, num_reqs) + allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] + return SamplingMetadata( temperature=temperature, all_greedy=self.all_greedy, @@ -460,6 +494,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata: min_tokens=self.min_tokens, no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], + allowed_token_ids_mask=allowed_token_ids_mask, ) def get_sampling_metadata( @@ -550,3 +585,7 @@ def max_num_logprobs(self) -> Optional[int]: @property def no_prompt_logprob(self) -> bool: return not self.num_prompt_logprobs + + @property + def no_allowed_token_ids(self) -> bool: + return len(self.has_allowed_token_ids) == 0 From 623a41407570b1dc6342c581833461e089d0cf1e Mon Sep 17 00:00:00 2001 From: Jennifer Zhao Date: Sat, 22 Feb 2025 00:08:29 -0800 Subject: [PATCH 165/317] [Bugfix] V1 Memory Profiling: V0 Sampler Integration without Rejection Sampler (#13594) Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Co-authored-by: Roger Wang --- vllm/v1/worker/gpu_model_runner.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d2e9c2650c7b..000b17c99b21 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,6 +31,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache @@ -1305,11 +1306,34 @@ def profile_run(self) -> None: if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) - # TODO(woosuk): Consider the memory usage of the sampler. + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.5), + all_greedy=False, + all_random=False, + spec_token_ids=None, + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=torch.ones_like(logits, dtype=torch.int64), + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)]) + sampler_output = self.model.sample( + logits=logits, sampling_metadata=dummy_metadata) else: logits = None + sampler_output = None + dummy_metadata = None torch.cuda.synchronize() - del hidden_states, logits + del hidden_states, logits, sampler_output, dummy_metadata self.encoder_cache.clear() gc.collect() From 2bf686db8210fb4524d009d65aa99b6522bfb541 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 22 Feb 2025 16:19:10 +0800 Subject: [PATCH 166/317] Correction to TP logic for Mamba Mixer 2 when Num Groups not divisible by TP Size (#13660) --- .../layers/mamba/mamba_mixer2.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5fd126491023..a6a95c8da7e9 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -133,7 +133,8 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int): if ngroups % tp_size == 0: return 0 - return tp_size - ngroups % tp_size + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups def mamba_v2_sharded_weight_loader( @@ -153,7 +154,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: boundary, loaded_boundary = 0, 0 # - iterate over the shard specs - for full_dim, extra, ratio in shard_spec: + for full_dim, extra, duplicate_groups in shard_spec: # - full dim is the model dim (before TP). # - extra > 0, means there is expected overall increase # of dimensions. This is so because of replication. @@ -167,7 +168,12 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - compute the rank into the loaded shard. # - if there is replication, different TP shards will # take from the same rank. - rank = tp_rank // ratio + if duplicate_groups: + # NOTE: currently we only support duplication + # in the case where num_groups == 1 + rank = 0 + else: + rank = tp_rank # - leftmost boundary index into loaded weight. loaded_skip = rank * shard_size @@ -233,12 +239,21 @@ def __init__(self, # - HOWEVER IF, world_size DOES NOT divide groups, then we need # to allocate extra space in the shard, such that groups # may be replicated to follow the head shard. + # - NOTE: currently for the world size DOES NOT divide groups + # case, we only support the case when n_groups == 1 self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() assert num_heads % self.tp_size == 0, \ "Tensor parallel world size must divide num heads." + + assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ + ( + "If tensor parallel world size does not divide num_heads, " + "then num_groups must equal 1." + ) + self.ssm_state_size = ssm_state_size self.activation = activation @@ -284,11 +299,10 @@ def __init__(self, self.n_groups * self.ssm_state_size, # expected model size (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned - self.num_heads // - n_groups, # ratio for mapping back to original group + n_groups == 1, # if there was only one group ) - intermediate_settings = (intermediate_size, 0, 1) - head_setings = (self.num_heads, 0, 1) + intermediate_settings = (intermediate_size, 0, False) + head_setings = (self.num_heads, 0, False) # - the weight already has a "weight_loader" attribute # which set_weight_attrs will raise if we do not From c4dbd887cb52803112220aa3ff3fa91a0be0f043 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Sat, 22 Feb 2025 08:20:00 +0000 Subject: [PATCH 167/317] [V1][Metrics] Support `vllm:cache_config_info` (#13299) --- tests/entrypoints/openai/test_metrics.py | 1 + vllm/config.py | 6 ++++++ vllm/engine/metrics.py | 5 ++--- vllm/engine/metrics_types.py | 10 ++-------- vllm/v1/metrics/loggers.py | 22 +++++++++++++++++++++- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 45a387a14adf..e0323abe2525 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -230,6 +230,7 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "vllm:prompt_tokens_total", "vllm:generation_tokens_total", "vllm:iteration_tokens_total", + "vllm:cache_config_info", "vllm:request_success_total", "vllm:request_prompt_tokens_sum", "vllm:request_prompt_tokens_bucket", diff --git a/vllm/config.py b/vllm/config.py index d6e197fe988a..dbcacdf4d955 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -88,6 +88,12 @@ def compute_hash(self) -> str: ... +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> Dict[str, str]: + ... + + class ModelImpl(str, enum.Enum): AUTO = "auto" VLLM = "vllm" diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 7c55d66e5077..e8736dffc446 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -8,9 +8,8 @@ import numpy as np import prometheus_client -from vllm.config import VllmConfig -from vllm.engine.metrics_types import (StatLoggerBase, Stats, - SupportsMetricsInfo) +from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.executor.ray_utils import ray from vllm.logger import init_logger diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 7f0c2fa70c3f..9e6d5ef29bed 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -15,9 +15,9 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional, Protocol +from typing import List, Optional -from vllm.config import VllmConfig +from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -70,12 +70,6 @@ class Stats: spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None -class SupportsMetricsInfo(Protocol): - - def metrics_info(self) -> Dict[str, str]: - ... - - class StatLoggerBase(ABC): """Base class for StatLogger.""" diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 5019e2b3f92a..e112a9f36e68 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -7,7 +7,7 @@ import numpy as np import prometheus_client -from vllm.config import VllmConfig +from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason @@ -228,6 +228,26 @@ def __init__(self, vllm_config: VllmConfig): buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) + self.log_metrics_info("cache_config", vllm_config.cache_config) + + def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): + metrics_info = config_obj.metrics_info() + + name, documentation = None, None + if type == "cache_config": + name = "vllm:cache_config_info" + documentation = "Information of the LLMEngine CacheConfig" + assert name is not None, f"Unknown metrics info type {type}" + + # Info type metrics are syntactic sugar for a gauge permanently set to 1 + # Since prometheus multiprocessing mode does not support Info, emulate + # info here with a gauge. + info_gauge = prometheus_client.Gauge( + name=name, + documentation=documentation, + labelnames=metrics_info.keys()).labels(**metrics_info) + info_gauge.set(1) + def log(self, scheduler_stats: SchedulerStats, iteration_stats: IterationStats): """Log to prometheus.""" From 35380d40dbf14a9b11adc915addfa342d353d6f9 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Sat, 22 Feb 2025 08:20:45 +0000 Subject: [PATCH 168/317] [Metrics] Add `--show-hidden-metrics-for-version` CLI arg (#13295) --- docs/source/serving/metrics.md | 8 ++++++++ tests/test_version.py | 36 ++++++++++++++++++++++++++++++++++ vllm/config.py | 4 +++- vllm/engine/arg_utils.py | 20 +++++++++++++++++++ vllm/engine/metrics.py | 5 +++++ vllm/v1/metrics/loggers.py | 5 +++++ vllm/version.py | 18 +++++++++++++++++ 7 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 tests/test_version.py diff --git a/docs/source/serving/metrics.md b/docs/source/serving/metrics.md index 6c0dc8880a90..1d55f201503c 100644 --- a/docs/source/serving/metrics.md +++ b/docs/source/serving/metrics.md @@ -36,3 +36,11 @@ The following metrics are exposed: :language: python :start-after: begin-metrics-definitions ::: + +The following metrics are deprecated and due to be removed in a future version: + +- *(No metrics are currently deprecated)* + +Note: when metrics are deprecated in version `X.Y`, they are hidden in version `X.Y+1` +but can be re-enabled using the `--show-hidden-metrics-for-version=X.Y` escape hatch, +and are then removed in version `X.Y+2`. diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 000000000000..56842b6d409d --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest + +from vllm import version + + +def test_version_is_defined(): + assert version.__version__ is not None + + +def test_version_tuple(): + assert len(version.__version_tuple__) in (3, 4, 5) + + +@pytest.mark.parametrize( + "version_tuple, version_str, expected", + [ + ((0, 0, "dev"), "0.0", True), + ((0, 0, "dev"), "foobar", True), + ((0, 7, 4), "0.6", True), + ((0, 7, 4), "0.5", False), + ((0, 7, 4), "0.7", False), + ((1, 2, 3), "1.1", True), + ((1, 2, 3), "1.0", False), + ((1, 2, 3), "1.2", False), + # This won't work as expected + ((1, 0, 0), "1.-1", True), + ((1, 0, 0), "0.9", False), + ((1, 0, 0), "0.17", False), + ]) +def test_prev_minor_version_was(version_tuple, version_str, expected): + with patch("vllm.version.__version_tuple__", version_tuple): + assert version._prev_minor_version_was(version_str) == expected diff --git a/vllm/config.py b/vllm/config.py index dbcacdf4d955..797697aac12d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2653,7 +2653,9 @@ def __post_init__(self): @dataclass class ObservabilityConfig: - """Configuration for observability.""" + """Configuration for observability - metrics and tracing.""" + show_hidden_metrics: bool = False + otlp_traces_endpoint: Optional[str] = None # Collecting detailed timing information for each request can be expensive. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8b460b33e235..d75e2324f5c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -10,6 +10,7 @@ import torch import vllm.envs as envs +from vllm import version from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, DecodingConfig, DeviceConfig, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, @@ -188,6 +189,7 @@ class EngineArgs: qlora_adapter_name_or_path: Optional[str] = None disable_logprobs_during_spec_decoding: Optional[bool] = None + show_hidden_metrics_for_version: Optional[str] = None otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False @@ -909,6 +911,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, help='Name or path of the QLoRA adapter.') + parser.add_argument('--show-hidden-metrics-for-version', + type=str, + default=None, + help='Enable deprecated Prometheus metrics that ' + 'have been hidden since the specified version. ' + 'For example, if a previously deprecated metric ' + 'has been hidden since the v0.7.0 release, you ' + 'use --show-hidden-metrics-for-version=0.7 as a ' + 'temporary escape hatch while you migrate to new ' + 'metrics. The metric is likely to be removed ' + 'completely in an upcoming release.') + parser.add_argument( '--otlp-traces-endpoint', type=str, @@ -1317,6 +1331,11 @@ def create_engine_config(self, decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) + show_hidden_metrics = False + if self.show_hidden_metrics_for_version is not None: + show_hidden_metrics = version._prev_minor_version_was( + self.show_hidden_metrics_for_version) + detailed_trace_modules = [] if self.collect_detailed_traces is not None: detailed_trace_modules = self.collect_detailed_traces.split(",") @@ -1326,6 +1345,7 @@ def create_engine_config(self, f"Invalid module {m} in collect_detailed_traces. " f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}") observability_config = ObservabilityConfig( + show_hidden_metrics=show_hidden_metrics, otlp_traces_endpoint=self.otlp_traces_endpoint, collect_model_forward_time="model" in detailed_trace_modules or "all" in detailed_trace_modules, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index e8736dffc446..cb3ca7a11881 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -516,6 +516,11 @@ def __init__(self, local_interval: float, labels: Dict[str, str], self.metrics = self._metrics_cls(labelnames=list(labels.keys()), vllm_config=vllm_config) + # Use this flag to hide metrics that were deprecated in + # a previous release and which will be removed future + self.show_hidden_metrics = \ + vllm_config.observability_config.show_hidden_metrics + def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. gauge.labels(**self.labels).set(data) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index e112a9f36e68..e562b4145afc 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -95,6 +95,11 @@ class PrometheusStatLogger(StatLoggerBase): def __init__(self, vllm_config: VllmConfig): self._unregister_vllm_metrics() + # Use this flag to hide metrics that were deprecated in + # a previous release and which will be removed future + self.show_hidden_metrics = \ + vllm_config.observability_config.show_hidden_metrics + labelnames = ["model_name"] labelvalues = [vllm_config.model_config.served_model_name] diff --git a/vllm/version.py b/vllm/version.py index 70cd0289b441..ab5909b101a0 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -11,3 +11,21 @@ __version__ = "dev" __version_tuple__ = (0, 0, __version__) + + +def _prev_minor_version_was(version_str): + """Check whether a given version matches the previous minor version. + + Return True if version_str matches the previous minor version. + + For example - return True if the current version if 0.7.4 and the + supplied version_str is '0.6'. + + Used for --show-hidden-metrics-for-version. + """ + # Match anything if this is a dev tree + if __version_tuple__[0:2] == (0, 0): + return True + + # Note - this won't do the right thing when we release 1.0! + return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" From e4f5b9c27776004f5c06979c3eb2021f1df7348d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 22 Feb 2025 16:21:30 +0800 Subject: [PATCH 169/317] [Misc] Reduce LoRA-related static variable (#13166) --- tests/lora/conftest.py | 17 +++++++-- tests/lora/test_lora_checkpoints.py | 13 ++++--- tests/lora/test_lora_huggingface.py | 7 ++-- tests/lora/test_lora_manager.py | 26 +++++--------- vllm/lora/models.py | 21 +++++------ vllm/lora/utils.py | 26 ++++++++++++++ vllm/lora/worker_manager.py | 8 +++-- vllm/model_executor/models/baichuan.py | 9 ----- vllm/model_executor/models/bamba.py | 6 ---- vllm/model_executor/models/chatglm.py | 10 ------ vllm/model_executor/models/commandr.py | 4 --- vllm/model_executor/models/exaone.py | 8 ----- vllm/model_executor/models/gemma.py | 12 ------- vllm/model_executor/models/gemma2.py | 11 ------ vllm/model_executor/models/glm4v.py | 15 -------- vllm/model_executor/models/gpt_bigcode.py | 5 +-- vllm/model_executor/models/granite.py | 4 --- vllm/model_executor/models/granitemoe.py | 7 ---- vllm/model_executor/models/idefics3.py | 15 -------- vllm/model_executor/models/interfaces.py | 12 +++---- vllm/model_executor/models/internlm2.py | 10 ------ vllm/model_executor/models/jamba.py | 4 --- vllm/model_executor/models/llama.py | 4 --- vllm/model_executor/models/minicpm.py | 8 ----- vllm/model_executor/models/minicpm3.py | 16 --------- vllm/model_executor/models/minicpmv.py | 42 ---------------------- vllm/model_executor/models/mixtral.py | 4 --- vllm/model_executor/models/molmo.py | 20 ----------- vllm/model_executor/models/nemotron.py | 3 -- vllm/model_executor/models/phi.py | 11 ------ vllm/model_executor/models/phimoe.py | 10 ------ vllm/model_executor/models/qwen.py | 9 ----- vllm/model_executor/models/qwen2.py | 20 ----------- vllm/model_executor/models/qwen2_5_vl.py | 21 ----------- vllm/model_executor/models/qwen2_rm.py | 10 ------ vllm/model_executor/models/qwen2_vl.py | 18 ---------- vllm/model_executor/models/qwen_vl.py | 15 -------- vllm/model_executor/models/solar.py | 8 ----- vllm/model_executor/models/transformers.py | 35 ++++++++++++++++++ vllm/model_executor/models/ultravox.py | 8 ----- vllm/worker/hpu_model_runner.py | 3 -- 41 files changed, 120 insertions(+), 395 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 47c89d5fd344..489ffc7d3257 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.platforms import current_platform @@ -98,9 +99,13 @@ def dist_init_torch_only(): backend=backend) +class DummyLoRAModel(nn.Sequential, SupportsLoRA): + pass + + @pytest.fixture def dummy_model() -> nn.Module: - model = nn.Sequential( + model = DummyLoRAModel( OrderedDict([ ("dense1", ColumnParallelLinear(764, 100)), ("dense2", RowParallelLinear(100, 50)), @@ -121,12 +126,13 @@ def dummy_model() -> nn.Module: ("sampler", Sampler()) ])) model.config = MagicMock() + model.embedding_modules = {"lm_head": "lm_head"} return model @pytest.fixture def dummy_model_gate_up() -> nn.Module: - model = nn.Sequential( + model = DummyLoRAModel( OrderedDict([ ("dense1", ColumnParallelLinear(764, 100)), ("dense2", RowParallelLinear(100, 50)), @@ -147,6 +153,13 @@ def dummy_model_gate_up() -> nn.Module: ("sampler", Sampler()) ])) model.config = MagicMock() + model.packed_modules_mapping = { + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + model.embedding_modules = {"lm_head": "lm_head"} return model diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index d2a4b901bd8d..e2c3d20d327f 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -12,6 +12,12 @@ lora_lst = [ "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" ] +BAICHUAN_LORA_MODULES = [ + "W_pack", + "o_proj", + "gate_up_proj", + "down_proj", +] @pytest.mark.parametrize("lora_name", lora_lst) @@ -22,12 +28,11 @@ def test_load_checkpoints( baichuan_regex_lora_files, chatglm3_lora_files, ): - supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules expected_lora_modules: List[str] = [] - for module in supported_lora_modules: + for module in BAICHUAN_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) else: @@ -90,12 +95,12 @@ def test_load_checkpoints( def test_lora_weights_mapping(baichuan_lora_files): - supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules + packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules expected_lora_modules: List[str] = [] - for module in supported_lora_modules: + for module in BAICHUAN_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) else: diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index 273fe9ae0eb5..44d111732d2a 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -11,17 +11,20 @@ # Provide absolute path and huggingface lora ids lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"] +LLAMA_LORA_MODULES = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" +] @pytest.mark.parametrize("lora_fixture_name", lora_fixture_name) def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_name = request.getfixturevalue(lora_fixture_name) - supported_lora_modules = LlamaForCausalLM.supported_lora_modules packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping embedding_modules = LlamaForCausalLM.embedding_modules embed_padding_modules = LlamaForCausalLM.embedding_padding_modules expected_lora_modules: List[str] = [] - for module in supported_lora_modules: + for module in LLAMA_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) else: diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 9fecd11f57af..7ab46b7ff9c9 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -19,7 +19,6 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) -from vllm.model_executor.layers.linear import RowParallelLinear from vllm.platforms import current_platform EMBEDDING_MODULES = { @@ -114,19 +113,16 @@ def create_packed_lora( def test_replace_submodules(dist_init, dummy_model): model = dummy_model - model.supported_lora_modules = ["dense1", "layer1.dense2"] - model.packed_modules_mapping = {} manager = LoRAModelManager( model, 1, 1, 1, LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), torch.device(DEVICES[0])) model = manager.model - assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) assert isinstance(model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA) - assert isinstance(model.get_submodule("dense2"), RowParallelLinear) + assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA) assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA) @@ -134,8 +130,6 @@ def test_replace_submodules(dist_init, dummy_model): @pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model - model.supported_lora_modules = ["dense1", "dense2", "lm_head"] - model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"], device=device) @@ -190,13 +184,18 @@ def test_lora_model_manager(dist_init, dummy_model, device): assert manager.device == device assert manager.punica_wrapper.device == device + assert hasattr(manager, "supported_lora_modules") + assert sorted(manager.supported_lora_modules) == [ + "dense1", + "dense2", + "lm_head", + "output", + ] @pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model - model.supported_lora_modules = ["dense1", "dense2", "lm_head"] - model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"], device=device) @@ -289,8 +288,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager model = dummy_model - model.supported_lora_modules = ["dense1", "dense2", "lm_head"] - model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"], device=device) @@ -572,13 +569,6 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, @pytest.mark.parametrize("device", DEVICES) def test_packed_loras(dist_init, dummy_model_gate_up, device): model = dummy_model_gate_up - model.supported_lora_modules = ["gate_up_proj"] - model.packed_modules_mapping = { - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } model_lora = create_packed_lora( 1, model, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index b7403980d0b0..eb53513a2830 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -26,6 +26,7 @@ from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, + get_supported_lora_modules, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models import SupportsLoRA, supports_multimodal @@ -332,15 +333,15 @@ def __init__( # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} super().__init__(model) - if hasattr(self.model, "supported_lora_modules"): - self.supported_lora_modules = copy.deepcopy( - self.model.supported_lora_modules) - if lora_config.long_lora_scaling_factors: - # We need to replace rotary emb layer to do batch computation - # for long lora. - self.supported_lora_modules.append("rotary_emb") - self.packed_modules_mapping = copy.deepcopy( - self.model.packed_modules_mapping) + self.supported_lora_modules = get_supported_lora_modules(self.model) + assert self.supported_lora_modules, "No supported LoRA modules found in" + f"{self.model.__class__.__name__}." + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") + self.packed_modules_mapping = copy.deepcopy( + self.model.packed_modules_mapping) # Used to indicate whether the model is a multimodal model self.supports_mm: bool = ( supports_multimodal(self.model) @@ -756,7 +757,7 @@ def create_lora_manager( lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" - if not hasattr(model, "supported_lora_modules"): + if not hasattr(model, "packed_modules_mapping"): raise ValueError(f"Model {type(model)} is not supported for LoRA.") lora_manager = lora_manager_cls( model=model, diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index f47b0af15522..361dac5b3313 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -29,6 +29,7 @@ ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) +from vllm.model_executor.layers.linear import LinearBase # yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -68,6 +69,14 @@ def from_layer(layer: nn.Module, ret = lora_cls(layer) ret.create_lora_weights(max_loras, lora_config, model_config) return ret + + # The Case for HFCompatibleLinear + if (hasattr(layer, "get_lora_class") + and layer.__class__.__name__ == "HFCompatibleLinear"): + lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras) + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret return layer @@ -170,6 +179,23 @@ def is_subset(sub_list, full_list): return False +def get_supported_lora_modules(model: nn.Module) -> List[str]: + """ + In vLLM, all linear layers support LoRA. + """ + supported_lora_modules: Set[str] = set() + # step1: traverse the model to get all the linear subfixes. + for name, module in model.named_modules(): + if isinstance(module, (LinearBase, )): + supported_lora_modules.add(name.split(".")[-1]) + # step 2: get the embedding modules if the model's mbedding_modules + # is not empty. + if model.embedding_modules: + for name in model.embedding_modules: + supported_lora_modules.add(name) + return list(supported_lora_modules) + + def get_adapter_absolute_path(lora_path: str) -> str: """ Resolves the given lora_path to an absolute local path. diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index b103acefe4aa..108beb34b244 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -84,9 +84,10 @@ def create_lora_manager( def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - model = self._adapter_manager.model - supported_lora_modules = model.supported_lora_modules - packed_modules_mapping = model.packed_modules_mapping + supported_lora_modules = ( + self._adapter_manager.supported_lora_modules) + packed_modules_mapping = ( + self._adapter_manager.packed_modules_mapping) expected_lora_modules: List[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: @@ -107,6 +108,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # For some models like Qwen2VL, we need to use hf_to_vllm_mapper # to ensure correct loading of lora weights. + model = self._adapter_manager.model hf_to_vllm_mapper = None if (hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None): diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 5dfaa727b75a..b613b70a7564 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -342,15 +342,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - "W_pack", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] def __init__( self, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index b9310108543c..22ae1775c3d9 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -389,12 +389,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 26b4a95c530e..ecf417655452 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -477,16 +477,6 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP): "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index e73627da05d4..0ceefc3e93aa 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -357,11 +357,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens" - ] embedding_modules = {"embed_tokens": "input_embeddings"} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 2eb91a682242..e795c7e288c4 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -415,14 +415,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "out_proj", - "gate_up_proj", - "c_proj", - "wte", - "lm_head", - ] embedding_modules = { "wte": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index cb81aa41e254..d0589e60a72b 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -344,18 +344,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - - # Gemma does not apply LoRA to the embedding layer. - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index a6dc8f84772b..6ee257d65c50 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -390,17 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - # Gemma does not apply LoRA to the embedding layer. - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 40010ec55906..8fc5a797f824 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -534,21 +534,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, "dense_h_to_4h": ["dense_h_to_4h"], "merged_proj": ["gate_proj", "dense_h_to_4h"] } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - # vision - "fc1", - "fc2", - "merged_proj", - "linear_proj" - ] - - embedding_modules = {} - embedding_padding_modules = [] def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 887a444748ae..799edff46ea3 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -261,15 +261,12 @@ def forward( class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} - supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] - + # LoRA specific attributes embedding_modules = { "wte": "input_embeddings", "lm_head": "output_embeddings", } - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 85911a0f41c2..2aeb179ee932 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -351,10 +351,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 8ae661bf15c4..40df9c72c561 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -329,13 +329,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - "layer", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 579253632c81..3a7e2a9a6a57 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -597,21 +597,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - # vision_model - "fc1", - "fc2", - "out_proj", - # text_model - "qkv_proj", # same name with vision encoder - "o_proj", - "gate_up_proj", - "down_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index bd6661d668d9..47bd05f140c8 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -118,11 +118,11 @@ class SupportsLoRA(Protocol): There is no need to redefine this flag if this class is in the MRO of your model class. """ - - packed_modules_mapping: ClassVar[Dict[str, List[str]]] - supported_lora_modules: ClassVar[List[str]] - embedding_modules: ClassVar[Dict[str, str]] - embedding_padding_modules: ClassVar[List[str]] + # The `embedding_module` and `embedding_padding_modules` + # are empty by default. + embedding_modules: ClassVar[Dict[str, str]] = {} + embedding_padding_modules: ClassVar[List[str]] = [] + packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} # We can't use runtime_checkable with ClassVar for issubclass checks @@ -132,7 +132,6 @@ class _SupportsLoRAType(Protocol): supports_lora: Literal[True] packed_modules_mapping: Dict[str, List[str]] - supported_lora_modules: List[str] embedding_modules: Dict[str, str] embedding_padding_modules: List[str] @@ -155,7 +154,6 @@ def supports_lora( if not result: lora_attrs = ( "packed_modules_mapping", - "supported_lora_modules", "embedding_modules", "embedding_padding_modules", ) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index c211ca5f4f8e..b21933dd5da7 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -329,16 +329,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): "gate_up_proj": ["w1", "w3"], } - # LoRA specific attributes - supported_lora_modules = [ - "wqkv", - "wo", - "gate_up_proj", - "w2", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index efc1496d44f0..5530e3ca708c 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -380,10 +380,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj", - "down_proj", "gate_proj", "out_proj", "in_proj", "x_proj" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2ff52dd78912..011d0a7aafaa 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -452,10 +452,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings" diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 29473f5bbaa0..52ab89488785 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -522,14 +522,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 878f0c895c34..b85306c40880 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -227,21 +227,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): ], } - # LoRA specific attributes - supported_lora_modules = [ - "kv_a_proj_with_mqa", - "q_a_proj", - "q_b_proj", - "kv_b_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] - - # `embedding_modules` and `embedding_padding_modules` - # are inherited from MiniCPMForCausalLM - def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 97596f9e82c6..1f278b65740c 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1228,23 +1228,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - # vision encoder - "fc1", - "fc2", - "out_proj", - # language model - "qkv_proj", # same name with vision encoder - "o_proj", - "gate_up_proj", - "down_proj", - # resampler - "kv_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -1338,23 +1321,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - # vision encoder - "fc1", - "fc2", - "out_proj", - # language model - "qkv_proj", # same name with vision encoder - "o_proj", - "gate_up_proj", - "down_proj", - # resampler - "kv_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -1460,13 +1426,6 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): which is not conducive to the current integration logic of LoRA and bitsandbytes in vLLM. Therefore, it is necessary to separate them. """ - # Ensure that the LoRA support check passes when the class is not - # initialized, but set all these attributes to empty. - # These will be updated when an instance class is selected - packed_modules_mapping = {} - supported_lora_modules = [] - embedding_modules = {} - embedding_padding_modules = [] def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -1487,7 +1446,6 @@ def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): # quant_config references base class members, # so update values before init is called cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) - cls.supported_lora_modules += instance_cls.supported_lora_modules cls.embedding_modules.update(instance_cls.embedding_modules) cls.embedding_padding_modules += instance_cls.embedding_padding_modules return instance_cls(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 70880eb75224..b83b69fd2c2d 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -332,10 +332,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3", - "gate" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 1d84d25c96ac..6ce9fbda182f 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1440,26 +1440,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, "merged_linear": ["gate_proj", "up_proj"] # image_projector } - # LoRA specific attributes - supported_lora_modules = [ - # language model - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", # same name with image_projector - # vision tower - "wq", - "wk", - "wv", - "wo", - "w1", - "w2", - # image_projector - "merged_linear", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 6f0b831ac272..a42734edb39a 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -389,9 +389,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "up_proj", "down_proj", "embed_tokens", "lm_head" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 6b05bfee9492..1ca8cad22ad9 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -273,17 +273,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ] } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "dense", - "fc1", - "fc2", - ] - - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index aa4bb52c444f..17369cb58e36 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -526,16 +526,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - "w1", - "w2", - "w3", - "gate", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a45e9463ab67..7c4627036203 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -354,15 +354,6 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA): "w1", ], } - # LoRA specific attributes - supported_lora_modules = [ - "c_attn", - "gate_up_proj", - "c_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e3de6b64fbb3..7da6e558ff33 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -430,16 +430,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -528,16 +518,6 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ff10fcb4315c..ef31f18445fd 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -734,27 +734,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - # language model - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", # Same name with vision encoder - # vision tower - "qkv", - "gate_proj", - "up_proj", - "attn.proj", # Distinguish patch_embed.proj - "fc1", - "fc2", - # projector - "mlp.0", - "mlp.2" - ] - - embedding_modules = {} - embedding_padding_modules = [] # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 00e4159e28cf..c6588a47d881 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -47,16 +47,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 919445267f4a..31701abd3339 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1048,24 +1048,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - # vision tower - "qkv", - "attn.proj", # Distinguish patch_embed.proj - "fc1", - "fc2", - # projector - "mlp.0", - "mlp.2" - ] - embedding_modules = {} - embedding_padding_modules = [] - # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 61a4584abf85..56faa390fc5d 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -667,21 +667,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, "w1", ], } - # LoRA specific attributes - supported_lora_modules = [ - "c_attn", - "gate_up_proj", - "c_proj", - # visual module - "out_proj", - "in_proj", - "c_fc", - # resampler - "kv_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 6215ed814bf4..ad98f3b07034 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -386,14 +386,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 9b456b248952..b431abb76b69 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -27,6 +27,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA) from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -103,6 +108,23 @@ def replace_linear_class( "rowwise": RowParallelLinear, }.get(style, ReplicatedLinear) + lora_linear_cls = { + ColumnParallelLinear: { + True: ColumnParallelLinearWithShardedLoRA, # fully sharded + False: ColumnParallelLinearWithLoRA # not fully sharded + }, + RowParallelLinear: { + True: RowParallelLinearWithShardedLoRA, + False: RowParallelLinearWithLoRA + }, + # ReplicatedLinear doesn't support fully sharded LoRA yet, + # so we use the same class for both cases. + ReplicatedLinear: { + True: ReplicatedLinearWithLoRA, + False: ReplicatedLinearWithLoRA + } + } + class HFCompatibleLinear(vllm_linear_cls): """ Wrapper class that removes `output_bias` from returned output. @@ -111,6 +133,19 @@ class HFCompatibleLinear(vllm_linear_cls): def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input)[0] + @classmethod + def get_lora_class(cls, fully_sharded: bool = False): + """ + Get the LoRA class corresponding to the current transformer + linear class. + + Args: + fully_sharded (bool): If True, select the LoRA class variant + that supports fully sharded LoRA. Defaults to False. + + """ + return lora_linear_cls[vllm_linear_cls][fully_sharded] + return HFCompatibleLinear( input_size=linear.in_features, output_size=linear.out_features, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index e24b4aeb8ae8..b99094e5d4ca 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -360,14 +360,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): "gate_up_proj": ["gate_proj", "up_proj"] } - # LoRA specific attributes - # TODO : Add LoRA to the audio tower and projector. - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj" - ] - embedding_modules = {} - embedding_padding_modules = [] - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index fe7c776d0a23..f22526cfad70 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -650,9 +650,6 @@ def load_model(self) -> None: logger.info(msg) if self.lora_config: - assert hasattr(self.model, "supported_lora_modules" - ) and self.model.supported_lora_modules, ( - "Model does not support LoRA") assert hasattr(self.model, "embedding_modules" ), "Model does not have embedding_modules" assert hasattr( From 439a0ead244c0c6038cbf9ecdc8de023c67e6159 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 22 Feb 2025 16:31:26 +0800 Subject: [PATCH 170/317] [CI/Build] Fix pre-commit errors (#13696) --- benchmarks/benchmark_latency.py | 6 +++--- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/entrypoints/openai/serving_score.py | 7 ++++++- vllm/model_executor/layers/mamba/mamba_mixer2.py | 13 +++++-------- vllm/utils.py | 6 ++++-- vllm/v1/worker/gpu_model_runner.py | 7 +++++-- 6 files changed, 24 insertions(+), 17 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index b1d68ea24694..71ec909cba48 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -43,9 +43,9 @@ def main(args: argparse.Namespace): # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + args.output_len), ( - "Please ensure that max_model_len is greater than" - " the sum of input_len and output_len.") + args.input_len + + args.output_len), ("Please ensure that max_model_len is greater than" + " the sum of input_len and output_len.") sampling_params = SamplingParams( n=args.n, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 05b5f95a5e59..d097bfcfc5ab 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -523,7 +523,7 @@ def _get_decoded_token(logprob: Logprob, return logprob.decoded_token return tokenizer.decode(token_id) - def _is_model_supported(self, model_name) -> bool: + def _is_model_supported(self, model_name: Optional[str]) -> bool: if not model_name: return True return self.models.is_base_model(model_name) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 01e2d3043610..a087a8d9ba0f 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -358,7 +358,12 @@ async def do_rerank( request.truncate_prompt_tokens, ) return self.request_output_to_rerank_response( - final_res_batch, request_id, self._get_model_name(request.model), documents, top_n) + final_res_batch, + request_id, + self._get_model_name(request.model), + documents, + top_n, + ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index a6a95c8da7e9..2bcf50e70713 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -134,7 +134,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int): return 0 # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups + return tp_size - ngroups def mamba_v2_sharded_weight_loader( @@ -168,12 +168,9 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - compute the rank into the loaded shard. # - if there is replication, different TP shards will # take from the same rank. - if duplicate_groups: - # NOTE: currently we only support duplication - # in the case where num_groups == 1 - rank = 0 - else: - rank = tp_rank + # NOTE: currently we only support duplication + # in the case where num_groups == 1 + rank = 0 if duplicate_groups else tp_rank # - leftmost boundary index into loaded weight. loaded_skip = rank * shard_size @@ -247,7 +244,7 @@ def __init__(self, assert num_heads % self.tp_size == 0, \ "Tensor parallel world size must divide num heads." - + assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ ( "If tensor parallel world size does not divide num_heads, " diff --git a/vllm/utils.py b/vllm/utils.py index dcafd5411bbb..25a3bdc6daff 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1198,10 +1198,12 @@ def check_port(self, value): try: value = int(value) except ValueError: - raise argparse.ArgumentTypeError("Port must be an integer") + msg = "Port must be an integer" + raise argparse.ArgumentTypeError(msg) from None if not (1024 <= value <= 65535): - raise argparse.ArgumentTypeError("Port must be between 1024 and 65535") + raise argparse.ArgumentTypeError( + "Port must be between 1024 and 65535") return value diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 000b17c99b21..0d76b1a35c74 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1319,13 +1319,16 @@ def profile_run(self) -> None: generators={}, max_num_logprobs=None, no_penalties=True, - prompt_token_ids=torch.ones_like(logits, dtype=torch.int64), + prompt_token_ids=torch.ones_like(logits, + dtype=torch.int64), frequency_penalties=dummy_tensors(0.1), presence_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1), output_token_ids=[[] for _ in range(num_reqs)], min_tokens={}, - logit_bias=[None for _ in range(num_reqs)]) + logit_bias=[None for _ in range(num_reqs)], + allowed_token_ids_mask=None, + ) sampler_output = self.model.sample( logits=logits, sampling_metadata=dummy_metadata) else: From f1a809e4883bc4726d4ae84befc8c67819f7d30e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 22 Feb 2025 19:28:59 +0800 Subject: [PATCH 171/317] [core] set up data parallel communication (#13591) Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 2 + examples/offline_inference/data_parallel.py | 76 ++++++++++++++++ vllm/config.py | 57 ++++++++++++ .../device_communicators/cuda_communicator.py | 4 +- .../device_communicators/custom_all_reduce.py | 11 ++- vllm/distributed/parallel_state.py | 76 +++++++++++++--- vllm/distributed/utils.py | 91 ++++++++++++++++++- vllm/envs.py | 20 ++++ vllm/forward_context.py | 34 ++++++- vllm/utils.py | 18 ++++ vllm/v1/engine/core.py | 3 + vllm/v1/engine/core_client.py | 14 +++ vllm/v1/engine/llm_engine.py | 26 +++++- vllm/v1/executor/multiproc_executor.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/v1/worker/gpu_worker.py | 3 + vllm/worker/worker_base.py | 5 + 17 files changed, 416 insertions(+), 28 deletions(-) create mode 100644 examples/offline_inference/data_parallel.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 66efe3ed3298..d96f0183bc67 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -134,7 +134,9 @@ steps: - tests/compile/test_basic_correctness - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py + - tests/examples/offline_inference/data_parallel.py commands: + - VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py new file mode 100644 index 000000000000..a9544c8cf8a8 --- /dev/null +++ b/examples/offline_inference/data_parallel.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py +# we need to have a launcher to create multiple data parallel +# ranks. And each rank will create a vLLM instance to process its own prompts. +import os + +from vllm import LLM, SamplingParams +from vllm.utils import get_open_port + + +def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): + os.environ["VLLM_DP_RANK"] = str(dp_rank) + os.environ["VLLM_DP_SIZE"] = str(dp_size) + os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip + os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) + # set devices for each dp_rank + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) * + GPUs_per_dp_rank)) + + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # with DP, each rank should process different prompts. + # usually all the DP ranks process a full dataset, + # and each rank processes a different part of the dataset. + promts_per_rank = len(prompts) // dp_size + start = dp_rank * promts_per_rank + end = start + promts_per_rank + prompts = prompts[start:end] + if len(prompts) == 0: + # if any rank has no prompts to process, + # we need to set a placeholder prompt + prompts = ["Placeholder"] + print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts") + + # Create a sampling params object. + # since we are doing data parallel, every rank can have different + # sampling params. here we set different max_tokens for different + # ranks for demonstration. + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=16 * (dp_rank + 1)) + + # Create an LLM. + llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, enforce_eager=True) + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"DP rank {dp_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") + + +if __name__ == "__main__": + from multiprocessing import Process + dp_size = 2 + GPUs_per_dp_rank = 2 + dp_master_ip = "127.0.0.1" + dp_master_port = get_open_port() + procs = [] + for i in range(dp_size): + proc = Process(target=main, + args=(dp_size, i, dp_master_ip, dp_master_port, + GPUs_per_dp_rank)) + proc.start() + procs.append(proc) + for proc in procs: + proc.join() diff --git a/vllm/config.py b/vllm/config.py index 797697aac12d..ed32a5028790 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,6 +16,7 @@ import torch from pydantic import BaseModel, Field, PrivateAttr +from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig import vllm.envs as envs @@ -1296,6 +1297,11 @@ class ParallelConfig: pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. tensor_parallel_size: int = 1 # Number of tensor parallel groups. + data_parallel_size: int = 1 # Number of data parallel groups. + data_parallel_rank: int = 0 # Rank of the data parallel group. + # IP of the data parallel master. + data_parallel_master_ip: str = "127.0.0.1" + data_parallel_master_port: int = 29500 # Port of the data parallel master. # Maximum number of multiple batches # when load model sequentially. To avoid RAM OOM when using tensor @@ -1329,10 +1335,55 @@ class ParallelConfig: worker_cls: str = "auto" sd_worker_cls: str = "auto" + # world_size is TPxPP, it affects the number of workers we create. world_size: int = field(init=False) + # world_size_across_dp is TPxPPxDP, it is the size of the world + # including data parallelism. + world_size_across_dp: int = field(init=False) rank: int = 0 + def get_next_dp_init_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + increment the port number each time we need to initialize a + new process group related to data parallelism. + """ + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + return answer + + def stateless_init_dp_group(self) -> "ProcessGroup": + from vllm.distributed.utils import ( + stateless_init_torch_distributed_process_group) + + # use gloo since the engine process might not have cuda device + dp_group = stateless_init_torch_distributed_process_group( + self.data_parallel_master_ip, + self.get_next_dp_init_port(), + self.data_parallel_rank, + self.data_parallel_size, + backend="gloo") + + return dp_group + + @staticmethod + def has_unfinished_dp(dp_group: "ProcessGroup", + has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], + dtype=torch.int32, + device="cpu") + # dp rank 0: has_unfinished_seqs=True + # dp rank 1: has_unfinished_seqs=False + # aggregated: has_unfinished_seqs=True + # so this is an OR operation, i.e. MAX in integers + torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) + aggregated_has_unfinished = bool(tensor.item()) + return aggregated_has_unfinished + def compute_hash(self): """ Provide a hash that uniquely identifies all the configs @@ -1350,6 +1401,12 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size + self.data_parallel_size = envs.VLLM_DP_SIZE + self.data_parallel_rank = envs.VLLM_DP_RANK + self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP + self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + self.world_size_across_dp = self.world_size * self.data_parallel_size + ray_only_devices = ["tpu"] from vllm.platforms import current_platform if (current_platform.device_type in ray_only_devices diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index f806f8b39ef9..07c9ff506092 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -16,8 +16,8 @@ def __init__(self, device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) - if "pp" in unique_name: - # pipeline parallel does not need custom allreduce + if "tp" not in unique_name: + # only tp uses custom allreduce use_custom_allreduce = False else: from vllm.distributed.parallel_state import ( diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a2614ed5d0bd..90f7f2d0f982 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -87,6 +87,7 @@ def __init__(self, return rank = dist.get_rank(group=self.group) + self.rank = rank world_size = dist.get_world_size(group=self.group) if world_size == 1: # No need to initialize custom allreduce for single GPU case. @@ -201,8 +202,10 @@ def create_shared_buffer( @staticmethod def free_shared_buffer(pointers: List[int], - group: Optional[ProcessGroup] = None) -> None: - rank = dist.get_rank(group=group) + group: Optional[ProcessGroup] = None, + rank: Optional[int] = None) -> None: + if rank is None: + rank = dist.get_rank(group=group) lib = CudaRTLibrary() lib.cudaFree(ctypes.c_void_p(pointers[rank])) @@ -298,8 +301,8 @@ def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs) - self.free_shared_buffer(self.buffer_ptrs) + self.free_shared_buffer(self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) def __del__(self): self.close() diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 781f870a756c..83484cd73550 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -750,6 +750,13 @@ def get_tp_group() -> GroupCoordinator: _PP: Optional[GroupCoordinator] = None +_DP: Optional[GroupCoordinator] = None + + +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, ("data parallel group is not initialized") + return _DP + def get_pp_group() -> GroupCoordinator: assert _PP is not None, ( @@ -811,6 +818,21 @@ def init_distributed_environment( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend) + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None and config.parallel_config.data_parallel_size > 1: + parallel_config = config.parallel_config + # adjust to take into account data parallelism + # offset the rank by the data parallel rank + rank = parallel_config.data_parallel_rank * world_size + rank + # adjust the world size to take into account data parallelism + world_size = parallel_config.world_size_across_dp + ip = parallel_config.data_parallel_master_ip + port = parallel_config.get_next_dp_init_port() + distributed_init_method = f"tcp://{ip}:{port}" # noqa + logger.info( + "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", + world_size, rank, distributed_init_method) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " @@ -870,20 +892,28 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() backend = backend or torch.distributed.get_backend( get_world_group().device_group) + data_parallel_size = 1 + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + data_parallel_size = config.parallel_config.data_parallel_size + + # the layout order is: DP x PP x TP + # to get group_ranks for each dimension, transpose that dimension to the + # last dimension, then reshape to 2D, then unbind the last dimension + all_ranks = torch.arange(world_size).reshape( + data_parallel_size, pipeline_model_parallel_size, + tensor_model_parallel_size) # noqa + # Build the tensor model-parallel groups. - num_tensor_model_parallel_groups: int = (world_size // - tensor_model_parallel_size) global _TP assert _TP is None, ("tensor model parallel group is already initialized") - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = list( - range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size)) - group_ranks.append(ranks) + group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, @@ -893,20 +923,33 @@ def initialize_model_parallel( group_name="tp") # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = (world_size // - pipeline_model_parallel_size) global _PP assert _PP is None, ( "pipeline model parallel group is already initialized") - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) + group_ranks = all_ranks.transpose(1, 2).reshape( + -1, pipeline_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="pp") + global _DP + assert _DP is None, ("data parallel group is already initialized") + group_ranks = all_ranks.transpose(0, + 2).reshape(-1, + data_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="dp") + + logger.info( + "rank %s in world size %s is assigned as " + "DP rank %s, PP rank %s, TP rank %s", rank, world_size, + _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: """ @@ -1011,6 +1054,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DP + if _DP: + _DP.destroy() + _DP = None + def destroy_distributed_environment(): global _WORLD diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 84f8c0a8e51c..79f9a84b476f 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -11,7 +11,11 @@ from typing import Any, Deque, Dict, Optional, Sequence, Tuple import torch -from torch.distributed import TCPStore +from torch.distributed import ProcessGroup, TCPStore +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + is_nccl_available) +from torch.distributed.rendezvous import rendezvous import vllm.envs as envs from vllm.logger import init_logger @@ -227,3 +231,88 @@ def create( world_size=world_size, store=store, data_expiration_seconds=data_expiration_seconds) + + +def stateless_init_torch_distributed_process_group( + host: str, port: int, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """ + A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. The created ProcessGroup object can be used for + some operations such as `allreduce`, because it does not depend on the + global rank. However, some operations such as `broadcast` cannot be used + because it depends on the global rank. + + # TODO: ask for help from PyTorch team if we need the `broadcast` operation. + + This function is useful when we are not sure about the total number of + processes in the process group. For example, we may have process + 1, 2, ..., 8 who want to communicate, and process 9 might be the same + process as process 1, or it might be a different process; process 10 + might be the same process as process 5, or it might be a different process. + In this case, how can we reliably form a communication channel within + process 9 and 10, without affecting the communication channel within + process 1, 2, ..., 8? + + One possible solution is to figure out if process 9 and 10 are the same + as process 1 and 5 beforehand, and then form a communication channel + based on the information, adjusting the ranks and world_size etc. However, + figuring out the information is not always easy, and it will interfere + with the main communication channel. + + Our solution is to always form a communication channel with process 1, 2, + ..., 8, and then use this function to form another communication channel + with process 9 and 10. This way, regardless of whether process 9 and 10 + are the same as process 1 and 5, the main communication channel is + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ + init_method = f"tcp://{host}:{port}" + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + pg_options = ProcessGroup.Options(backend=backend, timeout=timeout) + + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + pg_options, + ) + + if backend == "gloo": + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + elif backend == "nccl": + assert is_nccl_available() + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + + return pg diff --git a/vllm/envs.py b/vllm/envs.py index 45547416314f..1eb9b9f1bbf5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -90,6 +90,10 @@ VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True + VLLM_DP_RANK: int = 0 + VLLM_DP_SIZE: int = 1 + VLLM_DP_MASTER_IP: str = "" + VLLM_DP_MASTER_PORT: int = 0 def get_default_cache_root(): @@ -593,6 +597,22 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH": lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in ("1", "true"), + + # Rank of the process in the data parallel setting + "VLLM_DP_RANK": + lambda: int(os.getenv("VLLM_DP_RANK", "0")), + + # World size of the data parallel setting + "VLLM_DP_SIZE": + lambda: int(os.getenv("VLLM_DP_SIZE", "1")), + + # IP address of the master node in the data parallel setting + "VLLM_DP_MASTER_IP": + lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), + + # Port of the master node in the data parallel setting + "VLLM_DP_MASTER_PORT": + lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), } # end-env-vars-definition diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 10de8bc593ab..b91816af1b6d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,9 +4,10 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch +import torch.distributed as dist import vllm.envs as envs from vllm.config import VllmConfig @@ -32,6 +33,8 @@ class ForwardContext: attn_metadata: "AttentionMetadata" # set dynamically for each forward pass # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass + num_tokens_across_dp: Optional[ + List[int]] = None # set dynamically for each forward pass _forward_context: Optional[ForwardContext] = None @@ -48,7 +51,8 @@ def get_forward_context() -> ForwardContext: @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, - virtual_engine: int = 0): + virtual_engine: int = 0, + num_tokens: int = 0): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -57,12 +61,36 @@ def set_forward_context(attn_metadata: Any, need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() + num_tokens_across_dp = None + if vllm_config.parallel_config.data_parallel_size > 1: + dp_size = vllm_config.parallel_config.data_parallel_size + dp_rank = vllm_config.parallel_config.data_parallel_rank + if attn_metadata is not None: + if hasattr(attn_metadata, "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + # for v1 attention backends + batchsize = attn_metadata.num_input_tokens + else: + batchsize = num_tokens + num_tokens_across_dp = [0] * dp_size + num_tokens_across_dp[dp_rank] = batchsize + num_tokens_tensor = torch.tensor(num_tokens_across_dp, + device="cpu", + dtype=torch.int32) + from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + num_tokens_across_dp = num_tokens_tensor.tolist() + global _forward_context prev_context = _forward_context _forward_context = ForwardContext( attn_layers=vllm_config.compilation_config.static_forward_context, virtual_engine=virtual_engine, - attn_metadata=attn_metadata) + attn_metadata=attn_metadata, + num_tokens_across_dp=num_tokens_across_dp) try: yield finally: diff --git a/vllm/utils.py b/vllm/utils.py index 25a3bdc6daff..7d24154927b8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -501,6 +501,24 @@ def get_open_zmq_ipc_path() -> str: def get_open_port() -> int: + """ + Get an open port for the vLLM process to listen on. + An edge case to handle, is when we run data parallel, + we need to avoid ports that are potentially used by + the data parallel master process. + Right now we reserve 10 ports for the data parallel master + process. Currently it uses 2 ports. + """ + if "VLLM_DP_MASTER_PORT" in os.environ: + dp_port = envs.VLLM_DP_MASTER_PORT + while True: + port = _get_open_port() + if port >= dp_port and port < dp_port + 10: + continue + return port + return _get_open_port() + +def _get_open_port() -> int: port = envs.VLLM_PORT if port is not None: while True: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 03825d6ea430..981d23237e2a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -219,6 +219,9 @@ def sleep(self, level: int = 1): def wake_up(self): self.model_executor.wake_up() + def execute_dummy_batch(self): + self.model_executor.collective_rpc("execute_dummy_batch") + def add_lora(self, lora_request: LoRARequest) -> None: self.model_executor.add_lora(lora_request) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 43ba7583c662..e898a872c62b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -87,6 +87,12 @@ def sleep(self, level: int = 1) -> None: def wake_up(self) -> None: raise NotImplementedError + def execute_dummy_batch(self) -> None: + raise NotImplementedError + + async def execute_dummy_batch_async(self) -> None: + raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -156,6 +162,9 @@ def sleep(self, level: int = 1) -> None: def wake_up(self) -> None: self.engine_core.wake_up() + def execute_dummy_batch(self) -> None: + self.engine_core.execute_dummy_batch() + def add_lora(self, lora_request: LoRARequest) -> None: self.engine_core.add_lora(lora_request) @@ -331,6 +340,8 @@ def sleep(self, level: int = 1) -> None: def wake_up(self) -> None: self._call_utility("wake_up") + def execute_dummy_batch(self) -> None: + self._call_utility("execute_dummy_batch") class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -414,5 +425,8 @@ async def sleep_async(self, level: int = 1) -> None: async def wake_up_async(self) -> None: await self._call_utility_async("wake_up") + async def execute_dummy_batch_async(self) -> None: + await self._call_utility_async("execute_dummy_batch") + async def add_lora_async(self, lora_request: LoRARequest) -> None: await self._call_utility_async("add_lora", lora_request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 6b7de4deed39..04c7ee109e0b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -4,7 +4,7 @@ from typing_extensions import TypeVar -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING @@ -47,6 +47,13 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + # important: init dp group before init the engine_core + self.parallel_config = vllm_config.parallel_config + self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa + self.should_execute_dummy_batch = False + if self.dp_enabled: + self.dp_group = self.parallel_config.stateless_init_dp_group() + # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, @@ -106,7 +113,17 @@ def get_num_unfinished_requests(self) -> int: return self.output_processor.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: - return self.output_processor.has_unfinished_requests() + has_unfinished = self.output_processor.has_unfinished_requests() + if not self.dp_enabled: + return has_unfinished + return self.has_unfinished_requests_dp(has_unfinished) + + def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: + aggregated_has_unfinished = ParallelConfig.has_unfinished_dp( + self.dp_group, has_unfinished) + if not has_unfinished and aggregated_has_unfinished: + self.should_execute_dummy_batch = True + return aggregated_has_unfinished @classmethod def validate_outputs(cls, outputs, output_type): @@ -145,6 +162,11 @@ def add_request( def step(self) -> List[RequestOutput]: + if self.should_execute_dummy_batch: + self.should_execute_dummy_batch = False + self.engine_core.execute_dummy_batch() + return [] + # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e3f07172d8cd..14492f273ed3 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -239,7 +239,7 @@ def __init__( ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send(payload) - self.worker.init_device() + wrapper.init_device() self.worker.load_model() @staticmethod diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0d76b1a35c74..f002cbfccd40 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1167,7 +1167,7 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, self.vllm_config): + with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): hidden_states = model( input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 10154a752393..ece0fa555342 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -235,6 +235,9 @@ def profile(self, is_start: bool = True): else: self.profiler.stop() + def execute_dummy_batch(self) -> None: + self.model_runner._dummy_run(1) + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 190429074d56..44c26ed350a8 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -567,6 +567,11 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: self.worker = worker_class(**kwargs) assert self.worker is not None + def init_device(self): + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during device initialization + self.worker.init_device() # type: ignore + def execute_method(self, method: Union[str, bytes], *args, **kwargs): try: target = self if self.worker is None else self.worker From 4392f1419e19acf1e5610479bca1b34559716562 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 22 Feb 2025 20:28:59 +0800 Subject: [PATCH 172/317] [ci] fix linter (#13701) Signed-off-by: youkaichao --- examples/offline_inference/data_parallel.py | 9 +++++---- vllm/config.py | 2 +- vllm/utils.py | 1 + vllm/v1/engine/core_client.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 3 ++- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index a9544c8cf8a8..2e1fa50e2ab3 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -48,15 +48,16 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): max_tokens=16 * (dp_rank + 1)) # Create an LLM. - llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, enforce_eager=True) + llm = LLM(model="facebook/opt-125m", + tensor_parallel_size=2, + enforce_eager=True) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print( - f"DP rank {dp_rank}, Prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print(f"DP rank {dp_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") if __name__ == "__main__": diff --git a/vllm/config.py b/vllm/config.py index ed32a5028790..d3139b5fd84e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1372,7 +1372,7 @@ def stateless_init_dp_group(self) -> "ProcessGroup": @staticmethod def has_unfinished_dp(dp_group: "ProcessGroup", - has_unfinished: bool) -> bool: + has_unfinished: bool) -> bool: tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu") diff --git a/vllm/utils.py b/vllm/utils.py index 7d24154927b8..675edc3620b5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -518,6 +518,7 @@ def get_open_port() -> int: return port return _get_open_port() + def _get_open_port() -> int: port = envs.VLLM_PORT if port is not None: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index e898a872c62b..527aa72833ba 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -89,7 +89,7 @@ def wake_up(self) -> None: def execute_dummy_batch(self) -> None: raise NotImplementedError - + async def execute_dummy_batch_async(self) -> None: raise NotImplementedError @@ -343,6 +343,7 @@ def wake_up(self) -> None: def execute_dummy_batch(self) -> None: self._call_utility("execute_dummy_batch") + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f002cbfccd40..a7b9d4781183 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1167,7 +1167,8 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): hidden_states = model( input_ids=input_ids, positions=positions, From f60573412fd92e41d158a17f2715e6403fae095d Mon Sep 17 00:00:00 2001 From: Keyun Tong Date: Sat, 22 Feb 2025 05:17:44 -0800 Subject: [PATCH 173/317] Support SSL Key Rotation in HTTP Server (#13495) --- requirements-common.txt | 3 +- tests/entrypoints/test_ssl_cert_refresher.py | 72 +++++++++++++++++++ vllm/entrypoints/api_server.py | 6 ++ vllm/entrypoints/launcher.py | 14 +++- vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/cli_args.py | 5 ++ vllm/entrypoints/ssl.py | 74 ++++++++++++++++++++ 7 files changed, 173 insertions(+), 2 deletions(-) create mode 100644 tests/entrypoints/test_ssl_cert_refresher.py create mode 100644 vllm/entrypoints/ssl.py diff --git a/requirements-common.txt b/requirements-common.txt index f72aa40fccec..c0df136f500e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,7 +20,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines == 0.1.11 -lark == 1.2.2 +lark == 1.2.2 xgrammar == 0.1.11; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 @@ -37,3 +37,4 @@ einops # Required for Qwen2-VL. compressed-tensors == 0.9.2 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py +watchfiles # required for http server to monitor the updates of TLS files diff --git a/tests/entrypoints/test_ssl_cert_refresher.py b/tests/entrypoints/test_ssl_cert_refresher.py new file mode 100644 index 000000000000..23ce7a679f3e --- /dev/null +++ b/tests/entrypoints/test_ssl_cert_refresher.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import tempfile +from pathlib import Path +from ssl import SSLContext + +import pytest + +from vllm.entrypoints.ssl import SSLCertRefresher + + +class MockSSLContext(SSLContext): + + def __init__(self): + self.load_cert_chain_count = 0 + self.load_ca_count = 0 + + def load_cert_chain( + self, + certfile, + keyfile=None, + password=None, + ): + self.load_cert_chain_count += 1 + + def load_verify_locations( + self, + cafile=None, + capath=None, + cadata=None, + ): + self.load_ca_count += 1 + + +def create_file() -> str: + with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f: + return f.name + + +def touch_file(path: str) -> None: + Path(path).touch() + + +@pytest.mark.asyncio +async def test_ssl_refresher(): + ssl_context = MockSSLContext() + key_path = create_file() + cert_path = create_file() + ca_path = create_file() + ssl_refresher = SSLCertRefresher(ssl_context, key_path, cert_path, ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 0 + assert ssl_context.load_ca_count == 0 + + touch_file(key_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 1 + assert ssl_context.load_ca_count == 0 + + touch_file(cert_path) + touch_file(ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 2 + assert ssl_context.load_ca_count == 1 + + ssl_refresher.stop() + + touch_file(cert_path) + touch_file(ca_path) + await asyncio.sleep(1) + assert ssl_context.load_cert_chain_count == 2 + assert ssl_context.load_ca_count == 1 diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 11ffc4f67cea..28b8c847c0fd 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -128,6 +128,7 @@ async def run_server(args: Namespace, shutdown_task = await serve_http( app, sock=None, + enable_ssl_refresh=args.enable_ssl_refresh, host=args.host, port=args.port, log_level=args.log_level, @@ -152,6 +153,11 @@ async def run_server(args: Namespace, type=str, default=None, help="The CA certificates file") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") parser.add_argument( "--ssl-cert-reqs", type=int, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 79946a498dad..b09ee526f14a 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,13 +12,16 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, sock: Optional[socket.socket], +async def serve_http(app: FastAPI, + sock: Optional[socket.socket], + enable_ssl_refresh: bool = False, **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: @@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) config = uvicorn.Config(app, **uvicorn_kwargs) + config.load() server = uvicorn.Server(config) _add_shutdown_handlers(app, server) @@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], server_task = loop.create_task( server.serve(sockets=[sock] if sock else None)) + ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher( + ssl_context=config.ssl, + key_path=config.ssl_keyfile, + cert_path=config.ssl_certfile, + ca_path=config.ssl_ca_certs) + def signal_handler() -> None: # prevents the uvicorn signal handler to exit early server_task.cancel() + if ssl_cert_refresher: + ssl_cert_refresher.stop() async def dummy_shutdown() -> None: pass diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d037a4e63484..73061995572b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -960,6 +960,7 @@ def _listen_addr(a: str) -> str: shutdown_task = await serve_http( app, sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 3054958f3c8a..ba953c219708 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=nullable_str, default=None, help="The CA certificates file.") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") parser.add_argument( "--ssl-cert-reqs", type=int, diff --git a/vllm/entrypoints/ssl.py b/vllm/entrypoints/ssl.py new file mode 100644 index 000000000000..dba916b8bf13 --- /dev/null +++ b/vllm/entrypoints/ssl.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from ssl import SSLContext +from typing import Callable, Optional + +from watchfiles import Change, awatch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SSLCertRefresher: + """A class that monitors SSL certificate files and + reloads them when they change. + """ + + def __init__(self, + ssl_context: SSLContext, + key_path: Optional[str] = None, + cert_path: Optional[str] = None, + ca_path: Optional[str] = None) -> None: + self.ssl = ssl_context + self.key_path = key_path + self.cert_path = cert_path + self.ca_path = ca_path + + # Setup certification chain watcher + def update_ssl_cert_chain(change: Change, file_path: str) -> None: + logger.info("Reloading SSL certificate chain") + assert self.key_path and self.cert_path + self.ssl.load_cert_chain(self.cert_path, self.key_path) + + self.watch_ssl_cert_task = None + if self.key_path and self.cert_path: + self.watch_ssl_cert_task = asyncio.create_task( + self._watch_files([self.key_path, self.cert_path], + update_ssl_cert_chain)) + + # Setup CA files watcher + def update_ssl_ca(change: Change, file_path: str) -> None: + logger.info("Reloading SSL CA certificates") + assert self.ca_path + self.ssl.load_verify_locations(self.ca_path) + + self.watch_ssl_ca_task = None + if self.ca_path: + self.watch_ssl_ca_task = asyncio.create_task( + self._watch_files([self.ca_path], update_ssl_ca)) + + async def _watch_files(self, paths, fun: Callable[[Change, str], + None]) -> None: + """Watch multiple file paths asynchronously.""" + logger.info("SSLCertRefresher monitors files: %s", paths) + async for changes in awatch(*paths): + try: + for change, file_path in changes: + logger.info("File change detected: %s - %s", change.name, + file_path) + fun(change, file_path) + except Exception as e: + logger.error( + "SSLCertRefresher failed taking action on file change. " + "Error: %s", e) + + def stop(self) -> None: + """Stop watching files.""" + if self.watch_ssl_cert_task: + self.watch_ssl_cert_task.cancel() + self.watch_ssl_cert_task = None + if self.watch_ssl_ca_task: + self.watch_ssl_ca_task.cancel() + self.watch_ssl_ca_task = None From b296490b7e54cfded7c57f5fe64608cf3b894997 Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Sat, 22 Feb 2025 05:24:05 -0800 Subject: [PATCH 174/317] [NVIDIA] Support nvfp4 cutlass gemm (#13571) --- CMakeLists.txt | 4 +- csrc/ops.h | 5 + .../quantization/fp4/nvfp4_scaled_mm_entry.cu | 37 +++ .../fp4/nvfp4_scaled_mm_kernels.cu | 280 ++++++++++++++++++ csrc/torch_bindings.cpp | 7 + tests/kernels/test_nvfp4_scaled_mm.py | 150 ++++++++++ vllm/_custom_ops.py | 12 + 7 files changed, 494 insertions(+), 1 deletion(-) create mode 100644 csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu create mode 100644 csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu create mode 100644 tests/kernels/test_nvfp4_scaled_mm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index cd1c2c9015da..4b569ec25f12 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -229,7 +229,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. # Please keep this in sync with FetchContent_Declare line below. - set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -267,6 +267,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp") @@ -383,6 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" diff --git a/csrc/ops.h b/csrc/ops.h index 52ccf3b51f1e..13fbbe41286d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -152,6 +152,11 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); #ifndef USE_ROCM +void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha); + bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu new file mode 100644 index 000000000000..a0852c5732ee --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha); +#endif + +void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should " + "be compiled using CUDA 12.8 and target " + "compute capability 100 or above."); +} diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu new file mode 100644 index 000000000000..26fd91217dbd --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +// Kernel Perf config +template +struct KernelTraits; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; +}; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_4, _4, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_4, _4, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +template +struct Fp4GemmSm100 { + // A matrix configuration + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + // B matrix configuration + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + // C/D matrix configuration + using ElementD = T; + using ElementC = T; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Kernel functional config + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + // Kernel Perf config + using MmaTileShape = typename KernelTraits::MmaTileShape; + using ClusterShape = typename KernelTraits::ClusterShape; + using PerSmTileShape_MNK = typename KernelTraits::PerSmTileShape_MNK; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD, + LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, + LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); +}; + +template +typename T::Gemm::Arguments args_from_options( + at::Tensor& D, at::Tensor const& A, at::Tensor const& B, + at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, + int64_t M, int64_t N, int64_t K) { + using ElementA = typename T::Gemm::ElementA; + using ElementB = typename T::Gemm::ElementB; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementD = typename T::Gemm::ElementD; + using ElementCompute = float; + using StrideA = typename T::StrideA; + using StrideB = typename T::StrideB; + using StrideD = typename T::StrideD; + using Sm100BlkScaledConfig = + typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(m, n, k, 1)); + + typename T::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), stride_A, + static_cast(B.data_ptr()), stride_B, + static_cast(A_sf.data_ptr()), layout_SFA, + static_cast(B_sf.data_ptr()), layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + return arguments; +} + +template +void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, + at::Tensor const& A_sf, at::Tensor const& B_sf, + at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, + cudaStream_t stream) { + typename Fp4GemmSm100::Gemm gemm; + + auto arguments = + args_from_options>(D, A, B, A_sf, B_sf, alpha, m, n, k); + + size_t workspace_size = Fp4GemmSm100::Gemm::get_workspace_size(arguments); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} +#else +template +void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, + at::Tensor const& A_sf, at::Tensor const& B_sf, + at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, + cudaStream_t stream) { + TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " + "a CUTLASS 3.8 source directory to enable support."); +} +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +#define CHECK_TYPE(x, st, m) \ + TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha) { + CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); + + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); + + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + TORCH_CHECK(A.sizes()[1] == B.sizes()[1], + "a and b shapes cannot be multiplied (", A.sizes()[0], "x", + A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); + + auto const m = A.sizes()[0]; + auto const n = B.sizes()[0]; + auto const k = A.sizes()[1] * 2; + + constexpr int alignment = 32; + TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, + ", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], + "), k: ", k, "."); + TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, + ", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); + + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; + int rounded_m = round_up(m, 128); + int rounded_n = round_up(n, 128); + // Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an + // integer. + int rounded_k = round_up(k / 16, 4); + + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], + "x", B_sf.sizes()[1], ")"); + TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", rounded_m, + "x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", + A_sf.sizes()[1], ")"); + TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", rounded_n, + "x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", + B_sf.sizes()[1], ")"); + + auto out_dtype = D.dtype(); + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + + if (out_dtype == at::ScalarType::Half) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (out_dtype == at::ScalarType::BFloat16) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (out_dtype == at::ScalarType::Float) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); + } +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d2aecba442b4..72de2035d0c1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -302,6 +302,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "SymInt size_k) -> Tensor"); // conditionally compiled so impl registration is in source file + // CUTLASS nvfp4 block scaled GEMM + ops.def( + "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," + " Tensor block_scale_a, Tensor block_scale_b," + " Tensor alpha) -> ()"); + ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( diff --git a/tests/kernels/test_nvfp4_scaled_mm.py b/tests/kernels/test_nvfp4_scaled_mm.py new file mode 100644 index 000000000000..b08026c5867d --- /dev/null +++ b/tests/kernels/test_nvfp4_scaled_mm.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ['cuda:0'] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +kE2M1ToFloatArray = [ + 0., + 0.5, + 1., + 1.5, + 2., + 3., + 4., + 6., +] + + +def e2m1_to_fp32(int4_value): + signBit = (int4_value & 0x8) + int4_absValue = int4_value & 0x7 + float_result = kE2M1ToFloatArray[int4_absValue] + if (signBit): + float_result = -float_result + return float_result + + +def break_fp4_bytes(a, dtype): + assert (a.dtype == torch.uint8) + m, n = a.shape + a = a.flatten() + # Get upper 4 bits + highHalfByte = (a & 0xF0) >> 4 + # Get lower 4 bits + lowHalfByte = a & 0x0F + fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device) + fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device) + # [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC] + out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2) + return out + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + sf_m, sf_k = a_sf_swizzled.shape + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out + + +def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, + m, n, dtype, block_size, device): + _, m_k = a_fp4.shape + _, n_k = b_fp4.shape + assert (m_k == n_k) + a_in_dtype = dequantize_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_nvfp4_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + m, n, packed_k = shape + k = packed_k * 2 + block_size = 16 + a_dtype = torch.randn((m, k), dtype=dtype, device=device) + b_dtype = torch.randn((n, k), dtype=dtype, device=device) + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) + b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) + alpha = 1. / (a_global_scale * b_global_scale) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) + b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + + expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved, + b_scale_interleaved, a_global_scale, + b_global_scale, m, n, dtype, block_size, + device) + out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, + b_scale_interleaved, alpha, dtype) + + torch.testing.assert_close(out, + expected_out.to(dtype=dtype), + atol=1e-1, + rtol=1e-1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2112af1201f3..3306610ad800 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -433,6 +433,18 @@ def _ggml_mul_mat_a8_fake( # cutlass +def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, alpha: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + m, n = a.shape[0], b.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, + alpha) + return out + + def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) From 3ffae46bf96b6f9c996750dce170a90607dff592 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Sat, 22 Feb 2025 05:25:41 -0800 Subject: [PATCH 175/317] [V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095) --- tests/kernels/test_prefix_prefill.py | 8 ++------ vllm/attention/backends/rocm_flash_attn.py | 1 - vllm/attention/backends/xformers.py | 1 - vllm/attention/ops/paged_attn.py | 4 +--- vllm/attention/ops/prefix_prefill.py | 17 +++++++++-------- vllm/v1/attention/backends/rocm_attn.py | 12 ------------ 6 files changed, 12 insertions(+), 31 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 2184c98525fe..c3ac6a37e717 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -100,7 +100,7 @@ def test_contexted_kv_attention( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -154,7 +154,6 @@ def test_contexted_kv_attention( block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -171,7 +170,6 @@ def test_contexted_kv_attention( block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -333,7 +331,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -387,7 +385,6 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -404,7 +401,6 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e1a8d3d33613..1b1f6ca9beed 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -753,7 +753,6 @@ def forward( prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 723a4558d0b3..ec8e1f2ee5a6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -580,7 +580,6 @@ def forward( prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2c60bd0c38d6..fd703413db90 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -202,7 +202,6 @@ def forward_prefix( block_tables: torch.Tensor, query_start_loc: torch.Tensor, seq_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], @@ -220,9 +219,8 @@ def forward_prefix( value_cache, block_tables, # query_start_loc is (batch_size + 1,) - query_start_loc[:-1], + query_start_loc, seq_lens_tensor, - context_lens, max_query_len, k_scale, v_scale, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 362c46a95f32..103c408ebbf4 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -31,7 +31,6 @@ def _fwd_kernel( v_scale, B_Start_Loc, B_Seqlen, - B_Ctxlen, block_size, x, Out, @@ -72,10 +71,12 @@ def _fwd_kernel( cur_kv_head = cur_head // num_queries_per_kv - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len # start position inside of the query # generally, N goes over kv, while M goes over query_len @@ -466,7 +467,6 @@ def _fwd_kernel_alibi( v_scale, B_Start_Loc, B_Seqlen, - B_Ctxlen, Alibi_slopes, block_size, x, @@ -511,9 +511,12 @@ def _fwd_kernel_alibi( # cur_batch_seq_len: the length of prompts # cur_batch_ctx_len: the length of prefix # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len block_start_loc = BLOCK_M * start_m @@ -713,7 +716,6 @@ def context_attention_fwd(q, b_loc, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -765,6 +767,7 @@ def context_attention_fwd(q, batch, head = b_seq_len.shape[0], q.shape[1] num_queries_per_kv = q.shape[1] // k.shape[1] + assert batch + 1 == len(b_start_loc) grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, # 0 means "disable" @@ -784,7 +787,6 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, - b_ctx_len, alibi_slopes, v_cache.shape[3], k_cache.shape[4], @@ -838,7 +840,6 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, - b_ctx_len, v_cache.shape[3], k_cache.shape[4], o, diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 5f3eb37514d8..0f3fabf05fc2 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -150,17 +150,6 @@ def forward( layer._v_scale, ) - # TODO(sage): Refactor the context_attention_fwd kernel so that this - # overhead can be removed - context_lens = torch.empty_like(attn_metadata.seq_lens) - batch_size = len(attn_metadata.query_start_loc) - 1 - assert len(context_lens) == batch_size - for i in range(batch_size): - query_start = attn_metadata.query_start_loc[i] - query_end = attn_metadata.query_start_loc[i + 1] - context_lens[i] = attn_metadata.seq_lens[i] - (query_end - - query_start) - # Compute attention and update output up to `num_actual_tokens`. context_attention_fwd(q=query[:num_actual_tokens], k=key[:num_actual_tokens], @@ -172,7 +161,6 @@ def forward( b_loc=attn_metadata.block_table, b_start_loc=attn_metadata.query_start_loc, b_seq_len=attn_metadata.seq_lens, - b_ctx_len=context_lens, max_input_len=attn_metadata.max_query_len, k_scale=layer._k_scale, v_scale=layer._v_scale, From 762684495cdc84cfb48772d8725c09fda369dfec Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Sat, 22 Feb 2025 08:54:38 -0500 Subject: [PATCH 176/317] [ROCm] Apply FP8 weights padding to values not divisible by 512 bytes on ROCm (#13231) --- vllm/envs.py | 4 ++++ vllm/model_executor/layers/quantization/fp8.py | 15 +++++++++++++++ .../layers/quantization/utils/fp8_utils.py | 2 +- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index 1eb9b9f1bbf5..1104f108784f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -74,6 +74,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False + VLLM_ROCM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -507,6 +508,9 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + # Pad the fp8 weights to 256 bytes for ROCm + "VLLM_ROCM_FP8_PADDING": + lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), # Divisor for dynamic key scale factor calculation for FP8 KV Cache "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fe8ff7ca5e12..1ca39b0ffa82 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter @@ -251,6 +252,17 @@ def create_weights( else: layer.register_parameter("input_scale", None) + def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + def process_weights_after_loading(self, layer: Module) -> None: # TODO(rob): refactor block quant into separate class. if self.block_quant: @@ -264,6 +276,8 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight.data weight_scale_inv = layer.weight_scale_inv.data + weight = self.add_padding_to_weight(weight) + # Torch.compile cannot use Parameter subclasses. layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale_inv = Parameter(weight_scale_inv, @@ -327,6 +341,7 @@ def process_weights_after_loading(self, layer: Module) -> None: logical_widths=layer.logical_widths, ) + weight = self.add_padding_to_weight(weight) # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 891edf23010c..61706f485f46 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -494,7 +494,7 @@ def w8a8_block_fp8_matmul( assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] M = A.numel() // A.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert B.ndim == 2 and Bs.ndim == 2 N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] From 886189b8f00064637776b5c6757e7b7cd9710263 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 22 Feb 2025 22:04:31 +0800 Subject: [PATCH 177/317] [Doc] Dockerfile instructions for optional dependencies and dev transformers (#13699) --- docs/source/deployment/docker.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/docs/source/deployment/docker.md b/docs/source/deployment/docker.md index 334c02225bd6..9e52a2182cfb 100644 --- a/docs/source/deployment/docker.md +++ b/docs/source/deployment/docker.md @@ -27,6 +27,36 @@ container to access the host's shared memory. vLLM uses PyTorch, which uses shar memory to share data between processes under the hood, particularly for tensor parallel inference. ::: +:::{note} +Optional dependencies are not included in order to avoid licensing issues (e.g. ). + +If you need to use those dependencies (having accepted the license terms), +create a custom Dockerfile on top of the base image with an extra layer that installs them: + +```Dockerfile +FROM vllm/vllm-openai:v0.7.3 + +# e.g. install the `audio` and `video` optional dependencies +# NOTE: Make sure the version of vLLM matches the base image! +RUN uv pip install --system vllm[audio,video]==0.7.3 +``` + +::: + +:::{tip} +Some new models may only be available on the main branch of [HF Transformers](https://github.com/huggingface/transformers). + +To use the development version of `transformers`, create a custom Dockerfile on top of the base image +with an extra layer that installs their code from source: + +```Dockerfile +FROM vllm/vllm-openai:latest + +RUN uv pip install --system git+https://github.com/huggingface/transformers.git +``` + +::: + (deployment-docker-build-image-from-source)= ## Building vLLM's Docker Image from Source From 994dad5f0caf8d38c766a2bdc02988bd1c6fd7b5 Mon Sep 17 00:00:00 2001 From: Helena Kloosterman Date: Sat, 22 Feb 2025 17:04:12 +0100 Subject: [PATCH 178/317] [Bugfix] Fix boolean conversion for OpenVINO env variable (#13615) --- vllm/envs.py | 5 +++-- vllm/model_executor/model_loader/openvino.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 1104f108784f..8be9ebb95dde 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -360,8 +360,9 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Enables weights compression during model export via HF Optimum # default is False "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": - lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), - + lambda: + (os.environ.get("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", "0").lower() in + ("on", "true", "1")), # If the env var is set, then all workers will execute as separate # processes from the engine, and we use the same mechanism to trigger # execution on all workers. diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index fde200d576e2..805f0cfc585e 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -125,7 +125,8 @@ def __init__( "as-is, all possible options that may affect model conversion " "are ignored.") - load_in_8bit = envs.VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS + load_in_8bit = (envs.VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS + if export else False) pt_model = OVModelForCausalLM.from_pretrained( model_config.model, export=export, From 04cb85ae3986219c364f6871a840998fa2b6ef53 Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Sun, 23 Feb 2025 00:05:35 +0800 Subject: [PATCH 179/317] [XPU]fix setuptools version for xpu (#13548) --- requirements-xpu.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-xpu.txt b/requirements-xpu.txt index 42c6c321d040..be5cb6a4a99b 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -6,6 +6,7 @@ cmake>=3.26 ninja packaging setuptools-scm>=8 +setuptools>=75.8.0 wheel jinja2 From 8f46c1361b2ea1f5d5bf04ea2de53a898de887af Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Sat, 22 Feb 2025 17:25:20 +0100 Subject: [PATCH 180/317] [CI/Build] fix uv caching in Dockerfile (#13611) --- Dockerfile | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/Dockerfile b/Dockerfile index 310e003d427d..63314b906f15 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version # Install uv for faster pip installs -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ python3 -m pip install uv # Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 @@ -53,14 +53,14 @@ WORKDIR /workspace # we need to install torch and torchvision from the nightly builds first, # pytorch will not appear as a vLLM dependency in all of the following steps # after this step -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \ fi COPY requirements-common.txt requirements-common.txt COPY requirements-cuda.txt requirements-cuda.txt -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements-cuda.txt # cuda arch list used by torch @@ -81,7 +81,7 @@ ARG TARGETPLATFORM # install build dependencies COPY requirements-build.txt requirements-build.txt -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements-build.txt COPY . . @@ -101,7 +101,7 @@ ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 ARG SCCACHE_S3_NO_CREDENTIALS=0 # if USE_SCCACHE is set, use sccache to speed up compilation -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" = "1" ]; then \ echo "Installing sccache..." \ @@ -121,7 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ - --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" != "1" ]; then \ python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ @@ -146,7 +146,7 @@ FROM base as dev COPY requirements-lint.txt requirements-lint.txt COPY requirements-test.txt requirements-test.txt COPY requirements-dev.txt requirements-dev.txt -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements-dev.txt #################### DEV IMAGE #################### @@ -178,7 +178,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version # Install uv for faster pip installs -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ python3 -m pip install uv # Workaround for https://github.com/openai/triton/issues/2507 and @@ -191,14 +191,14 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ # we need to install torch and torchvision from the nightly builds first, # pytorch will not appear as a vLLM dependency in all of the following steps # after this step -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \ fi # Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ - --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ uv pip install --system dist/*.whl --verbose # If we need to build FlashInfer wheel before its release: @@ -213,7 +213,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # $ ls dist # $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post1/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl ; \ @@ -225,7 +225,7 @@ COPY examples examples # install build dependencies for JIT compilation. # TODO: Remove this once FlashInfer AOT wheel is fixed COPY requirements-build.txt requirements-build.txt -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements-build.txt #################### vLLM installation IMAGE #################### @@ -238,15 +238,15 @@ FROM vllm-base AS test ADD . /vllm-workspace/ # install development dependencies (for testing) -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements-dev.txt # install development dependencies (for testing) -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -e tests/vllm_test_utils # enable fast downloads from hf (for testing) -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system hf_transfer ENV HF_HUB_ENABLE_HF_TRANSFER 1 @@ -266,7 +266,7 @@ RUN mv vllm test_docs/ FROM vllm-base AS vllm-openai-base # install additional dependencies for openai api server -RUN --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=cache,target=/root/.cache/uv \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ else \ From e46908b50197893f2ddd8ca0de1d0a87c4d20944 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 22 Feb 2025 16:50:38 -0800 Subject: [PATCH 181/317] [CI/Build] Fix pre-commit errors from #13571 (#13709) Signed-off-by: Roger Wang --- csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu | 7 ++++--- csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index a0852c5732ee..7b57b32fdb08 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -31,7 +31,8 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, #if defined ENABLE_NVFP4 && ENABLE_NVFP4 return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should " - "be compiled using CUDA 12.8 and target " - "compute capability 100 or above."); + TORCH_CHECK_NOT_IMPLEMENTED(false, + "No compiled nvfp4 mm kernel, vLLM should " + "be compiled using CUDA 12.8 and target " + "compute capability 100 or above."); } diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 26fd91217dbd..9b30e4fef356 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -194,8 +194,9 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, cudaStream_t stream) { - TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " - "a CUTLASS 3.8 source directory to enable support."); + TORCH_CHECK(false, + "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " + "a CUTLASS 3.8 source directory to enable support."); } #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) From e5ad78f548e23ae81b8f2e2429eddcfd6b1e3349 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Sun, 23 Feb 2025 00:51:13 +0000 Subject: [PATCH 182/317] [BugFix] Minor: logger import in attention backend (#13706) Signed-off-by: Andy Lo --- vllm/attention/backends/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 5c1f9916e22c..baf01c9263d4 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -12,12 +12,12 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType -from vllm.logger import logging +from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad -logger = logging.getLogger(__name__) +logger = init_logger(__name__) if TYPE_CHECKING: from vllm.worker.model_runner_base import ModelRunnerBase From 5c7134541bf56546de995d6deb518f59c10d9976 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Sat, 22 Feb 2025 19:19:45 -0800 Subject: [PATCH 183/317] [ci] Use env var to control whether to use S3 bucket in CI (#13634) --- .buildkite/test-pipeline.yaml | 4 +- .../test_basic_correctness.py | 11 +- tests/basic_correctness/test_cumem.py | 9 +- tests/conftest.py | 73 +--------- tests/engine/test_computed_prefix_blocks.py | 7 +- tests/engine/test_detokenization.py | 8 +- tests/engine/test_executor.py | 21 +-- tests/engine/test_skip_tokenizer_init.py | 13 +- tests/entrypoints/llm/test_chat.py | 13 +- tests/entrypoints/llm/test_collective_rpc.py | 2 +- tests/entrypoints/llm/test_encode.py | 4 +- tests/entrypoints/llm/test_generate.py | 4 +- .../llm/test_generate_multiple_loras.py | 4 +- tests/entrypoints/llm/test_guided_generate.py | 7 +- tests/entrypoints/llm/test_lazy_outlines.py | 7 +- .../entrypoints/llm/test_prompt_validation.py | 9 +- tests/metrics/test_metrics.py | 55 ++++---- tests/models/test_initialization.py | 6 +- tests/mq_llm_engine/test_abort.py | 4 +- tests/mq_llm_engine/test_error_handling.py | 6 +- tests/mq_llm_engine/test_load.py | 6 +- tests/multimodal/test_processing.py | 6 +- tests/prefix_caching/test_prefix_caching.py | 2 +- tests/test_config.py | 14 +- tests/test_regression.py | 13 +- tests/worker/test_swap.py | 2 +- vllm/engine/arg_utils.py | 9 ++ vllm/envs.py | 4 + vllm/model_executor/model_loader/loader.py | 1 - vllm/test_utils.py | 129 ++++++++++++++++++ 30 files changed, 222 insertions(+), 231 deletions(-) create mode 100644 vllm/test_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d96f0183bc67..931057e6c197 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -278,7 +278,7 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py parallelism: 4 -- label: "PyTorch Fullgraph Smoke Test" # 9min +- label: PyTorch Fullgraph Smoke Test # 9min fast_check: true source_file_dependencies: - vllm/ @@ -289,7 +289,7 @@ steps: - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py -- label: "PyTorch Fullgraph Test" # 18min +- label: PyTorch Fullgraph Test # 18min source_file_dependencies: - vllm/ - tests/compile diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index cc25c8792aa9..d2fc0916bc55 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -9,7 +9,6 @@ import pytest from vllm import LLM -from vllm.config import LoadFormat from vllm.platforms import current_platform from ..conftest import VllmRunner @@ -34,7 +33,7 @@ def v1(run_with_both_engines): def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" - llm = LLM("distilbert/distilgpt2", load_format=LoadFormat.RUNAI_STREAMER) + llm = LLM("distilbert/distilgpt2") weak_llm = weakref.ref(llm) del llm # If there's any circular reference to vllm, this fails @@ -43,10 +42,10 @@ def test_vllm_gc_ed(): @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False]) def test_models( hf_runner, model: str, @@ -97,8 +96,8 @@ def test_models( "test_suite", [ ("distilbert/distilgpt2", "ray", "", "L4"), ("distilbert/distilgpt2", "mp", "", "L4"), - ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), - ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"), + ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4"), + ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4"), ("distilbert/distilgpt2", "ray", "", "A100"), ("distilbert/distilgpt2", "mp", "", "A100"), ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index f1148fc8e3f4..61c79a7bbc90 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -4,11 +4,9 @@ import torch from vllm import LLM, SamplingParams -from vllm.config import LoadFormat from vllm.device_allocator.cumem import CuMemAllocator from vllm.utils import GiB_bytes -from ..conftest import MODEL_WEIGHTS_S3_BUCKET from ..utils import fork_new_process_for_each_test @@ -121,7 +119,7 @@ def model(x): "model, use_v1", [ # sleep mode with safetensors - (f"{MODEL_WEIGHTS_S3_BUCKET}/meta-llama/Llama-3.2-1B", True), + ("meta-llama/Llama-3.2-1B", True), # sleep mode with pytorch checkpoint ("facebook/opt-125m", False), ]) @@ -130,10 +128,7 @@ def test_end_to_end(model: str, use_v1: bool): os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" free, total = torch.cuda.mem_get_info() used_bytes_baseline = total - free # in case other process is running - load_format = LoadFormat.AUTO - if "Llama" in model: - load_format = LoadFormat.RUNAI_STREAMER - llm = LLM(model, load_format=load_format, enable_sleep_mode=True) + llm = LLM(model, enable_sleep_mode=True) prompt = "How are you?" sampling_params = SamplingParams(temperature=0, max_tokens=10) output = llm.generate(prompt, sampling_params) diff --git a/tests/conftest.py b/tests/conftest.py index 9304b8f17dca..dd339030e5e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import LoadFormat, TaskOption, TokenizerPoolConfig +from vllm.config import TaskOption, TokenizerPoolConfig from vllm.connections import global_http_connection from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, @@ -47,70 +47,6 @@ _M = TypeVar("_M") -MODELS_ON_S3 = [ - "distilbert/distilgpt2", - "meta-llama/Llama-2-7b-hf", - "meta-llama/Meta-Llama-3-8B", - "meta-llama/Llama-3.2-1B", - "meta-llama/Llama-3.2-1B-Instruct", - "openai-community/gpt2", - "ArthurZ/Ilama-3.2-1B", - "llava-hf/llava-1.5-7b-hf", - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "ai21labs/Jamba-tiny-random", - "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", - "nm-testing/Phi-3-mini-128k-instruct-FP8", - "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV", - "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", - "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", - "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", - "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", - "AMead10/Llama-3.2-1B-Instruct-AWQ", - "shuyuej/Llama-3.2-1B-Instruct-GPTQ", - "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", - "ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", - "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", - "neuralmagic/Meta-Llama-3-8B-Instruct-FP8", - "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test", - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", - "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", - "neuralmagic/Llama-3.2-1B-quantized.w8a8", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", - "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", - "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", - "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", - "nm-testing/tinyllama-oneshot-w4a16-channel-v2", - "nm-testing/tinyllama-oneshot-w4a16-group128-v2", - "nm-testing/tinyllama-oneshot-w8a16-per-channel", - "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", - "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test", - "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme", - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", - "nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor", - "nm-testing/llama2.c-stories42M-pruned2.4-compressed", -] - -MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights" - _PromptMultiModalInput = Union[List[_M], List[List[_M]]] PromptImageInput = _PromptMultiModalInput[Image.Image] @@ -742,14 +678,8 @@ def __init__( enable_chunked_prefill: bool = False, swap_space: int = 4, enforce_eager: Optional[bool] = False, - load_format: Optional[LoadFormat] = None, **kwargs, ) -> None: - if model_name in MODELS_ON_S3 and not load_format: - model_name = (f"{MODEL_WEIGHTS_S3_BUCKET}/{model_name}") - load_format = LoadFormat.RUNAI_STREAMER - if not load_format: - load_format = LoadFormat.AUTO self.model = LLM( model=model_name, task=task, @@ -764,7 +694,6 @@ def __init__( max_model_len=max_model_len, block_size=block_size, enable_chunked_prefill=enable_chunked_prefill, - load_format=load_format, **kwargs, ) diff --git a/tests/engine/test_computed_prefix_blocks.py b/tests/engine/test_computed_prefix_blocks.py index 51e7c8e7739d..049fa2c8b12b 100644 --- a/tests/engine/test_computed_prefix_blocks.py +++ b/tests/engine/test_computed_prefix_blocks.py @@ -2,16 +2,12 @@ import pytest -from vllm.config import LoadFormat from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.sampling_params import SamplingParams -from ..conftest import MODEL_WEIGHTS_S3_BUCKET - -@pytest.mark.parametrize("model", - [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) @pytest.mark.parametrize("block_size", [16]) def test_computed_prefix_blocks(model: str, block_size: int): # This test checks if we are able to run the engine to completion @@ -28,7 +24,6 @@ def test_computed_prefix_blocks(model: str, block_size: int): "decoration.") engine_args = EngineArgs(model=model, - load_format=LoadFormat.RUNAI_STREAMER, block_size=block_size, enable_prefix_caching=True) diff --git a/tests/engine/test_detokenization.py b/tests/engine/test_detokenization.py index 6ae4be2e4786..2b7ebf705bbd 100644 --- a/tests/engine/test_detokenization.py +++ b/tests/engine/test_detokenization.py @@ -2,15 +2,11 @@ import pytest -from vllm.config import LoadFormat from vllm.entrypoints.llm import LLM from vllm.sampling_params import SamplingParams -from ..conftest import MODEL_WEIGHTS_S3_BUCKET - -@pytest.mark.parametrize("model", - [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) def test_computed_prefix_blocks(model: str): # This test checks if the engine generates completions both with and # without optional detokenization, that detokenization includes text @@ -21,7 +17,7 @@ def test_computed_prefix_blocks(model: str): "paper clips? Is there an easy to follow video tutorial available " "online for free?") - llm = LLM(model=model, load_format=LoadFormat.RUNAI_STREAMER) + llm = LLM(model=model) sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False) diff --git a/tests/engine/test_executor.py b/tests/engine/test_executor.py index 6a86401ce5db..c0a339e46ec4 100644 --- a/tests/engine/test_executor.py +++ b/tests/engine/test_executor.py @@ -6,17 +6,12 @@ import pytest -from vllm.config import LoadFormat from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine from vllm.executor.uniproc_executor import UniProcExecutor from vllm.sampling_params import SamplingParams -from ..conftest import MODEL_WEIGHTS_S3_BUCKET - -RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER - class Mock: ... @@ -38,12 +33,10 @@ def collective_rpc(self, CustomUniExecutorAsync = CustomUniExecutor -@pytest.mark.parametrize("model", - [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) def test_custom_executor_type_checking(model): with pytest.raises(ValueError): engine_args = EngineArgs(model=model, - load_format=RUNAI_STREAMER_LOAD_FORMAT, distributed_executor_backend=Mock) LLMEngine.from_engine_args(engine_args) with pytest.raises(ValueError): @@ -52,8 +45,7 @@ def test_custom_executor_type_checking(model): AsyncLLMEngine.from_engine_args(engine_args) -@pytest.mark.parametrize("model", - [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) def test_custom_executor(model, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -62,7 +54,6 @@ def test_custom_executor(model, tmp_path): engine_args = EngineArgs( model=model, - load_format=RUNAI_STREAMER_LOAD_FORMAT, distributed_executor_backend=CustomUniExecutor, enforce_eager=True, # reduce test time ) @@ -77,8 +68,7 @@ def test_custom_executor(model, tmp_path): os.chdir(cwd) -@pytest.mark.parametrize("model", - [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) def test_custom_executor_async(model, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -87,7 +77,6 @@ def test_custom_executor_async(model, tmp_path): engine_args = AsyncEngineArgs( model=model, - load_format=RUNAI_STREAMER_LOAD_FORMAT, distributed_executor_backend=CustomUniExecutorAsync, enforce_eager=True, # reduce test time ) @@ -106,8 +95,7 @@ async def t(): os.chdir(cwd) -@pytest.mark.parametrize("model", - [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) def test_respect_ray(model): # even for TP=1 and PP=1, # if users specify ray, we should use ray. @@ -116,7 +104,6 @@ def test_respect_ray(model): engine_args = EngineArgs( model=model, distributed_executor_backend="ray", - load_format=RUNAI_STREAMER_LOAD_FORMAT, enforce_eager=True, # reduce test time ) engine = LLMEngine.from_engine_args(engine_args) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index b0930eaac17b..5e197f5ffe59 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -2,22 +2,19 @@ import pytest -from vllm.config import LoadFormat from vllm.entrypoints.llm import LLM from vllm.sampling_params import SamplingParams -from ..conftest import MODEL_WEIGHTS_S3_BUCKET - -@pytest.mark.parametrize("model", - [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"]) +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) def test_skip_tokenizer_initialization(model: str): # This test checks if the flag skip_tokenizer_init skips the initialization # of tokenizer and detokenizer. The generated output is expected to contain # token ids. - llm = LLM(model=model, - skip_tokenizer_init=True, - load_format=LoadFormat.RUNAI_STREAMER) + llm = LLM( + model=model, + skip_tokenizer_init=True, + ) sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) with pytest.raises(ValueError, match="cannot pass text prompts when"): diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index f6fda5120d9e..77c80b2f8944 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -5,17 +5,12 @@ import pytest from vllm import LLM -from vllm.config import LoadFormat -from ...conftest import MODEL_WEIGHTS_S3_BUCKET from ..openai.test_vision import TEST_IMAGE_URLS -RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER - def test_chat(): - llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B-Instruct", - load_format=RUNAI_STREAMER_LOAD_FORMAT) + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") prompt1 = "Explain the concept of entropy." messages = [ @@ -33,8 +28,7 @@ def test_chat(): def test_multi_chat(): - llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B-Instruct", - load_format=RUNAI_STREAMER_LOAD_FORMAT) + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") prompt1 = "Explain the concept of entropy." prompt2 = "Explain what among us is." @@ -71,8 +65,7 @@ def test_multi_chat(): [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) def test_chat_multi_image(image_urls: List[str]): llm = LLM( - model=f"{MODEL_WEIGHTS_S3_BUCKET}/Phi-3.5-vision-instruct", - load_format=RUNAI_STREAMER_LOAD_FORMAT, + model="microsoft/Phi-3.5-vision-instruct", dtype="bfloat16", max_model_len=4096, max_num_seqs=5, diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index 69c60bbe6e8a..39d4810de9e7 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -28,7 +28,7 @@ class MyWorker(Worker): def echo_rank(self): return self.rank - llm = LLM(model="s3://vllm-ci-model-weights/Llama-3.2-1B-Instruct", + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, load_format="dummy", tensor_parallel_size=tp_size, diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index 61085bf43d1b..ebec8baba38d 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -6,10 +6,9 @@ import pytest from vllm import LLM, PoolingParams, PoolingRequestOutput -from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory -MODEL_NAME = "s3://vllm-ci-model-weights/e5-mistral-7b-instruct" +MODEL_NAME = "intfloat/e5-mistral-7b-instruct" PROMPTS = [ "Hello, my name is", @@ -33,7 +32,6 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, - load_format=LoadFormat.RUNAI_STREAMER, max_num_batched_tokens=32768, tensor_parallel_size=1, gpu_memory_utilization=0.75, diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index f1bad876be46..910e1a4507cc 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -6,10 +6,9 @@ import pytest from vllm import LLM, RequestOutput, SamplingParams -from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory -MODEL_NAME = "s3://vllm-ci-model-weights/distilgpt2" +MODEL_NAME = "distilbert/distilgpt2" PROMPTS = [ "Hello, my name is", @@ -31,7 +30,6 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, - load_format=LoadFormat.RUNAI_STREAMER, max_num_batched_tokens=4096, tensor_parallel_size=1, gpu_memory_utilization=0.10, diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index 487c00460a63..90e1d5814137 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -7,11 +7,10 @@ from huggingface_hub import snapshot_download from vllm import LLM -from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -MODEL_NAME = "s3://vllm-ci-model-weights/zephyr-7b-beta" +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" PROMPTS = [ "Hello, my name is", @@ -28,7 +27,6 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, - load_format=LoadFormat.RUNAI_STREAMER, tensor_parallel_size=1, max_model_len=8192, enable_lora=True, diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 252eb3fb334a..314dc59328cb 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -7,13 +7,12 @@ import jsonschema import pytest -from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -MODEL_NAME = "s3://vllm-ci-model-weights/Qwen2.5-1.5B-Instruct" +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] @@ -21,9 +20,7 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - load_format=LoadFormat.RUNAI_STREAMER, - max_model_len=1024) + llm = LLM(model=MODEL_NAME, max_model_len=1024) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py index 07608e15fe92..0598e3990d86 100644 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -6,7 +6,6 @@ from vllm_test_utils import BlameResult, blame from vllm import LLM, SamplingParams -from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory @@ -44,8 +43,7 @@ def run_normal(): sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM without guided decoding as a baseline. - llm = LLM(model="s3://vllm-ci-model-weights/distilgpt2", - load_format=LoadFormat.RUNAI_STREAMER, + llm = LLM(model="distilbert/distilgpt2", enforce_eager=True, gpu_memory_utilization=0.3) outputs = llm.generate(prompts, sampling_params) @@ -61,8 +59,7 @@ def run_normal(): def run_lmfe(sample_regex): # Create an LLM with guided decoding enabled. - llm = LLM(model="s3://vllm-ci-model-weights/distilgpt2", - load_format=LoadFormat.RUNAI_STREAMER, + llm = LLM(model="distilbert/distilgpt2", enforce_eager=True, guided_decoding_backend="lm-format-enforcer", gpu_memory_utilization=0.3) diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 04848131dfc8..61bd1d462a50 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -3,7 +3,6 @@ import pytest from vllm import LLM -from vllm.config import LoadFormat @pytest.fixture(autouse=True) @@ -15,17 +14,13 @@ def v1(run_with_both_engines): def test_empty_prompt(): - llm = LLM(model="s3://vllm-ci-model-weights/gpt2", - load_format=LoadFormat.RUNAI_STREAMER, - enforce_eager=True) + llm = LLM(model="openai-community/gpt2", enforce_eager=True) with pytest.raises(ValueError, match='Prompt cannot be empty'): llm.generate([""]) @pytest.mark.skip_v1 def test_out_of_vocab_token(): - llm = LLM(model="s3://vllm-ci-model-weights/gpt2", - load_format=LoadFormat.RUNAI_STREAMER, - enforce_eager=True) + llm = LLM(model="openai-community/gpt2", enforce_eager=True) with pytest.raises(ValueError, match='out of vocabulary'): llm.generate({"prompt_token_ids": [999999]}) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 45a13488f07e..d6183379c394 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -8,21 +8,17 @@ from prometheus_client import REGISTRY from vllm import EngineArgs, LLMEngine -from vllm.config import LoadFormat from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.metrics import RayPrometheusStatLogger from vllm.sampling_params import SamplingParams - -from ..conftest import MODEL_WEIGHTS_S3_BUCKET +from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET MODELS = [ "distilbert/distilgpt2", ] -RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @@ -146,9 +142,8 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, metrics_tag_content = stat_logger.labels["model_name"] if served_model_name is None or served_model_name == []: - actual_model_name = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" - assert metrics_tag_content == actual_model_name, ( - f"Metrics tag model_name is wrong! expect: {actual_model_name!r}\n" + assert metrics_tag_content == f"{MODEL_WEIGHTS_S3_BUCKET}/{model}", ( + f"Metrics tag model_name is wrong! expect: {model!r}\n" f"actual: {metrics_tag_content!r}") else: assert metrics_tag_content == served_model_name[0], ( @@ -174,10 +169,11 @@ async def test_async_engine_log_metrics_regression( when disable_log_stats=False (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678) """ - engine_args = AsyncEngineArgs(model=model, - dtype=dtype, - disable_log_stats=disable_log_stats, - load_format=RUNAI_STREAMER_LOAD_FORMAT) + engine_args = AsyncEngineArgs( + model=model, + dtype=dtype, + disable_log_stats=disable_log_stats, + ) async_engine = AsyncLLMEngine.from_engine_args(engine_args) for i, prompt in enumerate(example_prompts): results = async_engine.generate( @@ -189,7 +185,7 @@ async def test_async_engine_log_metrics_regression( async for _ in results: pass - assert_metrics(async_engine.engine, disable_log_stats, + assert_metrics(model, async_engine.engine, disable_log_stats, len(example_prompts)) @@ -204,10 +200,11 @@ def test_engine_log_metrics_regression( max_tokens: int, disable_log_stats: bool, ) -> None: - engine_args = EngineArgs(model=model, - dtype=dtype, - disable_log_stats=disable_log_stats, - load_format=RUNAI_STREAMER_LOAD_FORMAT) + engine_args = EngineArgs( + model=model, + dtype=dtype, + disable_log_stats=disable_log_stats, + ) engine = LLMEngine.from_engine_args(engine_args) for i, prompt in enumerate(example_prompts): engine.add_request( @@ -218,7 +215,8 @@ def test_engine_log_metrics_regression( while engine.has_unfinished_requests(): engine.step() - assert_metrics(engine, disable_log_stats, len(example_prompts)) + assert_metrics(f"{MODEL_WEIGHTS_S3_BUCKET}/{model}", engine, + disable_log_stats, len(example_prompts)) @pytest.mark.parametrize("model", MODELS) @@ -285,14 +283,15 @@ def test_metric_spec_decode_interval( ) -> None: k = 5 - engine_args = EngineArgs(model=model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4, - speculative_model=model, - num_speculative_tokens=k, - enforce_eager=True, - load_format=RUNAI_STREAMER_LOAD_FORMAT) + engine_args = EngineArgs( + model=model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4, + speculative_model=model, + num_speculative_tokens=k, + enforce_eager=True, + ) engine = LLMEngine.from_engine_args(engine_args) @@ -359,7 +358,7 @@ def test_metric_spec_decode_interval( cleanup_dist_env_and_memory() -def assert_metrics(engine: LLMEngine, disable_log_stats: bool, +def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool, num_requests: int) -> None: if disable_log_stats: with pytest.raises(AttributeError): @@ -370,7 +369,7 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool, # Ensure the count bucket of request-level histogram metrics matches # the number of requests as a simple sanity check to ensure metrics are # generated - labels = {'model_name': engine.model_config.model} + labels = {'model_name': model} request_histogram_metrics = [ "vllm:e2e_request_latency_seconds", "vllm:request_prompt_tokens", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index e0d5e0032275..c58c63723168 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -7,7 +7,6 @@ from vllm import LLM -from ..conftest import MODELS_ON_S3 from .registry import HF_EXAMPLE_MODELS @@ -43,11 +42,8 @@ def _initialize_kv_caches(self) -> None: with patch.object(LLM.get_engine_class(), "_initialize_kv_caches", _initialize_kv_caches): - model_name = model_info.default - if model_name in MODELS_ON_S3: - model_name = f"s3://vllm-ci-model-weights/{model_name.split('/')[-1]}" LLM( - model_name, + model_info.default, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, speculative_model=model_info.speculative_model, diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py index b0ac0fb327f4..808346b5e58d 100644 --- a/tests/mq_llm_engine/test_abort.py +++ b/tests/mq_llm_engine/test_abort.py @@ -10,8 +10,8 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs -MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, load_format="runai_streamer") +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) RAISED_ERROR = KeyError RAISED_VALUE = "foo" EXPECTED_TOKENS = 250 diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 4eac73417ad7..35d001781110 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -21,10 +21,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser -MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, - load_format="runai_streamer", - enforce_eager=True) +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True) RAISED_ERROR = KeyError RAISED_VALUE = "foo" diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py index 3162d56c6d4e..2069ff987f2f 100644 --- a/tests/mq_llm_engine/test_load.py +++ b/tests/mq_llm_engine/test_load.py @@ -10,14 +10,12 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs -MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it" +MODEL = "google/gemma-1.1-2b-it" NUM_EXPECTED_TOKENS = 10 NUM_REQUESTS = 10000 # Scenarios to test for num generated token. -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, - load_format="runai_streamer", - disable_log_requests=True) +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) @pytest.fixture(scope="function") diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index b247321ebb2f..c2fbe83abc83 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -553,8 +553,7 @@ def test_find_mm_placeholders( assert result == expected -@pytest.mark.parametrize( - "model_id", ["s3://vllm-ci-model-weights/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("limit", "num_supported", "is_valid"), [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), @@ -593,8 +592,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): profiler.get_dummy_data(model_config.max_model_len) -@pytest.mark.parametrize( - "model_id", ["s3://vllm-ci-model-weights/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("num_images", "limit", "is_valid"), [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 90d424fe35d8..2773d27a6813 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -16,7 +16,7 @@ from ..models.utils import check_outputs_equal MODELS = [ - "facebook/opt-125m", + "distilbert/distilgpt2", ] UNSTABLE_PROMPT_SEQUENCE = [ diff --git a/tests/test_config.py b/tests/test_config.py index bc87e6ccdfcc..8927a14d79ac 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,20 +8,14 @@ from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform -from .conftest import MODEL_WEIGHTS_S3_BUCKET - @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_task"), [ - (f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", "generate", - "generate"), - (f"{MODEL_WEIGHTS_S3_BUCKET}/intfloat/e5-mistral-7b-instruct", - "pooling", "embed"), - (f"{MODEL_WEIGHTS_S3_BUCKET}/jason9693/Qwen2.5-1.5B-apeach", "pooling", - "classify"), - (f"{MODEL_WEIGHTS_S3_BUCKET}/cross-encoder/ms-marco-MiniLM-L-6-v2", - "pooling", "score"), + ("distilbert/distilgpt2", "generate", "generate"), + ("intfloat/e5-mistral-7b-instruct", "pooling", "embed"), + ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), + ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), ("openai/whisper-small", "transcription", "transcription"), ], diff --git a/tests/test_regression.py b/tests/test_regression.py index 8cecc2892b6e..ce9498e8d7e8 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -10,9 +10,6 @@ import torch from vllm import LLM, SamplingParams -from vllm.config import LoadFormat - -from .conftest import MODEL_WEIGHTS_S3_BUCKET def test_duplicated_ignored_sequence_group(): @@ -21,8 +18,7 @@ def test_duplicated_ignored_sequence_group(): sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=256) - llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", - load_format=LoadFormat.RUNAI_STREAMER, + llm = LLM(model="distilbert/distilgpt2", max_num_batched_tokens=4096, tensor_parallel_size=1) prompts = ["This is a short prompt", "This is a very long prompt " * 1000] @@ -35,8 +31,7 @@ def test_max_tokens_none(): sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) - llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", - load_format=LoadFormat.RUNAI_STREAMER, + llm = LLM(model="distilbert/distilgpt2", max_num_batched_tokens=4096, tensor_parallel_size=1) prompts = ["Just say hello!"] @@ -46,9 +41,7 @@ def test_max_tokens_none(): def test_gc(): - llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", - load_format=LoadFormat.RUNAI_STREAMER, - enforce_eager=True) + llm = LLM(model="distilbert/distilgpt2", enforce_eager=True) del llm gc.collect() diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 2c337cc9fed2..3ab8070999b0 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -10,7 +10,7 @@ def test_swap() -> None: # Configure the engine. - engine_args = EngineArgs(model="s3://vllm-ci-model-weights/distilgpt2", + engine_args = EngineArgs(model="distilbert/distilgpt2", dtype="half", load_format="dummy") engine_config = engine_args.create_engine_config() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d75e2324f5c7..bab7cfe2aa3a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -22,6 +22,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.plugins import load_general_plugins +from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, StoreBoolean @@ -1141,6 +1142,14 @@ def create_engine_config(self, f", but got {self.cpu_offload_gb}") device_config = DeviceConfig(device=self.device) + + # NOTE: This is to allow model loading from S3 in CI + if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == LoadFormat.AUTO): # noqa: E501 + self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" + self.load_format = LoadFormat.RUNAI_STREAMER + model_config = self.create_model_config() if (model_config.is_multimodal_model and not envs.VLLM_USE_V1 diff --git a/vllm/envs.py b/vllm/envs.py index 8be9ebb95dde..dbf1d4623962 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -618,6 +618,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Port of the master node in the data parallel setting "VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), + + # Whether to use S3 path for model loading in CI via RunAI Streamer + "VLLM_CI_USE_S3": + lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", } # end-env-vars-definition diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index df957cfca3c0..8736cf1ca341 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1394,7 +1394,6 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" - if isinstance(load_config.load_format, type): return load_config.load_format(load_config) diff --git a/vllm/test_utils.py b/vllm/test_utils.py new file mode 100644 index 000000000000..eb9a4d80a2c2 --- /dev/null +++ b/vllm/test_utils.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +MODELS_ON_S3 = [ + "adept/fuyu-8b", + "ai21labs/AI21-Jamba-1.5-Mini", + "ai21labs/Jamba-tiny-random", + "ai21labs/Jamba-tiny-reward-dev", + "allenai/Molmo-7B-D-0924", + "allenai/OLMo-1B-hf", + "allenai/OLMoE-1B-7B-0924-Instruct", + "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test", + "AMead10/Llama-3.2-1B-Instruct-AWQ", + "ArthurZ/Ilama-3.2-1B", + "BAAI/bge-base-en-v1.5", + "BAAI/bge-multilingual-gemma2", + "BAAI/bge-reranker-v2-m3", + "bigcode/starcoder2-3b", + "cross-encoder/ms-marco-MiniLM-L-6-v2", + "cross-encoder/quora-roberta-base", + "deepseek-ai/deepseek-vl2-tiny", + "distilbert/distilgpt2", + "facebook/bart-base", + "facebook/bart-large-cnn", + # "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "google/gemma-1.1-2b-it", + "google/gemma-2-2b-it", + "google/paligemma-3b-pt-224", + "h2oai/h2ovl-mississippi-800m", + "HuggingFaceM4/Idefics3-8B-Llama3", + "internlm/internlm2-1_8b-reward", + "intfloat/e5-mistral-7b-instruct", + "intfloat/multilingual-e5-large", + "jason9693/Qwen2.5-1.5B-apeach", + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + "llava-hf/llava-v1.6-mistral-7b-hf", + "llava-hf/LLaVA-NeXT-Video-7B-hf", + # "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-3.2-11B-Vision-Instruct", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Meta-Llama-3-8B", + "microsoft/phi-2", + "microsoft/Phi-3-mini-4k-instruct", + "microsoft/Phi-3-small-8k-instruct", + "microsoft/Phi-3-vision-128k-instruct", + "microsoft/Phi-3.5-MoE-instruct", + "microsoft/Phi-3.5-vision-instruct", + # "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "mistralai/Pixtral-12B-2409", + "mistral-community/Mixtral-8x22B-v0.1-AWQ", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", + "ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", + "nm-testing/llama2.c-stories42M-pruned2.4-compressed", + "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", + "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", + "nm-testing/Phi-3-mini-128k-instruct-FP8", + "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV", + "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", + "nm-testing/tinyllama-oneshot-w4a16-channel-v2", + "nm-testing/tinyllama-oneshot-w4a16-group128-v2", + "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", + "nm-testing/tinyllama-oneshot-w8a16-per-channel", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", + "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", + "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", + "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme", + "nvidia/NVLM-D-72B", + "openai-community/gpt2", + # "openai/whisper-large-v3", + "openbmb/MiniCPM-o-2_6", + "openbmb/MiniCPM-V-2_6", + "OpenGVLab/InternVL2-1B", + "parasail-ai/GritLM-7B-vllm", + "Qwen/Qwen1.5-MoE-A2.7B-Chat", + "Qwen/Qwen2-7B-Instruct", + "Qwen/Qwen2-Audio-7B-Instruct", + "Qwen/Qwen2-VL-2B-Instruct", + "Qwen/Qwen2.5-1.5B-Instruct", + "Qwen/Qwen2.5-Math-PRM-7B", + "Qwen/Qwen2.5-Math-RM-72B", + "Qwen/Qwen2.5-VL-3B-Instruct", + "royokong/e5-v", + "sentence-transformers/all-roberta-large-v1", + "sentence-transformers/stsb-roberta-base-v2", + "shanearora/OLMo-7B-1124-hf", + "shuyuej/Llama-3.2-1B-Instruct-GPTQ", + "ssmits/Qwen2-7B-Instruct-embed-base", + "stabilityai/stablelm-3b-4e1t", + "stabilityai/stablelm-zephyr-3b", + "state-spaces/mamba-130m-hf", + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + "THUDM/glm-4v-9b", + "TIGER-Lab/Mantis-8B-siglip-llama3", + "TIGER-Lab/VLM2Vec-Full", + "tiiuae/falcon-40b", + "tiiuae/falcon-mamba-7b-instruct", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "upstage/solar-pro-preview-instruct", +] + +MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights" From e684adb9125cc77aa64eaddb23a843f31884b130 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 22 Feb 2025 22:21:15 -0500 Subject: [PATCH 184/317] [Quant] BaiChuan SupportsQuant (#13710) --- vllm/model_executor/models/baichuan.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index b613b70a7564..2e51b9c9c0c7 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -334,7 +334,8 @@ def forward( return hidden_states -class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + SupportsQuant): packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ From aad6f8ae39089337a0586b0cf03f135b3a42fde4 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 23 Feb 2025 17:46:03 +0800 Subject: [PATCH 185/317] [LMM] Implement merged multimodal processor for whisper (#13278) --- .../multimodal/processing/test_common.py | 11 +- vllm/model_executor/models/whisper.py | 206 +++++++++++------- vllm/multimodal/processing.py | 5 +- vllm/multimodal/profiling.py | 11 +- 4 files changed, 150 insertions(+), 83 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 331ffe82ec85..0115863f5626 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -83,11 +83,11 @@ def _test_processing_correctness( } tokenizer_encode_kwargs = {} - if model_config.hf_config.model_type == "mllama": - # For Mllama, tokenizer will always add bos_token at the beginning of - # prompt by default, causing hf_processor outputs incorrect token ids. - # So we need use `add_special_tokens=False` here to leave bos_token - # to be added by the processor. + if model_config.hf_config.model_type in ("mllama", "whisper"): + # For some encoder-decoder models, tokenizer will always add bos_token + # at the beginning of prompt by default, causing hf_processor outputs + # incorrect token ids. So we need use `add_special_tokens=False` here + # to leave bos_token to be added by the processor. tokenizer_encode_kwargs = {"add_special_tokens": False} for batch_idx in range(num_batches): @@ -173,6 +173,7 @@ def _test_processing_correctness( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "openai/whisper-large-v3", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 073a30d25e23..2ad1731144ef 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,15 +4,15 @@ from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, Union) -import numpy as np import torch from torch import nn +from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, + WhisperProcessor) from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -25,11 +25,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, - NestedTensors) -from vllm.multimodal.audio import resample_audio -from vllm.sequence import SequenceData -from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .interfaces import SupportsMultiModal, SupportsTranscription from .utils import AutoWeightsLoader, WeightsMapper, make_layers @@ -571,72 +574,126 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -def get_max_whisper_audio_tokens(ctx: InputContext) -> int: - return ctx.model_config.hf_config.max_source_positions - - -def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - assert mm_counts["audio"] == 1 - num_tokens = get_max_whisper_audio_tokens(ctx) - processor = cached_processor_from_config(ctx.model_config) - chunk_length = processor.feature_extractor.chunk_length - sampling_rate = processor.feature_extractor.sampling_rate - num_samples = chunk_length * sampling_rate - return DummyData( - SequenceData.from_prompt_token_counts((0, num_tokens)), - {"audio": [(np.zeros(num_samples), sampling_rate)]}, - ) - - -def input_processor_for_whisper(ctx: InputContext, inputs): - multi_modal_data = inputs["encoder"]["multi_modal_data"] - if isinstance(multi_modal_data["audio"], list): - assert len(multi_modal_data["audio"]) == 1 - multi_modal_data["audio"] = multi_modal_data["audio"][0] - # Resample and process audio - audio, orig_sr = multi_modal_data["audio"] - processor = cached_processor_from_config(ctx.model_config) - target_sr = processor.feature_extractor.sampling_rate - audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) - multi_modal_data["audio"] = (audio, target_sr) - # Pre-allocate placeholder tokens in encoder sequence - num_tokens = get_max_whisper_audio_tokens(ctx) - inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens - return inputs - - -def input_mapper_for_whisper( - ctx: InputContext, - multi_modal_data: Union[np.ndarray, List[np.ndarray]], -) -> MultiModalKwargs: - if not isinstance(multi_modal_data, list): - multi_modal_data = [multi_modal_data] - - assert len(multi_modal_data) == 1 - - if len(multi_modal_data) == 0: - return MultiModalKwargs() - - processor = cached_processor_from_config(ctx.model_config) - sampling_rate = processor.feature_extractor.sampling_rate - - audios = [audio for audio, _ in multi_modal_data] - - kwargs = processor(audios, - sampling_rate=sampling_rate, - return_tensors="pt") - kwargs["input_features"] = kwargs["input_features"].squeeze(0).to( - ctx.model_config.dtype) - - return MultiModalKwargs(kwargs) - - -@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) -@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) -@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_whisper_audio_tokens) +class WhisperProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> WhisperConfig: + return self.ctx.get_hf_config(WhisperConfig) + + def get_hf_processor(self, + sampling_rate: Optional[int] = None + ) -> WhisperProcessor: + return self.ctx.get_hf_processor(WhisperProcessor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": 1} + + def get_feature_extractor(self) -> WhisperFeatureExtractor: + hf_processor = self.get_hf_processor() + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_max_audio_tokens(self) -> int: + return self.get_hf_config().max_source_positions + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"audio": self.get_max_audio_tokens()} + + +class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + return ProcessorInputs( + prompt_text="<|startoftranscript|>" * num_audios, + mm_data=mm_data, + ) + + +class WhisperMultiModalProcessor( + EncDecMultiModalProcessor[WhisperProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + # Strictly speaking, whisper encoder only accept audio features. + # We create a dummy encoder prompt here which will be padded to + # num_audio_tokens. So that we can create dummy data from this + # for encoder profiling. + return [0] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_data = dict(audio=mm_data.pop("audios")) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + if "labels" in processed_outputs: + processed_outputs["input_ids"] = processed_outputs.pop("labels") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(input_features=MultiModalFieldConfig.batched("audio")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + num_tokens = self.info.get_max_audio_tokens() + return [ + PromptReplacement( + modality="audio", + target=[0], + replacement=[0] * num_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor, + info=WhisperProcessingInfo, + dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, SupportsMultiModal): packed_modules_mapping = { @@ -724,7 +781,8 @@ def _parse_and_validate_audio_input( if not isinstance(input_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio features. " f"Got type: {type(input_features)}") - input_features = [feat.to(self.dtype) for feat in input_features] + input_features = torch.cat( + [feat.to(self.dtype) for feat in input_features]) return WhisperAudioInputs(input_features=input_features) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index fcd02fbd5203..93756364dea1 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1297,7 +1297,10 @@ def create_encoder_prompt( prompt: Union[str, list[int]], mm_data: MultiModalDataDict, ) -> Union[str, list[int]]: - """Create input prompt for the encoder.""" + """ + Create input prompt for the encoder. HF processor will be applied on + this prompt during profiling and generation. + """ raise NotImplementedError def apply( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 81c92b38f8e9..802e40a0c952 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -166,8 +166,12 @@ def get_dummy_data( f"({set(mm_max_tokens_per_item.keys())})") mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] + # For encoder-decoder models, use encoder prompt token ids instead of + # decoder prompt to construct dummy seq_data for encoder profiling. + prompt_token_ids = ( + mm_inputs["prompt_token_ids"] if not is_encoder_data else + mm_inputs["encoder_prompt_token_ids"]) # type: ignore total_placeholders_by_modality = { modality: sum(item["length"] for item in placeholders) @@ -188,7 +192,7 @@ def get_dummy_data( # V0 does not support chunked prefill. if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: - if total_len > seq_len: + if total_len > seq_len and not is_encoder_data: logger.warning( "The context length (%d) of the model is too short " "to hold the multi-modal embeddings in the worst case " @@ -201,7 +205,8 @@ def get_dummy_data( total_placeholders_by_modality) return DummyData( - seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), + seq_data=SequenceData.from_prompt_token_counts( + (0, max(seq_len, total_len))), multi_modal_data=None, multi_modal_placeholders=None, ) From efddd99f89f40fe641de590ec0d73c7e9bb0a31c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 23 Feb 2025 02:54:29 -0800 Subject: [PATCH 186/317] [Core][Distributed] Use IPC (domain socket) ZMQ socket for local comms (#13688) --- .../device_communicators/shm_broadcast.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 48ac81ac008b..12a720d47fbb 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -19,7 +19,8 @@ import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address +from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, + is_valid_ipv6_address) VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -165,12 +166,12 @@ def get_metadata(self, current_idx: int): @dataclass class Handle: - connect_ip: str local_reader_ranks: List[int] = field(default_factory=list) buffer_handle: Optional[Tuple[int, int, int, str]] = None - local_subscribe_port: Optional[int] = None - remote_subscribe_port: Optional[int] = None + local_subscribe_addr: Optional[str] = None + remote_subscribe_addr: Optional[str] = None + remote_addr_ipv6: bool = False class MessageQueue: @@ -192,9 +193,6 @@ def __init__( n_remote_reader = n_reader - n_local_reader self.n_remote_reader = n_remote_reader - if connect_ip is None: - connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1" - context = Context() if n_local_reader > 0: @@ -212,32 +210,34 @@ def __init__( # message. otherwise, we will only receive the first subscription # see http://api.zeromq.org/3-3:zmq-setsockopt for more details self.local_socket.setsockopt(XPUB_VERBOSE, True) - local_subscribe_port = get_open_port() - socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" - logger.debug("Binding to %s", socket_addr) - self.local_socket.bind(socket_addr) + local_subscribe_addr = get_open_zmq_ipc_path() + logger.debug("Binding to %s", local_subscribe_addr) + self.local_socket.bind(local_subscribe_addr) self.current_idx = 0 - else: self.buffer = None # type: ignore - local_subscribe_port = None + local_subscribe_addr = None self.local_socket = None self.current_idx = -1 + remote_addr_ipv6 = False if n_remote_reader > 0: # for remote readers, we will: # create a publish-subscribe socket to communicate large data + if not connect_ip: + connect_ip = get_ip() self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() if is_valid_ipv6_address(connect_ip): self.remote_socket.setsockopt(IPV6, 1) + remote_addr_ipv6 = True socket_addr = f"tcp://*:{remote_subscribe_port}" self.remote_socket.bind(socket_addr) - + remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" else: - remote_subscribe_port = None + remote_subscribe_addr = None self.remote_socket = None self._is_writer = True @@ -247,12 +247,12 @@ def __init__( self._is_remote_reader = False self.handle = Handle( - connect_ip=connect_ip, local_reader_ranks=local_reader_ranks, buffer_handle=self.buffer.handle() if self.buffer is not None else None, - local_subscribe_port=local_subscribe_port, - remote_subscribe_port=remote_subscribe_port, + local_subscribe_addr=local_subscribe_addr, + remote_subscribe_addr=remote_subscribe_addr, + remote_addr_ipv6=remote_addr_ipv6, ) logger.info("vLLM message queue communication handle: %s", self.handle) @@ -278,7 +278,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket = context.socket(SUB) self.local_socket.setsockopt_string(SUBSCRIBE, "") - socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" + socket_addr = handle.local_subscribe_addr logger.debug("Connecting to %s", socket_addr) self.local_socket.connect(socket_addr) @@ -294,9 +294,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") - if is_valid_ipv6_address(handle.connect_ip): + if handle.remote_addr_ipv6: self.remote_socket.setsockopt(IPV6, 1) - socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + socket_addr = handle.remote_subscribe_addr logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) From 1d8bcf73e3b18e989150281a925f42d34e63a5ea Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 23 Feb 2025 05:32:20 -0800 Subject: [PATCH 187/317] [Misc] Deprecate `--dataset` from `benchmark_serving.py` (#13708) Signed-off-by: Roger Wang --- benchmarks/benchmark_serving.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9760737ccec3..9416a22b7357 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -867,18 +867,10 @@ def main(args: argparse.Namespace): tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) - if args.dataset is not None: - warnings.warn( - "The '--dataset' argument will be deprecated in the next " - "release. Please use '--dataset-name' and " - "'--dataset-path' in the future runs.", - stacklevel=2) - input_requests = sample_sharegpt_requests( - dataset_path=args.dataset, - num_requests=args.num_prompts, - tokenizer=tokenizer, - fixed_output_len=args.sharegpt_output_len, - ) + if args.dataset_name is None: + raise ValueError( + "Please specify '--dataset-name' and the corresponding " + "'--dataset-path' if required.") elif args.dataset_name == "sharegpt": input_requests = sample_sharegpt_requests( @@ -1052,13 +1044,6 @@ def main(args: argparse.Namespace): default="/v1/completions", help="API endpoint.", ) - parser.add_argument( - "--dataset", - type=str, - default=None, - help="Path to the ShareGPT dataset, will be deprecated in the " - "next release.", - ) parser.add_argument( "--dataset-name", type=str, From dc8db38cf4a64475a1004127662f369172ce2318 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 23 Feb 2025 22:47:24 +0800 Subject: [PATCH 188/317] [v1] torchrun compatibility (#13642) Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 1 + tests/distributed/test_torchrun_example.py | 6 ++++++ tests/v1/engine/test_engine_core.py | 6 ++++-- vllm/config.py | 5 +++++ vllm/executor/ray_distributed_executor.py | 2 +- vllm/executor/ray_utils.py | 4 +++- vllm/executor/uniproc_executor.py | 7 ++++--- vllm/v1/engine/core.py | 2 +- vllm/v1/engine/llm_engine.py | 9 +++++++-- vllm/v1/executor/abstract.py | 20 +++++++++++++++++--- vllm/v1/executor/multiproc_executor.py | 5 +++-- vllm/v1/worker/gpu_worker.py | 7 +++---- vllm/v1/worker/tpu_worker.py | 6 +++--- vllm/worker/worker_base.py | 11 +++++++++-- 14 files changed, 67 insertions(+), 24 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 931057e6c197..05c4d2616990 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -503,6 +503,7 @@ steps: - entrypoints/llm/test_collective_rpc.py commands: - pytest -v -s entrypoints/llm/test_collective_rpc.py + - VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index a092a548a59c..1c6c28b4ed35 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -48,6 +48,12 @@ def test_consistent_across_ranks(obj): test_consistent_across_ranks( llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) +# make sure we can access the model parameters from the calling process +# of the `LLM` instance. +params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. + model.parameters()) +test_consistent_across_ranks(len(params)) + # all ranks should have the same outputs for output in outputs: prompt = output.prompt diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index d035668098eb..8c2998e58892 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -5,6 +5,7 @@ import time import uuid from concurrent.futures import Future +from typing import List import pytest from transformers import AutoTokenizer @@ -211,8 +212,9 @@ def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest: class DummyExecutor(UniProcExecutor): - def initialize(self, kv_cache_config: KVCacheConfig) -> None: - super().initialize(kv_cache_config) + def initialize_from_config( + self, kv_cache_configs: List[KVCacheConfig]) -> None: + super().initialize_from_config(kv_cache_configs) # This executor actually can only run 1 batch at a time self.semaphore = threading.Semaphore(1) diff --git a/vllm/config.py b/vllm/config.py index d3139b5fd84e..6bcf34c3cff9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1407,6 +1407,11 @@ def __post_init__(self) -> None: self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT self.world_size_across_dp = self.world_size * self.data_parallel_size + if self.distributed_executor_backend == "external_launcher": + import os + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + logger.info("Disabling V1 multiprocessing for external launcher.") + ray_only_devices = ["tpu"] from vllm.platforms import current_platform if (current_platform.device_type in ray_only_devices diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 79ca45d55d96..b866413e3a62 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -541,7 +541,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool): # and the TP group executes in SPMD fashion. if self.use_v1: outputs = [ - worker.execute_model. + worker.execute_model_ray. bind( # type: ignore[attr-defined] outputs[i]) for i, worker in enumerate(tp_group) ] diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 8ad466a5572e..1734c670bf10 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -112,10 +112,12 @@ def setup_device_if_necessary(self): torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - def execute_model( + def execute_model_ray( self, scheduler_output: "SchedulerOutput", ) -> "ModelRunnerOutput": + # this method is used to compile ray CG, + # and it needs a special logic of self.setup_device_if_necessary() self.setup_device_if_necessary() assert self.worker is not None, "Worker is not initialized" if isinstance(scheduler_output, tuple): diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 94db232240d5..e041215de660 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -93,9 +93,10 @@ def _init_executor(self) -> None: ("ExecutorWithExternalLauncher needs deterministic " "execution, so it" "does not support delay_factor in scheduling") - assert not envs.VLLM_USE_V1, \ - ("V1 architecture cannot guarantee deterministic execution, " - "so it is not supported in ExecutorWithExternalLauncher.") + if envs.VLLM_USE_V1: + assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ + ("To get deterministic execution in V1, " + "please set VLLM_ENABLE_V1_MULTIPROCESSING=0") self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) # engines are launched in torchrun-compatible launchers diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 981d23237e2a..85c97293af8b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -110,7 +110,7 @@ def _initialize_kv_caches(self, num_cpu_blocks = 0 # Initialize kv cache and warmup the execution - self.model_executor.initialize(kv_cache_configs) + self.model_executor.initialize_from_config(kv_cache_configs) elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 04c7ee109e0b..33b1ddc0f6fe 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -4,10 +4,10 @@ from typing_extensions import TypeVar +import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase -from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -44,6 +44,7 @@ def __init__( use_cached_outputs: bool = False, multiprocess_mode: bool = False, ) -> None: + self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -83,6 +84,10 @@ def __init__( log_stats=False, # FIXME: implement ) + if not multiprocess_mode: + # for v0 compatibility + self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + @classmethod def from_engine_args( cls, @@ -97,7 +102,7 @@ def from_engine_args( vllm_config = engine_args.create_engine_config(usage_context) executor_class = Executor.get_class(vllm_config) - if VLLM_ENABLE_V1_MULTIPROCESSING: + if envs.VLLM_ENABLE_V1_MULTIPROCESSING: logger.debug("Enabling multiprocessing for LLMEngine.") enable_multiprocessing = True diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 3663cbd08aec..11002ad0022d 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -3,6 +3,9 @@ from concurrent.futures import Future from typing import List, Type, Union +import torch +import torch.distributed as dist + from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase from vllm.executor.uniproc_executor import ( # noqa @@ -49,12 +52,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: f"{distributed_executor_backend}") return executor_class - def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None: + def initialize_from_config(self, + kv_cache_configs: List[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - self.collective_rpc("initialize_cache", args=(kv_cache_configs, )) + self.collective_rpc("initialize_from_config", + args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") def determine_available_memory(self) -> int: # in bytes @@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - pass + + def determine_available_memory(self) -> int: # in bytes + # same as determine_num_available_blocks in v0, + # we need to get the min across all ranks. + memory = super().determine_available_memory() + from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group + memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) + dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + return memory_tensor.item() diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 14492f273ed3..d4582122fa6d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -216,9 +216,10 @@ def __init__( "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, + "is_driver_worker": rank == 0, } wrapper.init_worker(all_kwargs) - self.worker = wrapper.worker + self.worker = wrapper pid = os.getpid() _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) @@ -239,7 +240,7 @@ def __init__( ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send(payload) - wrapper.init_device() + self.worker.init_device() self.worker.load_model() @staticmethod diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ece0fa555342..d9a415aee528 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional import torch import torch.distributed @@ -185,9 +185,8 @@ def determine_available_memory(self) -> int: def get_kv_cache_spec(self) -> KVCacheSpec: return self.model_runner.get_kv_cache_spec() - def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - kv_cache_config = kv_cache_configs[self.rank] if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") @@ -225,7 +224,7 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - return output if self.rank == 0 else None + return output if self.is_driver_worker else None def profile(self, is_start: bool = True): if self.profiler is None: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index af614cfa2843..ae124c819a90 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -37,6 +37,7 @@ def __init__( distributed_init_method: str, is_driver_worker: bool = False, ): + self.is_driver_worker = is_driver_worker self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -152,7 +153,7 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - return output if self.rank == 0 else None + return output if self.is_driver_worker else None def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) @@ -174,9 +175,8 @@ def get_model(self) -> nn.Module: def get_kv_cache_spec(self) -> KVCacheSpec: return self.model_runner.get_kv_cache_spec() - def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - kv_cache_config = kv_cache_configs[self.rank] self.model_runner.initialize_kv_cache(kv_cache_config) def check_health(self) -> None: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 44c26ed350a8..445c0d3285bf 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -567,6 +567,10 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: self.worker = worker_class(**kwargs) assert self.worker is not None + def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: + kv_cache_config = kv_cache_configs[self.rpc_rank] + self.worker.initialize_from_config(kv_cache_config) # type: ignore + def init_device(self): with set_current_vllm_config(self.vllm_config): # To make vLLM config available during device initialization @@ -574,8 +578,11 @@ def init_device(self): def execute_method(self, method: Union[str, bytes], *args, **kwargs): try: - target = self if self.worker is None else self.worker - return run_method(target, method, args, kwargs) + # method resolution order: + # if a method is defined in this class, it will be called directly. + # otherwise, since we define `__getattr__` and redirect attribute + # query to `self.worker`, the method will be called on the worker. + return run_method(self, method, args, kwargs) except Exception as e: # if the driver worker also execute methods, # exceptions in the rest worker may cause deadlock in rpc like ray From 9226797231ef509250bcae53807fd71a03f732cd Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 23 Feb 2025 13:07:43 -0800 Subject: [PATCH 189/317] [V1][BugFix] Fix engine core client shutdown hangs (#13298) Even though ZMQ context.destroy() is meant to close open sockets before terminating the context, it appears to be necessary to do this explicitly or else it can hang in the context.term() method. Close zmq sockets explicitly before terminating context, make shutdown of client resource more robust, shut down engine core process prior to terminating zmq context. Signed-off-by: Nick Hill --- tests/v1/engine/test_engine_core_client.py | 4 +- vllm/v1/engine/core_client.py | 51 ++++++++++++++++------ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 828d7eed309f..a7c02322ff02 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -3,7 +3,6 @@ import asyncio import time import uuid -from contextlib import ExitStack from typing import Dict, List, Optional import pytest @@ -178,7 +177,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): @pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_asyncio(monkeypatch): - with monkeypatch.context() as m, ExitStack() as after: + with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") # Monkey-patch core engine utility function to test. @@ -195,7 +194,6 @@ async def test_engine_core_client_asyncio(monkeypatch): executor_class=executor_class, log_stats=True, ) - after.callback(client.shutdown) MAX_TOKENS = 20 params = SamplingParams(max_tokens=MAX_TOKENS) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 527aa72833ba..5ffaf63e6cec 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -8,6 +8,7 @@ import weakref from abc import ABC, abstractmethod from concurrent.futures import Future +from dataclasses import dataclass from threading import Thread from typing import Any, Dict, List, Optional, Type, Union @@ -169,6 +170,31 @@ def add_lora(self, lora_request: LoRARequest) -> None: self.engine_core.add_lora(lora_request) +@dataclass +class BackgroundResources: + """Used as a finalizer for clean shutdown, avoiding + circular reference back to the client object.""" + + ctx: Union[zmq.Context, zmq.asyncio.Context] = None + output_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None + input_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None + proc_handle: Optional[BackgroundProcHandle] = None + + def __call__(self): + """Clean up background resources.""" + + if self.proc_handle is not None: + self.proc_handle.shutdown() + # ZMQ context termination can hang if the sockets + # aren't explicitly closed first. + if self.output_socket is not None: + self.output_socket.close(linger=0) + if self.input_socket is not None: + self.input_socket.close(linger=0) + if self.ctx is not None: + self.ctx.destroy(linger=0) + + class MPClient(EngineCoreClient): """ MPClient: base client for multi-proc EngineCore. @@ -212,21 +238,22 @@ def sigusr1_handler(signum, frame): zmq.asyncio.Context() # type: ignore[attr-defined] if asyncio_mode else zmq.Context()) # type: ignore[attr-defined] - # Note(rob): shutdown function cannot be a bound method, - # else the gc cannot collect the object. - self._finalizer = weakref.finalize(self, lambda x: x.destroy(linger=0), - self.ctx) + # This will ensure resources created so far are closed + # when the client is garbage collected, even if an + # exception is raised mid-construction. + resources = BackgroundResources(ctx=self.ctx) + self._finalizer = weakref.finalize(self, resources) # Paths and sockets for IPC. output_path = get_open_zmq_ipc_path() input_path = get_open_zmq_ipc_path() - self.output_socket = make_zmq_socket(self.ctx, output_path, - zmq.constants.PULL) - self.input_socket = make_zmq_socket(self.ctx, input_path, - zmq.constants.PUSH) + resources.output_socket = make_zmq_socket(self.ctx, output_path, + zmq.constants.PULL) + resources.input_socket = make_zmq_socket(self.ctx, input_path, + zmq.constants.PUSH) # Start EngineCore in background process. - self.proc_handle = BackgroundProcHandle( + resources.proc_handle = BackgroundProcHandle( input_path=input_path, output_path=output_path, process_name="EngineCore", @@ -237,13 +264,11 @@ def sigusr1_handler(signum, frame): "log_stats": log_stats, }) + self.output_socket = resources.output_socket + self.input_socket = resources.input_socket self.utility_results: Dict[int, AnyFuture] = {} def shutdown(self): - """Clean up background resources.""" - if hasattr(self, "proc_handle"): - self.proc_handle.shutdown() - self._finalizer() From bc7c0aa219b76d6ca6682f00e5dc230cea51dada Mon Sep 17 00:00:00 2001 From: Huy Do Date: Sun, 23 Feb 2025 18:23:18 -0800 Subject: [PATCH 190/317] Fix some issues with benchmark data output (#13641) Signed-off-by: Huy Do --- .../convert-results-json-to-markdown.py | 27 +++++++++++++---- .../scripts/run-performance-benchmarks.sh | 3 ++ .../tests/throughput-tests.json | 2 +- benchmarks/benchmark_latency.py | 5 ++-- benchmarks/benchmark_serving.py | 5 ++-- benchmarks/benchmark_throughput.py | 5 ++-- benchmarks/benchmark_utils.py | 30 +++++++++++++++++++ 7 files changed, 61 insertions(+), 16 deletions(-) diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index e031686c7a29..1030ec24e8d7 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -84,8 +84,13 @@ def results_to_json(latency, throughput, serving): # this result is generated via `benchmark_serving.py` # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands")) as f: - command = json.loads(f.read()) + try: + with open(test_file.with_suffix(".commands")) as f: + command = json.loads(f.read()) + except OSError as e: + print(e) + continue + raw_result.update(command) # update the test name of this result @@ -99,8 +104,13 @@ def results_to_json(latency, throughput, serving): # this result is generated via `benchmark_latency.py` # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands")) as f: - command = json.loads(f.read()) + try: + with open(test_file.with_suffix(".commands")) as f: + command = json.loads(f.read()) + except OSError as e: + print(e) + continue + raw_result.update(command) # update the test name of this result @@ -121,8 +131,13 @@ def results_to_json(latency, throughput, serving): # this result is generated via `benchmark_throughput.py` # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands")) as f: - command = json.loads(f.read()) + try: + with open(test_file.with_suffix(".commands")) as f: + command = json.loads(f.read()) + except OSError as e: + print(e) + continue + raw_result.update(command) # update the test name of this result diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index 9425cb07ec01..a3555f72a666 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -309,11 +309,14 @@ run_serving_tests() { new_test_name=$test_name"_qps_"$qps + # pass the tensor parallel size to the client so that it can be displayed + # on the benchmark dashboard client_command="python3 benchmark_serving.py \ --save-result \ --result-dir $RESULTS_FOLDER \ --result-filename ${new_test_name}.json \ --request-rate $qps \ + --metadata "tensor_parallel_size=$tp" \ $client_args" echo "Running test case $test_name with qps $qps" diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests.json b/.buildkite/nightly-benchmarks/tests/throughput-tests.json index 91ef6d16be63..9bc87cbcd2bc 100644 --- a/.buildkite/nightly-benchmarks/tests/throughput-tests.json +++ b/.buildkite/nightly-benchmarks/tests/throughput-tests.json @@ -32,4 +32,4 @@ "backend": "vllm" } } -] \ No newline at end of file +] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 71ec909cba48..c82358d14512 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ import numpy as np import torch -from benchmark_utils import convert_to_pytorch_benchmark_format +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm import tqdm from vllm import LLM, SamplingParams @@ -30,8 +30,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, for k in ["avg_latency", "percentiles"]}) if pt_records: pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" - with open(pt_file, "w") as f: - json.dump(pt_records, f) + write_to_json(pt_file, pt_records) def main(args: argparse.Namespace): diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9416a22b7357..1bb83b082beb 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -56,7 +56,7 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_utils import convert_to_pytorch_benchmark_format +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -841,8 +841,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" - with open(pt_file, "w") as f: - json.dump(pt_records, f) + write_to_json(pt_file, pt_records) def main(args: argparse.Namespace): diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index ca54213c0646..04de08fa97c9 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -11,7 +11,7 @@ import torch import uvloop -from benchmark_utils import convert_to_pytorch_benchmark_format +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from PIL import Image from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, @@ -366,8 +366,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" - with open(pt_file, "w") as f: - json.dump(pt_records, f) + write_to_json(pt_file, pt_records) def main(args: argparse.Namespace): diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index 6f01cf20e17c..ac0688ca013f 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import json +import math import os from typing import Any, Dict, List @@ -34,6 +36,34 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, "extra_info": extra_info, }, } + + tp = record["benchmark"]["extra_info"]["args"].get( + "tensor_parallel_size") + # Save tensor_parallel_size parameter if it's part of the metadata + if not tp and "tensor_parallel_size" in extra_info: + record["benchmark"]["extra_info"]["args"][ + "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + records.append(record) return records + + +class InfEncoder(json.JSONEncoder): + + def clear_inf(self, o: Any): + if isinstance(o, dict): + return {k: self.clear_inf(v) for k, v in o.items()} + elif isinstance(o, list): + return [self.clear_inf(v) for v in o] + elif isinstance(o, float) and math.isinf(o): + return "inf" + return o + + def iterencode(self, o: Any, *args, **kwargs) -> Any: + return super().iterencode(self.clear_inf(o), *args, **kwargs) + + +def write_to_json(filename: str, records: List) -> None: + with open(filename, "w") as f: + json.dump(records, f, cls=InfEncoder) From fa87a0a56a1cdfb055bef66795765863f4b00d32 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Sun, 23 Feb 2025 22:32:11 -0800 Subject: [PATCH 191/317] [ci] Add logic to change model to S3 path only when S3 CI env var is on (#13727) Signed-off-by: <> Co-authored-by: EC2 Default User --- tests/metrics/test_metrics.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index d6183379c394..b276d9d9cb4e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -7,6 +7,7 @@ import ray from prometheus_client import REGISTRY +import vllm.envs as envs from vllm import EngineArgs, LLMEngine from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import AsyncEngineArgs @@ -141,8 +142,10 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] metrics_tag_content = stat_logger.labels["model_name"] + if envs.VLLM_CI_USE_S3: + model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" if served_model_name is None or served_model_name == []: - assert metrics_tag_content == f"{MODEL_WEIGHTS_S3_BUCKET}/{model}", ( + assert metrics_tag_content == model, ( f"Metrics tag model_name is wrong! expect: {model!r}\n" f"actual: {metrics_tag_content!r}") else: @@ -215,8 +218,9 @@ def test_engine_log_metrics_regression( while engine.has_unfinished_requests(): engine.step() - assert_metrics(f"{MODEL_WEIGHTS_S3_BUCKET}/{model}", engine, - disable_log_stats, len(example_prompts)) + if envs.VLLM_CI_USE_S3: + model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" + assert_metrics(model, engine, disable_log_stats, len(example_prompts)) @pytest.mark.parametrize("model", MODELS) From 49f7ae252d13487ef41d2ca40d502dae103fb4b7 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 24 Feb 2025 06:10:06 -0800 Subject: [PATCH 192/317] [V1][Core] Fix memory issue with logits & sampling (#13721) --- vllm/v1/worker/gpu_model_runner.py | 68 +++++++++++++++++------------- vllm/v1/worker/gpu_worker.py | 10 +++++ 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a7b9d4781183..cf6bdd050e4a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1179,6 +1179,43 @@ def _dummy_run( ) return hidden_states + @torch.inference_mode() + def _dummy_sampler_run( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + logits = self.model.compute_logits(hidden_states, None) + num_reqs = logits.size(0) + + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.5), + all_greedy=False, + all_random=False, + spec_token_ids=None, + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)], + allowed_token_ids_mask=None, + ) + sampler_output = self.model.sample(logits=logits, + sampling_metadata=dummy_metadata) + + return sampler_output + def profile_run(self) -> None: # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. @@ -1306,38 +1343,11 @@ def profile_run(self) -> None: dummy_kv_caches) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] - logits = self.model.compute_logits(hidden_states, None) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) - dummy_metadata = SamplingMetadata( - temperature=dummy_tensors(0.5), - all_greedy=False, - all_random=False, - spec_token_ids=None, - top_p=dummy_tensors(0.9), - top_k=dummy_tensors(logits.size(1) - 1), - min_p=None, - generators={}, - max_num_logprobs=None, - no_penalties=True, - prompt_token_ids=torch.ones_like(logits, - dtype=torch.int64), - frequency_penalties=dummy_tensors(0.1), - presence_penalties=dummy_tensors(0.1), - repetition_penalties=dummy_tensors(0.1), - output_token_ids=[[] for _ in range(num_reqs)], - min_tokens={}, - logit_bias=[None for _ in range(num_reqs)], - allowed_token_ids_mask=None, - ) - sampler_output = self.model.sample( - logits=logits, sampling_metadata=dummy_metadata) + sampler_output = self._dummy_sampler_run(hidden_states) else: - logits = None sampler_output = None - dummy_metadata = None torch.cuda.synchronize() - del hidden_states, logits, sampler_output, dummy_metadata + del hidden_states, sampler_output self.encoder_cache.clear() gc.collect() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d9a415aee528..d9030aae51d1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -211,6 +211,16 @@ def compile_or_warm_up_model(self) -> None: self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model() + + # Warm up sampler and preallocate memory buffer for logits and other + # sampling related tensors of max possible shape to avoid memory + # fragmentation issue. + # NOTE: This is called after `capture_model` on purpose to prevent + # memory buffers from being cleared by `torch.cuda.empty_cache`. + self.model_runner._dummy_sampler_run( + hidden_states=self.model_runner._dummy_run( + num_tokens=self.scheduler_config.max_num_seqs)) + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From 72f17436ec53b806d740b58e835ae6ddfa726be2 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 24 Feb 2025 22:10:14 +0800 Subject: [PATCH 193/317] [model][refactor] remove cuda hard code in models and layers (#13658) --- .../layers/fused_moe/fused_marlin_moe.py | 3 ++- vllm/model_executor/layers/rotary_embedding.py | 13 +++++++++---- .../layers/spec_decode_base_sampler.py | 4 +++- vllm/model_executor/model_loader/loader.py | 3 ++- vllm/model_executor/models/arctic.py | 5 +++-- vllm/model_executor/models/minicpm.py | 5 +++-- vllm/model_executor/models/minicpmv.py | 10 +++++++--- 7 files changed, 29 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 4ca569ca4f19..ee158d7ee474 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import direct_register_custom_op @@ -238,7 +239,7 @@ def fused_marlin_moe( max_workspace_size = (max(2 * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, - device="cuda", + device=current_platform.device_type, requires_grad=False) if has_no_zp: diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 5d7f9396c20b..ce1bc98ea426 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -30,6 +30,7 @@ from transformers import PretrainedConfig from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -650,9 +651,13 @@ def __init__( is_neox_style, dtype) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / - self.rotary_dim) + pos_freqs = self.base**( + torch.arange(0, + self.rotary_dim, + 2, + dtype=torch.float, + device=current_platform.device_type) / + self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) @@ -670,7 +675,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device="cuda", + device=current_platform.device_type, dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = (freqs.cos() * self.mscale) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 35c7ffec271e..54fd43fc6592 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -7,6 +7,8 @@ import torch.jit import torch.nn as nn +from vllm.platforms import current_platform + class SpecDecodeBaseSampler(nn.Module): """Base class for samplers used for Speculative Decoding verification @@ -35,7 +37,7 @@ def __init__(self, strict_mode: bool = False): def init_gpu_tensors(self, device: Union[int, str]) -> None: assert self.num_accepted_tokens is None if isinstance(device, int): - device = f"cuda:{device}" + device = f"{current_platform.device_type}:{device}" elif not isinstance(device, str): raise ValueError(f"Device must be int or str, get {type(device)}") self.num_accepted_tokens = torch.tensor(0, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8736cf1ca341..e23c63758556 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -914,7 +914,8 @@ def _parse_quant_state(param_name: str, if param_name + "." in k: quant_state[k] = temp_state_dict[k] - return QuantState.from_dict(quant_state, device="cuda") + return QuantState.from_dict(quant_state, + device=current_platform.device_type) # Second iterate over all prequant and normal weights # pre quantized weights would have a quant_state diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 27df448e63f7..77f383b6e46d 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -30,6 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.arctic import ArcticConfig @@ -138,13 +139,13 @@ def __init__(self, torch.empty(self.num_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", + device=current_platform.device_type, dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_experts, self.hidden_size, self.intermediate_size, - device="cuda", + device=current_platform.device_type, dtype=self.params_dtype)) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 52ab89488785..54b691b3572d 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -51,6 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -98,13 +99,13 @@ def __init__( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", + device=current_platform.device_type, dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - device="cuda", + device=current_platform.device_type, dtype=self.params_dtype)) set_weight_attrs(self.ws, { diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 1f278b65740c..5e883d00c1c6 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -59,6 +59,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .idefics2_vision_model import Idefics2VisionTransformer @@ -1184,7 +1185,8 @@ def init_resampler(self, quant_config=quant_config, prefix=prefix) - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -1266,7 +1268,8 @@ def init_resampler(self, quant_config=quant_config, prefix=prefix) - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -1360,7 +1363,8 @@ def init_resampler(self, quant_config=quant_config, prefix=prefix) - return resampler.to(device="cuda", dtype=torch.get_default_dtype()) + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) def get_vision_embedding( self, From 458d3a9af8a9710f5dd603fc5e6b5372f2ea3095 Mon Sep 17 00:00:00 2001 From: Roger Meier Date: Mon, 24 Feb 2025 15:10:25 +0100 Subject: [PATCH 194/317] [Bugfix] fix(logging): add missing opening square bracket (#13011) --- vllm/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/logger.py b/vllm/logger.py index b20d55e3c101..0ee47de173ad 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -20,7 +20,7 @@ VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX _FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " - "%(filename)s:%(lineno)d] %(message)s") + "[%(filename)s:%(lineno)d] %(message)s") _DATE_FORMAT = "%m-%d %H:%M:%S" DEFAULT_LOGGING_CONFIG = { From 449c61fd497136cce6c1e5ec6316f772ea85f3de Mon Sep 17 00:00:00 2001 From: Roger Meier Date: Mon, 24 Feb 2025 15:10:33 +0100 Subject: [PATCH 195/317] [CI/Build] add python-json-logger to requirements-common (#12842) --- examples/other/logging_configuration.md | 9 ++------- requirements-common.txt | 1 + 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/other/logging_configuration.md b/examples/other/logging_configuration.md index acd9c1f2bc0a..c70b853c1276 100644 --- a/examples/other/logging_configuration.md +++ b/examples/other/logging_configuration.md @@ -49,7 +49,8 @@ disabled, an error will occur while starting vLLM. ### Example 1: Customize vLLM root logger For this example, we will customize the vLLM root logger to use -[`python-json-logger`](https://github.com/madzak/python-json-logger) to log to +[`python-json-logger`](https://github.com/nhairs/python-json-logger) +(which is part of the container image) to log to STDOUT of the console in JSON format with a log level of `INFO`. To begin, first, create an appropriate JSON logging configuration file: @@ -82,12 +83,6 @@ To begin, first, create an appropriate JSON logging configuration file: } ``` -Next, install the `python-json-logger` package if it's not already installed: - -```bash -pip install python-json-logger -``` - Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set to the path of the custom logging configuration JSON file: diff --git a/requirements-common.txt b/requirements-common.txt index c0df136f500e..0514bf8adcaf 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -38,3 +38,4 @@ compressed-tensors == 0.9.2 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files +python-json-logger # Used by logging as per examples/other/logging_configuration.md From 521dce4bd39b0e86c70e0b4feeabbead3374e557 Mon Sep 17 00:00:00 2001 From: Jongseok Park <37990712+cakeng@users.noreply.github.com> Date: Mon, 24 Feb 2025 07:33:20 -0800 Subject: [PATCH 196/317] Expert Parallelism (EP) Support for DeepSeek V2 (#12583) --- benchmarks/kernels/benchmark_moe.py | 3 +- tests/distributed/test_expert_parallel.py | 227 ++++++++++++++++++ tests/kernels/test_awq_marlin.py | 9 +- tests/kernels/test_moe.py | 65 ++++- tests/kernels/utils.py | 4 +- tests/utils.py | 6 +- vllm/config.py | 20 ++ vllm/envs.py | 7 + .../layers/fused_moe/fused_moe.py | 126 +++++++--- vllm/model_executor/layers/fused_moe/layer.py | 70 +++++- .../layers/fused_moe/moe_torch_iterative.py | 10 +- .../layers/quantization/awq_marlin.py | 7 + .../compressed_tensors_moe.py | 10 + .../layers/quantization/experts_int8.py | 4 + .../model_executor/layers/quantization/fp8.py | 4 + .../layers/quantization/gptq_marlin.py | 2 + .../layers/quantization/moe_wna16.py | 4 + .../layers/quantization/quark/quark_moe.py | 4 + vllm/model_executor/models/deepseek_v2.py | 4 - 19 files changed, 527 insertions(+), 59 deletions(-) create mode 100644 tests/distributed/test_expert_parallel.py diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index a4a45c9cbff2..410750686ee1 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -468,7 +468,8 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] == "DeepseekV3ForCausalLM": + elif (config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM"): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size diff --git a/tests/distributed/test_expert_parallel.py b/tests/distributed/test_expert_parallel.py new file mode 100644 index 000000000000..bc5770642b79 --- /dev/null +++ b/tests/distributed/test_expert_parallel.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import List, Literal, NamedTuple, Optional + +import pytest + +from vllm.config import TaskOption +from vllm.logger import init_logger + +from ..utils import compare_two_settings, fork_new_process_for_each_test + +logger = init_logger("test_expert_parallel") + + +class ParallelSetup(NamedTuple): + tp_size: int + eager_mode: bool + chunked_prefill: bool + + +class EPTestOptions(NamedTuple): + trust_remote_code: bool + tokenizer_mode: Optional[str] + load_format: Optional[str] = None + hf_overrides: Optional[str] = None + + +@dataclass +class EPTestSettings: + parallel_setups: List[ParallelSetup] + distributed_backends: List[str] + task: TaskOption + test_options: EPTestOptions + + @staticmethod + def detailed( + *, + tp_base: int = 2, + task: TaskOption = "auto", + trust_remote_code: bool = False, + tokenizer_mode: Optional[str] = None, + load_format: Optional[str] = None, + hf_overrides: Optional[str] = None, + ): + return EPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=2 * tp_base, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=2 * tp_base, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp", "ray"], + task=task, + test_options=EPTestOptions(trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides), + ) + + @staticmethod + def fast( + *, + tp_base: int = 2, + task: TaskOption = "auto", + trust_remote_code: bool = False, + tokenizer_mode: Optional[str] = None, + load_format: Optional[str] = None, + hf_overrides: Optional[str] = None, + ): + return EPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp"], + task=task, + test_options=EPTestOptions(trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides), + ) + + def iter_params(self, model_name: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for distributed_backend in self.distributed_backends: + yield (model_name, parallel_setup, distributed_backend, + self.task, opts) + + +# NOTE: You can adjust tp_base locally to fit the model in GPU +# The values displayed here are only a rough indicator of the size of the model + +# yapf: disable +TEST_MODELS = { + "deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast( + trust_remote_code=True), + "mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4), +} + + +def _compare_tp( + model_name: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + task: TaskOption, + test_options: EPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate"], +): + ( + tp_size, + eager_mode, + chunked_prefill, + ) = parallel_setup + ( + trust_remote_code, + tokenizer_mode, + load_format, + hf_overrides, + ) = test_options + + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} GPUs") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + "--load-format", + "auto", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", hf_overrides]) + + ep_env = { + "VLLM_TEST_ENABLE_EP": "1", + } + + ep_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + # compare without expert parallelism + tp_env = { + "VLLM_TEST_ENABLE_EP": "0", + } + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + try: + compare_two_settings(model_name, + ep_args, + tp_args, + ep_env, + tp_env, + method=method, + max_wait_seconds=360) + except Exception: + raise + + +@pytest.mark.parametrize( + ("model_name", "parallel_setup", "distributed_backend", "task", + "test_options"), + [ + params for model_name, settings in TEST_MODELS.items() + for params in settings.iter_params(model_name) + ], +) +@fork_new_process_for_each_test +def test_ep( + model_name: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + task: TaskOption, + test_options: EPTestOptions, + num_gpus_available, +): + _compare_tp(model_name, + parallel_setup, + distributed_backend, + task, + test_options, + num_gpus_available, + method="generate") diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 67595010cb2a..939b0e7157be 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq( num_bits=num_bits, ) - torch_output = torch_moe( - a, - w_ref1.transpose(1, 2), - w_ref2.transpose(1, 2), - score, - topk, - ) + torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2), + score, topk, None) assert compute_max_diff(marlin_output, torch_output) < 4e-2 diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 0f13fbc96503..2f5c69046f48 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -26,6 +26,7 @@ from vllm.scalar_type import scalar_types NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] TOP_KS = [2, 6] @@ -34,6 +35,7 @@ @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_moe( m: int, @@ -41,6 +43,7 @@ def test_fused_moe( k: int, e: int, topk: int, + ep_size: int, dtype: torch.dtype, ): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 @@ -48,10 +51,38 @@ def test_fused_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) - torch_output = torch_moe(a, w1, w2, score, topk) + + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randint(0, + e, (local_e, ), + device="cuda", + dtype=torch.int32) + e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1 = w1[e_ids] + w2 = w2[e_ids] + else: + e_map = None + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + torch_output = torch_moe(a, w1, w2, score, topk, e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False) + iterative_output = iterative_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, @@ -63,13 +94,14 @@ def test_fused_moe( @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype, group_size: int, has_zp: bool, - weight_bits: int): + ep_size: int, dtype: torch.dtype, group_size: int, + has_zp: bool, weight_bits: int): print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -130,6 +162,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, if has_zp: w_qzeros[expert_id] = qzeros + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randint(0, + e, (local_e, ), + device="cuda", + dtype=torch.int32) + e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1_ref = w1_ref[e_ids] + w2_ref = w2_ref[e_ids] + w1_qweight = w1_qweight[e_ids] + w2_qweight = w2_qweight[e_ids] + w1_scales = w1_scales[e_ids] + w2_scales = w2_scales[e_ids] + w1_qzeros = w1_qzeros[e_ids] + w2_qzeros = w2_qzeros[e_ids] + else: + e_map = None + triton_output = fused_moe(a, w1_qweight, w2_qweight, @@ -138,12 +189,14 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, renormalize=False, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, w1_scale=w1_scales, w2_scale=w2_scales, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5be111d71308..1ee3a3325037 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1053,7 +1053,7 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_moe(a, w1, w2, score, topk): +def torch_moe(a, w1, w2, score, topk, expert_map): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) @@ -1061,6 +1061,8 @@ def torch_moe(a, w1, w2, score, topk): topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) + if expert_map is not None: + topk_ids = expert_map[topk_ids] for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): diff --git a/tests/utils.py b/tests/utils.py index f39cbe7ede03..2ad91ca2c869 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -297,12 +297,12 @@ def _test_completion_close( logprobs=5, temperature=0.0) - logporbs = completion.choices[0].logprobs.top_logprobs[0] - logporbs = {k: round(v, 2) for k, v in logporbs.items()} + logprobs = completion.choices[0].logprobs.top_logprobs[0] + logprobs = {k: round(v, 2) for k, v in logprobs.items()} results.append({ "test": "completion_close", - "logprobs": logporbs, + "logprobs": logprobs, }) return results diff --git a/vllm/config.py b/vllm/config.py index 6bcf34c3cff9..ace49a86eaef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -677,6 +677,23 @@ def _verify_bnb_config(self) -> None: "fallback to the eager mode.") self.enforce_eager = True + def _verify_with_expert_parallelism(self) -> None: + num_expert_names = [ + "moe_num_experts", # Dbrx + "num_experts", # Jamba + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = 0 + for name in num_expert_names: + num_experts = getattr(self.hf_text_config, name, 0) + if num_experts > 0: + break + if num_experts < 1: + raise ValueError( + "Number of experts in the model must be greater than 0 " + "when expert parallelism is enabled.") + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -730,6 +747,9 @@ def verify_with_parallel_config( " must be divisible by tensor parallel size " f"({tensor_parallel_size}).") + if envs.VLLM_TEST_ENABLE_EP: + self._verify_with_expert_parallelism() + pipeline_parallel_size = parallel_config.pipeline_parallel_size if pipeline_parallel_size > 1: architectures = getattr(self.hf_config, "architectures", []) diff --git a/vllm/envs.py b/vllm/envs.py index dbf1d4623962..84426cb5bb22 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -86,6 +86,7 @@ VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True VLLM_MLA_DISABLE_REQUANTIZATION: bool = False VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True + VLLM_TEST_ENABLE_EP: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" @@ -570,6 +571,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) ), + # If set, vLLM will use the experimental expert parallel implementation on + # the FusedMoE layer, using tensor parallelism size as expert parallelism + # size. + "VLLM_TEST_ENABLE_EP": + lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))), + # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, # so that users can colocate other actors on the same GPUs as vLLM. diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 543c8ced165a..4cab72a29da4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -20,6 +20,18 @@ logger = init_logger(__name__) +@triton.jit +def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, + token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, + compute_type): + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel_gptq_awq( # Pointers to matrices @@ -120,17 +132,26 @@ def fused_moe_kernel_gptq_awq( offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) - off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) - if use_int4_w4a16: b_ptrs = b_ptr + off_experts * stride_be + \ - (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ + stride_bn b_shifter = (offs_k[:, None] % 2) * 4 elif use_int8_w8a16: b_ptrs = b_ptr + off_experts * stride_be + \ @@ -170,7 +191,8 @@ def fused_moe_kernel_gptq_awq( b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ offs_bn[None, :] * stride_bsn + \ - ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \ + stride_bsk b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = b_scale.to(tl.float32) @@ -319,13 +341,22 @@ def fused_moe_kernel( offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) - off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) if use_int8_w8a16: @@ -349,7 +380,6 @@ def fused_moe_kernel( # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. @@ -544,8 +574,11 @@ def moe_align_block_size_triton( def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, - num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. @@ -555,6 +588,10 @@ def moe_align_block_size( top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. + - expert_map: A tensor of shape [num_experts] that maps the expert index + from the global space to the local index space of the current + expert parallel shard. If the expert is not in the current expert + parallel shard, the mapping is set to -1. Returns: - sorted_token_ids: A tensor containing the sorted token indices according @@ -589,7 +626,9 @@ def moe_align_block_size( device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), + # Expert ids must be zeroed out to prevent index out of bounds error while + # mapping global expert ids to local expert ids in expert parallelism. + expert_ids = torch.zeros((max_num_m_blocks, ), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), @@ -618,6 +657,9 @@ def moe_align_block_size( else: ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + return sorted_ids, expert_ids, num_tokens_post_pad @@ -1001,6 +1043,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1009,8 +1053,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, - w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, + global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, w2_zp, a1_scale, a2_scale, block_shape) def inplace_fused_experts_fake( @@ -1022,6 +1067,8 @@ def inplace_fused_experts_fake( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1049,6 +1096,8 @@ def outplace_fused_experts( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1058,8 +1107,9 @@ def outplace_fused_experts( block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, use_fp8_w8a8, use_int8_w8a16, - use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, - a1_scale, a2_scale, block_shape) + use_int4_w4a16, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, + a2_scale, block_shape) def outplace_fused_experts_fake( @@ -1071,6 +1121,8 @@ def outplace_fused_experts_fake( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1098,26 +1150,27 @@ def fused_experts(hidden_states: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None): + block_shape: Optional[List[int]] = None) -> torch.Tensor: + if inplace: - torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, - topk_weights, topk_ids, - use_fp8_w8a8, use_int8_w8a16, - use_int4_w4a16, w1_scale, - w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape) + torch.ops.vllm.inplace_fused_experts( + hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, + use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) return hidden_states else: return torch.ops.vllm.outplace_fused_experts( hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, - a1_scale, a2_scale, block_shape) + use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) def fused_experts_impl(hidden_states: torch.Tensor, @@ -1129,6 +1182,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1153,6 +1208,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_tokens, _ = hidden_states.shape E, N, _ = w1.shape + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.shape[1] # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE @@ -1166,20 +1224,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, try_get_optimal_moe_config, w1.shape, w2.shape, - topk_ids.shape[1], + top_k_num, config_dtype, block_shape=block_shape, ) config = get_config_func(M) - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + intermediate_cache1 = torch.empty((M, top_k_num, N), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) @@ -1221,7 +1279,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) invoke_fused_moe_kernel(curr_hidden_states, w1, @@ -1235,7 +1294,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], + top_k_num, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, @@ -1286,6 +1345,8 @@ def fused_moe( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1320,6 +1381,11 @@ def fused_moe( - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 activation to compute the inner products for w1 and w2. Defaults to False. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -1334,8 +1400,6 @@ def fused_moe( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - # Check constraints. - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" if use_grouped_topk: assert num_expert_group is not None and topk_group is not None @@ -1358,6 +1422,8 @@ def fused_moe( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, w1_zp=w1_zp, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f18c0313355d..49400b699cce 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,6 +6,7 @@ import torch +import vllm.envs as envs from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -55,6 +56,8 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None @@ -113,6 +116,8 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None @@ -125,6 +130,8 @@ def apply( use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) @@ -139,6 +146,8 @@ def forward_cuda( renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None @@ -160,7 +169,9 @@ def forward_cuda( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True) + inplace=True, + global_num_experts=global_num_experts, + expert_map=expert_map) def forward_cpu( self, @@ -172,6 +183,8 @@ def forward_cpu( renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, **kwargs, ): @@ -196,6 +209,8 @@ def forward_tpu( renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None @@ -215,6 +230,8 @@ def forward_tpu( w2=layer.w2_weight, topk=top_k, gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, renormalize=renormalize) forward_native = forward_cuda @@ -255,6 +272,7 @@ def __init__( topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + ep_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", @@ -267,8 +285,13 @@ def __init__( self.tp_size = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) + if envs.VLLM_TEST_ENABLE_EP: + self.ep_size = self.tp_size + self.tp_size = 1 + else: + self.ep_size = 1 self.top_k = top_k - self.num_experts = num_experts + self.num_experts = num_experts # Global number of experts assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results @@ -281,6 +304,26 @@ def __init__( self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.expert_map = None + + if self.ep_size > 1: + # Create a tensor of size num_experts filled with -1 + self.expert_map = torch.full((self.num_experts, ), + -1, + dtype=torch.int32) + # Create a expert map for the local experts + local_num_experts = num_experts // self.ep_size + ep_rank = get_tensor_model_parallel_rank() + if ep_rank < (self.ep_size - 1): + # Each non-last rank gets local_num_experts experts. + self.expert_map[ep_rank * local_num_experts: + (ep_rank + 1) * local_num_experts] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) + else: + # All remaining experts are assigned to the last rank. + local_num_experts = num_experts - ep_rank * local_num_experts + self.expert_map[-local_num_experts:] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -293,8 +336,11 @@ def __init__( self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None + local_num_experts = torch.sum(self.expert_map != -1) \ + if self.expert_map is not None else num_experts + moe_quant_params = { - "num_experts": num_experts, + "num_experts": local_num_experts, "hidden_size": hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, @@ -423,10 +469,22 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) + def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: + if self.expert_map is None: + return expert_id + return self.expert_map[expert_id].item() + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return + + # TP rank is set to 0 if EP is enabled + tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank() + # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -447,7 +505,6 @@ def weight_loader(self, param: torch.nn.Parameter, SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] - tp_rank = get_tensor_model_parallel_rank() # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -590,13 +647,16 @@ def forward(self, hidden_states: torch.Tensor, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.num_experts, + expert_map=self.expert_map, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias) - if self.reduce_results and self.tp_size > 1: + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py index d9a5de1b3033..da27633f2723 100644 --- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -10,7 +10,9 @@ def fused_moe( w2: torch.Tensor, gating_output: torch.Tensor, topk: int, - renormalize: bool, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, ) -> torch.Tensor: """ Args: @@ -18,6 +20,7 @@ def fused_moe( w1: [num_experts, intermediate_size * 2, hidden_size] w2: [num_experts, hidden_size, intermediate_size] gating_output: [*, num_experts] + expert_map: [num_experts] """ orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] @@ -27,13 +30,16 @@ def fused_moe( dtype = hidden_states.dtype hidden_states = hidden_states.view(num_tokens, hidden_size) - gating_output = gating_output.view(num_tokens, num_experts) + gating_output = gating_output.view(num_tokens, global_num_experts) topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(dtype) + if expert_map is not None: + selected_experts = expert_map[selected_experts] + final_hidden_states = None for expert_idx in range(num_experts): expert_w1 = w1[expert_idx] diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 111b3f74d50e..0e8c4c7b3ac5 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -464,10 +464,17 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if expert_map is not None: + raise NotImplementedError( + "Expert Parallelism is not supported for " + "fused Marlin MoE method.") + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index db8e8a4b6c11..389359a663cc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -214,6 +214,8 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -239,6 +241,8 @@ def apply( topk_ids=topk_ids, inplace=True, use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, @@ -540,10 +544,16 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if expert_map is not None: + raise NotImplementedError( + "Expert Parallelism is not supported for " + "fused Marlin MoE method.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 663fb8bf5b8e..0767926ee5c0 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -108,6 +108,8 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -133,6 +135,8 @@ def apply( topk_ids=topk_ids, inplace=True, use_int8_w8a16=True, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_scale, w2_scale=layer.w2_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1ca39b0ffa82..9f4cd2aa7378 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -670,6 +670,8 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -697,6 +699,8 @@ def apply( topk_ids=topk_ids, inplace=True, use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=(layer.w13_weight_scale_inv if self.block_quant else layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale_inv diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9f960d9fd37f..241fc7d777a6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -585,6 +585,8 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index da06ca3f70ec..a3adac1bb129 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -288,6 +288,8 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -317,6 +319,8 @@ def apply( inplace=True, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, w1_zp=layer.w13_qzeros if has_zp else None, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 98743b15e4b2..36b08589fd16 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -198,6 +198,8 @@ def apply( use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -223,6 +225,8 @@ def apply( topk_ids=topk_ids, inplace=True, use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a4d52c613b3e..9bf3ec2ffd81 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -106,10 +106,6 @@ def __init__( self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > config.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.n_routed_experts}.") if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " From 140913ece166bfa6e73cbf0b88306d5a625211dd Mon Sep 17 00:00:00 2001 From: Zhonghua Deng Date: Mon, 24 Feb 2025 23:37:32 +0800 Subject: [PATCH 197/317] [BugFix] Illegal memory access for MoE On H20 (#13693) --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4cab72a29da4..1ddc3ce6f895 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1271,7 +1271,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk_ids.shape[1]] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) From 06b6876e376049aec3c87c2de194bbe95896a8af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Mon, 24 Feb 2025 16:43:21 +0100 Subject: [PATCH 198/317] [Misc][Docs] Raise error when flashinfer is not installed and `VLLM_ATTENTION_BACKEND` is set (#12513) Signed-off-by: NickLucche --- docs/source/getting_started/quickstart.md | 10 ++++++++++ vllm/config.py | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index f3a4773f0fc6..f51856d6eaeb 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -184,3 +184,13 @@ chat_response = client.chat.completions.create( ) print("Chat response:", chat_response) ``` + +## On Attention Backends + +Currently, vLLM supports multiple backends for efficient Attention computation across different platforms and accelerator architectures. It automatically selects the most performant backend compatible with your system and model specifications. + +If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`. + +```{attention} +There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see [Dockerfile](https://github.com/vllm-project/vllm/blob/main/Dockerfile) for instructions on how to install it. +``` diff --git a/vllm/config.py b/vllm/config.py index ace49a86eaef..a584bc0d930f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,6 +9,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass, field, replace +from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, Final, List, Literal, Mapping, Optional, Protocol, Set, @@ -294,6 +295,14 @@ def __init__( self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) + if (backend := envs.VLLM_ATTENTION_BACKEND + ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: + raise ValueError( + "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " + "module was not found." + "See https://github.com/vllm-project/vllm/blob/main/Dockerfile" + "for instructions on how to install it.") + # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: self.tokenizer_revision = revision From 0395274257e147e4564cc7a339c6a4fb653aad07 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Mon, 24 Feb 2025 11:29:41 -0500 Subject: [PATCH 199/317] [V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980) Signed-off-by: Andrew Feldman Co-authored-by: Nick Hill --- tests/v1/engine/test_llm_engine.py | 103 ++++- .../v1/entrypoints/openai/test_completion.py | 102 +++++ vllm/v1/engine/async_llm.py | 27 +- vllm/v1/engine/llm_engine.py | 43 +- vllm/v1/engine/parallel_sampling.py | 375 ++++++++++++++++++ 5 files changed, 641 insertions(+), 9 deletions(-) create mode 100644 vllm/v1/engine/parallel_sampling.py diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 84b634316cb4..de2a39ee9c08 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,21 +1,114 @@ # SPDX-License-Identifier: Apache-2.0 +import random +from typing import Dict, List, Optional, Tuple + import pytest from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import LLM, SamplingParams +MODEL = "facebook/opt-125m" +DTYPE = "half" -def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): - """Test passes if LLMEngine raises an exception when it is configured - for automatic prefix caching and it receives a request with - prompt_logprobs enabled, which is incompatible.""" +def _vllm_model(apc: bool, vllm_runner, monkeypatch): + """Set up VllmRunner instance.""" monkeypatch.setenv("VLLM_USE_V1", "1") # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + return vllm_runner( + MODEL, + dtype=DTYPE, + max_model_len=128, + enforce_eager=True, + enable_prefix_caching=apc, + gpu_memory_utilization=0.5, + ) + + +@pytest.fixture( + # Function scope decouples tests & allows + # env var adjustment via monkeypatch + scope="function", + # Prefix caching + params=[False, True]) +def vllm_model(vllm_runner, request, monkeypatch): + """VllmRunner test fixture parameterized by APC True/False.""" + with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model: + yield vllm_model + + +@pytest.fixture(scope="function") +def vllm_model_apc(vllm_runner, monkeypatch): + """VllmRunner test fixture with APC.""" + with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model: + yield vllm_model + + +def _get_test_sampling_params( + prompt_list: List[str], + seed: Optional[int] = 42, +) -> Tuple[List[SamplingParams], List[int]]: + """Generate random sampling params for a batch.""" + + def get_mostly_n_gt1() -> int: + """Mostly n \in [2,20], ~1/3 n=1""" + x = random.randint(0, 28) + if x < 10: + return 1 + else: + return x - 8 + + n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))] + # High temperature to maximize the chance of unique completions + return [ + SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed) + for n in n_list + ], n_list + + +def test_parallel_sampling(vllm_model, example_prompts) -> None: + """Test passes if parallel sampling `n>1` yields `n` unique completions. + + Args: + vllm_model: VllmRunner instance under test. + example_prompt: test fixture providing prompts for testing. + """ + sampling_params_list, n_list = _get_test_sampling_params(example_prompts) + model: LLM = vllm_model.model + outputs = model.generate(example_prompts, sampling_params_list) + + # Validate each request response + for out, n in zip(outputs, n_list): + completion_counts: Dict[str, int] = {} + # Assert correct number of completions + assert len(out.outputs) == n, ( + f"{len(out.outputs)} completions; {n} expected.") + for idx in range(n): + comp = out.outputs[idx] + # Assert correct completion indices + assert comp.index == idx, (f"Index {comp.index}; expected {idx}.") + text = comp.text + completion_counts[text] = completion_counts.get(text, 0) + 1 + # Assert unique completions + if len(completion_counts) != n: + repeats = { + txt: num + for (txt, num) in completion_counts.items() if num > 1 + } + raise AssertionError( + f"{len(completion_counts)} unique completions; expected" + f" {n}. Repeats: {repeats}") + + +def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc): + """Test passes if LLMEngine raises an exception when it is configured + for automatic prefix caching and it receives a request with + prompt_logprobs enabled, which is incompatible.""" + model: LLM = vllm_model_apc.model with pytest.raises(ValueError) as excinfo: - LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( + model.generate( "Hello, my name is", SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index ef46a16ef344..35e059ccb548 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -250,6 +250,108 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_parallel_no_streaming(client: openai.AsyncOpenAI, + model_name: str): + """Parallel sampling without streaming. + A single request output contains a list of completions. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + # High temperature to maximize chance of unique completions. + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + temperature=0.95, + stream=False, + seed=42) + + # Assert `n` completions + num_completions = len(completion.choices) + assert num_completions == n, ( + f"Num completions {num_completions} but expected {n}.") + completion_repeats: Dict[str, int] = {} + for idx, choice in enumerate(completion.choices): + # Assert correct completion index & some finish reason. + assert choice.index == idx, ( + f"Index {choice.index} but expected {idx}.") + assert choice.finish_reason is not None, ( + "None finish_reason is invalid.") + text = choice.text + completion_repeats[text] = completion_repeats.get(text, 0) + 1 + # Assert `n` unique completions + num_unique = len(completion_repeats) + if num_unique != n: + repeats = { + txt: num + for (txt, num) in completion_repeats.items() if num > 1 + } + raise AssertionError( + f"Expected {n} unique completions, got {num_unique};" + f" repeats: {repeats}.") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): + """Streaming for parallel sampling. + The tokens from multiple samples, are flattened into a single stream, + with an index to indicate which sample the token belongs to. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + temperature=0.95, + stream=True, + seed=42) + chunks: List[List[str]] = [[] for i in range(n)] + finish_reason_count = 0 + async for chunk in stream: + index = chunk.choices[0].index + text = chunk.choices[0].text + chunks[index].append(text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # Assert `n` completions with correct finish reasons + assert finish_reason_count == n, ( + f"Expected {n} completions with valid indices and finish_reason.") + completion_repeats: Dict[str, int] = {} + for chunk in chunks: + chunk_len = len(chunk) + # Assert correct number of completion tokens + assert chunk_len == max_tokens, ( + f"max_tokens={max_tokens} but chunk len is {chunk_len}.") + text = "".join(chunk) + completion_repeats[text] = completion_repeats.get(text, 0) + 1 + print(text) + # Assert `n` unique completions + num_unique = len(completion_repeats) + if num_unique != n: + repeats = { + txt: num + for (txt, num) in completion_repeats.items() if num > 1 + } + raise AssertionError(f"{num_unique} unique completions, expected {n};" + f" repeats: {repeats}") + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 670454c283da..36a02628f405 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -24,6 +24,7 @@ from vllm.utils import cdiv, kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -170,7 +171,7 @@ async def add_request( # requests we don't need to send multiple messages to core proc, # and so we don't need multiple streams which then get # re-multiplexed in the API server anyhow. - async def generate( + async def _generate( self, prompt: PromptType, sampling_params: SamplingParams, @@ -241,6 +242,30 @@ async def generate( await self.abort(request_id) raise + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + kwargs = dict(prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) + if sampling_params.n is None or sampling_params.n == 1: + return self._generate(**kwargs) + else: + # Special handling for parallel sampling requests + return generate_parallel_sampling_async(generate=self._generate, + **kwargs) + async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 33b1ddc0f6fe..64fd8719c82e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,6 +21,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -48,6 +49,9 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + # Bookkeeping for parallel sampling requests + self.parallel_manager = SyncParallelSamplingManager() + # important: init dp group before init the engine_core self.parallel_config = vllm_config.parallel_config self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa @@ -115,7 +119,8 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - return self.output_processor.get_num_unfinished_requests() + return self.parallel_manager.get_num_unfinished_requests( + self.output_processor.get_num_unfinished_requests()) def has_unfinished_requests(self) -> bool: has_unfinished = self.output_processor.has_unfinished_requests() @@ -151,7 +156,36 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - + """Add request.""" + kwargs = dict(request_id=request_id, + prompt=prompt, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) + # Handle parallel sampling requests differently. + if params is None or isinstance(params, + PoolingParams) or params.n == 1: + self._add_request(**kwargs) + else: + # Special handling for parallel sampling requests + self.parallel_manager.add_request_parallel_sampling( + add_request=self._add_request, **kwargs) + + def _add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + """Add request, `n=1`""" # 1) Process raw inputs into the request. request = self.processor.process_inputs(request_id, prompt, params, arrival_time, lora_request, @@ -182,7 +216,10 @@ def step(self) -> List[RequestOutput]: # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) - return processed_outputs.request_outputs + request_outputs = processed_outputs.request_outputs + + # 4) Process unfinished parallel sampling requests + return self.parallel_manager.step(request_outputs) def get_model_config(self): return self.model_config diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py new file mode 100644 index 000000000000..5d4ea111abfc --- /dev/null +++ b/vllm/v1/engine/parallel_sampling.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 + +from copy import copy +from typing import (AsyncGenerator, Dict, List, Mapping, Optional, Protocol, + Tuple, Union) + +from vllm.inputs import PromptType +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.utils import merge_async_iterators + + +class AsyncGenerateMethodType(Protocol): + + def __call__(self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0) -> AsyncGenerator[RequestOutput, None]: + ... + + +class SyncAddRequestMethodType(Protocol): + + def __call__(self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0) -> None: + ... + + +class ParallelSamplingRequest: + """Info, state & processing for parallel sampling request. + + Store parent request ID and sampling params. + Facilitate generating child request sampling params. + Transform child request outputs into parent request + outputs. + When stream mode is disabled, then `self.request_output` + aggregates child request completions. + """ + + request_id: str + sampling_params: SamplingParams + cached_child_sampling_params: Optional[SamplingParams] + request_output: Optional[RequestOutput] + num_finished_completions: int + + def __init__(self, request_id: str, + sampling_params: SamplingParams) -> None: + self.request_id = request_id + self.sampling_params = sampling_params + self.cached_child_sampling_params = None + self.request_output = None + self.num_finished_completions = 0 + + def _get_child_sampling_params( + self, + index: int, + ) -> SamplingParams: + """Efficiently obtain child `sampling_params` + + If `sampling_params.seed` is not `None` then + each child request requires a unique clone of + parent `sampling_params` with a unique seed. + + Args: + index: index within `n` child requests + + Returns: + Child `sampling_params` instance. + """ + seed = self.sampling_params.seed + if self.cached_child_sampling_params: + # Reuse child sampling_params data structure + return self.cached_child_sampling_params + # Build child sampling_params + child_sampling_params = copy(self.sampling_params) + child_sampling_params.n = 1 + if seed is None: + # Cache child sampling_params for later reuse + self.cached_child_sampling_params = child_sampling_params + else: + # Each child gets a clone with a unique seed + child_sampling_params.seed = seed + index + return child_sampling_params + + def _add_output( + self, + child_req_output: RequestOutput, + index: int, + ) -> None: + """Aggregate a parallel sampling child + request output. + + Non-stream-mode (`output_kind == FINAL_ONLY`) + only. Inject correct parent request ID and + completion index. + + Args: + child_req_output: a single request output + from a parallel sampling + child request. + index: index within `n` child + """ + self.num_finished_completions += 1 + new_completion = child_req_output.outputs[0] + new_completion.index = index + if self.request_output is None: + # Save the first request output; reinstate + # original request ID; metrics are not + # supported for parallel sampling + child_req_output.request_id = self.request_id + child_req_output.metrics = None + self.request_output = child_req_output + else: + # Aggregate additional completion into request output + # Note: will be sorted by index later + self.request_output.outputs.append(new_completion) + + def _get_final_request_output(self) -> RequestOutput: + """Invariant: parent completion outputs sorted by index""" + assert self.request_output is not None + self.request_output.finished = True + self.request_output.outputs = sorted(self.request_output.outputs, + key=lambda x: x.index) + return self.request_output + + def get_child_info(self, index: int) -> Tuple[str, SamplingParams]: + """Get child request ID and sampling params. + + Args: + index: index within `n` child requests. + + Returns: + (request ID, sampling_params) tuple + """ + return (f"{index}_{self.request_id}", + self._get_child_sampling_params(index)) + + def process_output( + self, + child_req_output: RequestOutput, + index: int, + ) -> Optional[RequestOutput]: + """Filter, aggregate and transform parallel sampling + child request outputs. + + If the parent request has `stream=false` + (`output_kind == FINAL_ONLY`), each child will also have + `output_kind == FINAL_ONLY`. All child request outputs + must be aggregated into a single request output, with + multiple completions. This request output is only returned + once `n` completions are aggregated. + + If the parent request has `stream=true` + (`output_kind == DELTA`), each child will also have + `output_kind == DELTA`. All child request outputs + must be streamed directly to the caller. + + Args: + child_req_output: a single child request output + index: index within `n` child requests + + Returns: + `None`, unless a processed request output is ready to + send back to the caller. + """ + if self.output_kind != RequestOutputKind.FINAL_ONLY: + # stream=true: return child completions immediately + child_req_output.request_id = self.request_id + child_req_output.outputs[0].index = index + if child_req_output.finished: + # Parent request is complete if all child requests are + # complete. + self.num_finished_completions += 1 + child_req_output.finished = ( + self.num_finished_completions == self.n) + return child_req_output + + # stream=false: aggregate child completions + self._add_output(child_req_output, index) + if self.num_finished_completions == self.n: + # Return aggregated request output after obtaining + # all completions + return self._get_final_request_output() + return None + + async def wrap_child_async_generator( + self, + child_gen: AsyncGenerator[RequestOutput, None], + index: int, + ) -> AsyncGenerator[RequestOutput, None]: + """Output generator for a single parallel sampling + child request. + + Each parallel sampling request triggers at + least two child requests. This generator + yields zero or more request outputs to + return to the caller, as they become + available. + + Args: + child_gen: generator for child request + outputs. + index: index within the `n` child requests + + Returns: + Yields zero or more request outputs to return + to the caller. + """ + async for out in child_gen: + if req_out := self.process_output(out, index): + yield req_out + + @property + def n(self) -> int: + return self.sampling_params.n + + @property + def output_kind(self) -> RequestOutputKind: + return self.sampling_params.output_kind + + +class SyncParallelSamplingManager: + + def __init__(self): + # Parent req ID -> parent request manager + self.parent_reqs: Dict[str, ParallelSamplingRequest] = {} + # Child req ID -> (child req index, parent req ID) + self.child_reqs: Dict[str, Tuple[int, str]] = {} + + def _register_parent_request(self, req: ParallelSamplingRequest) -> None: + """Register parallel sampling parent request.""" + self.parent_reqs[req.request_id] = req + + def _register_child_request(self, req_id: str, child_req_id: str, + index: int) -> None: + """Register parallel sampling child request with parent. + + Args: + req_id: parent request ID + child_req_id: child request ID + index: child request index within `n` child requests + """ + self.child_reqs[child_req_id] = (index, req_id) + + def get_num_unfinished_requests(self, num_core_reqs: int) -> int: + """Get the number of unfinished requests, correcting for parallel + sampling. + + Args: + num_core_reqs: The number of unfinished requests in the engine core. + + Returns: + Number of unfinished requests, where each parallel sampling req + counts as 1 + """ + return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs) + + def add_request_parallel_sampling( + self, + add_request: SyncAddRequestMethodType, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + """Add sync parallel sampling request.""" + req = ParallelSamplingRequest(request_id, params) + self._register_parent_request(req) + # Add n child requests with unique request IDs & random seeds and n=1 + for idx in range(req.n): + child_req_id, child_params = req.get_child_info(idx) + self._register_child_request(request_id, child_req_id, idx) + add_request(request_id=child_req_id, + prompt=prompt, + params=child_params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) # type: ignore + + def step( + self, + outputs: List[RequestOutput], + ) -> List[RequestOutput]: + """Build parallel sampling request outputs. + + Extract child request outputs, aggregate them + into parent request output, and return parent + output when complete. + + Do not modify `n=1` requests. + + Args: + outputs: step request outputs. Mix of child request + outputs & `n=1` request outputs. + + Return: + List of parallel sampling parent request outputs & + unmodified `n=1` request outputs passed-thru from input. + """ + if not (self.parent_reqs and outputs): + # Return unmodified + return outputs + agg_outputs = [] + for output in outputs: + req_id = output.request_id + if child_req_entry := self.child_reqs.get(req_id, None): + # For each parallel sampling child request output: + (index, parent_req_id) = child_req_entry + req = self.parent_reqs[parent_req_id] + # Update parallel sampling request + if out := req.process_output(output, index): + # Return parent request output if complete; + # cleanup parent request bookkeeping. + agg_outputs.append(out) + del self.parent_reqs[parent_req_id] + # Cleanup child request bookkeeping. + del self.child_reqs[req_id] + else: + # Not a parallel sampling request output + agg_outputs.append(output) + return agg_outputs + + +async def generate_parallel_sampling_async( + generate: AsyncGenerateMethodType, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, +) -> AsyncGenerator[RequestOutput, None]: + """Generate completions for async parallel sampling requests.""" + parent_req = ParallelSamplingRequest(request_id, sampling_params) + + # Aggregate generators for n child requests + gens: List[AsyncGenerator[RequestOutput, None]] = [] + for idx in range(parent_req.n): + child_req_id, child_params = parent_req.get_child_info(idx) + child_gen = generate( + prompt=prompt, + sampling_params=child_params, + request_id=child_req_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) # type: ignore + gen = parent_req.wrap_child_async_generator(child_gen, idx) + gens.append(gen) + + # Merge generators + async for _, out in merge_async_iterators(*gens): + yield out From 26b42da04d5c0d210676a8bbcb57f412f42a469a Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 24 Feb 2025 09:16:05 -0800 Subject: [PATCH 200/317] Revert "[V1][Core] Fix memory issue with logits & sampling" (#13775) --- vllm/v1/worker/gpu_model_runner.py | 68 +++++++++++++----------------- vllm/v1/worker/gpu_worker.py | 10 ----- 2 files changed, 29 insertions(+), 49 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cf6bdd050e4a..a7b9d4781183 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1179,43 +1179,6 @@ def _dummy_run( ) return hidden_states - @torch.inference_mode() - def _dummy_sampler_run( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - - logits = self.model.compute_logits(hidden_states, None) - num_reqs = logits.size(0) - - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) - - dummy_metadata = SamplingMetadata( - temperature=dummy_tensors(0.5), - all_greedy=False, - all_random=False, - spec_token_ids=None, - top_p=dummy_tensors(0.9), - top_k=dummy_tensors(logits.size(1) - 1), - min_p=None, - generators={}, - max_num_logprobs=None, - no_penalties=True, - prompt_token_ids=None, - frequency_penalties=dummy_tensors(0.1), - presence_penalties=dummy_tensors(0.1), - repetition_penalties=dummy_tensors(0.1), - output_token_ids=[[] for _ in range(num_reqs)], - min_tokens={}, - logit_bias=[None for _ in range(num_reqs)], - allowed_token_ids_mask=None, - ) - sampler_output = self.model.sample(logits=logits, - sampling_metadata=dummy_metadata) - - return sampler_output - def profile_run(self) -> None: # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. @@ -1343,11 +1306,38 @@ def profile_run(self) -> None: dummy_kv_caches) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] - sampler_output = self._dummy_sampler_run(hidden_states) + logits = self.model.compute_logits(hidden_states, None) + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.5), + all_greedy=False, + all_random=False, + spec_token_ids=None, + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=torch.ones_like(logits, + dtype=torch.int64), + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)], + allowed_token_ids_mask=None, + ) + sampler_output = self.model.sample( + logits=logits, sampling_metadata=dummy_metadata) else: + logits = None sampler_output = None + dummy_metadata = None torch.cuda.synchronize() - del hidden_states, sampler_output + del hidden_states, logits, sampler_output, dummy_metadata self.encoder_cache.clear() gc.collect() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d9030aae51d1..d9a415aee528 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -211,16 +211,6 @@ def compile_or_warm_up_model(self) -> None: self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model() - - # Warm up sampler and preallocate memory buffer for logits and other - # sampling related tensors of max possible shape to avoid memory - # fragmentation issue. - # NOTE: This is called after `capture_model` on purpose to prevent - # memory buffers from being cleared by `torch.cuda.empty_cache`. - self.model_runner._dummy_sampler_run( - hidden_states=self.model_runner._dummy_run( - num_tokens=self.scheduler_config.max_num_seqs)) - # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From c912e7f5d415969fbf07c69ea8dd18cb5988a121 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 24 Feb 2025 12:25:47 -0500 Subject: [PATCH 201/317] Fix precommit fail in fused_moe intermediate_cache2 chunking (#13772) Signed-off-by: mgoin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1ddc3ce6f895..bc9573b36df7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1271,7 +1271,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk_ids.shape[1]] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.shape[1]] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) From 92b6a8f225ece9fa59bc5c3d4455210c1af60549 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 24 Feb 2025 13:52:21 -0500 Subject: [PATCH 202/317] [Misc] Clean Up `EngineArgs.create_engine_config` (#13734) Signed-off-by: rshaw@neuralmagic.com --- vllm/config.py | 4 +++ vllm/engine/arg_utils.py | 65 ++++++++++++++++------------------------ 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a584bc0d930f..0bc9b2f817f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1124,6 +1124,10 @@ def metrics_info(self): return {key: str(value) for key, value in self.__dict__.items()} def _verify_args(self) -> None: + if self.cpu_offload_gb < 0: + raise ValueError("CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + if self.gpu_memory_utilization > 1.0: raise ValueError( "GPU memory utilization must be less than 1.0. Got " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bab7cfe2aa3a..8378a116a6d4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1062,6 +1062,17 @@ def from_cli_args(cls, args: argparse.Namespace): return engine_args def create_model_config(self) -> ModelConfig: + # gguf file needs a specific model loader and doesn't use hf_repo + if check_gguf_file(self.model): + self.quantization = self.load_format = "gguf" + + # NOTE: This is to allow model loading from S3 in CI + if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == LoadFormat.AUTO): # noqa: E501 + self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" + self.load_format = LoadFormat.RUNAI_STREAMER + return ModelConfig( model=self.model, task=self.task, @@ -1101,26 +1112,6 @@ def create_model_config(self) -> ModelConfig: ) def create_load_config(self) -> LoadConfig: - return LoadConfig( - load_format=self.load_format, - download_dir=self.download_dir, - model_loader_extra_config=self.model_loader_extra_config, - ignore_patterns=self.ignore_patterns, - ) - - def create_engine_config(self, - usage_context: Optional[UsageContext] = None - ) -> VllmConfig: - from vllm.platforms import current_platform - current_platform.pre_register_and_update() - - if envs.VLLM_USE_V1: - self._override_v1_engine_args(usage_context) - - # gguf file needs a specific model loader and doesn't use hf_repo - if check_gguf_file(self.model): - self.quantization = self.load_format = "gguf" - # bitsandbytes quantization needs a specific model loader # so we make sure the quant method and the load format are consistent if (self.quantization == "bitsandbytes" or @@ -1137,19 +1128,23 @@ def create_engine_config(self, "BitsAndBytes load format and QLoRA adapter only support " f"'bitsandbytes' quantization, but got {self.quantization}") - assert self.cpu_offload_gb >= 0, ( - "CPU offload space must be non-negative" - f", but got {self.cpu_offload_gb}") + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) - device_config = DeviceConfig(device=self.device) + def create_engine_config(self, + usage_context: Optional[UsageContext] = None + ) -> VllmConfig: + from vllm.platforms import current_platform + current_platform.pre_register_and_update() - # NOTE: This is to allow model loading from S3 in CI - if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 - and self.model in MODELS_ON_S3 - and self.load_format == LoadFormat.AUTO): # noqa: E501 - self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" - self.load_format = LoadFormat.RUNAI_STREAMER + if envs.VLLM_USE_V1: + self._override_v1_engine_args(usage_context) + device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() if (model_config.is_multimodal_model and not envs.VLLM_USE_V1 @@ -1281,16 +1276,6 @@ def create_engine_config(self, if speculative_config is None \ else speculative_config.num_lookahead_slots - if not self.use_v2_block_manager: - logger.warning( - "[DEPRECATED] Block manager v1 has been removed, " - "and setting --use-v2-block-manager to True or False has " - "no effect on vLLM behavior. Please remove " - "--use-v2-block-manager in your engine argument. " - "If your use case is not supported by " - "SelfAttnBlockSpaceManager (i.e. block manager v2)," - " please file an issue with detailed information.") - scheduler_config = SchedulerConfig( runner_type=model_config.runner_type, max_num_batched_tokens=self.max_num_batched_tokens, From 2a643b91a8d71539f8032309809df3e93e243c35 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 24 Feb 2025 19:39:07 -0500 Subject: [PATCH 203/317] [Misc][Chore] Clean Up `AsyncOutputProcessing` Logs (#13780) --- vllm/config.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 0bc9b2f817f6..fea673b68560 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -710,8 +710,6 @@ def verify_async_output_proc(self, parallel_config, speculative_config, return if parallel_config.pipeline_parallel_size > 1: - logger.warning("Async output processing can not be enabled " - "with pipeline parallel") self.use_async_output_proc = False return @@ -719,15 +717,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # If the feature combo become valid from vllm.platforms import current_platform if not current_platform.is_async_output_supported(self.enforce_eager): - logger.warning( - "Async output processing is not supported on the " - "current platform type %s.", current_platform.device_type) self.use_async_output_proc = False return if envs.VLLM_USE_RAY_SPMD_WORKER: - logger.warning( - "Async output processing can not be enabled with ray spmd") self.use_async_output_proc = False return @@ -739,8 +732,6 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/features/compatibility_matrix.md # If the feature combo become valid if speculative_config: - logger.warning("Async output processing is not supported with" - " speculative decoding currently.") self.use_async_output_proc = False def verify_with_parallel_config( @@ -768,8 +759,6 @@ def verify_with_parallel_config( "Supported models implement the `SupportsPP` interface.") if self.use_async_output_proc: - logger.warning("Async output processor is not supported with " - "pipeline parallelism currently. Disabling it.") self.use_async_output_proc = False def get_hf_config_sliding_window( From ae7df05a240ce27bbc9c54d10b1de5d7acd8a663 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 25 Feb 2025 01:13:52 +0000 Subject: [PATCH 204/317] Remove unused kwargs from model definitions (#13555) --- docs/source/contributing/model/basic.md | 2 - docs/source/contributing/model/multimodal.md | 2 - tests/kernels/test_encoder_decoder_attn.py | 14 +-- vllm/attention/layer.py | 19 ++-- .../layers/mamba/mamba_mixer.py | 5 +- .../layers/mamba/mamba_mixer2.py | 4 +- vllm/model_executor/models/adapters.py | 6 +- vllm/model_executor/models/arctic.py | 24 +---- vllm/model_executor/models/aria.py | 5 - vllm/model_executor/models/baichuan.py | 24 +---- vllm/model_executor/models/bamba.py | 29 ++---- vllm/model_executor/models/bart.py | 93 ++++--------------- vllm/model_executor/models/bert.py | 44 +++------ vllm/model_executor/models/blip2.py | 7 +- vllm/model_executor/models/bloom.py | 31 ++----- vllm/model_executor/models/chameleon.py | 27 +----- vllm/model_executor/models/chatglm.py | 42 ++------- vllm/model_executor/models/commandr.py | 24 +---- vllm/model_executor/models/dbrx.py | 35 ++----- vllm/model_executor/models/deepseek.py | 26 ++---- vllm/model_executor/models/deepseek_mtp.py | 19 +--- vllm/model_executor/models/deepseek_v2.py | 31 ++----- vllm/model_executor/models/deepseek_vl2.py | 5 - vllm/model_executor/models/eagle.py | 7 +- vllm/model_executor/models/exaone.py | 30 ++---- vllm/model_executor/models/falcon.py | 31 ++----- vllm/model_executor/models/florence2.py | 34 ++----- vllm/model_executor/models/fuyu.py | 5 - vllm/model_executor/models/gemma.py | 24 +---- vllm/model_executor/models/gemma2.py | 24 +---- vllm/model_executor/models/glm4v.py | 10 +- vllm/model_executor/models/gpt2.py | 32 ++----- vllm/model_executor/models/gpt_bigcode.py | 32 ++----- vllm/model_executor/models/gpt_j.py | 31 ++----- vllm/model_executor/models/gpt_neox.py | 31 ++----- vllm/model_executor/models/granite.py | 29 ++---- vllm/model_executor/models/granitemoe.py | 26 ++---- vllm/model_executor/models/gritlm.py | 9 +- vllm/model_executor/models/idefics3.py | 9 -- vllm/model_executor/models/interfaces_base.py | 9 +- vllm/model_executor/models/internlm2.py | 35 ++----- vllm/model_executor/models/internlm2_ve.py | 14 +-- vllm/model_executor/models/internvl.py | 5 - vllm/model_executor/models/jais.py | 32 ++----- vllm/model_executor/models/jamba.py | 29 +----- vllm/model_executor/models/llama.py | 28 ++---- vllm/model_executor/models/llava.py | 5 - vllm/model_executor/models/llava_next.py | 5 - .../model_executor/models/llava_next_video.py | 5 - vllm/model_executor/models/llava_onevision.py | 5 - vllm/model_executor/models/mamba.py | 16 +--- vllm/model_executor/models/mamba2.py | 18 ++-- vllm/model_executor/models/minicpm.py | 24 +---- vllm/model_executor/models/minicpm3.py | 6 +- vllm/model_executor/models/minicpmo.py | 5 - vllm/model_executor/models/minicpmv.py | 5 - vllm/model_executor/models/mixtral.py | 26 ++---- vllm/model_executor/models/mixtral_quant.py | 26 ++---- vllm/model_executor/models/mllama.py | 50 +++------- vllm/model_executor/models/molmo.py | 25 +---- vllm/model_executor/models/mpt.py | 31 ++----- vllm/model_executor/models/nemotron.py | 28 +----- vllm/model_executor/models/olmo.py | 28 ++---- vllm/model_executor/models/olmo2.py | 28 ++---- vllm/model_executor/models/olmoe.py | 24 +---- vllm/model_executor/models/opt.py | 32 ++----- vllm/model_executor/models/orion.py | 29 ++---- vllm/model_executor/models/paligemma.py | 7 +- vllm/model_executor/models/persimmon.py | 27 +----- vllm/model_executor/models/phi.py | 29 ++---- vllm/model_executor/models/phi3_small.py | 28 +----- vllm/model_executor/models/phi3v.py | 5 - vllm/model_executor/models/phimoe.py | 24 +---- vllm/model_executor/models/pixtral.py | 5 - .../models/prithvi_geospatial_mae.py | 5 +- vllm/model_executor/models/qwen.py | 26 ++---- vllm/model_executor/models/qwen2.py | 29 ++---- vllm/model_executor/models/qwen2_5_vl.py | 5 - vllm/model_executor/models/qwen2_audio.py | 9 +- vllm/model_executor/models/qwen2_moe.py | 26 ++---- vllm/model_executor/models/qwen2_rm.py | 8 +- vllm/model_executor/models/qwen2_vl.py | 9 +- vllm/model_executor/models/qwen_vl.py | 8 +- vllm/model_executor/models/roberta.py | 7 +- vllm/model_executor/models/solar.py | 21 +---- vllm/model_executor/models/stablelm.py | 29 ++---- vllm/model_executor/models/starcoder2.py | 26 ++---- vllm/model_executor/models/transformers.py | 13 +-- vllm/model_executor/models/ultravox.py | 17 +--- vllm/model_executor/models/whisper.py | 89 +++--------------- vllm/spec_decode/draft_model_runner.py | 2 - vllm/v1/worker/gpu_model_runner.py | 22 +---- vllm/v1/worker/tpu_model_runner.py | 19 +--- vllm/worker/cpu_enc_dec_model_runner.py | 4 - vllm/worker/cpu_model_runner.py | 2 - vllm/worker/cpu_pooling_model_runner.py | 14 --- vllm/worker/enc_dec_model_runner.py | 14 +-- vllm/worker/hpu_model_runner.py | 36 +++---- vllm/worker/model_runner.py | 13 +-- vllm/worker/multi_step_model_runner.py | 4 +- vllm/worker/openvino_model_runner.py | 4 - vllm/worker/pooling_model_runner.py | 12 --- vllm/worker/tpu_model_runner.py | 24 ++--- vllm/worker/xpu_model_runner.py | 13 +-- 104 files changed, 436 insertions(+), 1654 deletions(-) diff --git a/docs/source/contributing/model/basic.md b/docs/source/contributing/model/basic.md index 180fdd59e9a6..ad31995f76be 100644 --- a/docs/source/contributing/model/basic.md +++ b/docs/source/contributing/model/basic.md @@ -74,8 +74,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: ... ``` diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 14a59953ef48..990eac82d516 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -16,8 +16,6 @@ Further update the model as follows: self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, + pixel_values: torch.Tensor, ) -> SamplerOutput: ``` diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0d11e8652ce6..0a93f7ce9450 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -644,11 +644,7 @@ def _run_encoder_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward( - reshaped_query, packed_qkv.key, packed_qkv.value, - torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), attn_metadata) + return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) def _run_decoder_self_attention_test( @@ -682,7 +678,6 @@ def _run_decoder_self_attention_test( & attn_metadata ''' attn = test_rsrcs.attn - kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata, vllm_config): @@ -695,8 +690,7 @@ def _run_decoder_self_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, - kv_cache, attn_metadata) + return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) def _run_encoder_decoder_cross_attention_test( @@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test( assert decoder_test_params.packed_qkvo.packed_qkv is not None attn = test_rsrcs.attn - kv_cache = test_rsrcs.kv_cache if cross_test_params is None: key = None value = None @@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, key, value, kv_cache, - attn_metadata) + return attn.forward(reshaped_query, key, value) @pytest.fixture(autouse=True) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e4df7ffc5885..bd7783cc3981 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import vllm.envs as envs -from vllm.attention import AttentionMetadata, AttentionType +from vllm.attention import AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context @@ -153,15 +153,10 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: - # NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments - # directly, use `self.kv_cache` and - # `get_forward_context().attn_metadata` instead. if self.calculate_kv_scales: - ctx_attn_metadata = get_forward_context().attn_metadata - if ctx_attn_metadata.enable_kv_scales_calculation: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: self.calc_kv_scales(key, value) if self.use_output: output = torch.empty_like(query) @@ -177,14 +172,14 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() - ctx_attn_metadata = forward_context.attn_metadata + attn_metadata = forward_context.attn_metadata self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, key, value, self_kv_cache, - ctx_attn_metadata, + attn_metadata, output=output) else: torch.ops.vllm.unified_attention_with_output( @@ -193,10 +188,10 @@ def forward( else: if self.use_direct_call: forward_context = get_forward_context() - ctx_attn_metadata = forward_context.attn_metadata + attn_metadata = forward_context.attn_metadata self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, - self_kv_cache, ctx_attn_metadata) + self_kv_cache, attn_metadata) else: return torch.ops.vllm.unified_attention( query, key, value, self.layer_name) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 93c3cc91bb09..156e8752e96c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -7,6 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -130,14 +131,14 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): ) if use_rms_norm else None def forward_native(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor): pass def forward_cuda(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams): + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) hidden_states, gate = projected_states.chunk(2, dim=-2) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2bcf50e70713..b53a540ed662 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -14,6 +14,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.forward_context import get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -376,17 +377,16 @@ def __init__(self, eps=rms_norm_eps) def forward_native(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor): pass def forward_cuda( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, sequence_idx: Optional[torch.Tensor] = None, ): + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 3e1daa773fc8..23d72d8e60f6 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T: return cls # Lazy import - from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.pooler import PoolingType @@ -201,13 +200,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: list[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = super().forward(input_ids, positions, kv_caches, - attn_metadata, + hidden_states = super().forward(input_ids, positions, intermediate_tensors, inputs_embeds) logits, _ = self.score(hidden_states) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 77f383b6e46d..e2d4a8de605b 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -5,7 +5,7 @@ import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -283,13 +283,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -336,16 +334,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual_input = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual_input + hidden_states @@ -400,8 +394,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -413,11 +405,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) @@ -458,13 +447,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index bff4100a1dee..656e9b037d96 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -9,7 +9,6 @@ from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.processing_aria import AriaProcessor -from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn @@ -626,8 +625,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -643,8 +640,6 @@ def forward( hidden_states = self.language_model( input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 2e51b9c9c0c7..4fb68e7b48da 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -20,13 +20,13 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -182,14 +182,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) if self.postion_embedding != "ALIBI": q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -232,8 +230,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -246,8 +242,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -301,8 +295,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -316,13 +308,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -379,13 +368,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 22ae1775c3d9..69da05884ded 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -1,17 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only Bamba model.""" # Added by the IBM Team, 2024 -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import BambaConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -107,7 +107,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, sequence_idx: Optional[torch.Tensor] = None, @@ -120,8 +119,8 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, - mamba_cache_params, sequence_idx) + hidden_states = self.mamba(hidden_states, mamba_cache_params, + sequence_idx) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -215,15 +214,13 @@ def self_attention( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -231,8 +228,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], **kwargs, ): @@ -246,8 +241,6 @@ def forward( hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( @@ -312,8 +305,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -323,6 +314,7 @@ def forward( # proper continuous batching computation including # chunked prefill seq_idx = None + attn_metadata = get_forward_context().attn_metadata if attn_metadata.num_prefills > 0: seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) for i, (srt, end) in enumerate( @@ -348,9 +340,7 @@ def forward( num_attn = 0 for i in range(len(self.layers)): layer = self.layers[i] - kv_cache = None if isinstance(layer, BambaAttentionDecoderLayer): - kv_cache = kv_caches[num_attn] num_attn += 1 layer_mamba_cache_params = None @@ -361,8 +351,6 @@ def forward( hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=layer_mamba_cache_params, sequence_idx=seq_idx, @@ -440,8 +428,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): @@ -454,8 +440,7 @@ def forward(self, self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, *self._get_mamba_cache_shape()) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params, + hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 204c48d0d896..5d2a8cdcb97d 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -19,14 +19,14 @@ # limitations under the License. """PyTorch BART model.""" import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import BartConfig from transformers.utils import logging -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionType from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -181,14 +181,13 @@ def __init__( prefix=f"{prefix}.attn", attn_type=AttentionType.ENCODER) - def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Input shape: Batch x Time x Channel""" qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -261,14 +260,13 @@ def __init__( prefix=f"{prefix}.attn", attn_type=AttentionType.DECODER) - def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Input shape: Batch x Time x Channel""" qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -344,8 +342,6 @@ def __init__( def forward( self, decoder_hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" @@ -363,7 +359,7 @@ def forward( _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -411,23 +407,16 @@ def __init__( self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: r""" Args: hidden_states torch.Tensor of *encoder* input embeddings. - kv_cache: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Encoder layer output torch.Tensor """ residual = hidden_states - hidden_states = self.self_attn(hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -509,18 +498,12 @@ def __init__( def forward( self, decoder_hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: decoder_hidden_states torch.Tensor of *decoder* input embeddings. - kv_cache: - KV cache tensor - attn_metadata: - vLLM Attention metadata structure encoder_hidden_states torch.Tensor of *encoder* input embeddings. Returns: @@ -529,9 +512,7 @@ def forward( residual = decoder_hidden_states # Self Attention - hidden_states = self.self_attn(hidden_states=decoder_hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states = self.self_attn(hidden_states=decoder_hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -542,8 +523,6 @@ def forward( hidden_states = self.encoder_attn( decoder_hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) @@ -609,9 +588,8 @@ def __init__(self, self.layernorm_embedding = nn.LayerNorm(embed_dim) - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, input_ids: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: r""" Args: input_ids @@ -620,10 +598,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, provide it. positions Positions of *encoder* input sequence tokens. - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Decoder output torch.Tensor """ @@ -636,12 +610,8 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) - for idx, encoder_layer in enumerate(self.layers): - hidden_states = encoder_layer( - hidden_states=hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, - ) + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states=hidden_states) return hidden_states @@ -693,9 +663,7 @@ def __init__( def forward(self, decoder_input_ids: torch.Tensor, decoder_positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor: r""" Args: decoder_input_ids @@ -706,10 +674,6 @@ def forward(self, decoder_input_ids: torch.Tensor, Positions of *decoder* input sequence tokens. encoder_hidden_states: Tensor of encoder output embeddings - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Decoder output torch.Tensor """ @@ -725,11 +689,9 @@ def forward(self, decoder_input_ids: torch.Tensor, # decoder layers - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: hidden_states = decoder_layer( decoder_hidden_states=hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) @@ -768,8 +730,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + encoder_positions: torch.Tensor) -> torch.Tensor: r""" Args: input_ids @@ -782,10 +743,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Indices of *encoder* input sequence tokens in the vocabulary. encoder_positions: Positions of *encoder* input sequence tokens. - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Model output torch.Tensor """ @@ -796,18 +753,14 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + positions=encoder_positions) # decoder outputs consists of # (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids=input_ids, decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + encoder_hidden_states=encoder_hidden_states) return decoder_outputs @@ -845,8 +798,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, *, encoder_input_ids: torch.Tensor, @@ -863,15 +814,11 @@ def forward( torch.Tensor of *encoder* input token ids. encoder_positions torch.Tensor of *encoder* position indices - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Output torch.Tensor """ return self.model(input_ids, positions, encoder_input_ids, - encoder_positions, kv_caches, attn_metadata) + encoder_positions) def compute_logits( self, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 4d0f5ac8ea5d..4ff69527653d 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import BertConfig -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -113,12 +114,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): def forward( self, hidden_states: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: - for i in range(len(self.layer)): - layer = self.layer[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + for layer in self.layer: + hidden_states = layer(hidden_states) return hidden_states @@ -152,13 +150,8 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.output") - def forward( - self, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, - ): - attn_output = self.attention(hidden_states, kv_cache, attn_metadata) + def forward(self, hidden_states: torch.Tensor): + attn_output = self.attention(hidden_states) intermediate_output = self.intermediate(attn_output) output = self.output(intermediate_output, attn_output) return output @@ -191,10 +184,8 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: - self_output = self.self(hidden_states, kv_cache, attn_metadata) + self_output = self.self(hidden_states) return self.output(self_output, hidden_states) @@ -246,12 +237,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - output = self.attn(q, k, v, kv_cache, attn_metadata) + output = self.attn(q, k, v) return output @@ -343,8 +332,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, @@ -352,13 +339,14 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: + attn_metadata = get_forward_context().attn_metadata assert hasattr(attn_metadata, "seq_lens_tensor") hidden_states = self.embeddings( input_ids=input_ids, seq_lens=attn_metadata.seq_lens_tensor, position_ids=position_ids, token_type_ids=token_type_ids) - return self.encoder(hidden_states, kv_caches, attn_metadata) + return self.encoder(hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -420,17 +408,13 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, position_ids=positions, - kv_caches=kv_caches, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors, - attn_metadata=attn_metadata) + intermediate_tensors=intermediate_tensors) def pooler( self, @@ -519,16 +503,12 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.bert(input_ids=input_ids, position_ids=positions, - kv_caches=kv_caches, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, - attn_metadata=attn_metadata, token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 0463a0b97d40..23bb3cd07f1d 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, +from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -9,7 +9,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, apply_chunking_to_forward) -from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig @@ -658,8 +657,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -708,8 +705,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 229677ae7d98..84b79613abc4 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -18,13 +18,13 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import BloomConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -126,13 +126,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: del position_ids # Unused. qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -193,8 +191,6 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) @@ -209,8 +205,6 @@ def forward( attention_output = self.self_attention( position_ids=position_ids, hidden_states=layernorm_output, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) attention_output = attention_output + residual layernorm_output = self.post_attention_layernorm(attention_output) @@ -266,8 +260,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -279,14 +271,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) @@ -322,14 +308,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 2d4dfab60730..e91399b2674d 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import cached_property -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, +from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -10,7 +10,7 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, ChameleonVQVAEConfig) -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -310,15 +310,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -372,8 +370,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -386,8 +382,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -447,8 +441,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -456,8 +448,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.input_layernorm(hidden_states) @@ -906,8 +896,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -921,13 +909,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -1028,8 +1013,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, @@ -1048,8 +1031,6 @@ def forward( hidden_states = self.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ecf417655452..6eca25212ee6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -2,13 +2,13 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from torch.nn import LayerNorm -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -108,19 +108,11 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) - context_layer = self.attn( - q, - k, - v, - kv_cache, - attn_metadata, - ) + context_layer = self.attn(q, k, v) attn_output, _ = self.dense(context_layer) return attn_output @@ -215,8 +207,6 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # hidden_states: [num_tokens, h] # Layer norm at the beginning of the transformer layer. @@ -225,8 +215,6 @@ def forward( attention_output = self.self_attention( hidden_states=layernorm_output, position_ids=position_ids, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Residual connection. @@ -289,17 +277,10 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> Union[torch.Tensor, IntermediateTensors]: - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - hidden_states=hidden_states, - position_ids=position_ids, - kv_cache=kv_caches[i - self.start_layer], - attn_metadata=attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states=hidden_states, + position_ids=position_ids) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -350,8 +331,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -369,8 +348,6 @@ def forward( hidden_states = self.encoder( hidden_states=hidden_states, position_ids=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) return hidden_states @@ -494,12 +471,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 0ceefc3e93aa..b0cb4a62333a 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -21,14 +21,14 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers import CohereConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -218,8 +218,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -227,7 +225,7 @@ def forward( q, k = self._apply_qk_norm(q, k) if self.v1 or self.sliding_window: q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -255,8 +253,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -265,8 +261,6 @@ def forward( hidden_states_attention = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states_mlp = self.mlp(hidden_states) # Add everything together @@ -311,8 +305,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -326,13 +318,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -389,13 +378,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index bb3f4f40dd21..7830dd4ce2ec 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch import torch.nn as nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -230,15 +230,13 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.Wqkv(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) hidden_states, _ = self.out_proj(attn_output) return hidden_states @@ -265,16 +263,12 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm_1(hidden_states) x = self.attn( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + x residual = hidden_states @@ -303,14 +297,10 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states, residual = self.norm_attn_norm( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.ffn(hidden_states) hidden_states = hidden_states + residual @@ -353,8 +343,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -366,14 +354,8 @@ def forward( else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - block = self.blocks[i] - hidden_states = block( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for block in self.blocks[self.start_layer:self.end_layer]: + hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) @@ -415,14 +397,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 9599e1df6a3c..c04e7a02bae2 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -248,13 +248,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -309,8 +307,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -323,8 +319,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -370,8 +364,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -384,11 +376,8 @@ def forward( else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -425,13 +414,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 1a051992a306..cac1b2b3b11c 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -69,8 +68,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, @@ -88,8 +85,6 @@ def forward( hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=None) hidden_states = residual + hidden_states return self.shared_head(hidden_states) @@ -122,8 +117,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, @@ -131,8 +124,6 @@ def forward( return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( input_ids, positions, - kv_caches[spec_step_idx], - attn_metadata, previous_hidden_states, inputs_embeds, spec_step_idx, @@ -165,16 +156,14 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, previous_hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 9bf3ec2ffd81..22b2bf7ca469 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, @@ -279,8 +279,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] @@ -313,7 +311,7 @@ def forward( v = torch.nn.functional.pad( v, [0, self.qk_head_dim - self.v_head_dim], value=0).view(-1, self.num_local_heads * self.qk_head_dim) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output = attn_output.view( -1, self.num_local_heads, self.qk_head_dim)[..., :self.v_head_dim].reshape( @@ -451,8 +449,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: if self.q_lora_rank is not None: ckq = self.q_a_proj(hidden_states)[0] @@ -462,8 +458,7 @@ def forward( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, - attn_metadata) + return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe) class DeepseekV2DecoderLayer(nn.Module): @@ -532,8 +527,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -546,8 +539,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -608,8 +599,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -624,11 +613,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -665,13 +651,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 5f684fa295ad..4e2dda33bcab 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -13,7 +13,6 @@ from einops import rearrange, repeat from transformers import BatchFeature -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -595,8 +594,6 @@ def get_input_embeddings( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object): @@ -614,8 +611,6 @@ def forward(self, hidden_states = self.language_model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index ab3f0dc07f4d..f2a2935e6c69 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -121,8 +120,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -140,8 +137,6 @@ def forward( input_ids=None, inputs_embeds=inputs_embeds, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, ) return hidden_states diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index e795c7e288c4..79939f6f40e4 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -24,12 +24,12 @@ # limitations under the License. """Inference-only Exaone model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -179,13 +179,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -225,14 +223,10 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: return self.attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) @@ -288,8 +282,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -301,8 +293,6 @@ def forward( hidden_states = self.attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -365,8 +355,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -381,13 +369,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] + for layer in self.h[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) @@ -471,14 +456,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + model_output = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 01b66a1c2a5f..7154ac2e6a5a 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -20,14 +20,14 @@ """PyTorch Falcon model.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from torch.nn import LayerNorm from transformers import FalconConfig as HF_FalconConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -190,8 +190,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, bias = self.query_key_value(hidden_states) if bias is not None: @@ -199,7 +197,7 @@ def forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_rotary: q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, bias = self.dense(attn_output) return attn_output, bias @@ -291,8 +289,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states @@ -306,8 +302,6 @@ def forward( attention_output, attention_bias = self.self_attention( positions=positions, hidden_states=attention_layernorm_out, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) if self.reduce_row_parallel_results and attention_bias is not None: attention_output += attention_bias @@ -384,8 +378,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -396,14 +388,8 @@ def forward( hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) @@ -450,14 +436,11 @@ def forward( self, input_ids: torch.LongTensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 4a1ad5f4ee0c..06912bcfdc8a 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch import torch.nn as nn -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -50,8 +49,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + encoder_positions: torch.Tensor) -> torch.Tensor: r""" Args: input_ids @@ -64,10 +62,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Indices of *encoder* input sequence tokens in the vocabulary. encoder_positions: Positions of *encoder* input sequence tokens. - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Model output torch.Tensor """ @@ -78,18 +72,14 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + positions=encoder_positions) # decoder outputs consists of # (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids=input_ids, decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + encoder_hidden_states=encoder_hidden_states) return decoder_outputs @@ -122,8 +112,6 @@ def forward( positions: torch.Tensor, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: r""" @@ -136,15 +124,11 @@ def forward( torch.Tensor of *encoder* input token ids. encoder_positions torch.Tensor of *encoder* position indices - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Output torch.Tensor """ return self.model(input_ids, positions, encoder_input_ids, - encoder_positions, kv_caches, attn_metadata) + encoder_positions) def compute_logits( self, @@ -213,8 +197,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, *, encoder_input_ids: torch.Tensor, @@ -231,15 +213,11 @@ def forward( torch.Tensor of *encoder* input token ids. encoder_positions torch.Tensor of *encoder* position indices - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Output torch.Tensor """ return self.language_model(input_ids, positions, encoder_input_ids, - encoder_positions, kv_caches, attn_metadata) + encoder_positions) def compute_logits( self, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 42a6aa979427..4f5519f325e0 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -25,7 +25,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProcessor) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.sampler import SamplerOutput @@ -351,8 +350,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -371,8 +368,6 @@ def forward( hidden_states = self.language_model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index d0589e60a72b..da17646c540f 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -16,13 +16,13 @@ # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" from functools import cache -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GemmaConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -183,13 +183,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -233,8 +231,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -247,8 +243,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -298,8 +292,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -313,13 +305,10 @@ def forward( else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -370,13 +359,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 6ee257d65c50..cf744fc2b9d1 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -15,13 +15,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import Gemma2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -164,13 +164,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -220,8 +218,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -233,8 +229,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -284,8 +278,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -300,13 +292,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -415,13 +404,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 8fc5a797f824..48543c5642ea 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -4,7 +4,7 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from typing import List, Literal, Mapping, Optional, TypedDict, Union +from typing import Literal, Mapping, Optional, TypedDict, Union import torch from torch import nn @@ -15,7 +15,6 @@ from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput -from vllm.attention import AttentionMetadata from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -628,8 +627,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -645,8 +642,7 @@ def forward( vision_embeddings) input_ids = None - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 7ad9a24dcbbc..776c03f652bd 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GPT2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed.parallel_state import ( @@ -92,12 +92,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -164,16 +162,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + attn_output = self.attn(hidden_states=hidden_states) # residual connection hidden_states = attn_output + residual @@ -222,8 +214,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: @@ -236,11 +226,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer(hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -279,14 +266,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 799edff46ea3..43f3d4f6dc9c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -19,13 +19,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GPTBigCodeConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -101,8 +101,6 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.split( @@ -112,7 +110,7 @@ def forward( ], dim=-1, ) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -173,16 +171,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + attn_output = self.attn(hidden_states=hidden_states, ) # residual connection hidden_states = attn_output + residual @@ -234,8 +226,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -246,11 +236,8 @@ def forward( else: hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer(hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -302,14 +289,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 815aba145d30..752aec0b223d 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -17,13 +17,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GPTJConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -104,13 +104,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, _ = self.out_proj(attn_output) return attn_output @@ -167,16 +165,12 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) mlp_output = self.mlp(hidden_states) hidden_states = attn_output + mlp_output + residual @@ -217,8 +211,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -229,14 +221,8 @@ def forward( hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) @@ -273,14 +259,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 550ca3f7ca9e..4b30c7bb3035 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -17,13 +17,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GPTNeoXConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -104,13 +104,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -167,15 +165,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: attn_input = self.input_layernorm(hidden_states) attn_output = self.attention( position_ids=position_ids, hidden_states=attn_input, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) if self.use_parallel_residual: @@ -230,8 +224,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -242,14 +234,8 @@ def forward( hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layer_norm(hidden_states) @@ -285,14 +271,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.gpt_neox(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 2aeb179ee932..201e15d3a30f 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GraniteConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -166,13 +166,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -233,8 +231,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -242,8 +238,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states * self.residual_multiplier # Fully Connected @@ -300,8 +294,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -318,14 +310,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -405,13 +391,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 40df9c72c561..9b56874a8add 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GraniteMoe model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers.models.granitemoe import GraniteMoeConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -173,13 +173,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -226,8 +224,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -235,8 +231,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states * self.residual_multiplier residual = hidden_states @@ -287,8 +281,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -303,11 +295,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -377,13 +366,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 0f3a2ffe9a13..a20328289f92 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,15 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from array import array -from typing import List, Optional, Union +from typing import Optional, Union import torch import torch.nn as nn from xformers.ops.fmha.attn_bias import BlockDiagonalMask -from vllm.attention import AttentionMetadata from vllm.attention.backends.xformers import XFormersImpl from vllm.config import ModelConfig, VllmConfig +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.models.llama import LlamaForCausalLM @@ -217,13 +217,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: # Change attention to non-causal for pooling tasks. if self.runner_type == "pooling": + attn_metadata = get_forward_context().attn_metadata assert attn_metadata.prefill_metadata.attn_bias is None attn_metadata.prefill_metadata.attn_bias = [ BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) @@ -232,8 +231,6 @@ def forward( return super().forward( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, **kwargs, ) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 3a7e2a9a6a57..0a8763cf910c 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -25,7 +25,6 @@ from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, Idefics3Processor) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear @@ -563,8 +562,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -572,8 +569,6 @@ def forward( hidden_states = self.text_model( input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds, ) @@ -645,8 +640,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -664,8 +657,6 @@ def forward( hidden_states = self.model.text_model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index c5f7be135d71..22c9287509ed 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union, - overload, runtime_checkable) +from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload, + runtime_checkable) import torch import torch.nn as nn @@ -11,7 +11,6 @@ from vllm.utils import supports_kw if TYPE_CHECKING: - from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import PoolerOutput from vllm.model_executor.layers.sampler import SamplerOutput @@ -46,8 +45,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: "AttentionMetadata", ) -> T_co: ... @@ -62,7 +59,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: if not callable(model_forward): return False - vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata") + vllm_kws = ("input_ids", "positions") missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw)) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index b21933dd5da7..41ca399b9efb 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -175,13 +175,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.wqkv(hidden_states) q, k, v = self.split_qkv(qkv) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.wo(attn_output) return output @@ -227,8 +225,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -241,8 +237,6 @@ def forward( hidden_states = self.attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -290,8 +284,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -305,15 +297,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -363,13 +348,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states @@ -466,13 +448,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) logits, _ = self.v_head(hidden_states) return logits diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 106c3b6b78cc..69b0caab8f8e 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -65,8 +64,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], visual_token_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -80,8 +77,6 @@ def forward( hidden_states = self.attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -113,8 +108,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None, @@ -129,13 +122,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, visual_token_mask=visual_token_mask, ) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 4a6007876776..52ddb279cca3 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,7 +17,6 @@ from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig @@ -929,8 +928,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -951,8 +948,6 @@ def forward( forward_kwargs = { "input_ids": input_ids, "positions": positions, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, "intermediate_tensors": intermediate_tensors, "inputs_embeds": inputs_embeds, } diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 72bcef5e2282..78fe6588eddc 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -21,12 +21,12 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -123,12 +123,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -200,16 +198,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + attn_output = self.attn(hidden_states=hidden_states, ) # residual connection hidden_states = attn_output + residual @@ -266,8 +258,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: @@ -285,11 +275,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer(hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -332,14 +319,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5530e3ca708c..14e56df6cadf 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only Jamba model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import JambaConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -138,7 +137,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, **kwargs, @@ -150,8 +148,7 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, - mamba_cache_params) + hidden_states = self.mamba(hidden_states, mamba_cache_params) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -223,13 +220,11 @@ def self_attention( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -237,8 +232,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], **kwargs, ): @@ -252,8 +245,6 @@ def forward( hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( @@ -320,8 +311,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -339,12 +328,9 @@ def forward( kv_cache_index = 0 mamba_cache_index = 0 - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - kv_cache = None + for layer in self.layers[self.start_layer:self.end_layer]: layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache = kv_caches[kv_cache_index] kv_cache_index += 1 if isinstance(layer, JambaMambaDecoderLayer): current_state_layer = mamba_cache_index @@ -355,8 +341,6 @@ def forward( hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=layer_mamba_cache_params) if not get_pp_group().is_last_rank: @@ -429,8 +413,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): @@ -443,8 +425,7 @@ def forward(self, mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params, + hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 011d0a7aafaa..a0aff9e609d9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union import torch from torch import nn from transformers import LlamaConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -197,13 +197,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -268,8 +266,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -280,9 +276,7 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states=hidden_states) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -347,8 +341,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -363,11 +355,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -535,13 +524,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 19752ba703f4..72b1591306f2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -15,7 +15,6 @@ from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn @@ -658,8 +657,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -712,8 +709,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c39daec709fc..6a050d7798a2 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,6 @@ get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -508,8 +507,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -571,8 +568,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 2af3cc05080a..807d6977ed40 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -10,7 +10,6 @@ from transformers import (BatchFeature, LlavaNextVideoConfig, LlavaNextVideoProcessor) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -443,8 +442,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -468,8 +465,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 8eb8071e6577..e57eea4286e9 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -13,7 +13,6 @@ get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -922,8 +921,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -955,8 +952,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ba88950ee898..9f1cd8c29a5a 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import MambaConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -64,7 +63,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, **kwargs, @@ -75,8 +73,7 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, attn_metadata, - mamba_cache_params) + hidden_states = self.mixer(hidden_states, mamba_cache_params) return hidden_states, residual @@ -125,7 +122,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -146,7 +142,6 @@ def forward( hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=mamba_cache_params.at_layer_idx( i - self.start_layer)) @@ -208,8 +203,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): @@ -222,9 +215,8 @@ def forward(self, mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params, intermediate_tensors, - inputs_embeds) + hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 6366fc023682..266cdc243ac4 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """PyTorch MAMBA2 model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( @@ -63,7 +64,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, sequence_idx: Optional[torch.Tensor], @@ -75,8 +75,8 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, attn_metadata, - mamba_cache_params, sequence_idx) + hidden_states = self.mixer(hidden_states, mamba_cache_params, + sequence_idx) return hidden_states, residual @@ -122,7 +122,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -142,6 +141,7 @@ def forward( # proper continuous batching computation including # chunked prefill seq_idx = None + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata if attn_metadata.num_prefills > 0: seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) for i, (srt, end) in enumerate( @@ -158,7 +158,6 @@ def forward( hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=mamba_cache_params.at_layer_idx( i - self.start_layer), @@ -224,8 +223,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): @@ -238,9 +235,8 @@ def forward(self, mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params, intermediate_tensors, - inputs_embeds) + hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 54b691b3572d..34e1f3927a9a 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -23,13 +23,13 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -257,8 +257,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -266,7 +264,7 @@ def forward( q, k = q.float(), k.float() q, k = self.rotary_emb(positions, q, k) q, k = q.to(orig_dtype), k.to(orig_dtype) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -331,8 +329,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -341,8 +337,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states * \ (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) @@ -409,8 +403,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -424,13 +416,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -579,13 +568,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index b85306c40880..1b24c38cef1b 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -29,7 +29,7 @@ from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm @@ -129,8 +129,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: q, _ = self.q_a_proj(hidden_states) q = self.q_a_layernorm(q) @@ -170,7 +168,7 @@ def forward( v, [0, self.qk_head_dim - self.v_head_dim], value=0).view(-1, self.num_local_heads * self.qk_head_dim) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output = attn_output.view( -1, self.num_local_heads, self.qk_head_dim)[..., :self.v_head_dim].reshape( diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index aa8c193ed6a5..e354e5323327 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -33,7 +33,6 @@ from transformers.models.whisper.modeling_whisper import ( ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import MultiModalFieldConfig @@ -792,8 +791,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: Any, ) -> torch.Tensor: @@ -818,8 +815,6 @@ def forward( output = self.llm.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=vlm_embeddings, ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5e883d00c1c6..46f794e88ad5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -37,7 +37,6 @@ from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, @@ -1030,8 +1029,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: Any, ) -> torch.Tensor: @@ -1051,8 +1048,6 @@ def forward( output = self.llm.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=vlm_embeddings, ) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b83b69fd2c2d..c8dea557e571 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import MixtralConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -175,13 +175,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -224,8 +222,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -238,8 +234,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -291,8 +285,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -306,11 +298,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -377,13 +366,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index fdc438917542..21b52d9f54c7 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import numpy as np import torch @@ -30,7 +30,7 @@ from torch import nn from transformers import MixtralConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -229,13 +229,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -274,8 +272,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -288,8 +284,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -333,8 +327,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -348,11 +340,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -390,13 +379,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 1f8f5b2eb136..459928fe3fb0 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -38,7 +38,8 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.selector import _Backend from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tp_group +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -416,11 +417,11 @@ def __init__(self, prefix: str = ""): super().__init__() - model_parallel_size = get_tensor_model_parallel_world_size() + tensor_parallel_size = get_tp_group().world_size self.embed_dim = config.hidden_size self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads - self.num_local_heads = self.num_heads // model_parallel_size + self.num_local_heads = self.num_heads // tensor_parallel_size self.q_size = self.num_local_heads * self.head_dim self.kv_size = self.num_local_heads * self.head_dim @@ -771,12 +772,13 @@ def __init__( ): super().__init__() self.config = config - self.model_parallel_size = get_tensor_model_parallel_world_size() + self.pipeline_parallel_rank = get_pp_group().rank_in_group + self.tensor_parallel_size = get_tp_group().world_size self.num_heads = self.config.num_attention_heads - self.num_local_heads = self.num_heads // self.model_parallel_size + self.num_local_heads = self.num_heads // self.tensor_parallel_size self.num_key_value_heads = self.config.num_key_value_heads self.num_local_key_value_heads = \ - self.num_key_value_heads // self.model_parallel_size + self.num_key_value_heads // self.tensor_parallel_size self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads @@ -824,8 +826,6 @@ def forward( attention_mask: Optional[torch.Tensor], kv_range_for_decode: Optional[List[Tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv_dec, _ = self.qkv_proj(hidden_states) q, _, _ = qkv_dec.split( @@ -846,14 +846,11 @@ def forward( q = self.q_norm(q) if attention_mask is not None: - output = self._attention_with_mask(q, k, v, kv_cache, - attention_mask, - kv_range_for_decode, - attn_metadata) + output = self._attention_with_mask(q, k, v, attention_mask, + kv_range_for_decode) else: output = self.attn( - q.view(-1, self.num_local_heads * self.head_dim), k, v, - kv_cache, attn_metadata) + q.view(-1, self.num_local_heads * self.head_dim), k, v) out, _ = self.o_proj(output) return out @@ -862,11 +859,11 @@ def _attention_with_mask( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - kv_cache: torch.Tensor, attention_mask: torch.Tensor, kv_range_for_decode: List[Tuple[int, int]], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: + kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) @@ -978,8 +975,6 @@ def forward( cross_attention_mask: torch.Tensor, kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: torch.Tensor, - kv_cache: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -989,8 +984,6 @@ def forward( attention_mask=cross_attention_mask, kv_range_for_decode=kv_range_for_decode, cross_attention_states=cross_attention_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh( @@ -1054,14 +1047,12 @@ def forward( kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, skip_cross_attention: bool, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): if not skip_cross_attention: hidden_states = decoder_layer( @@ -1071,15 +1062,11 @@ def forward( kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask= full_text_row_masked_out_mask, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, ) elif isinstance(decoder_layer, LlamaDecoderLayer): hidden_states, residual = decoder_layer( positions=positions, hidden_states=hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, residual=None, ) hidden_states = hidden_states + residual @@ -1124,8 +1111,6 @@ def forward( kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, skip_cross_attention: bool, ) -> torch.Tensor: hidden_states = self.model( @@ -1135,8 +1120,6 @@ def forward( cross_attention_mask=cross_attention_mask, kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - kv_caches=kv_caches, - attn_metadata=attn_metadata, skip_cross_attention=skip_cross_attention, ) return hidden_states @@ -1353,10 +1336,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs: object, ) -> Union[Tuple, CausalLMOutputWithPast]: + attn_metadata = get_forward_context().attn_metadata if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") @@ -1410,8 +1392,6 @@ def forward( cross_attention_mask=cross_attention_mask, kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - kv_caches=kv_caches, - attn_metadata=attn_metadata, skip_cross_attention=skip_cross_attention, ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 6ce9fbda182f..cc4d38d8740b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -16,7 +16,7 @@ from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -460,15 +460,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.q_norm is not None and self.k_norm is not None: q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -580,8 +578,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Self Attention @@ -594,8 +590,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states, residual = self.post_attention_layernorm( @@ -610,8 +604,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Self Attention @@ -619,8 +611,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.input_layernorm(hidden_states) @@ -841,8 +831,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -858,13 +846,10 @@ def forward( residual = intermediate_tensors["residual"] # Apply blocks one-by-one. - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -1643,8 +1628,6 @@ def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -1663,8 +1646,6 @@ def forward( hidden_states = self.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 676c960623ed..d716818f31c0 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -2,12 +2,12 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch import torch.nn as nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -125,8 +125,6 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: del position_ids # unused. qkv, _ = self.Wqkv(hidden_states) @@ -136,7 +134,7 @@ def forward( if self.qk_ln: q = self.q_ln(q) k = self.k_ln(k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -196,15 +194,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: x = self.norm_1(hidden_states) x = self.attn( position_ids=position_ids, hidden_states=x, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = hidden_states + x x = self.norm_2(hidden_states) @@ -253,8 +247,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -267,14 +259,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - block = self.blocks[i] - hidden_states = block( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for block in self.blocks[self.start_layer:self.end_layer]: + hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) @@ -306,14 +292,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index a42734edb39a..3b86b91465ca 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -27,7 +27,7 @@ import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -204,13 +204,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -269,8 +267,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -283,8 +279,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -343,8 +337,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -359,15 +351,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -444,13 +429,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 3b470dfdd05b..4a341c97d6cd 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import OlmoConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -119,15 +119,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -212,14 +210,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Attention block. residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(positions, hidden_states, kv_cache, - attn_metadata) + hidden_states = self.self_attn(positions, hidden_states) hidden_states = hidden_states + residual # MLP block. @@ -263,8 +258,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -281,14 +274,9 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] # Apply blocks one-by-one. - for i in range(self.start_layer, self.end_layer): + for layer in self.layers[self.start_layer:self.end_layer]: # shape: (batch_size, seq_len, d_model) - hidden_states = self.layers[i]( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -332,16 +320,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index d06f894123ac..54cc851de934 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -24,12 +24,12 @@ """Inference-only OLMo2 model compatible with HuggingFace weights.""" from functools import partial -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.communication_op import tensor_model_parallel_all_gather @@ -153,14 +153,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -239,13 +237,10 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Attention block. residual = hidden_states - hidden_states = self.self_attn(positions, hidden_states, kv_cache, - attn_metadata) + hidden_states = self.self_attn(positions, hidden_states) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = hidden_states + residual @@ -287,8 +282,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], ) -> Union[torch.Tensor, IntermediateTensors]: """ @@ -307,14 +300,9 @@ def forward( assert isinstance(hidden_states, torch.Tensor) # Apply blocks one-by-one. - for i in range(self.start_layer, self.end_layer): + for layer in self.layers[self.start_layer:self.end_layer]: # shape: (batch_size, seq_len, d_model) - hidden_states = self.layers[i]( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -357,15 +345,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, ) return hidden_states diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index d6e24c6d67f3..e27ff5deace2 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -168,14 +168,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -222,8 +220,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -237,8 +233,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -283,8 +277,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -299,13 +291,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) @@ -347,13 +336,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index ad1d66902435..e4775478a54d 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import OPTConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -107,12 +107,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -164,17 +162,13 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Self Attention residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: @@ -261,8 +255,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -277,11 +269,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer(hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -317,15 +306,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: return self.decoder(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) @@ -362,13 +347,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index f4f5cdff6437..6668ede91eec 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -5,13 +5,13 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -136,13 +136,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -189,8 +187,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -198,8 +194,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -247,8 +241,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -260,14 +252,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -303,13 +289,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 955a59953eb4..02d1861b8027 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, +from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch from torch import nn from transformers import PaliGemmaConfig -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) @@ -288,8 +287,6 @@ def get_input_embeddings( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: @@ -306,8 +303,6 @@ def forward(self, hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 6a80bea348ea..db8d170a8c91 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -21,13 +21,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PersimmonConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -142,8 +142,6 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # [seq_length, 3 x hidden_size] qkv, _ = self.query_key_value(hidden_states) @@ -161,7 +159,7 @@ def forward( k = self._merge_heads(k) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -189,8 +187,6 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states @@ -200,8 +196,6 @@ def forward( hidden_states = self.self_attn( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -248,8 +242,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -261,13 +253,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - hidden_states = self.layers[i]( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) @@ -298,16 +285,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 1ca8cad22ad9..6ee80210c2b4 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -36,13 +36,13 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PhiConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -126,13 +126,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -186,16 +184,12 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_outputs = self.self_attn( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = attn_outputs + feed_forward_hidden_states + residual @@ -234,8 +228,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -247,14 +239,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -304,13 +290,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 873e9d37771d..33984f54ae27 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers.configuration_utils import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -231,8 +231,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: qkv, _ = self.query_key_value(hidden_states) @@ -248,7 +246,7 @@ def forward( v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -282,8 +280,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -291,8 +287,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -338,8 +332,6 @@ def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: @@ -354,14 +346,8 @@ def forward( else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) @@ -438,16 +424,12 @@ def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: output_hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 207204df2055..61d63e104de4 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -23,7 +23,6 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, ProcessorMixin) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig @@ -672,8 +671,6 @@ def get_input_embeddings( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object): @@ -691,8 +688,6 @@ def forward(self, hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 17369cb58e36..c35c7e9fcce7 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers.configuration_utils import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -357,13 +357,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -410,8 +408,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: residual = hidden_states @@ -422,8 +418,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = hidden_states + residual @@ -478,8 +472,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -494,13 +486,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) @@ -571,13 +560,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 273dc3b1cf75..87b1d50749a2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -16,7 +16,6 @@ from transformers.models.pixtral.modeling_pixtral import ( PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, @@ -270,8 +269,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -291,8 +288,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 9383cbae11bc..0d0c367e677e 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -15,13 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM/NASA Prithvi Geospatial model.""" -from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union +from typing import Iterable, Mapping, Optional, Set, Tuple, Union import torch import torch.nn as nn from transformers import BatchFeature -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (IsAttentionFree, @@ -181,8 +180,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 7c4627036203..96abfb9d1096 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -6,13 +6,13 @@ # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -124,13 +124,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.c_proj(attn_output) return output @@ -168,8 +166,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -181,8 +177,6 @@ def forward( hidden_states = self.attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -225,8 +219,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -241,13 +233,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] + for layer in self.h[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -373,12 +362,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7da6e558ff33..fe615c41aeaa 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -23,13 +23,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import Qwen2Config -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -170,13 +170,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -233,8 +231,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -247,8 +243,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -328,8 +322,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -343,13 +335,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -468,13 +457,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states @@ -553,12 +539,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors) + return self.model(input_ids, positions, intermediate_tensors) def pooler( self, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ef31f18445fd..858cf28d2b87 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -37,7 +37,6 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -992,8 +991,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -1047,8 +1044,6 @@ def forward( hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 3df5dd2bdd41..f0dc8573ee14 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -22,8 +22,8 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from functools import cached_property -from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict, + Union) import torch import torch.nn as nn @@ -33,7 +33,6 @@ Qwen2AudioProcessor) from transformers.models.whisper import WhisperFeatureExtractor -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -380,8 +379,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -400,8 +397,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 35d9854a55d6..41536b34b2f2 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -23,14 +23,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, @@ -232,13 +232,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -296,8 +294,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -310,8 +306,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -358,8 +352,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -373,11 +365,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -416,13 +405,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index c6588a47d881..21cc9e8ed1c6 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -5,12 +5,11 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-RM model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -80,13 +79,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) logits, _ = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 31701abd3339..849ef7293bb7 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -24,8 +24,8 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import cached_property, partial -from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, - Set, Tuple, Type, TypedDict, Union) +from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set, + Tuple, Type, TypedDict, Union) import torch import torch.nn as nn @@ -38,7 +38,6 @@ Qwen2VLConfig, Qwen2VLVisionConfig) from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -1302,8 +1301,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -1354,8 +1351,6 @@ def forward( hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 56faa390fc5d..e0d8bf2fa3d2 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -22,7 +22,6 @@ from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -766,8 +765,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -783,7 +780,6 @@ def forward( vision_embeddings) input_ids = None - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 742e63a065b1..f86fa268072d 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import RobertaConfig -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -243,16 +242,12 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.roberta(input_ids=input_ids, position_ids=positions, - kv_caches=kv_caches, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, - attn_metadata=attn_metadata, token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index ad98f3b07034..0f9e517aeb55 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -23,13 +23,13 @@ # limitations under the License. """Inference-only Solar model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -172,13 +172,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -238,8 +236,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -252,8 +248,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -315,8 +309,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -357,8 +349,6 @@ def forward( hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) @@ -438,13 +428,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index a5d4432669f4..a15faec547b9 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -20,13 +20,13 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import StableLmConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -147,13 +147,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -183,8 +181,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -192,8 +188,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -241,8 +235,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -254,14 +246,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) @@ -296,13 +282,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 01ea43666482..90098af9dde0 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -19,13 +19,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import Starcoder2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -118,13 +118,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -184,8 +182,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -193,8 +189,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -246,8 +240,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -259,11 +251,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) @@ -306,13 +295,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index b431abb76b69..1c3c443b2941 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -22,7 +22,7 @@ from transformers import AutoModel, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide @@ -59,7 +59,6 @@ def vllm_flash_attention_forward( # Transformers kwargs scaling: Optional[float] = None, # vLLM kwargs - attn_metadata: Optional[AttentionMetadata] = None, attention_instances: Optional[list[Attention]] = None, **kwargs): self_attn = attention_instances[module.layer_idx] @@ -68,12 +67,7 @@ def vllm_flash_attention_forward( hidden = query.shape[-2] query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) - return self_attn.forward( - query, - key, - value, - kv_cache=None, # argument not used - attn_metadata=attn_metadata), None + return self_attn.forward(query, key, value), None ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward @@ -251,8 +245,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: list[torch.Tensor], # argument not used - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -260,7 +252,6 @@ def forward( input_ids[None, ...], use_cache=False, position_ids=positions[None, ...], - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, attention_instances=self.attention_instances, return_dict=False)[0][0, ...] # we remove batch dimension for now diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b99094e5d4ca..1dbba3c50b19 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,8 +4,8 @@ """PyTorch Ultravox model.""" import math from functools import cached_property -from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) import torch import torch.utils.checkpoint @@ -16,8 +16,8 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder from vllm import envs -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -495,13 +495,13 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, - attn_metadata: Optional[AttentionMetadata] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: # TODO(ywang96): remove this block after v0 is deprecated. if not envs.VLLM_USE_V1: + attn_metadata = get_forward_context().attn_metadata merge_multimodal_embeddings_from_map( inputs_embeds, multimodal_embeddings, attn_metadata.multi_modal_placeholder_index_maps["audio"]) @@ -514,8 +514,6 @@ def get_input_embeddings( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, IntermediateTensors]: @@ -540,17 +538,12 @@ def forward(self, elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - # TODO(ywang96): remove attn_metadata from get_input_embeddings - # after v0 is deprecated inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings, - attn_metadata) + multimodal_embeddings) input_ids = None hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 2ad1731144ef..e5f77e08c403 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -10,7 +10,7 @@ WhisperProcessor) from transformers.models.whisper.modeling_whisper import sinusoids -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -134,13 +134,11 @@ def _init_qkv( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) @@ -196,8 +194,6 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): q, _ = self.q_proj(hidden_states) @@ -209,13 +205,7 @@ def forward( else: k = v = None - attn_output = self.attn( - q, - k, - v, - kv_cache, - attn_metadata, - ) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) @@ -285,16 +275,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -348,14 +332,10 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states @@ -363,8 +343,6 @@ def forward( hidden_states = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -411,12 +389,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_positions.weight.copy_( sinusoids(*self.embed_positions.weight.shape)) - def forward( - self, - input_features: Union[torch.Tensor, List[torch.Tensor]], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ): + def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]): hidden_states = [] for features in input_features: embeds = nn.functional.gelu(self.conv1(features)) @@ -426,12 +399,8 @@ def forward( hidden_states.append(embeds) hidden_states = torch.cat(hidden_states) - for idx, encoder_layer in enumerate(self.layers): - hidden_states = encoder_layer( - hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, - ) + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) hidden_states = self.layer_norm(hidden_states) return hidden_states @@ -466,19 +435,15 @@ def forward( input_ids, positions: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ): inputs_embeds = self.get_input_embeddings(input_ids) positions = self.embed_positions(positions) hidden_states = inputs_embeds + positions - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, encoder_hidden_states=encoder_hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, ) hidden_states = self.layer_norm(hidden_states) @@ -505,36 +470,22 @@ def forward( input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: - encoder_outputs = self.get_encoder_outputs( - input_features, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) + encoder_outputs = self.get_encoder_outputs(input_features) decoder_outputs = self.decoder( input_ids=input_ids, positions=positions, encoder_hidden_states=encoder_outputs, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) return decoder_outputs def get_encoder_outputs( self, input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> Optional[torch.Tensor]: if input_features is None: return None - return self.encoder( - input_features, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) + return self.encoder(input_features) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -733,8 +684,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: audio_input = self._parse_and_validate_audio_input(**kwargs) @@ -742,31 +691,19 @@ def forward( input_features=audio_input["input_features"], input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) return decoder_outputs - def get_multimodal_embeddings( - self, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - **kwargs, - ) -> Optional[NestedTensors]: + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # TODO: This method does not obey the interface for SupportsMultiModal. # Refactor this once encoder/decoder support is implemented in V1. audio_input = self._parse_and_validate_audio_input(**kwargs) - return self.model.get_encoder_outputs( - audio_input["input_features"], - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) + return self.model.get_encoder_outputs(audio_input["input_features"]) def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, - attn_metadata: Optional[AttentionMetadata] = None, ) -> torch.Tensor: # TODO: This method just returns the decoder sequence embeddings since # Whisper does not have encoder text tokens. Refactor this once diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 7353d3c53ae9..40ecc3481e6b 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -288,8 +288,6 @@ def execute_model( hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a7b9d4781183..1fbce3098a34 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -939,8 +939,6 @@ def execute_model( hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=self.kv_caches, - attn_metadata=None, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) @@ -1137,11 +1135,8 @@ def _get_prompt_logprobs_dict( def _dummy_run( self, num_tokens: int, - kv_caches: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: model = self.model - if kv_caches is None: - kv_caches = self.kv_caches if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -1172,26 +1167,12 @@ def _dummy_run( hidden_states = model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=None, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def profile_run(self) -> None: - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value `None`. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - dummy_kv_caches = [ - torch.tensor((), dtype=torch.float32, device=self.device) - for _ in range(self.num_attn_layers) - ] - # Profile with multimodal encoder & encoder cache. # TODO: handle encoder-decoder models once we support them. if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 @@ -1302,8 +1283,7 @@ def profile_run(self) -> None: with self.maybe_profile_with_lora(self.lora_config, num_scheduled_tokens): # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens, - dummy_kv_caches) + hidden_states = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 9aa74ddee81b..b68c1ac9d71b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -14,11 +14,10 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr -from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingType @@ -645,7 +644,6 @@ def execute_model( assert self.model is not None selected_token_ids = self.model(prompt_data.input_tokens, prompt_data.input_positions, - prompt_data.attn_metadata, self.kv_caches) # In parallel to TPU execution, prepare the next iteration @@ -684,7 +682,6 @@ def execute_model( assert self.model is not None selected_token_ids = self.model(decode_data.input_tokens, decode_data.input_positions, - decode_data.attn_metadata, self.kv_caches) # Transfer sampled tokens from TPU to CPU @@ -873,7 +870,7 @@ def dummy_run( with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model(token_ids, position_ids, attn_metadata, kv_caches) + self.model(token_ids, position_ids, kv_caches) def capture_model(self) -> None: """Compile the model.""" @@ -1000,7 +997,6 @@ def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, - attn_metadata: AttentionMetadata, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -1008,7 +1004,6 @@ def forward( Args: token_ids: The input token IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len]. - attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. @@ -1017,7 +1012,8 @@ def forward( memory profiling at initialization. """ # Skip this in memory profiling at initialization. - if attn_metadata is not None and kv_caches[0][0].numel() > 0: + if kv_caches[0][0].numel() > 0: + attn_metadata = get_forward_context().attn_metadata # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it @@ -1038,12 +1034,7 @@ def forward( attn_metadata.slot_mapping = slot_mapping assert self.model is not None - hidden_states = self.model( - token_ids, - position_ids, - kv_caches, - attn_metadata, - ) + hidden_states = self.model(token_ids, position_ids) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, None) diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index 71e32c5f7aca..ac7c93e48395 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -297,10 +297,6 @@ def execute_model( model_input.encoder_input_tokens, "encoder_positions": model_input.encoder_input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - model_input.attn_metadata, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), "intermediate_tensors": diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 9400893105d7..8407f073040e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -654,8 +654,6 @@ def execute_model( hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **execute_model_kwargs, **multimodal_kwargs, diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index c0744d63b8d0..1ceb2557c6b3 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -41,16 +41,6 @@ def execute_model( raise ValueError( "CPU worker does not support multi-step execution.") - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] - model_executable = self.model cross_enc_kwargs = {} if model_input.token_type_ids is not None: @@ -60,10 +50,6 @@ def execute_model( model_input.input_tokens, "positions": model_input.input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - model_input.attn_metadata, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), **cross_enc_kwargs, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index e2d338f75761..5f39f2fa4947 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -184,8 +184,6 @@ def execute_model( positions=model_input.input_positions, encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), @@ -324,21 +322,11 @@ def profile_run(self) -> None: or encoder_dummy_data.multi_modal_placeholders) seqs.append(seq) - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None - self.execute_model(model_input, kv_caches, intermediate_tensors) + self.execute_model(model_input, None, intermediate_tensors) torch.cuda.synchronize() return diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index f22526cfad70..d6eaf84e40f6 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -384,11 +384,12 @@ def forward(self, *args, **kwargs): if 'virtual_engine' in kwargs: virtual_engine = kwargs.pop('virtual_engine') input_ids = kwargs['input_ids'] - kwargs['attn_metadata'] = self._update_metadata( - kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) + attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'), + input_ids.size(0), + input_ids.size(1), + input_ids.device, self.dtype) LoraMask.setLoraMask(kwargs.pop('lora_mask')) - with set_forward_context(kwargs['attn_metadata'], self.vllm_config, + with set_forward_context(attn_metadata, self.vllm_config, virtual_engine): hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -1346,15 +1347,13 @@ def profile_run(self) -> None: max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] max_batch_size = min(self.max_num_batched_tokens // max_seq_len, self.scheduler_config.max_num_seqs) - self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, - False, True) + self.warmup_scenario(max_batch_size, max_seq_len, True, False, True) return def warmup_scenario(self, batch_size, seq_len, is_prompt, - kv_caches, is_pt_profiler_run=False, is_lora_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) @@ -1418,7 +1417,7 @@ def warmup_scenario(self, profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + self.execute_model(inputs, None, warmup_mode=True) torch.hpu.synchronize() if profiler: profiler.step() @@ -1470,17 +1469,16 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): f"free_mem:{free_mem}") logger.info(msg) - def warmup_all_buckets(self, buckets, is_prompt, kv_caches): + def warmup_all_buckets(self, buckets, is_prompt): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + self.warmup_scenario(batch_size, seq_len, is_prompt) def warmup_graphs(self, strategy, buckets, is_prompt, - kv_caches, available_mem, starting_mem=0, total_batch_seq=0.001): @@ -1512,7 +1510,7 @@ def warmup_graphs(self, self.graphed_buckets.add(graphed_bucket) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + self.warmup_scenario(batch_size, seq_len, is_prompt) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) available_mem -= used_mem @@ -1542,8 +1540,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: graphs = graph == 't' if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, - True) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, True) raise AssertionError("Finished profiling") if self.skip_warmup: logger.info("Skipping warmup...") @@ -1608,9 +1605,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, - True, kv_caches) + True) self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, - False, kv_caches) + False) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1641,11 +1638,11 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( prompt_strategy, self.bucketing_global_state.prompt_buckets, - True, kv_caches, prompt_available_memory) + True, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( decode_strategy, self.bucketing_global_state.decode_buckets, - False, kv_caches, decode_available_memory) + False, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1656,7 +1653,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.warmup_graphs( prompt_strategy, self.bucketing_global_state.prompt_buckets, True, - kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1669,7 +1665,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: mem_post_decode, _, _ = self.warmup_graphs( decode_strategy, self.bucketing_global_state.decode_buckets, False, - kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) @@ -1982,7 +1977,6 @@ def execute_model( execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, - "kv_caches": kv_caches, "attn_metadata": self.trim_attn_metadata(attn_metadata), "intermediate_tensors": intermediate_tensors, "lora_mask": lora_mask, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1a78498ad124..86dcde234f86 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -26,7 +26,7 @@ from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, graph_capture) -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -1727,8 +1727,6 @@ def execute_model( hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), @@ -1913,8 +1911,6 @@ def capture( self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_inputs, **kwargs, ) @@ -1927,8 +1923,6 @@ def capture( output_hidden_or_intermediate_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_inputs, **kwargs, ) @@ -1976,13 +1970,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], **kwargs, ) -> torch.Tensor: - # KV caches are fixed tensors, so we don't need to copy them. - del kv_caches + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 90771e8ac75d..7ddf382079c6 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -476,7 +476,7 @@ def execute_model( # path for warm up runs if not model_input.is_multi_step: return self._base_model_runner.execute_model( - frozen_model_input, kv_caches, intermediate_tensors, num_steps) + frozen_model_input, None, intermediate_tensors, num_steps) # make sure we skip the sampler on the lask rank and only pythonize # if CPU is ahead. @@ -538,7 +538,7 @@ def execute_model( # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, - kv_caches, + None, intermediate_tensors, num_steps=1) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index f7a5ab9de9fa..5035ea20294c 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -346,10 +346,6 @@ def execute_model( input_tokens, "positions": input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - attn_metadata, **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, device=self.device), } diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index 4cbe5db44534..cbd5e2060cad 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -91,16 +91,6 @@ def execute_model( else: model_executable = self.model - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] - multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -121,8 +111,6 @@ def execute_model( hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 1051fa1b74c7..bb973f883248 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -15,7 +15,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -334,8 +334,8 @@ def _dummy_run( torch._dynamo.mark_dynamic(p, 0) # Dummy run. with set_forward_context(attn_metadata, self.vllm_config, 0): - self.model(token_ids, position_ids, attn_metadata, input_lens, t, - p, num_samples, kv_caches) + self.model(token_ids, position_ids, input_lens, t, p, num_samples, + kv_caches) def warmup_model( self, @@ -814,8 +814,8 @@ def execute_model( self.vllm_config, model_input.virtual_engine): output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, - p, model_input.num_samples, + input_lens, t, p, + model_input.num_samples, kv_caches) next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -871,8 +871,8 @@ def execute_model( self.vllm_config, model_input.virtual_engine): output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, - p, model_input.num_samples, + input_lens, t, p, + model_input.num_samples, kv_caches) self.cached_step_outputs.append(output_token_ids) @@ -949,7 +949,6 @@ def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, - attn_metadata: AttentionMetadata, input_lens: torch.Tensor, t: torch.Tensor, p: torch.Tensor, @@ -961,7 +960,6 @@ def forward( Args: token_ids: The input token IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len]. - attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. @@ -974,6 +972,7 @@ def forward( start_indicies = torch.arange( batch_size, dtype=torch.int32, device=input_lens.device) * seq_len logits_indices = start_indicies + input_lens - 1 + attn_metadata = get_forward_context().attn_metadata # FIXME(woosuk): This is a temporary hack to avoid using the existing # sampler and sampling metadata. @@ -1005,12 +1004,7 @@ def forward( slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping - hidden_states = self.model( - token_ids, - position_ids, - kv_caches, - attn_metadata, - ) + hidden_states = self.model(token_ids, position_ids) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9c726e1a107e..39957e661c47 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -484,15 +484,6 @@ def profile_run(self) -> None: multi_modal_placeholders=dummy_data.multi_modal_placeholders) seqs.append(seq) - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - ] * num_layers finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) @@ -502,7 +493,7 @@ def profile_run(self) -> None: batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) - self.execute_model(model_input, kv_caches, intermediate_tensors) + self.execute_model(model_input, None, intermediate_tensors) torch.xpu.synchronize() return @@ -581,8 +572,6 @@ def execute_model( hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, From 9d08eccf4c92d5394cb0956c2492c087b3c13726 Mon Sep 17 00:00:00 2001 From: Eli Boyarski Date: Tue, 25 Feb 2025 04:23:04 +0200 Subject: [PATCH 205/317] [Doc] arg_utils.py: fixed a typo (#13785) --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8378a116a6d4..663ea1ef8afd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -382,7 +382,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'https://github.com/noamgat/lm-format-enforcer.' ' Can be overridden per request via guided_decoding_backend' ' parameter.\n' - 'Backend-sepcific options can be supplied in a comma-separated ' + 'Backend-specific options can be supplied in a comma-separated ' 'list following a colon after the backend name. Valid backends and ' 'all available options are: [xgrammar:no-fallback, ' 'outlines:no-fallback, lm-format-enforcer:no-fallback]') From 82eef03d0ae92591882afdc9ec40eb22c5c1fb47 Mon Sep 17 00:00:00 2001 From: cjackal <44624812+cjackal@users.noreply.github.com> Date: Tue, 25 Feb 2025 11:26:12 +0900 Subject: [PATCH 206/317] [Misc] set single whitespace between log sentences (#13771) Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> --- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/backends/mla/common.py | 2 +- vllm/attention/backends/rocm_flash_attn.py | 4 ++-- vllm/config.py | 12 ++++++------ .../device_communicators/pynccl_wrapper.py | 4 ++-- .../distributed/kv_transfer/kv_pipe/mooncake_pipe.py | 2 +- vllm/entrypoints/chat_utils.py | 2 +- vllm/entrypoints/llm.py | 2 +- vllm/entrypoints/openai/api_server.py | 2 +- vllm/executor/ray_distributed_executor.py | 2 +- vllm/executor/ray_utils.py | 2 +- vllm/lora/models.py | 2 +- .../compressed_tensors/compressed_tensors_moe.py | 2 +- vllm/model_executor/layers/quantization/gptq.py | 2 +- vllm/model_executor/layers/quantization/modelopt.py | 2 +- .../layers/quantization/neuron_quant.py | 4 ++-- .../layers/quantization/quark/quark_moe.py | 2 +- .../layers/quantization/utils/marlin_utils.py | 2 +- vllm/model_executor/model_loader/loader.py | 2 +- vllm/model_executor/models/deepseek_vl2.py | 2 +- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/gritlm.py | 8 ++++---- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/prithvi_geospatial_mae.py | 4 ++-- vllm/multimodal/profiling.py | 2 +- vllm/platforms/cuda.py | 2 +- vllm/platforms/openvino.py | 2 +- vllm/platforms/xpu.py | 2 +- vllm/prompt_adapter/models.py | 2 +- vllm/spec_decode/draft_model_runner.py | 2 +- vllm/transformers_utils/configs/jais.py | 8 ++++---- vllm/utils.py | 6 +++--- vllm/v1/worker/gpu_worker.py | 2 +- vllm/worker/openvino_worker.py | 2 +- vllm/worker/worker.py | 4 ++-- 36 files changed, 54 insertions(+), 54 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 715ed6748b84..0556c191ddea 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -438,7 +438,7 @@ def __post_init__(self): not in supported_head_sizes: raise ValueError( f"Only {supported_head_sizes} are supported for head_dim,", - f"received {self.head_dim}.") + f" received {self.head_dim}.") def begin_forward(self): if self.num_prefill_tokens > 0: diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index c3dbbdb86823..f47ea3684e03 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -533,7 +533,7 @@ def __post_init__(self): not in supported_head_sizes: raise ValueError( f"Only {supported_head_sizes} are supported for head_dim,", - f"received {self.head_dim}.") + f" received {self.head_dim}.") @property def prefill_metadata(self) -> Optional["MLACommonMetadata"]: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1b1f6ca9beed..3f40686ee2fd 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -497,7 +497,7 @@ def __init__( if logits_soft_cap is not None: raise ValueError( "ROCm Triton FlashAttention does not support attention" - "logits soft capping." + " logits soft capping." " please try using the ROCm CK " "FA backend instead by setting the env var " "`VLLM_USE_TRITON_FLASH_ATTN=0`") @@ -528,7 +528,7 @@ def __init__( if self.use_naive_attn: if logits_soft_cap is not None: raise ValueError( - "ROCm Naive FlashAttention does not support" + "ROCm Naive FlashAttention does not support " "attention logits soft capping.") self.attn_func = _sdpa_attention diff --git a/vllm/config.py b/vllm/config.py index fea673b68560..8e1ce87438af 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -924,8 +924,8 @@ def get_num_layers_by_block_type( layers_block_type_value = getattr(self.hf_config, "layers_block_type", None) if layers_block_type_value is None: - raise ValueError("The model is an hybrid without a" - "layers_block_type in the hf_config," + raise ValueError("The model is an hybrid without a " + "layers_block_type in the hf_config, " "cannot determine the num of " f"{block_type.value} layers") @@ -2516,7 +2516,7 @@ def _get_and_verify_dtype( if current_platform.is_hpu() and config_dtype == torch.float16: logger.info( - "For HPU, we cast models to bfloat16 instead of" + "For HPU, we cast models to bfloat16 instead of " "using float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 @@ -2732,7 +2732,7 @@ def __post_init__(self): backend=self.guided_decoding_backend).backend_name if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}," - f"must be one of {valid_guided_backends}") + f" must be one of {valid_guided_backends}") @dataclass @@ -3008,7 +3008,7 @@ def uuid(self): def model_post_init(self, __context: Any) -> None: if not self.enable_reshape and self.enable_fusion: logger.warning_once( - "Fusion enabled but reshape elimination disabled." + "Fusion enabled but reshape elimination disabled. " "RMSNorm + quant (fp8) fusion might not work") pass_config: PassConfig = Field(default_factory=PassConfig) @@ -3563,7 +3563,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): logger.warning( "`torch.compile` is turned on, but the model %s" " does not support it. Please open an issue on GitHub" - "if you want it to be supported.", + " if you want it to be supported.", vllm_config.model_config.model) _current_vllm_config = old_vllm_config diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 03c3b0be7639..4f04899e92e6 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -227,10 +227,10 @@ def __init__(self, so_file: Optional[str] = None): self.lib = NCCLLibrary.path_to_library_cache[so_file] except Exception as e: logger.error( - "Failed to load NCCL library from %s ." + "Failed to load NCCL library from %s. " "It is expected if you are not running on NVIDIA/AMD GPUs." "Otherwise, the nccl library might not exist, be corrupted " - "or it does not support the current platform %s." + "or it does not support the current platform %s. " "If you already have the library, please set the " "environment variable VLLM_NCCL_SO_PATH" " to point to the correct nccl library path.", so_file, diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 58ab7f0b6424..57a2b0393ba4 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -137,7 +137,7 @@ def initialize(self, local_hostname: str, metadata_server: str, if metadata_backend not in supported_backend: raise ValueError( "Mooncake Configuration error. `metadata_backend`" - f"should be one of {supported_backend}.") + f" should be one of {supported_backend}.") self.engine.initializeExt(local_hostname, metadata_server, protocol, device_name, metadata_backend) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f04902ae1c76..c50c631dafcc 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -823,7 +823,7 @@ def _parse_chat_message_content_part( # content is empty, log a warning and skip if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: logger.warning( - "Skipping multimodal part (type: '%s')" + "Skipping multimodal part (type: '%s') " "with empty / unparsable content.", part_type) return None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cefb9184b202..3f3262f6e72c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1342,7 +1342,7 @@ def _add_guided_params( return params if params.guided_decoding is not None: - raise ValueError("Cannot set both guided_options_request and" + raise ValueError("Cannot set both guided_options_request and " "params.guided_decoding.") params.guided_decoding = GuidedDecodingParams( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 73061995572b..9995951b3f3d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -575,7 +575,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( "To indicate that the rerank API is not part of the standard OpenAI" - " API, we have located it at `/rerank`. Please update your client" + " API, we have located it at `/rerank`. Please update your client " "accordingly. (Note: Conforms to JinaAI rerank API)") return await do_rerank(request, raw_request) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index b866413e3a62..cf834fdca426 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -513,7 +513,7 @@ def _check_ray_adag_installation(self): if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: raise ValueError( "cupy is not installed but required since " - "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set." + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set. " "Run `pip install ray[adag]` and check cupy installation.") def _compiled_ray_dag(self, enable_asyncio: bool): diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 1734c670bf10..7104004fcfae 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -317,7 +317,7 @@ def initialize_ray_cluster( if parallel_config.world_size > device_bundles: raise ValueError( f"The number of required {device_str}s exceeds the total " - f"number of available {device_str}s in the placement group." + f"number of available {device_str}s in the placement group. " f"Required number of devices: {parallel_config.world_size}. " f"Total number of devices: {device_bundles}.") else: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index eb53513a2830..774c3876e774 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -437,7 +437,7 @@ def _add_adapter(self, lora: LoRAModel): def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" raise NotImplementedError( - "Pinning is not supported in LoRAModelManager." + "Pinning is not supported in LoRAModelManager. " "Use LRUCacheLoRAModelManager for pinning") # type: ignore def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 389359a663cc..a8de36491c5c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -71,7 +71,7 @@ def __init__( if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR and self.input_quant.strategy == QuantizationStrategy.TENSOR): raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales" + "For FP8 Fused MoE layers, only per-tensor scales " "for weights and activations are supported. Found " f"{self.weight_quant}, {self.input_quant}") diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 09291c2bf1f0..1c8d6cb1ea79 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -74,7 +74,7 @@ def __init__( def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " - f"desc_act={self.desc_act})," + f"desc_act={self.desc_act}), " f"lm_head_quantized={self.lm_head_quantized}), " f"dynamic={self.dynamic}") diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 050130de1c0f..36711a7a5098 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -56,7 +56,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": quant_method = quant_config["quant_algo"] is_checkpoint_fp8_serialized = ("FP8" in quant_method) if not is_checkpoint_fp8_serialized: - raise ValueError("ModelOpt currently only supports static FP8" + raise ValueError("ModelOpt currently only supports static FP8 " "quantization in vLLM. Please check the " "`hf_quant_config.json` file for your model's " "quant configuration.") diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py index 82954612fb2a..f6f66803f816 100644 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -25,8 +25,8 @@ def __init__( if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: raise ValueError( f"Neuron quantization datatype {self.quant_dtype} is not valid," - f"the quantization datatype should match one of the below types" - f"{SUPPORTED_QUANT_DTYPE_LIST}") + f" the quantization datatype should match one of the below " + f"types {SUPPORTED_QUANT_DTYPE_LIST}") self.dequant_dtype = dequant_dtype self.quantize_method = quantize_method diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 36b08589fd16..18393517a0bf 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -55,7 +55,7 @@ def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str, if not (weight_qscheme == "per_tensor" and input_qscheme == "per_tensor"): raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales" + "For FP8 Fused MoE layers, only per-tensor scales " "for weights and activations are supported. Found " f"{weight_qscheme}, {input_qscheme}") # noqa E501 diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 05e37251aa16..80416c1bc6eb 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -118,7 +118,7 @@ def verify_marlin_supports_shape(output_size_per_partition: int, and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " "with --quantization gptq.") diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index e23c63758556..4e8ef49235ed 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1088,7 +1088,7 @@ def _load_weights(self, model_config: ModelConfig, self.model_type = type(model).__name__ logger.info("Loading weights with BitsAndBytes quantization. " - " May take a while ...") + "May take a while ...") quant_config = getattr(model_config.hf_config, "quantization_config", None) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 4e2dda33bcab..c58b65d49348 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -562,7 +562,7 @@ def _process_image_input( # 3D tensor return list(torch.unbind(image_data, dim=0)) raise ValueError( - "We expect batched 2D tensors;" + "We expect batched 2D tensors; " "this can be either a list of 2D tensors or a single 3D tensor." ) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 4f5519f325e0..7e4cc6bac5e6 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -290,7 +290,7 @@ def _validate_shape(d: torch.Tensor): expected_expr = str(expected_dims) raise ValueError( "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " + f"per patch is {expected_expr}. " f"You supplied {tuple(d.shape)}.") for d in data: diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index a20328289f92..16223953ff83 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -90,8 +90,8 @@ def _get_instruction_len(self, prompt_token_ids: array) -> int: # Return no instruction in case of missing BOS token. if prompt_token_ids[0] != self.token_ids[""]: - logger.warning("BOS token not found in prompt," - "thus using empty string for instruction." + logger.warning("BOS token not found in prompt, " + "thus using empty string for instruction. " "GritLM requires BOS token in prompt.") return instruction_len @@ -111,8 +111,8 @@ def _get_instruction_len(self, prompt_token_ids: array) -> int: if found_embed_pattern_idx != -1: instruction_len = found_embed_pattern_idx + len(embed_pattern_ids) else: - logger.warning("Query instruction not found in prompt," - "thus using BOS token as instruction instead." + logger.warning("Query instruction not found in prompt, " + "thus using BOS token as instruction instead. " "GritLM requires query instruction in prompt.") instruction_len = 1 diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 46f794e88ad5..2699958331f3 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -673,7 +673,7 @@ def check_mm_inputs(self, inputs: Dict[str, object], for modality, count in counts.items(): if modality not in inputs or not inputs[modality]: raise ValueError(f"None input data of {modality}." - "But prompt requires.") + " But prompt requires.") counter_key = self.get_modality_num_counter(modality) if len(inputs[modality][counter_key]) != count: raise ValueError(f"The prompt requires {count} " diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 61d63e104de4..0f45f131065a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -639,7 +639,7 @@ def _process_image_input( # 3D tensor return list(torch.unbind(image_data, dim=0)) raise ValueError( - "We expect batched 2D tensors;" + "We expect batched 2D tensors; " "this can be either a list of 2D tensors or a single 3D tensor." ) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 0d0c367e677e..3d95e949e71d 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -153,8 +153,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]) if self.model is None: raise ValueError( - "Unsupported task." - "Only SemanticSegmentationTask is supported for now" + "Unsupported task. " + "Only SemanticSegmentationTask is supported for now " "by PrithviGeospatialMAE.") def _parse_and_validate_multimodal_data( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 802e40a0c952..093f8b7a8179 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -160,7 +160,7 @@ def get_dummy_data( if mm_counts.keys() != mm_max_tokens_per_item.keys(): raise AssertionError( - "The keys returned by `get_supported_mm_limits`" + "The keys returned by `get_supported_mm_limits` " f"({set(mm_counts.keys())}) should be the same as those " "returned by `get_mm_max_tokens_per_item` " f"({set(mm_max_tokens_per_item.keys())})") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 5b0731256147..bf425b89132e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -190,7 +190,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, "Cannot use FlashAttention-2 backend for FP8 KV cache.") logger.warning( "Please use FlashInfer backend with FP8 KV Cache for " - "better performance by setting environment variable " + "better performance by setting environment variable " "VLLM_ATTENTION_BACKEND=FLASHINFER") target_backend = _Backend.XFORMERS elif block_size % 16 != 0: diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 41221de0afe5..f385064875ca 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -97,7 +97,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": if not OpenVinoPlatform.is_openvino_cpu(): - logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" + logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is " "ignored for GPU, f16 data type will be used.") cache_config.cache_dtype = ov.Type.f16 else: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 04af319566af..d99d4ef3dac0 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -73,7 +73,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: logger.warning( "bfloat16 is only supported on Intel Data Center GPU, " "Intel Arc GPU is not supported yet. Your device is %s," - "which is not supported. will fallback to float16", + " which is not supported. will fallback to float16", cls.get_device_name()) model_config.dtype = torch.float16 if not model_config.enforce_eager: diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 3ba7d0896f95..795591606f25 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -226,7 +226,7 @@ def register_module(self, module_name: str, module: nn.Module): def pin_adapter(self, prompt_adapter_id: int) -> bool: """Pin a PromptAdapterModel in the manager cache.""" raise NotImplementedError( - "Pinning is not supported in PromptAdapterModelManager." + "Pinning is not supported in PromptAdapterModelManager. " "Use LRUCachePromptAdapterModelManager for pinning" ) # type: ignore diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 40ecc3481e6b..c54e6abe18d7 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -16,7 +16,7 @@ ROCmFlashAttentionMetadata as FlashAttentionMetadata) except (ModuleNotFoundError, ImportError) as err: raise RuntimeError( - "Draft model speculative decoding currently only supports" + "Draft model speculative decoding currently only supports " "CUDA and ROCm flash attention backend.") from err from vllm.logger import init_logger diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 0cab2c42e579..be0f3b7e5e52 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -212,26 +212,26 @@ def _alibi_scaling_validation(self): if (not isinstance(self.alibi_scaling, dict) or len(self.alibi_scaling) != 2): raise ValueError( - "`alibi_scaling` must be a dictionary with two fields," + "`alibi_scaling` must be a dictionary with two fields, " "`type` and `factor` or `type` and `train_seq_len`, " f"got {self.alibi_scaling}") alibi_scaling_type = self.alibi_scaling.get("type", None) alibi_scaling_factor = self.alibi_scaling.get("factor", None) alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) if alibi_scaling_type is None or alibi_scaling_type != "linear": - raise ValueError(f"`alibi_scaling`'s type field must be 'linear'," + raise ValueError(f"`alibi_scaling`'s type field must be 'linear', " f"got {alibi_scaling_type}") if (alibi_scaling_factor is not None and not isinstance(alibi_scaling_factor, float) or (alibi_scaling_factor is not None and alibi_scaling_factor <= 1.0)): raise ValueError( - f"`alibi_scaling`'s factor field must be a float > 1.0," + f"`alibi_scaling`'s factor field must be a float > 1.0, " f"got {alibi_scaling_factor}") if (alibi_dynamic_scaling is not None and not isinstance(alibi_dynamic_scaling, int) or (alibi_dynamic_scaling is not None and alibi_dynamic_scaling <= 1)): raise ValueError( - f"`alibi_scaling`'s `train_seq_len` field must be an" + f"`alibi_scaling`'s `train_seq_len` field must be an " f"integer > 1, got {alibi_dynamic_scaling}") diff --git a/vllm/utils.py b/vllm/utils.py index 675edc3620b5..29e60a9c9be2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -447,7 +447,7 @@ def get_ip() -> str: logger.warning( "The environment variable HOST_IP is deprecated and ignored, as" " it is often used by Docker and other software to" - "interact with the container's network stack. Please " + " interact with the container's network stack. Please " "use VLLM_HOST_IP instead to set the IP address for vLLM processes" " to communicate with each other.") if host_ip: @@ -2091,8 +2091,8 @@ def set_ulimit(target_soft_limit=65535): (target_soft_limit, current_hard)) except ValueError as e: logger.warning( - "Found ulimit of %s and failed to automatically increase" - "with error %s. This can cause fd limit errors like" + "Found ulimit of %s and failed to automatically increase " + "with error %s. This can cause fd limit errors like " "`OSError: [Errno 24] Too many open files`. Consider " "increasing with ulimit -n", current_soft, e) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d9a415aee528..a14a7082df4b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -277,5 +277,5 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the" + "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 0690222d91af..1ad66e6f3be7 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -545,7 +545,7 @@ def model_profile_run(): "value. This may cause low performance due to " "occupying the majority of available system " "memory. Please consider decreasing " - "gpu_memory_utilization or explicitly setting" + "gpu_memory_utilization or explicitly setting " "`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment " "variable.", memory_utilization) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ff38e3bfc207..5d548bdb59f7 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -525,7 +525,7 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the" + "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") @@ -533,7 +533,7 @@ def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, max_model_len) -> None: if is_attention_free and num_gpu_blocks != 0: raise ValueError("No memory should be allocated for the cache blocks " - f"for an attention-free model, but {num_gpu_blocks}" + f"for an attention-free model, but {num_gpu_blocks} " "blocks are allocated.") if not is_attention_free and num_gpu_blocks <= 0: raise ValueError("No available memory for the cache blocks. " From 0fe54d7d3e74b462485f31bcae8905d6b07bb168 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 24 Feb 2025 21:54:17 -0500 Subject: [PATCH 207/317] [Bugfix][Quantization] Fix FP8 + EP (#13784) Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/fused_moe/layer.py | 30 +++++++++---------- .../layers/quantization/awq_marlin.py | 2 +- .../compressed_tensors_moe.py | 2 +- .../model_executor/layers/quantization/fp8.py | 6 ++-- .../layers/quantization/gptq_marlin.py | 2 +- .../layers/quantization/quark/quark_moe.py | 2 +- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 49400b699cce..452f390f4987 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -260,7 +260,7 @@ class FusedMoE(torch.nn.Module): def __init__( self, - num_experts: int, + num_experts: int, # Global number of experts top_k: int, hidden_size: int, intermediate_size: int, @@ -291,7 +291,8 @@ def __init__( else: self.ep_size = 1 self.top_k = top_k - self.num_experts = num_experts # Global number of experts + self.global_num_experts = num_experts + self.local_num_experts = self.global_num_experts // self.ep_size assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results @@ -308,27 +309,29 @@ def __init__( if self.ep_size > 1: # Create a tensor of size num_experts filled with -1 - self.expert_map = torch.full((self.num_experts, ), + self.expert_map = torch.full((self.global_num_experts, ), -1, dtype=torch.int32) # Create a expert map for the local experts - local_num_experts = num_experts // self.ep_size ep_rank = get_tensor_model_parallel_rank() if ep_rank < (self.ep_size - 1): # Each non-last rank gets local_num_experts experts. - self.expert_map[ep_rank * local_num_experts: - (ep_rank + 1) * local_num_experts] = \ - torch.arange(0, local_num_experts, dtype=torch.int32) + self.expert_map[ep_rank * self.local_num_experts: + (ep_rank + 1) * self.local_num_experts] = \ + torch.arange(0, self.local_num_experts, dtype=torch.int32) else: # All remaining experts are assigned to the last rank. - local_num_experts = num_experts - ep_rank * local_num_experts - self.expert_map[-local_num_experts:] = \ - torch.arange(0, local_num_experts, dtype=torch.int32) + self.local_num_experts = (self.global_num_experts - + ep_rank * self.local_num_experts) + self.expert_map[-self.local_num_experts:] = \ + torch.arange(0, self.local_num_experts, dtype=torch.int32) if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") + # Note: get_quant_method will look at the layer's local_num_experts + # for heuristic purposes, so it must be initialized first. if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( UnquantizedFusedMoEMethod()) @@ -336,11 +339,8 @@ def __init__( self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None - local_num_experts = torch.sum(self.expert_map != -1) \ - if self.expert_map is not None else num_experts - moe_quant_params = { - "num_experts": local_num_experts, + "num_experts": self.local_num_experts, "hidden_size": hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, @@ -647,7 +647,7 @@ def forward(self, hidden_states: torch.Tensor, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.num_experts, + global_num_experts=self.global_num_experts, expert_map=self.expert_map, topk_group=self.topk_group, num_expert_group=self.num_expert_group, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 0e8c4c7b3ac5..7a2fb203dec3 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -136,7 +136,7 @@ def get_quant_method(self, layer: torch.nn.Module, self.full_config).get_quant_method(layer, prefix) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): - if layer.num_experts > 32: + if layer.local_num_experts > 32: # For MoEs with many experts the moe_wna16 kernel is faster return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a8de36491c5c..e7f08a91e268 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -190,7 +190,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.num_experts): + for expert_id in range(layer.local_num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9f4cd2aa7378..5e1bec0bb4be 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -573,11 +573,11 @@ def process_weights_after_loading(self, layer: Module) -> None: # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - layer.num_experts, + layer.local_num_experts, dtype=torch.float32, device=w13_weight.device), requires_grad=False) - for expert in range(layer.num_experts): + for expert in range(layer.local_num_experts): w13_weight[expert, :, :], layer.w13_weight_scale[ expert] = ops.scaled_fp8_quant( layer.w13_weight.data[expert, :, :]) @@ -644,7 +644,7 @@ def process_weights_after_loading(self, layer: Module) -> None: assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.num_experts): + for expert_id in range(layer.local_num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 241fc7d777a6..94a1de71bbca 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -153,7 +153,7 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, FusedMoE): - if layer.num_experts > 32: + if layer.local_num_experts > 32: # For MoEs with many experts the moe_wna16 kernel is faster return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 18393517a0bf..32dce5aaf5e0 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -174,7 +174,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.num_experts): + for expert_id in range(layer.local_num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( From 8c6e1ca82ac50a3429ce7036b0213771f926a675 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 25 Feb 2025 11:19:30 +0800 Subject: [PATCH 208/317] [Misc][Attention][Quantization] init property earlier (#13733) Signed-off-by: wangxiyuan --- vllm/attention/layer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index bd7783cc3981..24f2a6372b45 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -85,6 +85,11 @@ def __init__( self._k_scale_float = 1.0 self._v_scale_float = 1.0 + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window + quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: @@ -116,10 +121,6 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, **extra_impl_args) - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_kv_heads - self.sliding_window = sliding_window self.backend = backend_name_to_enum(attn_backend.get_name()) self.dtype = dtype From e48f1b8f3efb96862c62fa00480969659ffcf4be Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 25 Feb 2025 04:01:33 +0000 Subject: [PATCH 209/317] [V1][Metrics] Implement vllm:lora_requests_info metric (#13504) --- vllm/v1/engine/output_processor.py | 23 +++++++-- vllm/v1/metrics/loggers.py | 31 +++++++++++- vllm/v1/metrics/stats.py | 77 ++++++++++++++++++++++++++++-- 3 files changed, 121 insertions(+), 10 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1438f9d5a7b4..9ae8303df54d 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -11,7 +11,8 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor -from vllm.v1.metrics.stats import IterationStats, RequestStateStats +from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, + RequestStateStats) @dataclass @@ -26,6 +27,7 @@ class RequestState: def __init__( self, request_id: str, + lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], prompt_token_ids: List[int], @@ -36,6 +38,7 @@ def __init__( log_stats: bool, ): self.request_id = request_id + self.lora_name = lora_name self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids @@ -58,6 +61,8 @@ def from_new_request( ) -> "RequestState": return cls( request_id=request.request_id, + lora_name=(request.lora_request.name + if request.lora_request is not None else None), output_kind=request.sampling_params.output_kind, prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, @@ -86,6 +91,7 @@ def __init__( self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: Dict[str, RequestState] = {} + self.lora_states = LoRARequestStates() def is_request_active(self, request_id: str) -> bool: return request_id in self.request_states @@ -101,7 +107,9 @@ def abort_requests( request_ids: List[str], ) -> None: for request_id in request_ids: - self.request_states.pop(request_id, None) + req_state = self.request_states.pop(request_id, None) + if req_state is not None: + self.lora_states.abort_request(req_state) def add_request( self, @@ -112,11 +120,13 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - self.request_states[request_id] = RequestState.from_new_request( + req_state = RequestState.from_new_request( tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), request=request, queue=queue, log_stats=self.log_stats) + self.request_states[request_id] = req_state + self.lora_states.add_request(req_state) def process_outputs( self, @@ -214,6 +224,8 @@ def process_outputs( finish_reason, iteration_stats) + self.lora_states.update_iteration_stats(iteration_stats) + return OutputProcessorOutput( request_outputs=request_outputs, reqs_to_abort=reqs_to_abort, @@ -226,13 +238,15 @@ def _update_stats_from_output(self, req_state: RequestState, if iteration_stats is None: return + lora_stats = self.lora_states.get_stats(req_state) + assert engine_core_timestamp is not None assert req_state.stats is not None iteration_stats.update_from_output(engine_core_output, engine_core_timestamp, req_state.is_prefilling, req_state.prompt_len, - req_state.stats) + req_state.stats, lora_stats) def _update_stats_from_finished(self, req_state: RequestState, request_output: RequestOutput, @@ -246,6 +260,7 @@ def _update_stats_from_finished(self, req_state: RequestState, iteration_stats.update_from_finished_request(finish_reason, request_output, req_state.stats) + self.lora_states.finish_request(req_state) @staticmethod def _make_request_output( diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index e562b4145afc..2c17da0ebc83 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -2,7 +2,7 @@ import time from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Dict, List, Optional import numpy as np import prometheus_client @@ -233,6 +233,22 @@ def __init__(self, vllm_config: VllmConfig): buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) + self.gauge_lora_info: Optional[prometheus_client.Gauge] = None + if vllm_config.lora_config is not None: + self.labelname_max_lora = "max_lora" + self.labelname_waiting_lora_adapters = "waiting_lora_adapters" + self.labelname_running_lora_adapters = "running_lora_adapters" + self.max_lora = vllm_config.lora_config.max_loras + self.gauge_lora_info = \ + prometheus_client.Gauge( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + labelnames=[ + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + self.labelname_running_lora_adapters, + ]) + self.log_metrics_info("cache_config", vllm_config.cache_config) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): @@ -295,6 +311,19 @@ def log(self, scheduler_stats: SchedulerStats, for prefill_time in iteration_stats.prefill_times_iter: self.histogram_prefill_time_request.observe(prefill_time) + if self.gauge_lora_info is not None: + running_lora_adapters = \ + ",".join(iteration_stats.running_lora_adapters.keys()) + waiting_lora_adapters = \ + ",".join(iteration_stats.waiting_lora_adapters.keys()) + lora_info_labels = { + self.labelname_running_lora_adapters: running_lora_adapters, + self.labelname_waiting_lora_adapters: waiting_lora_adapters, + self.labelname_max_lora: self.max_lora, + } + self.gauge_lora_info.labels(**lora_info_labels)\ + .set_to_current_time() + @staticmethod def _unregister_vllm_metrics(): # Unregister any existing vLLM collectors (for CI/CD diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index a0e6204929eb..74d4a1bc4fba 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -2,11 +2,12 @@ import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Dict, List, Optional, Set if TYPE_CHECKING: from vllm.outputs import RequestOutput from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason + from vllm.v1.output_processor import RequestState @dataclass @@ -36,6 +37,12 @@ class SchedulerStats: default_factory=PrefixCacheStats) +@dataclass +class LoRAStats: + waiting_requests: Set[str] = field(default_factory=set) + running_requests: Set[str] = field(default_factory=set) + + @dataclass class RequestStateStats: """Stats that need to be tracked across delta updates.""" @@ -76,6 +83,8 @@ def __init__(self): self.time_per_output_tokens_iter: List[float] = [] self.queue_times_iter: List[float] = [] self.prefill_times_iter: List[float] = [] + self.waiting_lora_adapters: Dict[str, int] = {} + self.running_lora_adapters: Dict[str, int] = {} def _time_since(self, start: float) -> float: """Calculate an interval relative to this iteration's timestamp.""" @@ -83,7 +92,8 @@ def _time_since(self, start: float) -> float: def update_from_output(self, output: "EngineCoreOutput", engine_core_timestamp: float, is_prefilling: bool, - prompt_len: int, req_stats: RequestStateStats): + prompt_len: int, req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats]): num_new_generation_tokens = len(output.new_token_ids) self.num_generation_tokens += num_new_generation_tokens @@ -105,7 +115,8 @@ def update_from_output(self, output: "EngineCoreOutput", # Process request-level engine core events if output.events is not None: - self.update_from_events(output.events, is_prefilling, req_stats) + self.update_from_events(output.request_id, output.events, + is_prefilling, req_stats, lora_stats) # Process the batch-level "new tokens" engine core event if is_prefilling: @@ -123,17 +134,21 @@ def update_from_output(self, output: "EngineCoreOutput", if num_new_generation_tokens > 0: req_stats.last_token_ts = engine_core_timestamp - def update_from_events(self, events: List["EngineCoreEvent"], - is_prefilling: bool, req_stats: RequestStateStats): + def update_from_events(self, req_id: str, events: List["EngineCoreEvent"], + is_prefilling: bool, req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats]): # Avoid circular dependency from vllm.v1.engine import EngineCoreEventType for event in events: if event.type == EngineCoreEventType.QUEUED: req_stats.queued_ts = event.timestamp + if lora_stats is not None: + lora_stats.waiting_requests.add(req_id) elif event.type == EngineCoreEventType.SCHEDULED: queued_interval = event.timestamp - req_stats.queued_ts self.queue_times_iter.append(queued_interval) req_stats.scheduled_ts = event.timestamp + LoRARequestStates.scheduled_request(lora_stats, req_id) def update_from_finished_request(self, finish_reason: "FinishReason", request_output: "RequestOutput", @@ -151,3 +166,55 @@ def update_from_finished_request(self, finish_reason: "FinishReason", inference_time=inference_time, decode_time=decode_time) self.finished_requests.append(finished_req) + + +class LoRARequestStates: + """Per-LoRA request state stats.""" + + def __init__(self): + self.lora_name_to_stats: Dict[str, LoRAStats] = {} + + def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]: + if req_state.lora_name is None: + return None + if req_state.lora_name not in self.lora_name_to_stats: + self.lora_name_to_stats[req_state.lora_name] = LoRAStats() + return self.lora_name_to_stats[req_state.lora_name] + + def add_request(self, req_state: 'RequestState'): + if (lora_stats := self.get_stats(req_state)) is not None: + lora_stats.waiting_requests.add(req_state.request_id) + + def finish_request(self, req_state: 'RequestState'): + if req_state.lora_name is None: + return + lora_stats = self.lora_name_to_stats[req_state.lora_name] + lora_stats.running_requests.remove(req_state.request_id) + + def abort_request(self, req_state: 'RequestState'): + if req_state.lora_name is None: + return + lora_stats = self.lora_name_to_stats[req_state.lora_name] + lora_stats.waiting_requests.discard(req_state.request_id) + lora_stats.running_requests.discard(req_state.request_id) + + # Break the pattern for this lifecycle methods so we can + # call this from IterationStats.update_from_events() + @staticmethod + def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str): + if lora_stats is None: + return + lora_stats.waiting_requests.remove(request_id) + lora_stats.running_requests.add(request_id) + + def update_iteration_stats(self, + iteration_stats: Optional[IterationStats]): + if iteration_stats is None: + return + for lora_name, stats in self.lora_name_to_stats.items(): + if stats.waiting_requests: + iteration_stats.waiting_lora_adapters[lora_name] = \ + len(stats.waiting_requests) + if stats.running_requests: + iteration_stats.running_lora_adapters[lora_name] = \ + len(stats.running_requests) From 6ff6dd0b4f0394621f6cab6b0716f5b85ac20e62 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 23:33:59 -0500 Subject: [PATCH 210/317] [Bugfix] Fix deepseek-v2 error: "missing 1 required positional argument: 'residual'" (#13802) --- vllm/model_executor/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 22b2bf7ca469..79484cee167d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -614,7 +614,7 @@ def forward( residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states, residual = layer(positions, hidden_states) + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ From 3660d88dc0d1322080d3e3397781a576806d0c15 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 25 Feb 2025 01:10:31 -0500 Subject: [PATCH 211/317] [Bugfix] Support MLA for CompressedTensorsWNA16 (#13725) Signed-off-by: mgoin --- vllm/attention/backends/mla/common.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index f47ea3684e03..4dd562be3838 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1130,13 +1130,13 @@ def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ ) def get_layer_weight(layer): - if hasattr(layer, "weight"): - return layer.weight - elif hasattr(layer, "qweight"): - return layer.qweight - else: - raise AttributeError( - f"Layer '{layer}' has neither weight nor qweight") + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): From 9127c8e899740a714829ca14e18ac5a808891d12 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 25 Feb 2025 03:17:14 -0500 Subject: [PATCH 212/317] Fix CompressedTensorsWNA16MoE with grouped scales (#13769) --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index e7f08a91e268..f1f316f08339 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -527,7 +527,8 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, replace_tensor("w13_weight_scale", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( layer.w2_weight_scale, - layer.w2_weight_scale.shape[1] * self.packed_factor, + layer.w2_weight_scale.shape[1] * + (self.group_size if self.group_size != -1 else self.packed_factor), size_k2, self.group_size, self.num_bits, From 36646a8708742c3b68235382770e392236bece25 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 25 Feb 2025 13:48:02 +0530 Subject: [PATCH 213/317] [Core] LoRA V1 - Add add/pin/list/remove_lora functions (#13705) --- tests/lora/test_add_lora.py | 13 +- tests/lora/test_lora_functions.py | 137 ++++++++++++++++++++++ vllm/v1/engine/async_llm.py | 18 ++- vllm/v1/engine/core.py | 15 ++- vllm/v1/engine/core_client.py | 63 ++++++++-- vllm/v1/engine/llm_engine.py | 18 ++- vllm/v1/worker/gpu_worker.py | 11 +- vllm/v1/worker/lora_model_runner_mixin.py | 17 ++- 8 files changed, 270 insertions(+), 22 deletions(-) create mode 100644 tests/lora/test_lora_functions.py diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 2b421bfd9eb8..70b058b201d6 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -7,6 +7,7 @@ import pytest from huggingface_hub import snapshot_download +import vllm.envs as env from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import TextPrompt from vllm.lora.request import LoRARequest @@ -144,10 +145,14 @@ async def test_add_lora(): await requests_processing_time(llm, dummy_run_requests) # Run with warmup - for lr in warmup_run_requests: - await llm.add_lora(lr) - # Wait for the add_lora function to complete on the server side. - await asyncio.sleep(30) + add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests] + add_lora_results = await asyncio.gather(*add_lora_tasks) + if env.VLLM_USE_V1: + # Test that all all_lora calls are successful. + assert all(add_lora_results) + else: + # No way to check V0 engine results as the calls just return None. + pass time_with_add_lora = await requests_processing_time( llm, warmup_run_requests) diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py new file mode 100644 index 000000000000..1309848868b4 --- /dev/null +++ b/tests/lora/test_lora_functions.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Script to test add_lora, remove_lora, pin_lora, list_loras functions. +""" + +import os +from typing import List + +import pytest + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.llm import LLM +from vllm.lora.request import LoRARequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" +LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test" +LORA_RANK = 8 + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +def make_lora_request(lora_id: int): + return LoRARequest(lora_name=f"{lora_id}", + lora_int_id=lora_id, + lora_path=LORA_MODULE_PATH) + + +def test_lora_functions_sync(): + + max_loras = 4 + # Create engine in eager-mode. Due to high max_loras, the CI can + # OOM during cuda-graph capture. + engine_args = EngineArgs(model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True) + + llm = LLM.get_engine_class().from_engine_args(engine_args) + + def run_check(fn, args, expected: List): + fn(args) + assert set(llm.list_loras()) == set(expected) + + run_check(llm.add_lora, make_lora_request(1), [1]) + run_check(llm.add_lora, make_lora_request(2), [1, 2]) + + # Pin LoRA 1 and test that it is never removed on subsequent adds. + run_check(llm.pin_lora, 1, [1, 2]) + run_check(llm.add_lora, make_lora_request(3), [1, 2, 3]) + run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4]) + run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4]) + run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4]) + run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7]) + run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7]) + run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7]) + run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10]) + + # Remove LoRA 1 and continue adding. + run_check(llm.remove_lora, 1, [8, 9, 10]) + run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11]) + run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11]) + run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11]) + + # Remove all LoRAs + run_check(llm.remove_lora, 13, [12, 10, 11]) + run_check(llm.remove_lora, 12, [10, 11]) + run_check(llm.remove_lora, 11, [10]) + run_check(llm.remove_lora, 10, []) + + +@pytest.mark.asyncio +async def test_lora_functions_async(): + + if os.getenv("VLLM_USE_V1") == "0": + pytest.skip( + reason= + "V0 AsyncLLMEngine does not expose remove/list/pin LoRA functions") + + # The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1` + # environment variable. reload vllm.enging.async_llm_engine as + # vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the + # env var. + import importlib + + import vllm.engine.async_llm_engine + importlib.reload(vllm.engine.async_llm_engine) + from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) + + max_loras = 4 + engine_args = AsyncEngineArgs(model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True) + + async def run_check(fn, args, expected: List): + await fn(args) + assert set(await llm.list_loras()) == set(expected) + + async with build_async_engine_client_from_engine_args(engine_args) as llm: + await run_check(llm.add_lora, make_lora_request(1), [1]) + await run_check(llm.add_lora, make_lora_request(2), [1, 2]) + + # Pin LoRA 1 and test that it is never removed on subsequent adds. + await run_check(llm.pin_lora, 1, [1, 2]) + await run_check(llm.add_lora, make_lora_request(3), [1, 2, 3]) + await run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4]) + await run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4]) + await run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4]) + await run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7]) + await run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7]) + await run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7]) + await run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10]) + + # Remove LoRA 1 and continue adding. + await run_check(llm.remove_lora, 1, [8, 9, 10]) + await run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11]) + await run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11]) + await run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11]) + + # Remove all LoRAs + await run_check(llm.remove_lora, 13, [12, 10, 11]) + await run_check(llm.remove_lora, 12, [10, 11]) + await run_check(llm.remove_lora, 11, [10]) + await run_check(llm.remove_lora, 10, []) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 36a02628f405..0c04e14cec2f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -2,7 +2,7 @@ import asyncio import os -from typing import AsyncGenerator, List, Mapping, Optional, Type, Union +from typing import AsyncGenerator, List, Mapping, Optional, Set, Type, Union import numpy as np @@ -392,9 +392,21 @@ async def sleep(self, level: int = 1) -> None: async def wake_up(self) -> None: await self.engine_core.wake_up_async() - async def add_lora(self, lora_request: LoRARequest) -> None: + async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" - await self.engine_core.add_lora_async(lora_request) + return await self.engine_core.add_lora_async(lora_request) + + async def remove_lora(self, lora_id: int) -> bool: + """Remove an already loaded LoRA adapter.""" + return await self.engine_core.remove_lora_async(lora_id) + + async def list_loras(self) -> Set[int]: + """List all registered adapters.""" + return await self.engine_core.list_loras_async() + + async def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + return await self.engine_core.pin_lora_async(lora_id) @property def is_running(self) -> bool: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 85c97293af8b..041896f1c7cc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,7 +7,7 @@ from concurrent.futures import Future from inspect import isclass, signature from multiprocessing.connection import Connection -from typing import Any, List, Optional, Tuple, Type +from typing import Any, List, Optional, Set, Tuple, Type import msgspec import psutil @@ -222,8 +222,17 @@ def wake_up(self): def execute_dummy_batch(self): self.model_executor.collective_rpc("execute_dummy_batch") - def add_lora(self, lora_request: LoRARequest) -> None: - self.model_executor.add_lora(lora_request) + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_executor.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_executor.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_executor.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) class EngineCoreProc(EngineCore): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 5ffaf63e6cec..9f36e11d12d7 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -10,7 +10,7 @@ from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Set, Type, Union import zmq import zmq.asyncio @@ -97,7 +97,16 @@ async def execute_dummy_batch_async(self) -> None: def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError - def add_lora(self, lora_request: LoRARequest) -> None: + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def list_loras(self) -> Set[int]: + raise NotImplementedError + + def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError async def get_output_async(self) -> EngineCoreOutputs: @@ -121,7 +130,16 @@ async def wake_up_async(self) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None: raise NotImplementedError - async def add_lora_async(self, lora_request: LoRARequest) -> None: + async def add_lora_async(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + async def remove_lora_async(self, lora_id: int) -> bool: + raise NotImplementedError + + async def list_loras_async(self) -> Set[int]: + raise NotImplementedError + + async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError @@ -166,8 +184,17 @@ def wake_up(self) -> None: def execute_dummy_batch(self) -> None: self.engine_core.execute_dummy_batch() - def add_lora(self, lora_request: LoRARequest) -> None: - self.engine_core.add_lora(lora_request) + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.engine_core.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.engine_core.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.engine_core.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.engine_core.pin_lora(lora_id) @dataclass @@ -356,8 +383,17 @@ def profile(self, is_start: bool = True) -> None: def reset_prefix_cache(self) -> None: self._call_utility("reset_prefix_cache") - def add_lora(self, lora_request: LoRARequest) -> None: - self._call_utility("add_lora", lora_request) + def add_lora(self, lora_request: LoRARequest) -> bool: + return self._call_utility("add_lora", lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self._call_utility("remove_lora", lora_id) + + def list_loras(self) -> Set[int]: + return self._call_utility("list_loras") + + def pin_lora(self, lora_id: int) -> bool: + return self._call_utility("pin_lora", lora_id) def sleep(self, level: int = 1) -> None: self._call_utility("sleep", level) @@ -454,5 +490,14 @@ async def wake_up_async(self) -> None: async def execute_dummy_batch_async(self) -> None: await self._call_utility_async("execute_dummy_batch") - async def add_lora_async(self, lora_request: LoRARequest) -> None: - await self._call_utility_async("add_lora", lora_request) + async def add_lora_async(self, lora_request: LoRARequest) -> bool: + return await self._call_utility_async("add_lora", lora_request) + + async def remove_lora_async(self, lora_id: int) -> bool: + return await self._call_utility_async("remove_lora", lora_id) + + async def list_loras_async(self) -> Set[int]: + return await self._call_utility_async("list_loras") + + async def pin_lora_async(self, lora_id: int) -> bool: + return await self._call_utility_async("pin_lora", lora_id) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 64fd8719c82e..ccf52250c1d6 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Mapping, Optional, Type, Union +from typing import Dict, List, Mapping, Optional, Set, Type, Union from typing_extensions import TypeVar @@ -254,3 +254,19 @@ def get_tokenizer_group( f"found type: {type(tokenizer_group)}") return tokenizer_group + + def add_lora(self, lora_request: LoRARequest) -> bool: + """Load a new LoRA adapter into the engine for future requests.""" + return self.engine_core.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + """Remove an already loaded LoRA adapter.""" + return self.engine_core.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + """List all registered adapters.""" + return self.engine_core.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + return self.engine_core.pin_lora(lora_id) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a14a7082df4b..f681925f557e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Set import torch import torch.distributed @@ -240,6 +240,15 @@ def execute_dummy_batch(self) -> None: def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + def check_health(self) -> None: # worker will always be healthy as long as it's running. return diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 053897da0aa7..731e758e6e74 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -131,4 +131,19 @@ def maybe_profile_with_lora(self, lora_config: LoRAConfig, def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) \ No newline at end of file + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() \ No newline at end of file From 4d5b2e3bbdb324a8beb33d6e7bac4711f2ce2fbe Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 25 Feb 2025 16:18:19 +0800 Subject: [PATCH 214/317] [Misc] Check that the model can be inspected upon registration (#13743) --- vllm/model_executor/models/registry.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 81623defd337..bae6444267fa 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -347,6 +347,10 @@ def register_model( when importing the model and thus the related error :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. """ + if not isinstance(model_arch, str): + msg = f"`model_arch` should be a string, not a {type(model_arch)}" + raise TypeError(msg) + if model_arch in self.models: logger.warning( "Model architecture %s is already registered, and will be " @@ -360,8 +364,18 @@ def register_model( raise ValueError(msg) model = _LazyRegisteredModel(*split_str) - else: + + try: + model.inspect_model_cls() + except Exception as exc: + msg = f"Unable to inspect model {model_cls}" + raise RuntimeError(msg) from exc + elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module): model = _RegisteredModel.from_model_cls(model_cls) + else: + msg = ("`model_cls` should be a string or PyTorch model class, " + f"not a {type(model_arch)}") + raise TypeError(msg) self.models[model_arch] = model From 53fd4859eba297c7679dc222e09700ed64919100 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 03:21:25 -0500 Subject: [PATCH 215/317] [Core] xgrammar: Expand list of unsupported jsonschema keywords (#13783) Signed-off-by: Russell Bryant --- vllm/model_executor/guided_decoding/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py index c3c0378ea952..10981776e768 100644 --- a/vllm/model_executor/guided_decoding/utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -33,6 +33,18 @@ def check_object(obj: dict) -> bool: ]): return True + # Unsupported keywords for strings + if obj.get("type") == "string" and any( + key in obj for key in ["minLength", "maxLength", "format"]): + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any(key in obj for key in [ + "minProperties", "maxProperties", "propertyNames", + "patternProperties" + ]): + return True + # Recursively check all nested objects and arrays for value in obj.values(): if isinstance(value, dict): From a6a99fe606e35981ba4283a1abe5553badcf51f7 Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Tue, 25 Feb 2025 16:36:07 +0800 Subject: [PATCH 216/317] [Bugfix] Modify modelscope api usage in transformer_utils (#13807) --- vllm/transformers_utils/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index d0b5d7f01a99..87e446f89438 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -29,9 +29,8 @@ def modelscope_list_repo_files( ) -> List[str]: """List files in a modelscope repo.""" from modelscope.hub.api import HubApi - from modelscope.utils.hf_util import _try_login - _try_login(token) api = HubApi() + api.login(token) # same as huggingface_hub.list_repo_files files = [ file['Path'] for file in api.get_model_files( From 77c117b3d0f6d369d5900324de14e849bbed1db9 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Tue, 25 Feb 2025 00:37:08 -0800 Subject: [PATCH 217/317] [misc] Clean up ray compiled graph type hints (#13731) --- vllm/executor/ray_distributed_executor.py | 16 ++++++++++++---- vllm/executor/ray_utils.py | 7 +++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index cf834fdca426..673d0fc5d23e 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -528,10 +528,18 @@ def _compiled_ray_dag(self, enable_asyncio: bool): envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) with InputNode() as input_data: # Example DAG: PP=2, TP=4 - # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 - # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501 - # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501 - # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501 + # + # For V0: + # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501 + # + # For V1: + # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput # noqa: E501 # All workers in the first TP group will take in the # ExecuteModelRequest as input. diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7104004fcfae..a9661fe0ef16 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -114,8 +114,11 @@ def setup_device_if_necessary(self): def execute_model_ray( self, - scheduler_output: "SchedulerOutput", - ) -> "ModelRunnerOutput": + scheduler_output: Union["SchedulerOutput", + Tuple["SchedulerOutput", + "IntermediateTensors"]], + ) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput", + "IntermediateTensors"]]: # this method is used to compile ray CG, # and it needs a special logic of self.setup_device_if_necessary() self.setup_device_if_necessary() From a5f567484a1814314993c5e78f5fe3132e71b485 Mon Sep 17 00:00:00 2001 From: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com> Date: Tue, 25 Feb 2025 02:38:42 -0600 Subject: [PATCH 218/317] [Feature] Support KV cache offloading and disagg prefill with LMCache connector. (#12953) --- .../offline_inference/cpu_offload_lmcache.py | 65 +++++++++ .../disaggregated_prefill_lmcache.py | 130 ++++++++++++++++++ .../kv_transfer/kv_connector/factory.py | 5 + .../kv_connector/lmcache_connector.py | 108 +++++++++++++++ vllm/distributed/parallel_state.py | 4 +- 5 files changed, 310 insertions(+), 2 deletions(-) create mode 100644 examples/offline_inference/cpu_offload_lmcache.py create mode 100644 examples/offline_inference/disaggregated_prefill_lmcache.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py diff --git a/examples/offline_inference/cpu_offload_lmcache.py b/examples/offline_inference/cpu_offload_lmcache.py new file mode 100644 index 000000000000..8211629b24ec --- /dev/null +++ b/examples/offline_inference/cpu_offload_lmcache.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of cpu offloading +with LMCache. + +Note that `pip install lmcache` is needed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os +import time + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# Use experimental features in LMCache +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Enable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "True" +# Set local CPU memory limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" + +# This example script runs two requests with a shared prefix. +shared_prompt = "Hello, how are you?" * 1000 +first_prompt = [ + shared_prompt + "Hello, my name is", +] +second_prompt = [ + shared_prompt + "Tell me a very long story", +] + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + +ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}') +# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB +# memory. Reduce the value if your GPU has less memory. +# Note that LMCache is not compatible with chunked prefill for now. +llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + enable_chunked_prefill=False, + gpu_memory_utilization=0.8) + +outputs = llm.generate(first_prompt, sampling_params) +for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") +print("First request done.") + +time.sleep(1) + +outputs = llm.generate(second_prompt, sampling_params) +for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") +print("Second request done.") + +# Clean up lmcache backend +LMCacheEngineBuilder.destroy(ENGINE_NAME) diff --git a/examples/offline_inference/disaggregated_prefill_lmcache.py b/examples/offline_inference/disaggregated_prefill_lmcache.py new file mode 100644 index 000000000000..36d343c6812e --- /dev/null +++ b/examples/offline_inference/disaggregated_prefill_lmcache.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of disaggregated prefilling +with LMCache. +We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), +and launch an additional LMCache server. +KV cache is transferred in the following manner: +VLLM prefill node -> LMCache server -> VLLM decode node. + +Note that `pip install lmcache` is needed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os +import subprocess +import time +from multiprocessing import Event, Process + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# The port to start LMCache server +port = 8100 +# Use experimental features in LMCache +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Disable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "False" +# Set local CPU memory buffer limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" +# Set the remote URL for LMCache server +os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" +# Set the serializer/deserializer between vllm and LMCache server +# `naive` indicates using raw bytes of the tensor without any compression +os.environ["LMCACHE_REMOTE_SERDE"] = "naive" + + +def run_prefill(prefill_done, prompts): + # We use GPU 0 for prefill node. + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' + ) + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + #llm.generate(prompts, sampling_params) + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print("Prefill node is finished.") + prefill_done.set() + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_decode(prefill_done, prompts, timeout=1): + # We use GPU 1 for decode node. + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' + ) + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # of memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + print("Waiting for prefill node to finish...") + prefill_done.wait() + time.sleep(timeout) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_lmcache_server(port): + server_proc = subprocess.Popen([ + "python", "-m", "lmcache.experimental.server", "localhost", + str(port) + ]) + return server_proc + + +if __name__ == "__main__": + + prompts = [ + "Hello, how are you?" * 1000, + ] + + prefill_done = Event() + prefill_process = Process(target=run_prefill, args=(prefill_done, prompts)) + decode_process = Process(target=run_decode, args=(prefill_done, prompts)) + lmcache_server_process = run_lmcache_server(port) + + # Start prefill node + prefill_process.start() + + # Start decode node + decode_process.start() + + # Clean up the processes + decode_process.join() + prefill_process.terminate() + lmcache_server_process.terminate() + lmcache_server_process.wait() diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index fe480533458b..7336c54ec8a3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -48,3 +48,8 @@ def create_connector(cls, rank: int, local_rank: int, "MooncakeConnector", "vllm.distributed.kv_transfer.kv_connector.simple_connector", "SimpleConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnector", + "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", + "LMCacheConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py new file mode 100644 index 000000000000..bf9117133af5 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +LMCache KV Cache Connector for Distributed Machine Learning Inference + +The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker +(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; +(2) offload and share KV caches. +""" + +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class LMCacheConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.transfer_config = config.kv_transfer_config + self.vllm_config = config + + from lmcache.experimental.cache_engine import LMCacheEngineBuilder + from lmcache.integration.vllm.utils import ENGINE_NAME + from lmcache.integration.vllm.vllm_adapter import ( + RetrieveStatus, StoreStatus, init_lmcache_engine, + lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv) + logger.info("Initializing LMCacheConfig under kv_transfer_config %s", + self.transfer_config) + + # TODO (Jiayi): Find model_config, parallel_config, and cache_config + self.engine = init_lmcache_engine(config.model_config, + config.parallel_config, + config.cache_config) + self.lmcache_engine_name = ENGINE_NAME + self.lmcache_engine_builder = LMCacheEngineBuilder + + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + self.lmcache_retrieve_kv = lmcache_retrieve_kv + self.lmcache_store_kv = lmcache_store_kv + self.lmcache_should_store = lmcache_should_store + self.store_status = StoreStatus + self.retrieve_status = RetrieveStatus + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + hidden_or_intermediate_states = None + + # TODO (Jiayi): Need to support chunked prefill + retrieve_status = self.retrieve_status.PREFILL + + model_input, bypass_model_exec = self.lmcache_retrieve_kv( + model_executable, model_input, self.cache_config, kv_caches, + retrieve_status) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + num_reqs = 0 + seq_group_list = model_input.sampling_metadata.seq_groups + assert seq_group_list is not None + for seq_group in seq_group_list: + seq_ids = seq_group.seq_ids + for seq_id in seq_ids: + num_reqs += 1 + + # TODO (Jiayi): Only normal prefill is supported for now + store_status = self.lmcache_should_store(model_input) + self.lmcache_store_kv( + self.model_config, + self.parallel_config, + self.cache_config, + model_executable, + model_input, + kv_caches, + store_status, + ) + + def close(self): + self.lmcache_engine_builder.destroy(self.lmcache_engine_name) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 83484cd73550..86166dd5bb83 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -962,8 +962,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: return if all([ - vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER - is None + vllm_config.kv_transfer_config.is_kv_transfer_instance, + _KV_TRANSFER is None ]): _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, From d8c31f33e226e981d35c86ae7375b0ccf185cc6c Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Tue, 25 Feb 2025 03:39:59 -0500 Subject: [PATCH 219/317] [ROCm][Quantization][Kernel] Using HIP FP8 header (#12593) --- CMakeLists.txt | 19 + csrc/quantization/fp8/amd/hip_float8.h | 137 ------- csrc/quantization/fp8/amd/hip_float8_impl.h | 315 ---------------- csrc/quantization/fp8/amd/quant_utils.cuh | 398 +++++++++++--------- csrc/quantization/fp8/common.cuh | 8 +- tests/kernels/test_cache.py | 24 +- 6 files changed, 267 insertions(+), 634 deletions(-) delete mode 100644 csrc/quantization/fp8/amd/hip_float8.h delete mode 100644 csrc/quantization/fp8/amd/hip_float8_impl.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b569ec25f12..82ad7b8819d5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -174,6 +174,25 @@ include(FetchContent) file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") +# +# Set rocm version dev int. +# +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # Overriding the default -O set up by cmake, adding ggdb3 for the most verbose devug info + # + set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3") + + + # + # Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates + # a lot of warnings that always mask real issues. Suppressing until this is properly addressed. + # + set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result") +endif() + # # Define other extension targets # diff --git a/csrc/quantization/fp8/amd/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h deleted file mode 100644 index f9c80fcdec57..000000000000 --- a/csrc/quantization/fp8/amd/hip_float8.h +++ /dev/null @@ -1,137 +0,0 @@ -#pragma once - -#ifdef __HIPCC__ - #include -#else - #include - #include - #include - #include -#endif - -#include "hip_float8_impl.h" - -struct alignas(1) hip_fp8 { - struct from_bits_t {}; - HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { - return from_bits_t(); - } - uint8_t data; - - hip_fp8() = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; - explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) - : data(v) {} - -#ifdef __HIP__MI300__ - // NOTE: ON-DEVICE... always optimal bias - explicit HIP_FP8_DEVICE hip_fp8(float v) - : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} - - explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) - : hip_fp8(static_cast(v)) {} - - // Host only implementation using s/w simulation - explicit HIP_FP8_HOST -#else // __HIP__MI300__ - // both Host and DEVICE for non-MI300 using s/w simulation - explicit HIP_FP8_HOST_DEVICE -#endif // __HIP__MI300__ - hip_fp8(float v) { - data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, - true /*clip*/>(v); - } - - explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) - : hip_fp8(static_cast(v)) {} - -#ifdef __HIP__MI300__ - // upcast using device specific intrinsic - explicit inline HIP_FP8_DEVICE operator float() const { - float fval; - uint32_t i32val = static_cast(data); - - // upcast - asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" - : "=v"(fval) - : "v"(i32val)); - - return fval; - } - - explicit inline HIP_FP8_HOST operator float() const -#else // __HIP__MI300__ - explicit inline HIP_FP8_HOST_DEVICE operator float() const -#endif // __HIP__MI300__ - { - return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( - data); - } -}; - -namespace std { -inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } -inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } -HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } -} // namespace std - -// Special operator overloading -inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { - return os << float(f8); -} - -// all + operator overloading with mixed types -// mixed types, always converts to f32, does computation in f32, and returns -// float -inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { - return (fa + float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { - return (float(a) + fb); -} - -inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { - return hip_fp8(float(a) + float(b)); -} - -inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { - return a = hip_fp8(float(a) + float(b)); -} - -// overloading multiplication, always returns float, -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { - return float(a) * float(b); -} - -inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { - return (a * float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { - return (float(a) * b); -} - -inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { - return ((float)a * float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { - return ((float)a * float(b)); -} - -// overloading for compare -inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { - return (a.data == b.data); -} -inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { - return (a.data != b.data); -} - -inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { - return static_cast(a) >= static_cast(b); -} -inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { - return static_cast(a) > static_cast(b); -} diff --git a/csrc/quantization/fp8/amd/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h deleted file mode 100644 index 8b9cd26f2f76..000000000000 --- a/csrc/quantization/fp8/amd/hip_float8_impl.h +++ /dev/null @@ -1,315 +0,0 @@ -#pragma once - -#if defined(__HIPCC__) && defined(__gfx942__) - #define __HIP__MI300__ -#endif - -#ifdef __HIPCC__ - #define HIP_FP8_HOST_DEVICE __host__ __device__ - #define HIP_FP8_HOST __host__ - #define HIP_FP8_DEVICE __device__ -#else - #define HIP_FP8_HOST_DEVICE - #define HIP_FP8_HOST - #define HIP_FP8_DEVICE -#endif - -namespace hip_fp8_impl { - -#ifdef __HIP__MI300__ -HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { - uint8_t i8data; - union { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // NOTE: not endian independent - } val; - - uint32_t ival = 0; - val.fval = v; - - if ((val.i32val & 0x7F800000) != - 0x7F800000) { /// propagate NAN/INF, no clipping - val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); - } - - ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, - false); // false -> WORD0 - val.i32val = ival; - i8data = val.i8val[0]; - - return i8data; -} -#endif // __HIP__MI300__ - -HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } -#if defined(__HIPCC__) || defined(__CUDA_ARCH__) -HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } -#endif - -template -HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, - uint32_t rng = 0) { -#ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; -#else - constexpr bool is_half = false; -#endif - constexpr bool is_float = std::is_same::value; - static_assert(wm + we == 7, "wm+we==7"); - static_assert(is_half || is_float, "Only half and float can be cast to f8"); - - const int mfmt = (sizeof(T) == 4) ? 23 : 10; - uint32_t x; - if (sizeof(T) == 4) { - x = reinterpret_cast(_x); - } else { - x = reinterpret_cast(_x); - } - - uint32_t head, mantissa; - int exponent, bias; - uint32_t sign; - - if (sizeof(T) == 4) { - head = x & 0xFF800000; - mantissa = x & 0x7FFFFF; - exponent = (head >> 23) & 0xFF; - sign = head >> 31; - bias = 127; - } else { - head = x & 0xFC00; - mantissa = x & 0x3FF; - exponent = (head >> 10) & 0x1F; - sign = head >> 15; - bias = 15; - } - - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); - - // Deal with inf and NaNs - if (negative_zero_nan) { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return 0x80; - } - } else { - // if(__hisinf(x) || __hisnan(x)) - if ((x & 0x7C00) == 0x7C00) { - return 0x80; - } - } - } else { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } else { - if ((x & 0x7C00) == 0x7C00) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } - } - if (x == 0) { - return 0; - } - - // First need to check if it is normal or denorm as there is a difference of - // implicit 1 Then need to adjust the exponent to align with the F8 exponent, - // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng - // to mantissa and truncate. And for RNE, no need to add rng. Then probably - // need to check whether there is carry and adjust exponent and mantissa again - - // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent - // bits - const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); - const int f8_denormal_act_exponent = - 1 - f8_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // f8_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted - int act_exponent, f8_exponent, exponent_diff; - - if (exponent == 0) { // fp32/fp16 is in denormal. - /* fp32 denormal is below 2^-127 so it is usually not a concern here, we -mostly concern fp16 here. In this case, f8 is usually in denormal. But there -could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has -exponent bias 16. It means that there are some numbers in fp16 denormal but they -are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers -where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 -(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; - exponent_diff = - f8_denormal_act_exponent - - act_exponent; // actual exponent is exponent-bias+1 as it is denormal - } else { // fp32/fp16 is normal with implicit 1 - act_exponent = exponent - bias; - if (act_exponent <= f8_denormal_act_exponent) { - /* This is the case where fp32/fp16 is normal but it is in f8 denormal -range. For example fp8 nanoo mode, denormal exponent is -7, but if the -fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, -Therefore it needs to be adjust to -6 and mantissa shift right by 1. -So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ - exponent_diff = f8_denormal_act_exponent - act_exponent; - } else { // both fp32/fp16 and f8 are in normal range - exponent_diff = 0; // exponent_diff=0 does not mean there is no - // difference for this case, act_exponent could be - // larger. Just that it does not need shift mantissa - } - mantissa += (1 << mfmt); // Add the implicit 1 into mantissa - } - - bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == - static_cast(1 << (mfmt - wm + exponent_diff - 1)); - /* This part is a bit tricky. The judgment of whether it is a tie needs to be - done before we shift right as shift right could rip off some residual part - and make something not midpoint look like midpoint. For example, the fp16 - number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after - shift right by 4 bits, it would look like midpoint. -*/ - - if (exponent_diff > 0) { - mantissa >>= exponent_diff; - } else if (exponent_diff == -1) { - mantissa <<= -exponent_diff; - } - bool implicit_one = mantissa & (1 << mfmt); - // if there is no implicit 1, it means the f8 is denormal and need to adjust - // to denorm exponent - f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + - f8_bias - (implicit_one ? 0 : 1); - - // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1 << (mfmt - wm)) - 1; - bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit - // that is not truncated is 1 - mantissa += - (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & - drop_mask; - - // Now we deal with overflow - if (f8_exponent == 0) { - if ((1 << mfmt) & mantissa) { - f8_exponent = 1; // denormal overflow to become normal, promote exponent - } - } else { - if ((1 << (mfmt + 1)) & mantissa) { - mantissa >>= 1; - f8_exponent++; - } - } - - mantissa >>= (mfmt - wm); - - // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); - if (f8_exponent > max_exp) { - if (clip) { - mantissa = (1 << wm) - 1; - f8_exponent = max_exp; - } else { - return signed_inf; - } - } - - if (f8_exponent == 0 && mantissa == 0) { - return negative_zero_nan ? 0 : (sign << 7); - } - mantissa &= (1 << wm) - 1; - return (sign << 7) | (f8_exponent << wm) | mantissa; -} - -template -inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { -#ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; -#else - constexpr bool is_half = false; -#endif - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "only half and float are supported"); - - constexpr int weo = is_half ? 5 : 8; - constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); - - T fInf, fNegInf, fNaN, fNeg0; - -#ifdef __HIPCC__ - if (is_half) { - const uint16_t ihInf = 0x7C00; - const uint16_t ihNegInf = 0xFC00; - const uint16_t ihNaN = 0x7C01; - const uint16_t ihNeg0 = 0x8000; - fInf = reinterpret_cast(ihInf); - fNegInf = reinterpret_cast(ihNegInf); - fNaN = reinterpret_cast(ihNaN); - fNeg0 = reinterpret_cast(ihNeg0); - } else -#endif - if (is_float) { - const uint32_t ifInf = 0x7F800000; - const uint32_t ifNegInf = 0xFF800000; - const uint32_t ifNaN = 0x7F800001; - const uint32_t ifNeg0 = 0x80000000; - fInf = reinterpret_cast(ifInf); - fNegInf = reinterpret_cast(ifNegInf); - fNaN = reinterpret_cast(ifNaN); - fNeg0 = reinterpret_cast(ifNeg0); - } - - if (x == 0) { - return 0; - } - - uint32_t sign = x >> 7; - uint32_t mantissa = x & ((1 << wm) - 1); - int exponent = (x & 0x7F) >> wm; - if (negative_zero_nan) { - if (x == 0x80) { - return fNaN; - } - } else { - if (x == 0x80) { - return fNeg0; - } - if (exponent == ((1 << we) - 1)) { - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; - } - } - typename std::conditional::type retval; - if (we == 5 && is_half && !negative_zero_nan) { - retval = x << 8; - return reinterpret_cast(retval); - } - - const int exp_low_cutoff = - (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); - - // subnormal input - if (exponent == 0) { - // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + clz(mantissa) - (32 - wm); - mantissa <<= sh; - exponent += 1 - sh; - mantissa &= ((1 << wm) - 1); - } - exponent += exp_low_cutoff - 1; - mantissa <<= wmo - wm; - - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) - if (exponent <= 0) { - mantissa |= 1 << wmo; - mantissa >>= 1 - exponent; - exponent = 0; - } - - if (sizeof(T) == 2) { - retval = (sign << 15) | (exponent << 10) | mantissa; - } else { - retval = (sign << 31) | (exponent << 23) | mantissa; - } - return reinterpret_cast(retval); -} - -} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index eb66834222f3..b2196b8ed516 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -1,13 +1,11 @@ #pragma once -#include "hip_float8.h" +#include #include #include #include -#include "../../../attention/dtype_fp8.cuh" -#include "../../../attention/dtype_float32.cuh" -#include "../../../attention/dtype_bfloat16.cuh" +#include "../../../attention/attention_dtypes.h" namespace vllm { #ifdef USE_ROCM @@ -26,40 +24,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, return x; } + #if HIP_FP8_TYPE_FNUZ +using fp8_type = __hip_fp8_e4m3_fnuz; +using fp8x2_type = __hip_fp8x2_e4m3_fnuz; + #elif HIP_FP8_TYPE_OCP +using fp8_type = __hip_fp8_e4m3; +using fp8x2_type = __hip_fp8x2_e4m3; + #endif + // fp8 -> half template <> __inline__ __device__ uint16_t vec_conversion(const uint8_t& a) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8); - return res.x; + return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x; } // fp8x2 -> half2 template <> __inline__ __device__ uint32_t vec_conversion(const uint16_t& a) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); union { __half2_raw h2r; uint32_t ui32; } tmp; - tmp.h2r.x.data = f2[0]; - tmp.h2r.y.data = f2[1]; + tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); return tmp.ui32; - #else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = vec_conversion(static_cast(a)); - tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); - return tmp.u32; - #endif } // fp8x4 -> half2x2 @@ -92,9 +81,9 @@ using __nv_bfloat16 = __hip_bfloat16; template <> __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f); + fp8_type f8; + f8.__x = a; + return __float2bfloat16(static_cast(f8)); } using __nv_bfloat162 = __hip_bfloat162; @@ -136,27 +125,18 @@ __inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { // fp8 -> float template <> __inline__ __device__ float vec_conversion(const uint8_t& a) { - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8); + fp8_type f8; + f8.__x = a; + return static_cast(f8); } // fp8x2 -> float2 template <> __inline__ __device__ float2 vec_conversion(const uint16_t& a) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0]; - res.y = f2[1]; - return res; - #else - float2 res; - res.x = vec_conversion(static_cast(a)); - res.y = vec_conversion(static_cast(a >> 8U)); - return res; - #endif + fp8x2_type f8x2; + f8x2.__x = a; + return static_cast(f8x2); } // fp8x4 -> float4 @@ -169,6 +149,15 @@ vec_conversion(const uint32_t& a) { return res; } +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + // fp8x8 -> float8 template <> __inline__ __device__ Float8_ vec_conversion(const uint2& a) { @@ -189,33 +178,36 @@ __inline__ __device__ uint8_t vec_conversion(const uint16_t& a) { __half_raw tmp; tmp.x = a; + return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} - hip_fp8 f8{static_cast(tmp.data)}; - return f8.data; +template <> +__inline__ __device__ uint16_t +vec_conversion(const uint32_t& a) { + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, + fp8_type::__default_interpret); } // bf16 -> fp8 template <> __inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) { - hip_fp8 res{__bfloat162float(a)}; - return res.data; + return __hip_cvt_float_to_fp8(__bfloat162float(a), + fp8_type::__default_saturation, + fp8_type::__default_interpret); } // float -> fp8 template <> __inline__ __device__ uint8_t vec_conversion(const float& a) { - hip_fp8 f8(a); - return f8.data; -} - -// fp8x4 -> float4 -template <> -__inline__ __device__ float4 -vec_conversion(const uint32_t& a) { - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; + return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation, + fp8_type::__default_interpret); } // float2 -> half2 @@ -307,90 +299,22 @@ vec_conversion(const Float8_& a) { */ -// fp8 -> half -template <> -__inline__ __device__ uint16_t -scaled_vec_conversion(const uint8_t& a, const float scale) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8) * scale; - return res.x; -} - -// fp8x2 -> half2 -template <> -__inline__ __device__ uint32_t scaled_vec_conversion( - const uint16_t& a, const float scale) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0] * scale; - tmp.h2r.y.data = f2[1] * scale; - return tmp.ui32; - #else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = - scaled_vec_conversion(static_cast(a), scale); - tmp.u16[1] = scaled_vec_conversion( - static_cast(a >> 8U), scale); - return tmp.u32; - #endif -} - -// fp8x4 -> half2x2 -template <> -__inline__ __device__ uint2 -scaled_vec_conversion(const uint32_t& a, const float scale) { - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); - tmp.u32[1] = - scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return tmp.u32x2; -} - -// fp8x8 -> half2x4 -template <> -__inline__ __device__ uint4 -scaled_vec_conversion(const uint2& a, const float scale) { - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = scaled_vec_conversion(a.x, scale); - tmp.u64[1] = scaled_vec_conversion(a.y, scale); - return tmp.u64x2; -} - using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> __inline__ __device__ __nv_bfloat16 -scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, - const float scale) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f * scale); +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { + fp8_type f8; + f8.__x = a; + return __float2bfloat16(static_cast(f8) * scale); } -using __nv_bfloat162 = __hip_bfloat162; - // fp8x2 -> __nv_bfloat162 template <> __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, - const float scale) { + float scale) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); res.y = @@ -400,8 +324,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion( - const uint32_t& a, const float scale) { +__inline__ __device__ bf16_4_t +scaled_vec_conversion(const uint32_t& a, float scale) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), @@ -412,7 +336,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion( // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t -scaled_vec_conversion(const uint2& a, const float scale) { +scaled_vec_conversion(const uint2& a, float scale) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale); tmp2 = scaled_vec_conversion(a.y, scale); @@ -427,29 +351,19 @@ scaled_vec_conversion(const uint2& a, const float scale) { // fp8 -> float template <> __inline__ __device__ float scaled_vec_conversion( - const uint8_t& a, const float scale) { - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8) * scale; + const uint8_t& a, float scale) { + fp8_type f8; + f8.__x = a; + return static_cast(f8) * scale; } // fp8x2 -> float2 template <> __inline__ __device__ float2 -scaled_vec_conversion(const uint16_t& a, const float scale) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0] * scale; - res.y = f2[1] * scale; - return res; - #else - float2 res; - res.x = scaled_vec_conversion(static_cast(a), scale); - res.y = scaled_vec_conversion(static_cast(a >> 8U), - scale); - return res; - #endif +scaled_vec_conversion(const uint16_t& a, float scale) { + fp8x2_type f8x2; + f8x2.__x = a; + return static_cast(f8x2) * scale; } // fp8x4 -> float4 @@ -462,10 +376,18 @@ scaled_vec_conversion(const uint32_t& a, const float scale) { return res; } +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, float scale) { + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; +} + // fp8x8 -> float8 template <> __inline__ __device__ Float8_ -scaled_vec_conversion(const uint2& a, const float scale) { +scaled_vec_conversion(const uint2& a, float scale) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale); tmp2 = scaled_vec_conversion(a.y, scale); @@ -477,44 +399,184 @@ scaled_vec_conversion(const uint2& a, const float scale) { return res; } -/* Quantize(HP / scale) => FP8 */ +// fp8 -> half +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, float scale) { + __half_raw res; + res.data = scaled_vec_conversion(a, scale); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint16_t& a, float scale) { + __half2_raw h2r = + __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); + tmp.h2r.x.data *= scale; + tmp.h2r.y.data *= scale; + return tmp.ui32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; +} -// TODO(Hai): vectorized to add +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, + float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; +} // half -> fp8 template <> __inline__ __device__ uint8_t -scaled_vec_conversion(const uint16_t& a, const float scale) { +scaled_vec_conversion(const uint16_t& a, float scale) { __half_raw tmp; tmp.x = a; + tmp.data /= scale; + return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} - hip_fp8 f8{static_cast(tmp.data) / scale}; - return f8.data; +// halfx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + tmp.h2r.x.data /= scale; + tmp.h2r.y.data /= scale; + return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// half2x2 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint2& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// half2x4 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, + float scale) { + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; } // bf16 -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const __nv_bfloat16& a, const float scale) { - hip_fp8 res{__bfloat162float(a) / scale}; - return res.data; + const __nv_bfloat16& a, float scale) { + return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale, + fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// bf16x2 -> fp8x2 +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const __nv_bfloat162& a, float scale) { + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; +} + +// bf16x4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const bf16_4_t& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// bf16x8 -> fp8x8 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const bf16_8_t& a, float scale) { + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; } // float -> fp8 template <> __inline__ __device__ uint8_t -scaled_vec_conversion(const float& a, const float scale) { - hip_fp8 f8(a / scale); - return f8.data; +scaled_vec_conversion(const float& a, float scale) { + return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation, + fp8_type::__default_interpret); } -// fp8x4 -> float4 +// floatx2 -> fp8x2 template <> -__inline__ __device__ float4 -scaled_vec_conversion(const uint32_t& a, const float scale) { - Float4_ tmp = scaled_vec_conversion(a, scale); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ uint16_t +scaled_vec_conversion(const float2& a, float scale) { + return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// floatx4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const float4& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; } #endif // ENABLE_FP8 diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index 15bd5b6ed156..fac99b297342 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -12,7 +12,7 @@ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); #else #include - #include "amd/hip_float8.h" + #include "amd/quant_utils.cuh" using FP8_TYPE = c10::Float8_e4m3fnuz; // Using the default max value from pytorch (240.0) will cause accuracy // issue when running dynamic quantization. Here use 224.0f for rocm. @@ -47,8 +47,10 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, return static_cast(r); #else // Use hardware cvt instruction for fp8 on rocm - return c10::Float8_e4m3fnuz(hip_fp8(r).data, - c10::Float8_e4m3fnuz::from_bits()); + return c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation, + fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); #endif } diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index b8b5e2045457..fb3688748214 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -159,19 +159,20 @@ def test_reshape_and_cache( device) key_cache, value_cache = key_caches[0], value_caches[0] + # Using default kv_scale + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) + # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache) + ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item()) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache) + ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item()) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() - # Using default kv_scale - k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) - # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, @@ -182,9 +183,9 @@ def test_reshape_and_cache( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(result_key_cache, key_cache) + ops.convert_fp8(result_key_cache, key_cache, k_scale.item()) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(result_value_cache, value_cache) + ops.convert_fp8(result_value_cache, value_cache, v_scale.item()) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -268,15 +269,16 @@ def test_reshape_and_cache_flash( del key_caches del value_caches - k_scale = (key.amax() / 256.0).to(torch.float32) - v_scale = (value.amax() / 256.0).to(torch.float32) + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) + ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(), + kv_cache_dtype) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache, v_scale, + ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(), kv_cache_dtype) else: cloned_key_cache = key_cache.clone() From ab1bdb3a8beb8529d2deedf3d10ae3bf4f00107b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 25 Feb 2025 18:01:15 +0800 Subject: [PATCH 220/317] [CI/Build] Fix V1 LoRA failure (#13767) --- tests/lora/test_gemma.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index a1b4c897c45e..bbdfbe37175e 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -41,6 +41,8 @@ def v1(run_with_both_engines_lora): pass +# The V1 lora test for this model requires more than 24GB. +@pytest.mark.skip_v1 @pytest.mark.xfail(current_platform.is_rocm(), reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): From 2248377a0d4a83e5778dceef1a588f694f314ebe Mon Sep 17 00:00:00 2001 From: Chen1022 <112855051+Chen-0210@users.noreply.github.com> Date: Tue, 25 Feb 2025 18:12:19 +0800 Subject: [PATCH 221/317] [Misc]Clarify Error Handling for Non-existent Model Paths and HF Repo IDs (#13724) Signed-off-by: Chen-0210 Co-authored-by: Michael Goin --- vllm/transformers_utils/config.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index dd6ee9a34adb..55a620b4bf14 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -253,14 +253,28 @@ def get_config( model = Path(model).parent if config_format == ConfigFormat.AUTO: - if is_gguf or file_or_path_exists( - model, HF_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.HF - elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, - revision=revision): - config_format = ConfigFormat.MISTRAL - else: - raise ValueError(f"No supported config format found in {model}.") + try: + if is_gguf or file_or_path_exists( + model, HF_CONFIG_NAME, revision=revision): + config_format = ConfigFormat.HF + elif file_or_path_exists(model, + MISTRAL_CONFIG_NAME, + revision=revision): + config_format = ConfigFormat.MISTRAL + + except Exception as e: + error_message = ( + "Invalid repository ID or local directory specified:" + " '{model}'.\nPlease verify the following requirements:\n" + "1. Provide a valid Hugging Face repository ID.\n" + "2. Specify a local directory that contains a recognized " + "configuration file.\n" + " - For Hugging Face models: ensure the presence of a " + "'config.json'.\n" + " - For Mistral models: ensure the presence of a " + "'params.json'.\n") + + raise ValueError(error_message) from e if config_format == ConfigFormat.HF: config_dict, _ = PretrainedConfig.get_config_dict( From 5bf3a9b0790df15051fdd88d4dbf1126298f4410 Mon Sep 17 00:00:00 2001 From: Junlin Zhou Date: Tue, 25 Feb 2025 18:13:09 +0800 Subject: [PATCH 222/317] [Bugfix] Initialize attention bias on the same device as Query/Key/Value (#13468) --- vllm/attention/backends/xformers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index ec8e1f2ee5a6..9fa76634e1fc 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -673,7 +673,9 @@ def _run_memory_efficient_xformers_forward( # Cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens, + device=query.device) # Encoder branch of encoder-decoder model uses # attn_metadata.encoder_seq_lens @@ -683,7 +685,7 @@ def _run_memory_efficient_xformers_forward( # Encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.encoder_seq_lens) + attn_metadata.encoder_seq_lens, device=query.device) # Self-attention block of encoder-only model just # uses the seq_lens directly. @@ -692,7 +694,7 @@ def _run_memory_efficient_xformers_forward( # Encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens) + attn_metadata.seq_lens, device=query.device) # Self-attention block of decoder branch just # uses the seq_lens directly @@ -701,7 +703,7 @@ def _run_memory_efficient_xformers_forward( # Decoder self-attention mask is causal attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + attn_metadata.seq_lens, device=query.device) else: raise ValueError("Unknown AttentionType: %s", attn_type) From be48b410877413f9493a998d728054c19c9dca90 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" <165712832+naromero77amd@users.noreply.github.com> Date: Tue, 25 Feb 2025 05:08:20 -0600 Subject: [PATCH 223/317] [Bugfix] Flush TunableOp results before worker processes are destroyed. (#13623) Signed-off-by: Nichols A. Romero --- vllm/executor/multiproc_worker_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index cef6a994a9c0..68a83bb610a4 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -250,6 +250,15 @@ def _run_worker_process( except Exception: logger.exception("Worker failed") + # Flush TunableOp results when TunableOp is enabled and + # online (in situ) tuning is enabled. + # Offline tuning API (record_untuned_is_enabled()) only + # available in PyTorch 2.6 or later. + import torch.cuda.tunable as tunable + if (tunable.is_enabled() and tunable.tuning_is_enabled() + and not tunable.record_untuned_is_enabled()): + tunable.write_file() + logger.info("Worker exiting") From 6d05cde2fd917a0948093cd67d0ace4a2af30007 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 25 Feb 2025 22:03:02 +0800 Subject: [PATCH 224/317] [Bugfix] Fix deepseek-vl2 inference with more than 2 images (#13818) --- vllm/model_executor/models/deepseek_vl2.py | 50 ++++++++++++++++++---- vllm/model_executor/models/h2ovl.py | 6 ++- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index c58b65d49348..ea217e244404 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -25,7 +25,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, ProcessingCache, + PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, @@ -138,18 +139,24 @@ def get_hf_processor(self, **kwargs: object): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_num_image_tokens(self, *, image_width: int, - image_height: int) -> int: + def get_num_image_tokens(self, + *, + image_width: int, + image_height: int, + cropping: bool = True) -> int: hf_processor = self.get_hf_processor() image_size = hf_processor.image_size patch_size = hf_processor.patch_size downsample_ratio = hf_processor.downsample_ratio - best_width, best_height = hf_processor.select_best_resolution( - (image_width, image_height)) + if cropping: + best_width, best_height = hf_processor.select_best_resolution( + (image_width, image_height)) + num_width_tiles, num_height_tiles = (best_width // image_size, + best_height // image_size) + else: + num_width_tiles = num_height_tiles = 1 - num_width_tiles, num_height_tiles = (best_width // image_size, - best_height // image_size) h = w = math.ceil((image_size // patch_size) / downsample_ratio) global_views_tokens = h * (w + 1) @@ -169,10 +176,12 @@ def get_mm_max_tokens_per_item( seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: + num_images = mm_counts.get("image", 0) max_image_size = self.get_image_size_with_most_features() max_image_tokens = self.get_num_image_tokens( image_height=max_image_size.height, - image_width=max_image_size.width) + image_width=max_image_size.width, + cropping=num_images <= 2) return {"image": max_image_tokens} @@ -207,6 +216,30 @@ def get_dummy_processor_inputs( class DeepseekVL2MultiModalProcessor( BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): + def __init__( + self, + info: DeepseekVL2ProcessingInfo, + dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]", + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: + super().__init__( + info, + dummy_inputs, + cache=cache, + enable_sanity_checks=enable_sanity_checks, + ) + + mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt + if self.cache is not None and mm_limit["image"] > 2: + # The processor output depends on the number of images passed, + # making it incompatible with processing cache which is supposed + # to be invariant of how many images are passed per prompt + self.cache = None + logger.warning_once( + f"{type(self).__name__} does not support processing cache with " + "image limit larger than 2.") + def _call_hf_processor( self, prompt: str, @@ -271,6 +304,7 @@ def get_replacement_deepseek_vl2(item_idx: int): num_image_tokens = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, + cropping=len(images) <= 2, ) return [image_token_id] * num_image_tokens diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 01b721fa79e1..bab9c256b9aa 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -477,13 +477,15 @@ def __init__(self, enable_sanity_checks=enable_sanity_checks, ) - if self.cache is not None: + mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt + if self.cache is not None and mm_limit["image"] >= 2: # The processor output depends on the number of images passed, # making it incompatible with processing cache which is supposed # to be invariant of how many images are passed per prompt self.cache = None logger.warning_once( - f"{type(self).__name__} does not support processing cache.") + f"{type(self).__name__} does not support processing cache with " + "multi-image support enabled.") def _get_prompt_replacements( self, From f63bfeac34fce38b5c6b81c81b1f5cb49414a5cf Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Tue, 25 Feb 2025 22:03:33 +0800 Subject: [PATCH 225/317] Fix `/v1/audio/transcriptions ` Bad Request Error (#13811) --- requirements-common.txt | 3 +-- vllm/entrypoints/openai/protocol.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 0514bf8adcaf..942c3e039eaf 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -9,8 +9,7 @@ py-cpuinfo transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9' -fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' +fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) pydantic >= 2.9 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 45b98a032bda..cd2902f934bf 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1448,7 +1448,7 @@ class UnloadLoraAdapterRequest(BaseModel): class TranscriptionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation - #https://platform.openai.com/docs/api-reference/audio/createTranscription + # https://platform.openai.com/docs/api-reference/audio/createTranscription file: UploadFile """ From d8265fd0b00df759d460ab8c83e123e17a767047 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 26 Feb 2025 00:18:50 +0800 Subject: [PATCH 226/317] [Bugfix] Revert inspection code in #13743 (#13832) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/registry.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index bae6444267fa..05fb3d21953d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -364,12 +364,6 @@ def register_model( raise ValueError(msg) model = _LazyRegisteredModel(*split_str) - - try: - model.inspect_model_cls() - except Exception as exc: - msg = f"Unable to inspect model {model_cls}" - raise RuntimeError(msg) from exc elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module): model = _RegisteredModel.from_model_cls(model_cls) else: From 8c32ae8662dfb23cf5d4dd8a6b3e2c674a928877 Mon Sep 17 00:00:00 2001 From: Chen1022 <112855051+Chen-0210@users.noreply.github.com> Date: Wed, 26 Feb 2025 00:20:29 +0800 Subject: [PATCH 227/317] Fix string parsing error (#13825) --- vllm/transformers_utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 55a620b4bf14..1937b1388471 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -272,7 +272,7 @@ def get_config( " - For Hugging Face models: ensure the presence of a " "'config.json'.\n" " - For Mistral models: ensure the presence of a " - "'params.json'.\n") + "'params.json'.\n").format(model=model) raise ValueError(error_message) from e From fecd1c2d639335cf6e547f9f605c9d5c0c036128 Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Tue, 25 Feb 2025 11:47:49 -0800 Subject: [PATCH 228/317] [Neuron] Add custom_ops for neuron backend (#13246) Signed-off-by: Liangfu Chen Co-authored-by: George Novack Co-authored-by: Aoyu Zhang --- tests/neuron/test_activation.py | 42 ++++++++ tests/neuron/test_layernorm.py | 56 +++++++++++ tests/neuron/test_logits_processor.py | 95 +++++++++++++++++++ tests/neuron/test_prefix_prefill.py | 7 +- tests/neuron/test_rotary_embedding.py | 58 +++++++++++ vllm/model_executor/custom_op.py | 7 ++ vllm/model_executor/layers/activation.py | 7 ++ .../model_executor/layers/logits_processor.py | 1 + .../model_executor/layers/rotary_embedding.py | 76 +++++++++++++++ 9 files changed, 346 insertions(+), 3 deletions(-) create mode 100644 tests/neuron/test_activation.py create mode 100644 tests/neuron/test_layernorm.py create mode 100644 tests/neuron/test_logits_processor.py create mode 100644 tests/neuron/test_rotary_embedding.py diff --git a/tests/neuron/test_activation.py b/tests/neuron/test_activation.py new file mode 100644 index 000000000000..ec2b1238e404 --- /dev/null +++ b/tests/neuron/test_activation.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.activation import FastGELU, SiluAndMul +from vllm.platforms import current_platform + + +@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"]) +@pytest.mark.parametrize("num_tokens,d,dtype", [ + (7, 512, torch.half), + (7, 512, torch.float), + (83, 512, torch.half), +]) +@torch.inference_mode() +def test_act_and_mul( + activation: str, + num_tokens: int, + d: int, + dtype: torch.dtype, +) -> None: + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + current_platform.seed_everything(0) + torch.set_default_device("cpu") + x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device) + if activation == "silu_and_mul": + layer = SiluAndMul() + fn = layer.forward_native + elif activation == "gelu_fast": + layer = FastGELU() + fn = F.gelu + else: + raise NotImplementedError( + f"activation {activation} is not implemented.") + assert x.is_xla, "input tensor under testing is expected to be XLA tensor." + out = layer.to(device=device).forward_neuron(x) + ref_out = fn(x.cpu()) + torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0) diff --git a/tests/neuron/test_layernorm.py b/tests/neuron/test_layernorm.py new file mode 100644 index 000000000000..e96df8db6ccd --- /dev/null +++ b/tests/neuron/test_layernorm.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform + + +@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [ + (7, 8, False, torch.half), + (83, 768, False, torch.half), + (83, 768, True, torch.half), + (83, 768, True, torch.bfloat16), + (83, 768, True, torch.float32), +]) +@torch.inference_mode() +def test_rms_norm( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, +) -> None: + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + current_platform.seed_everything(0) + torch.set_default_device("cpu") + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device) + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None + + residual_cpu = residual.cpu() if add_residual else None + ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu) + assert x.is_xla, "input tensor under testing is expected to be XLA tensor." + out = layer.to(device=device)(x, residual) + + # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger + # numerical errors than other operators because they involve reductions. + # Therefore, we use a larger tolerance. + if add_residual: + assert out[0].is_xla, "output tensor is expected to be XLA tensor" + torch.testing.assert_close(out[0].cpu(), + ref_out[0], + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(out[1].cpu(), + ref_out[1], + atol=1e-2, + rtol=1e-2) + else: + assert out.is_xla, "output tensor is expected to be XLA tensor" + torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2) diff --git a/tests/neuron/test_logits_processor.py b/tests/neuron/test_logits_processor.py new file mode 100644 index 000000000000..37d59c9e76a7 --- /dev/null +++ b/tests/neuron/test_logits_processor.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random +from typing import Tuple +from unittest.mock import patch + +import pytest +import torch + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import is_pin_memory_available + + +class MockLogitsProcessor(LogitsProcessor): + + def __init__(self, vocab_size: int, scale: float, + fake_logits: torch.Tensor): + super().__init__(vocab_size=vocab_size, scale=scale) + self.fake_logits = fake_logits.clone() + + def forward(self, *args, **kwargs): + with patch( + "vllm.model_executor.layers.logits_processor._prune_hidden_states", + lambda x, y: x + ), patch( + "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", + lambda *args, **kwargs: self.fake_logits): + return super().forward(*args, **kwargs) + + +def _prepare_test( + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: + vocab_size = 32000 + input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) + fake_logits = torch.full((batch_size, vocab_size), + 1e-2, + dtype=input_tensor.dtype) + logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) + return input_tensor, fake_logits, logits_processor + + +RANDOM_SEEDS = list(range(8)) + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_logits_processors(seed: int): + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + set_random_seed(seed) + torch.set_default_device("cpu") + batch_size = random.randint(1, 256) + input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + seq_group_metadata_list = [] + seq_lens = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, + logits_processors=[pick_ith]), + block_tables={0: [1]}, + )) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=device, + pin_memory=is_pin_memory_available()) + logits_processor_output = logits_processor( + lm_head=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + + fake_logits *= logits_processor.scale + torch.testing.assert_close(logits_processor_output[:, 1], + fake_logits[:, 1], + rtol=1e-4, + atol=0.0) diff --git a/tests/neuron/test_prefix_prefill.py b/tests/neuron/test_prefix_prefill.py index 347a139f39b4..2c6ac47888d5 100644 --- a/tests/neuron/test_prefix_prefill.py +++ b/tests/neuron/test_prefix_prefill.py @@ -345,6 +345,7 @@ def test_contexted_kv_attention( torch.manual_seed(0) torch.set_printoptions(sci_mode=False) + torch.set_default_device("cpu") dtype = torch.float32 min_ctx_len = 32 @@ -438,9 +439,9 @@ def pad_to_next_power_of_2(a): # transform block table active_block_table = get_active_block_tables( - block_table, - torch.tensor(query_lens), - torch.tensor(seq_lens), + block_table.cpu(), + torch.tensor(query_lens).cpu(), + torch.tensor(seq_lens).cpu(), block_size, num_active_blocks, ) diff --git a/tests/neuron/test_rotary_embedding.py b/tests/neuron/test_rotary_embedding.py new file mode 100644 index 000000000000..c015b80bd472 --- /dev/null +++ b/tests/neuron/test_rotary_embedding.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for miscellaneous utilities +""" + +import pytest +import torch + +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + + +@pytest.mark.parametrize( + "max_position,is_neox_style,rotary_dim,head_size,seq_len", [ + (16, False, 32, 32, 1024), + (16, False, 32, 128, 1024), + (16, True, 32, 32, 1024), + (16, True, 32, 128, 1024), + ]) +def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, + head_size, seq_len): + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + current_platform.seed_everything(0) + torch.set_default_device("cpu") + + batch_size = 1 + base = 10000 + num_heads = 8 + + rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, torch.float32) + + positions = torch.randint(0, + max_position, (batch_size, seq_len), + device="cpu") + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=torch.float32, + device="cpu") + key = torch.randn_like(query) + + assert positions.is_cpu, \ + "reference input tensor is expected to be CPU tensor." + ref_query, ref_key = rot.to(device="cpu").forward_native( + positions, query, key) + out_query, out_key = rot.to(device=device).forward_neuron( + positions.to(device=device), query.to(device=device), + key.to(device=device)) + assert out_query.is_xla and out_key.is_xla, \ + "output tensor is expected to be XLA tensor" + torch.testing.assert_close(out_query.cpu(), + ref_query, + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index ee4f41ea6ec9..dfd052f62521 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -59,6 +59,11 @@ def forward_hpu(self, *args, **kwargs): # PyTorch-native implementation. return self.forward_native(*args, **kwargs) + def forward_neuron(self, *args, **kwargs): + # By default, we assume that Neuron ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + def forward_oot(self, *args, **kwargs): # By default, we assume that OOT ops are compatible with the # PyTorch-native implementation. @@ -88,6 +93,8 @@ def dispatch_forward(self): return self.forward_tpu elif current_platform.is_xpu(): return self.forward_xpu + elif current_platform.is_neuron(): + return self.forward_neuron elif current_platform.is_out_of_tree(): return self.forward_oot else: diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f782920d06a0..1de0f499c1a6 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -89,6 +89,13 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: self.op(out, x) return out + def forward_neuron(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d]) + result = s * x_reshaped[:, d:] + return result.view(*x.shape[:-1], d) + @CustomOp.register("mul_and_silu") class MulAndSilu(CustomOp): diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 9b1742998578..2f39a0e87854 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -53,6 +53,7 @@ def __init__(self, # Whether to use gather or all-gather to gather the logits. parallel_config = get_current_vllm_config().parallel_config self.use_all_gather = current_platform.is_tpu() \ + or current_platform.is_neuron() \ or envs.VLLM_USE_V1 \ or parallel_config.distributed_executor_backend == "external_launcher" # noqa diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index ce1bc98ea426..64c2dac524f2 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -254,6 +254,82 @@ def forward_hpu( key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_neuron( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + def _apply_rotary_emb_neuron( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + ) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + # x1 = x[..., ::2] + + # x2 = x[..., 1::2] + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) + x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + if offsets is not None: + positions = positions + offsets + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + + if self.rotary_dim == self.head_size: + query = _apply_rotary_emb(query, cos, sin, self.is_neox_style) + query = query.reshape(query_shape) + key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) + key = key.reshape(key_shape) + else: + head_size = query.shape[-1] + query_reshaped = query.view(-1, head_size) + query_pass = query_reshaped[:, self.rotary_dim:].view( + *query.shape[:-1], head_size - self.rotary_dim) + query_rot = query_reshaped[:, :self.rotary_dim].view( + *query.shape[:-1], self.rotary_dim) + query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), + dim=-1).reshape(query_shape) + + key_reshaped = key.view(-1, head_size) + key_pass = key_reshaped[:, self.rotary_dim:].view( + *key.shape[:-1], head_size - self.rotary_dim) + key_rot = key_reshaped[:, :self.rotary_dim].view( + *key.shape[:-1], self.rotary_dim) + key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" From e78fc4b387ec6a0a5609b1410ab962d995f202bb Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 25 Feb 2025 20:33:03 +0000 Subject: [PATCH 229/317] Fix failing `MyGemma2Embedding` test (#13820) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .../vllm_add_dummy_model/my_gemma_embedding.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 3af62b2885e5..a376d2cb340c 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch import torch.nn as nn -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.models.gemma2 import Gemma2Model @@ -37,16 +36,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) From 5b754aadaeae00c3c47b45a64151f3c726d34aac Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 25 Feb 2025 20:07:12 -0500 Subject: [PATCH 230/317] [Model] Support Grok1 (#13795) Signed-off-by: mgoin --- docs/source/models/supported_models.md | 5 + tests/models/registry.py | 2 + .../layers/fused_moe/fused_moe.py | 43 +- vllm/model_executor/layers/fused_moe/layer.py | 22 +- .../layers/quantization/awq_marlin.py | 2 + .../compressed_tensors_moe.py | 4 + .../layers/quantization/experts_int8.py | 2 + .../model_executor/layers/quantization/fp8.py | 2 + .../layers/quantization/gptq_marlin.py | 3 + vllm/model_executor/models/grok1.py | 565 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 11 files changed, 634 insertions(+), 17 deletions(-) create mode 100644 vllm/model_executor/models/grok1.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index ae851c35e626..9959f7233e86 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -286,6 +286,11 @@ See [this page](#generative-models) for more information on how to use generativ * `parasail-ai/GritLM-7B-vllm`. * ✅︎ * ✅︎ +- * `Grok1ModelForCausalLM` + * Grok1 + * `hpcai-tech/grok-1`. + * ✅︎ + * ✅︎ - * `InternLMForCausalLM` * InternLM * `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. diff --git a/tests/models/registry.py b/tests/models/registry.py index d89a41dae3aa..566a4418feb1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -130,6 +130,8 @@ def check_available_online( "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), + "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", + trust_remote_code=True), "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", trust_remote_code=True), "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bc9573b36df7..00260313e72e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, - global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + activation, use_fp8_w8a8, use_int8_w8a16, + use_int4_w4a16, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) def inplace_fused_experts_fake( @@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1093,6 +1096,7 @@ def outplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1106,7 +1110,7 @@ def outplace_fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, - False, use_fp8_w8a8, use_int8_w8a16, + False, activation, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor, if inplace: torch.ops.vllm.inplace_fused_experts( - hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + hidden_states, w1, w2, topk_weights, topk_ids, activation, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) return hidden_states else: return torch.ops.vllm.outplace_fused_experts( - hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + hidden_states, w1, w2, topk_weights, topk_ids, activation, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) def fused_experts_impl(hidden_states: torch.Tensor, @@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int4_w4a16=use_int4_w4a16, block_shape=block_shape) - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") invoke_fused_moe_kernel(intermediate_cache2, w2, @@ -1339,6 +1354,7 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, + activation: str = "silu", use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, @@ -1370,6 +1386,8 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk @@ -1420,6 +1438,7 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, + activation=activation, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 452f390f4987..42554b61f67a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -120,7 +120,8 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -134,7 +135,8 @@ def apply( expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + activation=activation) def forward_cuda( self, @@ -150,7 +152,8 @@ def forward_cuda( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -170,6 +173,7 @@ def forward_cuda( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, global_num_experts=global_num_experts, expert_map=expert_map) @@ -186,9 +190,11 @@ def forward_cpu( global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, + activation: str = "silu", **kwargs, ): assert custom_routing_function is None + assert activation == "silu", f"{activation} is not supported." return layer.ipex_fusion( x, use_grouped_topk, @@ -213,7 +219,8 @@ def forward_tpu( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: assert not use_grouped_topk assert num_expert_group is None @@ -225,6 +232,7 @@ def forward_tpu( if e_score_correction_bias is not None: raise NotImplementedError( "Expert score correction bias is not supported for TPU.") + assert activation == "silu", f"{activation} is not supported for TPU." return fused_moe_pallas(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -277,6 +285,7 @@ def __init__( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ): super().__init__() @@ -305,6 +314,7 @@ def __init__( self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.activation = activation self.expert_map = None if self.ep_size > 1: @@ -653,7 +663,9 @@ def forward(self, hidden_states: torch.Tensor, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias) + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + ) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 7a2fb203dec3..473816fcc3ec 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -469,7 +469,9 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." if expert_map is not None: raise NotImplementedError( "Expert Parallelism is not supported for " diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f1f316f08339..c9aa0ec285ba 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -219,6 +219,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -240,6 +241,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_fp8_w8a8=True, global_num_experts=global_num_experts, expert_map=expert_map, @@ -550,7 +552,9 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." if expert_map is not None: raise NotImplementedError( "Expert Parallelism is not supported for " diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 0767926ee5c0..d18ca55afebd 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -113,6 +113,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -134,6 +135,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_int8_w8a16=True, global_num_experts=global_num_experts, expert_map=expert_map, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5e1bec0bb4be..76a7d4df8a36 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -675,6 +675,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -698,6 +699,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_fp8_w8a8=True, global_num_experts=global_num_experts, expert_map=expert_map, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94a1de71bbca..21db8ccba059 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -590,7 +590,10 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." + # The input must currently be float16 orig_dtype = x.dtype x = x.half() diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py new file mode 100644 index 000000000000..f2e82017f653 --- /dev/null +++ b/vllm/model_executor/models/grok1.py @@ -0,0 +1,565 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from +# https://github.com/ROCm/vllm/blob/cea7419f151cc50293a05b7fac8547f8f887c9f6/vllm/model_executor/models/grok1.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Grok1 model.""" +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +# Default Grok1-specific constants, overridden by config values if present +DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845 +DEFAULT_OUTPUT_MULTIPLIER_SCALE = 0.5773502691896257 +DEFAULT_EMBEDDING_MULTIPLIER_SCALE = 78.38367176906169 + + +class Grok1MoE(nn.Module): + """A tensor-parallel MoE implementation for Grok1 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + activation="gelu", + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + router_logits = 30.0 * F.tanh(router_logits / 30.0) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class Grok1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + config=None, # Added config parameter + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.config = config # Store config reference + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + attn_logits_soft_cap = max( + getattr(config, "attn_logit_softcapping", 30.0), 0.0) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + + # Apply attention output multiplier if specified in config + attn_multiplier = getattr(self.config, "attn_output_multiplier", + None) if self.config else None + if attn_multiplier is not None: + output = output * attn_multiplier + return output + + +class Grok1DecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Check for fp8 quantization + self.use_fp8 = False + if quant_config is not None: + self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", + lambda: False)() + if not self.use_fp8 and hasattr(quant_config, "is_fp8"): + self.use_fp8 = quant_config.is_fp8 + + # Requires transformers > 4.32.0 + # Default rope_theta value if not in config + rope_theta = 10000 + self.attn = Grok1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + config=config) # Pass config to Grok1Attention + + # Grok1 uses "num_experts" in its config + num_experts = getattr(config, "num_experts", 8) + num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) + + self.moe_block = Grok1MoE(num_experts=num_experts, + top_k=num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block") + + self.pre_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.pre_attn_norm(hidden_states) + else: + hidden_states, residual = self.pre_attn_norm( + hidden_states, residual) + + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Post attention normalization + hidden_states = self.post_attn_norm(hidden_states) + + # MoE block with normalization + hidden_states, residual = self.pre_moe_norm(hidden_states, residual) + hidden_states = self.moe_block(hidden_states) + hidden_states = self.post_moe_norm(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class Grok1Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embedding_multiplier_scale = getattr( + config, "embedding_multiplier_scale", + DEFAULT_EMBEDDING_MULTIPLIER_SCALE) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Grok1DecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier_scale + return hidden_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = Grok1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.output_multiplier_scale = getattr( + config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + self.output_multiplier_scale) + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Map Grok1's unique expert parameter names to standard names + # Grok1 uses "num_experts" in its config + num_experts = getattr(self.config, "num_experts", 8) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", # Grok1 specific + ckpt_down_proj_name="linear_1", # Grok1 specific + ckpt_up_proj_name="linear_v", # Grok1 specific + num_experts=num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + # Handle Grok1-specific norm.scale naming + if "norm.scale" in name: + name = name.replace("scale", "weight") + + # Skip lm_head when tie_word_embeddings is True + if "lm_head" in name and self.config.tie_word_embeddings: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 05fb3d21953d..58155905a7b7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -60,6 +60,7 @@ "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), "GritLM": ("gritlm", "GritLM"), + "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), From 1f903085866685bf6da801f926c34f4f7f1527b9 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 26 Feb 2025 01:24:57 +0000 Subject: [PATCH 231/317] DeepSeek V2/V3/R1 only place `lm_head` on last pp rank (#13833) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/deepseek_v2.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 79484cee167d..6ff3ef129a74 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -636,9 +636,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.model = DeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( From 1b1e51dc66a1941d9d0fd4f2e49406301cf6ca49 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Tue, 25 Feb 2025 17:53:43 -0800 Subject: [PATCH 232/317] [misc] Show driver IP info when Ray fails to allocate driver worker (#13858) Signed-off-by: Rui Qiao --- vllm/executor/ray_distributed_executor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 673d0fc5d23e..bcad274bab49 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -229,9 +229,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( - "Ray does not allocate any GPUs on the driver node. Consider " - "adjusting the Ray placement group or running the driver on a " - "GPU node.") + "Ray does not allocate any GPUs on the driver node." + f"Driver IP: {driver_ip}, worker IPs: {worker_ips}." + "Consider adjusting the Ray placement group or running " + "the driver on a GPU node.") ip_counts: Dict[str, int] = {} for ip in worker_ips: From cfb690d380d658501612f01ef63ce816cf37948b Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Tue, 25 Feb 2025 18:14:48 -0800 Subject: [PATCH 233/317] [V1][Spec Decode] Change Spec Decode Rejection Sampling API (#13729) --- tests/v1/sample/test_rejection_sampler.py | 17 ++- tests/v1/sample/test_sampler.py | 1 - tests/v1/worker/test_gpu_input_batch.py | 1 - vllm/v1/sample/metadata.py | 3 - vllm/v1/sample/rejection_sampler.py | 134 +++++++++++----------- vllm/v1/sample/sampler.py | 19 ++- vllm/v1/worker/gpu_input_batch.py | 11 -- vllm/v1/worker/gpu_model_runner.py | 29 +++-- 8 files changed, 104 insertions(+), 111 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 956d91c6daf7..f00585b40ba3 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: temperature=torch.tensor([]), all_greedy=True, all_random=False, - spec_token_ids=spec_tokens, top_p=None, top_k=None, min_p=torch.empty(batch_size, ), @@ -55,7 +54,7 @@ def test_perfect_match(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) @@ -70,7 +69,7 @@ def test_early_mismatch(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) @@ -85,7 +84,7 @@ def test_multiple_sequences(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) @@ -100,7 +99,7 @@ def test_single_token_sequence(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) @@ -113,7 +112,7 @@ def test_empty_sequence(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) @@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID], [4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], dtype=torch.int, @@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) @@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens, vocab_size) - output = sampler(logits, metadata) + output = sampler(spec_tokens, logits, metadata) expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) assert logits.shape[-1] == vocab_size diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 34fba5a9f6d7..435c1b7b5fda 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -105,7 +105,6 @@ def _create_default_sampling_metadata( prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, - spec_token_ids=None, frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 0aee266264ac..327370e71fff 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata( dtype=torch.float, device=device), output_token_ids=output_token_ids, - spec_token_ids=None, min_tokens=min_tokens, no_penalties=(all(x == 0 for x in presence_penalties) and all(x == 0 for x in frequency_penalties) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 9f7770bbd078..b757a1dc60c7 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -13,9 +13,6 @@ class SamplingMetadata: all_greedy: bool all_random: bool - # None when there are no speculated tokens. - spec_token_ids: Optional[List[List[int]]] - top_p: Optional[torch.Tensor] top_k: Optional[torch.Tensor] min_p: Optional[torch.Tensor] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 580ad44297aa..2e3927345eb5 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import List + import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence @@ -52,62 +54,62 @@ def __init__(self): else: self.forward_method = self.forward_native - def forward(self, logits: torch.Tensor, + def forward(self, draft_token_ids: List[List[int]], + target_probs: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: if not sampling_metadata.all_greedy: raise NotImplementedError( "Currently, only greedy sampling is supported by " "rejection sampler.") - return self.forward_method(logits, sampling_metadata) + return self.forward_method(draft_token_ids, target_probs, + sampling_metadata) def flashinfer_sample( self, - logits: torch.Tensor, + draft_token_ids: List[List[int]], + target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: # NOTE: The following input preparationg can be moved # to the model runner with a persistent manner for better # performance. - assert sampling_metadata.spec_token_ids is not None - spec_token_ids = sampling_metadata.spec_token_ids - max_spec_len = max(len(s) for s in spec_token_ids) - batch_size = len(spec_token_ids) - draft_token_ids = torch.full((batch_size, max_spec_len), - INVALID_TOKEN_ID, - device="cpu", - dtype=torch.long) - - target_token_ids = torch.full((batch_size, max_spec_len + 1), - fill_value=INVALID_TOKEN_ID, - device=logits.device, - dtype=torch.long) - - # TODO: Vectorize the following loop for better performance. - start_loc = 0 - for i in range(batch_size): - num_spec_tokens = len(spec_token_ids[i]) - draft_token_ids[i, :num_spec_tokens] = torch.tensor( - spec_token_ids[i], device="cpu", dtype=torch.long) - end_loc = start_loc + num_spec_tokens + 1 - # Assume greedy sampling. - target_token_ids[i, :num_spec_tokens + 1] = torch.argmax( - logits[start_loc:end_loc], dim=-1) - start_loc = end_loc - - vocab_size = logits.size(-1) - # NOTE: CPU <-> GPU synchronization happens here. - draft_token_ids = draft_token_ids.to(logits.device) - draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size, - logits.device) - target_probs = _create_greedy_token_probs(target_token_ids, vocab_size, - logits.device) - uniform_samples = torch.zeros(batch_size, - max_spec_len + 1, - device=logits.device) + sample_lens = [len(x) + 1 for x in draft_token_ids] + # Convert draft token IDs to a tensor, split by sample_lens, then pad. + draft_token_ids = [ + torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids + ] + draft_token_ids_tensor = pad_sequence(draft_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) + + if sampling_metadata.all_greedy: + target_token_ids = target_probs.argmax(dim=-1).view(-1) + target_token_ids = target_token_ids.split(sample_lens) + target_token_ids = pad_sequence(target_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) + + vocab_size = target_probs.size(-1) + # NOTE: CPU <-> GPU synchronization happens here. + draft_token_ids_tensor = draft_token_ids_tensor.to( + target_probs.device) + draft_probs = _create_greedy_token_probs(draft_token_ids_tensor, + vocab_size, + target_probs.device) + target_probs = _create_greedy_token_probs(target_token_ids, + vocab_size, + target_probs.device) + uniform_samples = torch.zeros(draft_token_ids_tensor.size(0), + draft_token_ids_tensor.size(1) + 1, + device=target_probs.device) + else: + raise NotImplementedError( + "Currently, only greedy sampling is supported by " + "rejection sampler.") sampled_token_ids, _, _ = fs.chain_speculative_sampling( draft_probs, - draft_token_ids, + draft_token_ids_tensor, uniform_samples, target_probs, ) @@ -117,35 +119,35 @@ def flashinfer_sample( # TODO: The following method can be optimized for better performance. def forward_native( self, - logits: torch.Tensor, + draft_token_ids: List[List[int]], + target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - assert sampling_metadata.spec_token_ids is not None - spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] - # Add 1 to include the 'bonus' token. - sample_lens = [x + 1 for x in spec_lens] - - output_token_ids = logits.argmax(dim=-1).view(-1) - output_token_ids = output_token_ids.split(sample_lens) - output_token_ids = pad_sequence(output_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - - # Convert spec token IDs to a tensor, split by sample_lens, then pad. - spec_token_ids = [ - torch.tensor(x, - dtype=output_token_ids.dtype, - device=output_token_ids.device) - for x in sampling_metadata.spec_token_ids + sample_lens = [len(x) + 1 for x in draft_token_ids] + # Convert draft token IDs to a tensor, split by sample_lens, then pad. + draft_token_ids = [ + torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids ] - spec_token_ids = pad_sequence(spec_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - - # Produce a mask that remains 1 (True) until the first - # mismatch (cumprod turns 0 after a mismatch). - accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod( - dim=1) + draft_token_ids_tensor = pad_sequence(draft_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) + draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device) + # Add 1 to include the 'bonus' token. + if sampling_metadata.all_greedy: + output_token_ids = target_probs.argmax(dim=-1).view(-1) + output_token_ids = output_token_ids.split(sample_lens) + output_token_ids = pad_sequence(output_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) + # Produce a mask that remains 1 (True) until the first + # mismatch (cumprod turns 0 after a mismatch). + accept_mask = ( + output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod( + dim=1) + else: + raise NotImplementedError( + "Currently, only greedy sampling is supported by " + "rejection sampler.") # Identify valid positions (non-padding). valid_mask = output_token_ids != INVALID_TOKEN_ID # Generate mask with bonus token. diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 47ec26d42024..b0eb533ae2e5 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -9,7 +9,6 @@ from vllm.v1.sample.ops.penalties import (apply_all_penalties, apply_min_token_penalties) from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler -from vllm.v1.sample.rejection_sampler import RejectionSampler _SAMPLING_EPS = 1e-5 @@ -19,22 +18,12 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.topk_topp_sampler = TopKTopPSampler() - self.rejection_sampler = RejectionSampler() def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - if sampling_metadata.spec_token_ids: - if sampling_metadata.max_num_logprobs: - raise NotImplementedError( - "Rejection sampling does not support logprobs.") - return self.rejection_sampler( - logits, - sampling_metadata, - ) - # NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that @@ -127,6 +116,14 @@ def sample( ) return sampled + def compute_probs(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + if sampling_metadata.all_greedy: + return logits + # Apply temperature. This is an in-place op changing logits. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + return logits.softmax(dim=-1, dtype=torch.float32) + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d9fc53490c07..e4e6b88245d0 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -490,23 +490,12 @@ def _make_sampling_metadata(self) -> SamplingMetadata: presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], output_token_ids=cast(List[List[int]], self.req_output_token_ids), - spec_token_ids=None, min_tokens=self.min_tokens, no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], allowed_token_ids_mask=allowed_token_ids_mask, ) - def get_sampling_metadata( - self, - req_id_to_spec_token_ids: Dict[str, List[int]], - ) -> SamplingMetadata: - # Set the new spec token ids in the cached sampling metadata. - self.sampling_metadata.spec_token_ids = [ - req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids - ] if req_id_to_spec_token_ids else None - return self.sampling_metadata - def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1fbce3098a34..4d0ae9a205a1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,7 +32,7 @@ KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID +from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -122,7 +122,7 @@ def __init__( self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True - + self.rejection_sampler = RejectionSampler() # TODO: find a better way to check if we are using ngram. assert self.speculative_config.ngram_prompt_lookup_min, \ "Currently, only ngram spec decode is supported in V1." @@ -951,12 +951,24 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.get_sampling_metadata( - scheduler_output.scheduled_spec_decode_tokens) - sampler_output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) + sampling_metadata = self.input_batch.sampling_metadata + if not self.use_spec_decode: + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + target_probs = self.model.sampler.compute_probs( + logits, sampling_metadata) + scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys( + ) + draft_token_ids = [ + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + for req_id in scheduled_request_ids + ] + sampler_output = self.rejection_sampler(draft_token_ids, + target_probs, + sampling_metadata) # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. @@ -1293,7 +1305,6 @@ def profile_run(self) -> None: temperature=dummy_tensors(0.5), all_greedy=False, all_random=False, - spec_token_ids=None, top_p=dummy_tensors(0.9), top_k=dummy_tensors(logits.size(1) - 1), min_p=None, From c1bc1acf6998c203c1032bdd2ded5eb686856bc8 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Wed, 26 Feb 2025 10:44:30 +0800 Subject: [PATCH 234/317] [Misc]Code Cleanup (#13859) Signed-off-by: noemotiovon Co-authored-by: noemotiovon --- vllm/executor/ray_distributed_executor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index bcad274bab49..2908fefc8e7e 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -95,7 +95,6 @@ def _init_executor(self) -> None: self.use_v1 = envs.VLLM_USE_V1 self.pp_locks: Optional[List[asyncio.Lock]] = None - self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER if not self.use_ray_compiled_dag: self.driver_exec_method = make_async( self.driver_worker.execute_method) From db656e20d27ecb0e24fde2dfe67dfad4356cfb4f Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 25 Feb 2025 18:52:03 -0800 Subject: [PATCH 235/317] [Kernel][Build/CI] Bump CUTLASS to 3.8 and add initializers for cutlass epilogues (#13797) --- CMakeLists.txt | 8 +++--- .../epilogue/scaled_mm_epilogues_c2x.hpp | 26 +++++++++-------- .../epilogue/scaled_mm_epilogues_c3x.hpp | 28 ++++++++++--------- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 82ad7b8819d5..02a60c0e3520 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -266,7 +266,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git # Please keep this in sync with CUTLASS_REVISION line above. - GIT_TAG v3.7.0 + GIT_TAG v3.8.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -321,7 +321,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUDA 12.0 or later (and only work on Hopper, 9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" @@ -401,7 +401,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # FP4 Archs and flags cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) - set(SRCS + set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" ) @@ -612,7 +612,7 @@ endif() if(VLLM_FLASH_ATTN_SRC_DIR) FetchContent_Declare( - vllm-flash-attn SOURCE_DIR + vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR} BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn ) diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index ef413e6dd75c..64b7ddae3d2d 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -122,8 +122,8 @@ struct ScaledEpilogue auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; } }; @@ -167,8 +167,8 @@ struct ScaledEpilogueBias auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; } }; @@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; } }; @@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; } }; -}; // namespace vllm::c2x \ No newline at end of file +}; // namespace vllm::c2x diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 583fa3c45511..1a0cd45f4e20 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -146,8 +146,8 @@ struct ScaledEpilogue auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; } }; @@ -193,8 +193,8 @@ struct ScaledEpilogueBias auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; } }; @@ -236,8 +236,8 @@ struct ScaledEpilogueColumnBias auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; } }; @@ -297,9 +297,10 @@ struct ScaledEpilogueBiasAzp auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; } }; @@ -374,10 +375,11 @@ struct ScaledEpilogueBiasAzpToken auto azp_adj_args = SUPER::template args_from_tensor(azp_adj); - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; } }; From 043428f165d929d449ba6c8a6caf1f994ae64846 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 26 Feb 2025 02:53:56 +0000 Subject: [PATCH 236/317] Improve pipeline partitioning (#13839) --- tests/distributed/test_pipeline_partition.py | 24 ++++++++++++++++ vllm/distributed/utils.py | 30 ++++++++++++++------ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/tests/distributed/test_pipeline_partition.py b/tests/distributed/test_pipeline_partition.py index 3ed104820b47..18c5be29c5ce 100644 --- a/tests/distributed/test_pipeline_partition.py +++ b/tests/distributed/test_pipeline_partition.py @@ -34,3 +34,27 @@ def _verify(partition_str, num_layers, pp_size, goldens): # Wrong number of layers with pytest.raises(ValueError): _verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + + +@pytest.mark.parametrize( + "num_hidden_layers,pp_size,pp_rank,indices", + [ + # pp_size 2 + (2, 2, 0, (0, 1)), + (2, 2, 1, (1, 2)), + (3, 2, 0, (0, 2)), + (3, 2, 1, (2, 3)), + # pp_size 3 + (3, 3, 0, (0, 1)), + (3, 3, 1, (1, 2)), + (3, 3, 2, (2, 3)), + (4, 3, 0, (0, 1)), + (4, 3, 1, (1, 3)), + (4, 3, 2, (3, 4)), + (5, 3, 0, (0, 2)), + (5, 3, 1, (2, 4)), + (5, 3, 2, (4, 5)), + ]) +def test_uneven_auto_partition(num_hidden_layers: int, pp_size: int, + pp_rank: int, indices: tuple[int, int]): + assert indices == get_pp_indices(num_hidden_layers, pp_rank, pp_size) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 79f9a84b476f..d6fca4f0221b 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -67,8 +67,17 @@ def split_tensor_along_last_dim( def get_pp_indices(num_hidden_layers: int, pp_rank: int, pp_size: int) -> Tuple[int, int]: """Try to evenly distribute layers across partitions. + If the number of layers is not divisible by the number of partitions, - the last partition will have the remaining layers. + the remaining layers are evenly distributed across all but the last + partition. The last partition is excluded because it often contains an + additional norm layer and we are attempting to balance compute. + + If `pp_size > 2` and the number of remaining layers is + `0 < x <= pp_size - 2` then the remaining layers are evenly distributed + across the middle partitions. The first and last partitions are excluded + because they contain the input and output embeddings respectively and we + are attempting to reduce maximum memory consumption across partitions. """ partition_list_str = envs.VLLM_PP_LAYER_PARTITION if partition_list_str is not None: @@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, if sum(partitions) != num_hidden_layers: raise ValueError( f"{sum(partitions)=} does not match {num_hidden_layers=}.") - start_layer = sum(partitions[:pp_rank]) - end_layer = start_layer + partitions[pp_rank] else: layers_per_partition = num_hidden_layers // pp_size - start_layer = pp_rank * layers_per_partition - end_layer = start_layer + layers_per_partition - - if pp_rank == pp_size - 1: - end_layer = num_hidden_layers + partitions = [layers_per_partition for _ in range(pp_size)] + + if remaining_layers := num_hidden_layers % pp_size: + for i in range(2, remaining_layers + 2): + partitions[-i] += 1 + logger.info("Hidden layers were unevenly partitioned: %s", + ",".join(str(p) for p in partitions)) + logger.info("This can be manually overridden using the " + "VLLM_PP_LAYER_PARTITION environment variable") + + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] return (start_layer, end_layer) From 734cb2ec7c9a8285ffe9d0c54a45e47148a74168 Mon Sep 17 00:00:00 2001 From: Albert Date: Wed, 26 Feb 2025 14:56:19 +0800 Subject: [PATCH 237/317] [Doc] fix the incorrect module path of tensorize_vllm_model (#13863) --- examples/other/tensorize_vllm_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/other/tensorize_vllm_model.py b/examples/other/tensorize_vllm_model.py index 68345e6cb98d..7d11ba51a094 100644 --- a/examples/other/tensorize_vllm_model.py +++ b/examples/other/tensorize_vllm_model.py @@ -27,7 +27,7 @@ To serialize a model, install vLLM from source, then run something like this from the root level of this repository: -python -m examples.offline_inference.tensorize_vllm_model \ +python -m examples.other.tensorize_vllm_model \ --model facebook/opt-125m \ serialize \ --serialized-directory s3://my-bucket \ @@ -47,7 +47,7 @@ To deserialize a model, you can run something like this from the root level of this repository: -python -m examples.offline_inference.tensorize_vllm_model \ +python -m examples.other.tensorize_vllm_model \ --model EleutherAI/gpt-j-6B \ --dtype float16 \ deserialize \ @@ -65,11 +65,11 @@ model-rank-%03d.tensors For more information on the available arguments for serializing, run -`python -m examples.offline_inference.tensorize_vllm_model serialize --help`. +`python -m examples.other.tensorize_vllm_model serialize --help`. Or for deserializing: -`python -m examples.offline_inference.tensorize_vllm_model deserialize --help`. +`python -m examples.other.tensorize_vllm_model deserialize --help`. Once a model is serialized, tensorizer can be invoked with the `LLM` class directly to load models: @@ -90,7 +90,7 @@ In order to see all of the available arguments usable to configure loading with tensorizer that are given to `TensorizerConfig`, run: -`python -m examples.offline_inference.tensorize_vllm_model deserialize --help` +`python -m examples.other.tensorize_vllm_model deserialize --help` under the `tensorizer options` section. These can also be used for deserialization in this example script, although `--tensorizer-uri` and From 021ef73d337af73c23b89f9f96d8f7d3530e58fa Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 25 Feb 2025 22:56:58 -0800 Subject: [PATCH 238/317] [ROCm] Disable chunked prefill/prefix caching when running MLA on non-cuda platforms (#13844) Signed-off-by: Sage Moore --- vllm/attention/backends/mla/common.py | 42 +++++++++++++++++++-------- vllm/config.py | 14 +++++++++ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 4dd562be3838..225fee8d2a0d 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -232,6 +232,7 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down try: @@ -1371,18 +1372,35 @@ def _forward_prefill( v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) + if has_context: + if not current_platform.is_cuda(): + raise NotImplementedError( + "Chunked Prefill for MLA is not currently supported on" + "non-cuda platforms") + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=True, + ) + else: + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + ) if has_context: suffix_output, suffix_lse = output diff --git a/vllm/config.py b/vllm/config.py index 8e1ce87438af..a5d8ee9303d0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3422,6 +3422,20 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.model_config and self.model_config.use_mla and \ + not current_platform.is_cuda(): + logger.info( + "MLA is enabled on a non-cuda platform; forcing chunked " + "prefill and prefix caching to be disabled.") + self.scheduler_config.enable_chunked_prefill = False + self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.max_num_batched_tokens = max( + self.scheduler_config.max_model_len, + _DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.cache_config is not None: + self.cache_config.enable_prefix_caching = False + current_platform.check_and_update_config(self) if not self.instance_id: From 77ca08ee2d843de8b23e40d6acd00622002307d4 Mon Sep 17 00:00:00 2001 From: Seth Kimmel Date: Tue, 25 Feb 2025 22:58:24 -0800 Subject: [PATCH 239/317] [v0][Core] Use xgrammar shared context to avoid copy overhead for offline engine (#13837) Signed-off-by: Seth Kimmel --- .../guided_decoding/xgrammar_decoding.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 329b03a573da..e6ba7f5ecc6e 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -3,7 +3,6 @@ # noqa: UP007 from __future__ import annotations -import copy import json import re from dataclasses import dataclass, field @@ -348,5 +347,26 @@ def __call__(self, input_ids: list[int], return scores def clone(self) -> XGrammarLogitsProcessor: - """Deepcopy due to per-sequence state in the matchers""" - return copy.deepcopy(self) + """Create a new instance with shared compiled grammar + but separate state""" + new_processor = XGrammarLogitsProcessor(self.config) + + # Share the compiled grammar context (immutable after compilation) + new_processor.ctx = self.ctx + + # Create fresh matchers for the new sequence + if self.ctx is not None: + new_processor.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + + # Create a new token bitmask with the same size + if hasattr(self, 'token_bitmask') and self.token_bitmask is not None: + new_processor.token_bitmask = self.token_bitmask + + # Copy simple attributes + new_processor.batch_size = self.batch_size + # Reset prefilled state for new sequence + new_processor.prefilled = False + + return new_processor From 241fa245726992cae3f8f395224d9962f4086d7a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 3 Mar 2025 18:06:03 +0000 Subject: [PATCH 240/317] [Misc] Improve LoRA spelling (#13831) --- benchmarks/kernels/benchmark_lora.py | 2 +- docs/source/features/lora.md | 2 +- tests/core/test_scheduler.py | 2 +- tests/entrypoints/openai/test_cli_args.py | 2 +- .../entrypoints/openai/test_serving_models.py | 20 +++++++++---------- tests/lora/test_layers.py | 18 ++++++++--------- tests/lora/test_long_context.py | 4 ++-- vllm/engine/llm_engine.py | 2 +- vllm/entrypoints/openai/api_server.py | 10 +++++----- vllm/entrypoints/openai/protocol.py | 4 ++-- vllm/entrypoints/openai/serving_models.py | 14 ++++++------- vllm/lora/fully_sharded_layers.py | 12 +++++------ vllm/lora/layers.py | 8 ++++---- vllm/lora/models.py | 10 +++++----- vllm/lora/peft_helper.py | 2 +- vllm/lora/punica_wrapper/punica_base.py | 2 +- vllm/lora/utils.py | 18 ++++++++--------- vllm/spec_decode/proposer_worker_base.py | 4 ++-- vllm/spec_decode/spec_decode_worker.py | 4 ++-- vllm/transformers_utils/configs/arctic.py | 2 +- vllm/worker/neuron_worker.py | 4 ++-- vllm/worker/openvino_worker.py | 4 ++-- vllm/worker/worker_base.py | 2 +- vllm/worker/xpu_worker.py | 4 ++-- 24 files changed, 78 insertions(+), 78 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index ecde8fbaa15b..1deb0026a6e5 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -89,7 +89,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str) -> torch.Tensor: """ - All prompts are mapped to a Lora ID in range [0, num_active_loras). + All prompts are mapped to a LoRA ID in range [0, num_active_loras). where 0 refers to first lora, 1 refers to second lora and so on. """ assert num_active_loras > 0 diff --git a/docs/source/features/lora.md b/docs/source/features/lora.md index fb5a7a0d519c..dff7e916fb46 100644 --- a/docs/source/features/lora.md +++ b/docs/source/features/lora.md @@ -170,7 +170,7 @@ Now, you can specify a base_model_name alongside the name and path using JSON fo To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case. -## Lora model lineage in model card +## LoRA model lineage in model card The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this: diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index dcc97ebaa7c5..66bc5257f081 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -491,7 +491,7 @@ def test_prefill_schedule_max_lora(): lora_path="abc")) scheduler.add_seq_group(seq_group) # Add two more requests to verify lora is prioritized. - # 0: Lora, 1: Lora, 2: regular, 3: regular + # 0: LoRA, 1: LoRA, 2: regular, 3: regular # In the first iteration, index 0, 2 is scheduled. # If a request is not scheduled because it hits max lora, it is # prioritized. Verify that. diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 2f065ec1070e..e0285b5e5566 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -26,7 +26,7 @@ def serve_parser(): return make_arg_parser(parser) -### Tests for Lora module parsing +### Tests for LoRA module parsing def test_valid_key_value_format(serve_parser): # Test old format: name=path args = serve_parser.parse_args([ diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index 55900163eef5..e8f3c2f8b39e 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -8,8 +8,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, - LoadLoraAdapterRequest, - UnloadLoraAdapterRequest) + LoadLoRAAdapterRequest, + UnloadLoRAAdapterRequest) from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.lora.request import LoRARequest @@ -51,7 +51,7 @@ async def test_serving_model_name(): @pytest.mark.asyncio async def test_load_lora_adapter_success(): serving_models = await _async_serving_models_init() - request = LoadLoraAdapterRequest(lora_name="adapter", + request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2") response = await serving_models.load_lora_adapter(request) assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') @@ -62,7 +62,7 @@ async def test_load_lora_adapter_success(): @pytest.mark.asyncio async def test_load_lora_adapter_missing_fields(): serving_models = await _async_serving_models_init() - request = LoadLoraAdapterRequest(lora_name="", lora_path="") + request = LoadLoRAAdapterRequest(lora_name="", lora_path="") response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" @@ -72,14 +72,14 @@ async def test_load_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_load_lora_adapter_duplicate(): serving_models = await _async_serving_models_init() - request = LoadLoraAdapterRequest(lora_name="adapter1", + request = LoadLoRAAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") response = await serving_models.load_lora_adapter(request) assert response == LORA_LOADING_SUCCESS_MESSAGE.format( lora_name='adapter1') assert len(serving_models.lora_requests) == 1 - request = LoadLoraAdapterRequest(lora_name="adapter1", + request = LoadLoRAAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) @@ -91,12 +91,12 @@ async def test_load_lora_adapter_duplicate(): @pytest.mark.asyncio async def test_unload_lora_adapter_success(): serving_models = await _async_serving_models_init() - request = LoadLoraAdapterRequest(lora_name="adapter1", + request = LoadLoRAAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") response = await serving_models.load_lora_adapter(request) assert len(serving_models.lora_requests) == 1 - request = UnloadLoraAdapterRequest(lora_name="adapter1") + request = UnloadLoRAAdapterRequest(lora_name="adapter1") response = await serving_models.unload_lora_adapter(request) assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( lora_name='adapter1') @@ -106,7 +106,7 @@ async def test_unload_lora_adapter_success(): @pytest.mark.asyncio async def test_unload_lora_adapter_missing_fields(): serving_models = await _async_serving_models_init() - request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) + request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None) response = await serving_models.unload_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" @@ -116,7 +116,7 @@ async def test_unload_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_unload_lora_adapter_not_found(): serving_models = await _async_serving_models_init() - request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") + request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter") response = await serving_models.unload_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "NotFoundError" diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 0838ca02c9b7..61699e7052c9 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -14,16 +14,16 @@ from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LinearScalingRotaryEmbeddingWithLora, + LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLora, - QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) @@ -866,9 +866,9 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedQKVParallelLinearWithLora(linear) + lora_linear = (MergedQKVParallelLinearWithLoRA(linear) if not fully_shard else - MergedQKVParallelLinearWithShardedLora(linear)) + MergedQKVParallelLinearWithShardedLoRA(linear)) else: linear = QKVParallelLinear(4096, 64, @@ -876,9 +876,9 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = QKVParallelLinearWithLora( + lora_linear = QKVParallelLinearWithLoRA( linear - ) if not fully_shard else QKVParallelLinearWithShardedLora(linear) + ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear) @dataclass class FakeConfig: @@ -1024,7 +1024,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, base, is_neox_style, ) - lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) + lora_rope = LinearScalingRotaryEmbeddingWithLoRA(rope) lora_rope.set_mapping(punica_wrapper) lora_rope.create_lora_weights(max_loras, lora_config) linear_rope = get_rope(head_size, rotary_dim, max_position, base, diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 62005de73ddb..0a94298c9f77 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -8,7 +8,7 @@ import vllm from vllm import SamplingParams -from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora +from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLoRA from vllm.lora.request import LoRARequest from vllm.model_executor.layers.rotary_embedding import ( LinearScalingRotaryEmbedding) @@ -151,7 +151,7 @@ def test_rotary_emb_replaced(dist_init): if "rotary_emb" in module_name: if "base_layer" not in module_name: rotary_emb_count += 1 - assert isinstance(module, LinearScalingRotaryEmbeddingWithLora) + assert isinstance(module, LinearScalingRotaryEmbeddingWithLoRA) else: assert isinstance(module, LinearScalingRotaryEmbedding) # Llama 2 has 32 layers. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3ce9a0461368..1690017f924c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1629,7 +1629,7 @@ def _get_stats(self, max_tokens_requests: List[int] = [] finished_reason_requests: List[str] = [] - # Lora requests + # LoRA requests running_lora_adapters = dict( collectionsCounter([ running_request.lora_request.lora_name diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9995951b3f3d..1b65484c446a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -53,7 +53,7 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, - LoadLoraAdapterRequest, + LoadLoRAAdapterRequest, PoolingChatRequest, PoolingCompletionRequest, PoolingRequest, PoolingResponse, @@ -63,7 +63,7 @@ TokenizeResponse, TranscriptionRequest, TranscriptionResponse, - UnloadLoraAdapterRequest) + UnloadLoRAAdapterRequest) from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat @@ -690,12 +690,12 @@ async def stop_profile(raw_request: Request): if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: logger.warning( - "Lora dynamic loading & unloading is enabled in the API server. " + "LoRA dynamic loading & unloading is enabled in the API server. " "This should ONLY be used for local development!") @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) - async def load_lora_adapter(request: LoadLoraAdapterRequest, + async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): handler = models(raw_request) response = await handler.load_lora_adapter(request) @@ -707,7 +707,7 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest, @router.post("/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]) - async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + async def unload_lora_adapter(request: UnloadLoRAAdapterRequest, raw_request: Request): handler = models(raw_request) response = await handler.unload_lora_adapter(request) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index cd2902f934bf..31214211cfc4 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1431,12 +1431,12 @@ class DetokenizeResponse(OpenAIBaseModel): prompt: str -class LoadLoraAdapterRequest(BaseModel): +class LoadLoRAAdapterRequest(BaseModel): lora_name: str lora_path: str -class UnloadLoraAdapterRequest(BaseModel): +class UnloadLoRAAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index 6ade4ece6d03..0f4a174a8c15 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -9,10 +9,10 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, - LoadLoraAdapterRequest, + LoadLoRAAdapterRequest, ModelCard, ModelList, ModelPermission, - UnloadLoraAdapterRequest) + UnloadLoRAAdapterRequest) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest @@ -88,7 +88,7 @@ async def init_static_loras(self): if self.static_lora_modules is None: return for lora in self.static_lora_modules: - load_request = LoadLoraAdapterRequest(lora_path=lora.path, + load_request = LoadLoRAAdapterRequest(lora_path=lora.path, lora_name=lora.name) load_result = await self.load_lora_adapter( request=load_request, base_model_name=lora.base_model_name) @@ -140,7 +140,7 @@ async def show_available_models(self) -> ModelList: async def load_lora_adapter( self, - request: LoadLoraAdapterRequest, + request: LoadLoRAAdapterRequest, base_model_name: Optional[str] = None ) -> Union[ErrorResponse, str]: error_check_ret = await self._check_load_lora_adapter_request(request) @@ -177,7 +177,7 @@ async def load_lora_adapter( async def unload_lora_adapter( self, - request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]: error_check_ret = await self._check_unload_lora_adapter_request(request ) if error_check_ret is not None: @@ -192,7 +192,7 @@ async def unload_lora_adapter( return f"Success: LoRA adapter '{lora_name}' removed successfully." async def _check_load_lora_adapter_request( - self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]: # Check if both 'lora_name' and 'lora_path' are provided if not request.lora_name or not request.lora_path: return create_error_response( @@ -214,7 +214,7 @@ async def _check_load_lora_adapter_request( async def _check_unload_lora_adapter_request( self, - request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]: # Check if either 'lora_name' or 'lora_int_id' is provided if not request.lora_name and not request.lora_int_id: return create_error_response( diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 3d6620817b4b..41e1ec94145d 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -13,8 +13,8 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.lora.layers import (ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLora, - QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) if TYPE_CHECKING: @@ -167,9 +167,9 @@ def can_replace_layer( ) -class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): +class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): """ - Differs from QKVParallelLinearWithLora by slicing the + Differs from QKVParallelLinearWithLoRA by slicing the LoRA A's also. Based on S-LoRA, slicing happens along the rank dim. @@ -202,9 +202,9 @@ def can_replace_layer(cls, source_layer: nn.Module, ) -class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): +class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): """ - Differs from MergedQKVParallelLinearWithLora by slicing the + Differs from MergedQKVParallelLinearWithLoRA by slicing the LoRA A's also. Based on S-LoRA, slicing happens along the rank dim. diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5e700d2e10d2..7b718458c70d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -368,7 +368,7 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], lora_bias: Optional[torch.Tensor] = None, ): - # Except for QKVParallelLinearWithLora and + # Except for QKVParallelLinearWithLoRA and # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # store weights in a tuple of size 1. These two layers will # override this function. @@ -693,7 +693,7 @@ def can_replace_layer( and len(packed_modules_list) == 2) -class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ ColumnParallelLinear layer that is specifically designed for qkv_proj. Certain models, such as chatglm3 and baichuan-7b, @@ -761,7 +761,7 @@ def can_replace_layer(cls, source_layer: nn.Module, packed_modules_list) == 1 -class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA): +class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj). @@ -1135,7 +1135,7 @@ def can_replace_layer( return False -class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): +class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): """Implements RoPE-scaled embeddings with linear scaling for multiple LoRA adapters with a specialized kernel. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 774c3876e774..e1294884ac2a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -20,7 +20,7 @@ from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import (BaseLayerWithLoRA, - LinearScalingRotaryEmbeddingWithLora, + LinearScalingRotaryEmbeddingWithLoRA, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper @@ -201,7 +201,7 @@ def from_local_checkpoint( expected_lora_modules: Name of modules that are expected to be replaced by lora. peft_helper: Loaded lora configuration information. - lora_model_id: Lora model id. If not given, automatically set by + lora_model_id: LoRA model id. If not given, automatically set by a global counter. device: Device where the lora model is loaded. dtype: dtype of the lora model weights. @@ -480,9 +480,9 @@ def _create_lora_modules(self): from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) - # LinearScalingRotaryEmbeddingWithLora is used to handle + # LinearScalingRotaryEmbeddingWithLoRA is used to handle # long context lora. Register relevant metadata. - if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): + if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA): self.long_lora_context = LongContextLoRAContext( new_module.scaling_factors, new_module.rotary_dim) self.scaling_factor_to_offset = \ @@ -527,7 +527,7 @@ def create_dummy_lora( bias_enabled = self.lora_config.bias_enabled if (not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) - or isinstance(module, LinearScalingRotaryEmbeddingWithLora) + or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA) or self._filter_unsupported_mm_module(module_name)): continue parts = module_name.split(".") diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 9496ab5a75c0..f6944368b36e 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -42,7 +42,7 @@ class PEFTHelper: def _validate_features(self) -> List[str]: """ - Check if there are any unsupported Lora features. + Check if there are any unsupported LoRA features. """ error_msg = [] if self.modules_to_save: diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index d160b2739bc7..0332867055b7 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -314,7 +314,7 @@ def embeddings_indices(self) -> torch.Tensor: def long_lora_indices(self) -> torch.Tensor: """ This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora. + lora, specifically for LinearScalingRotaryEmbeddingWithLoRA. """ long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 361dac5b3313..63b465fdf743 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -15,17 +15,17 @@ from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) # being imported for _all_lora_classes below # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LinearScalingRotaryEmbeddingWithLora, + LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLora, - QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) @@ -41,17 +41,17 @@ VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, - QKVParallelLinearWithLora, - MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, ReplicatedLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, - QKVParallelLinearWithShardedLora, + QKVParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, + MergedQKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA, - LinearScalingRotaryEmbeddingWithLora, + LinearScalingRotaryEmbeddingWithLoRA, } diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index 2bebf80fadae..2829d631b49e 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -6,10 +6,10 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposer -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoRANotSupportedWorkerBase -class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer): +class ProposerWorkerBase(LoRANotSupportedWorkerBase, SpeculativeProposer): """Interface for proposer workers""" @abstractmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8af71842224b..871a3aee6306 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -47,7 +47,7 @@ get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) from vllm.utils import resolve_obj_by_qualname -from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase +from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) @@ -118,7 +118,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": # Reminder: Please update docs/source/features/compatibility_matrix.md # If the feature combo become valid -class SpecDecodeWorker(LoraNotSupportedWorkerBase): +class SpecDecodeWorker(LoRANotSupportedWorkerBase): """Worker which implements speculative decoding. Speculative decoding reduces decoding per-token latency by using a proposal diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py index 6625ccf0f2a8..5ab70c0e4136 100644 --- a/vllm/transformers_utils/configs/arctic.py +++ b/vllm/transformers_utils/configs/arctic.py @@ -21,7 +21,7 @@ @dataclass -class ArcticLoraConfig: +class ArcticLoRAConfig: lora_r: int = 64 lora_alpha: float = 16 shard_base_weights: bool = False diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 95e7acd025f0..df651e05a7bb 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -13,11 +13,11 @@ from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerBase, + LoRANotSupportedWorkerBase, WorkerBase, WorkerInput) -class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): +class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 1ad66e6f3be7..fad91270ea2a 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -24,7 +24,7 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.utils import bind_kv_cache from vllm.worker.openvino_model_runner import OpenVINOModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase +from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) @@ -203,7 +203,7 @@ def get_cache_block_size( return dtype_size * total -class OpenVINOWorker(LoraNotSupportedWorkerBase): +class OpenVINOWorker(LoRANotSupportedWorkerBase): """A worker class that executes the model on OpenVINO backend. Each worker is associated with a single OpenVINO device. The worker is diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 445c0d3285bf..7cc1562a5bce 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -189,7 +189,7 @@ def __getattr__(self, attr): return getattr(self.worker, attr) -class LoraNotSupportedWorkerBase(WorkerBase): +class LoRANotSupportedWorkerBase(WorkerBase): """Partial implementation of WorkerBase that raises exceptions when LoRA methods are invoked. """ diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 047c0bbbc355..3aea0d7419d0 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -18,13 +18,13 @@ from vllm.platforms import current_platform from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase +from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase from vllm.worker.xpu_model_runner import XPUModelRunner logger = init_logger(__name__) -class XPUWorker(LoraNotSupportedWorkerBase, Worker): +class XPUWorker(LoRANotSupportedWorkerBase, Worker): """A worker class that executes (a partition of) the model on a GPU. Each worker is associated with a single XPU device. The worker is From 101ff850efb7c4ed341d4b2b47b3138d7a5d58ab Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Tue, 25 Feb 2025 23:56:34 -0800 Subject: [PATCH 241/317] [Misc] Fix input processing for Ultravox (#13871) --- tests/models/multimodal/processing/test_common.py | 6 +++--- tests/models/registry.py | 2 +- vllm/model_executor/models/ultravox.py | 13 ++----------- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0115863f5626..a84999cfbf4f 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -83,8 +83,8 @@ def _test_processing_correctness( } tokenizer_encode_kwargs = {} - if model_config.hf_config.model_type in ("mllama", "whisper"): - # For some encoder-decoder models, tokenizer will always add bos_token + if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): + # For some multimodal models, tokenizer will always add bos_token # at the beginning of prompt by default, causing hf_processor outputs # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. @@ -172,7 +172,7 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", - "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "fixie-ai/ultravox-v0_4", "openai/whisper-large-v3", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 566a4418feb1..b47eaef30bf2 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -284,7 +284,7 @@ def check_available_online( "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 min_transformers_version="4.49"), # noqa: E501 - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_4", trust_remote_code=True), # [Encoder-decoder] "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 1dbba3c50b19..b8d4aef252e5 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -146,7 +146,8 @@ def _call_hf_processor( ) -> BatchFeature: # Text-only input not supported in composite processor if not mm_data or not mm_data.get("audios", []): - prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self.info.get_tokenizer().encode( + prompt, add_special_tokens=False) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") @@ -185,16 +186,6 @@ def _call_hf_processor( ) return BatchFeature(combined_outputs) - def _apply_hf_processor_tokens_only( - self, - prompt_tokens: list[int], - ) -> list[int]: - # HF processor omits bos_token_id by setting add_special_tokens=False - tokenizer = self.info.get_tokenizer() - assert prompt_tokens[0] == tokenizer.bos_token_id - - return prompt_tokens[1:] - def _get_mm_fields_config( self, hf_inputs: BatchFeature, From 9001578c2d422b41600fd620a0fb2c4ce37cb644 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 26 Feb 2025 18:31:43 +0800 Subject: [PATCH 242/317] [Bugfix] Add test example for Ultravox v0.5 (#13890) --- tests/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index b47eaef30bf2..8614baf18f3b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -285,6 +285,7 @@ def check_available_online( "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 min_transformers_version="4.49"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_4", + extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501 trust_remote_code=True), # [Encoder-decoder] "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 From b8fe8c1c4330e65f938996b1f333b0ad71ce9f4a Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 26 Feb 2025 10:41:02 +0000 Subject: [PATCH 243/317] Add comments on accessing `kv_cache` and `attn_metadata` (#13887) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/attention/layer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 24f2a6372b45..c45c83a0707f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,6 +47,10 @@ def __init__( attn_type: str = AttentionType.DECODER, **extra_impl_args, ) -> None: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + """ super().__init__() if per_layer_sliding_window is not None: # per-layer sliding window @@ -155,6 +159,15 @@ def forward( key: torch.Tensor, value: torch.Tensor, ) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ if self.calculate_kv_scales: attn_metadata = get_forward_context().attn_metadata if attn_metadata.enable_kv_scales_calculation: From 71fa6b57ae6ecc80497023a374447fb08ed2c4c7 Mon Sep 17 00:00:00 2001 From: Florian Greinacher Date: Wed, 26 Feb 2025 12:06:21 +0100 Subject: [PATCH 244/317] [Bugfix] Handle None parameters in Mistral function calls. (#13786) --- tests/tokenization/test_mistral_tokenizer.py | 35 ++++++++++++++++++- vllm/transformers_utils/tokenizers/mistral.py | 3 +- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index 03e1f1fadd73..f1c880286951 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -41,7 +41,40 @@ ) ], ), - )], + ), + ( + { + "messages": + [{ + "role": "user", + "content": "What is the current local date and time?", + }], + "tools": [{ + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": None, + }, + }], + }, + ChatCompletionRequest( + messages=[ + UserMessage( + content="What is the current local date and time?") + ], + tools=[ + Tool( + type="function", + function=Function( + name="get_current_time", + description="Fetch the current local date and time.", + parameters={}, + ), + ) + ], + ), + )], ) def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request): diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 4e76f2dc871b..801597bd3650 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -164,7 +164,8 @@ def make_mistral_chat_completion_request( tool["function"] for tool in tools if tool["type"] == "function" ]: - function.setdefault("parameters", {}) + if function.get("parameters") is None: + function["parameters"] = {} from mistral_common.protocol.instruct.request import ChatCompletionRequest return ChatCompletionRequest(messages=messages, From 87b6aeb00328a2c03628dd33e693f5ffd395db73 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Wed, 26 Feb 2025 06:06:47 -0500 Subject: [PATCH 245/317] [Misc]: Add support for goodput on guided benchmarking + TPOT calculation refactor (#13736) Signed-off-by: Brayden Zhong --- benchmarks/benchmark_serving_guided.py | 87 ++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_serving_guided.py b/benchmarks/benchmark_serving_guided.py index 04942b06ffd5..05eadff79787 100644 --- a/benchmarks/benchmark_serving_guided.py +++ b/benchmarks/benchmark_serving_guided.py @@ -9,7 +9,7 @@ ./launch_tgi_server.sh On the client side, run: - python benchmarks/benchmark_serving.py \ + python benchmarks/benchmark_serving_guided.py \ --backend \ --model \ --dataset json \ @@ -31,7 +31,7 @@ import time import warnings from dataclasses import dataclass -from typing import AsyncGenerator, List, Optional, Tuple +from typing import AsyncGenerator, Dict, List, Optional, Tuple import datasets import numpy as np @@ -264,6 +264,7 @@ def calculate_metrics( tokenizer: PreTrainedTokenizerBase, selected_percentile_metrics: List[str], selected_percentiles: List[float], + goodput_config_dict: Optional[Dict[str, float]] = None, ) -> Tuple[BenchmarkMetrics, List[int]]: actual_output_lens: List[int] = [] total_input = 0 @@ -287,10 +288,10 @@ def calculate_metrics( total_input += input_requests[i].prompt_len tpot = 0 if output_len > 1: - tpot = (outputs[i].latency - outputs[i].ttft) / (output_len - - 1) + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) tpots.append(tpot) - outputs[i].tpot = sum(tpots) / len(tpots) if len(tpots) else 0 + outputs[i].tpot = tpot # Note: if output_len <= 1, we regard tpot as 0 for goodput all_tpots.append(tpot) itls += outputs[i].itl @@ -300,6 +301,28 @@ def calculate_metrics( else: actual_output_lens.append(0) + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + if completed == 0: warnings.warn( "All requests failed. This is likely due to a misconfiguration " @@ -356,6 +379,7 @@ async def benchmark( max_concurrency: Optional[int], guided_decoding_ratio: float, guided_decoding_backend: str, + goodput_config_dict: Optional[Dict[str, float]] = None, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -483,6 +507,7 @@ async def limited_request_func(request_func_input, pbar): tokenizer=tokenizer, selected_percentile_metrics=selected_percentile_metrics, selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) @@ -494,6 +519,9 @@ async def limited_request_func(request_func_input, pbar): metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", @@ -617,6 +645,40 @@ def _eval_correctness(expected, actual): 100) if len(not_none_scores) > 0 else None +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def check_goodput_args(args): + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + def main(args: argparse.Namespace): print(args) random.seed(args.seed) @@ -661,6 +723,8 @@ def main(args: argparse.Namespace): input_requests = sample_requests(tokenizer, args) + goodput_config_dict = check_goodput_args(args) + benchmark_result, ret = asyncio.run( benchmark( backend=backend, @@ -681,6 +745,7 @@ def main(args: argparse.Namespace): max_concurrency=args.max_concurrency, guided_decoding_ratio=args.guided_decoding_ratio, guided_decoding_backend=args.guided_decoding_backend, + goodput_config_dict=goodput_config_dict, )) # Save config and results to json @@ -865,6 +930,18 @@ def main(args: argparse.Namespace): "Default value is \"99\". " "Use \"--percentile-metrics\" to select metrics.", ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + parser.add_argument("--no-guided-decoding", action='store_true', default=False, From ddd560fba60751f08ec2358d99567cc6b673daa8 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 26 Feb 2025 04:07:29 -0700 Subject: [PATCH 246/317] [Bugfix] Do not crash V0 engine on input errors (#13101) Signed-off-by: Joe Runde --- tests/mq_llm_engine/test_error_handling.py | 78 ++++++++++++++++++++++ vllm/engine/llm_engine.py | 62 ++++++++++++++++- vllm/engine/multiprocessing/engine.py | 9 +++ vllm/worker/model_runner.py | 11 ++- vllm/worker/model_runner_base.py | 18 +++++ 5 files changed, 172 insertions(+), 6 deletions(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 35d001781110..aad7fc5303c1 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -18,6 +18,7 @@ from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.lora.request import LoRARequest +from vllm.sequence import SequenceGroupMetadata from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser @@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket): await client.check_health() client.close() + + +def run_with_evil_input_processing(engine_args: AsyncEngineArgs, + ipc_path: str): + """Simulate an exception while preparing inputs for the model. + In the wild, this could be something like a multimodal input processor + failing on invalid image data.""" + + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + runner = engine.engine.model_executor.driver_worker.worker.model_runner + + # Raise error in the model runner when adding a sequence group. + # See class ModelInputForGPUBuilder + def raiser(_, seq_group_metadata: SequenceGroupMetadata): + if seq_group_metadata.request_id.startswith("evil"): + raise RAISED_ERROR(RAISED_VALUE) + + runner.builder.per_seq_group_compute_fns.append(raiser) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_inputs(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_input_processing) as engine: + + client = await engine.make_client() + assert client.is_running + + # Engine should be healthy + await client.check_health() + + async def run_failing_request(): + async for _ in client.generate( + prompt="Hello my name is", + sampling_params=SamplingParams(max_tokens=10), + request_id="evil" + str(uuid.uuid4())): + pass + + async def run_passing_request(): + async for _ in client.generate( + prompt="Hello my name is", + sampling_params=SamplingParams(max_tokens=10), + request_id=str(uuid.uuid4())): + pass + + passing_tasks = [ + asyncio.create_task(run_passing_request()) for _ in range(10) + ] + failing_tasks = [ + asyncio.create_task(run_failing_request()) for _ in range(10) + ] + await asyncio.gather(*failing_tasks, return_exceptions=True) + await asyncio.gather(*passing_tasks) + + # All the bad inputs should have raised + for task in failing_tasks: + with pytest.raises(RAISED_ERROR): + task.result() + + # But all good inputs should have still succeeded + for task in passing_tasks: + task.result() + + # And the engine should remain healthy + assert not client.errored + await client.check_health() + + client.close() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1690017f924c..3dee4dab4c47 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -60,6 +60,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs, resolve_obj_by_qualname, weak_bind) from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.model_runner_base import InputProcessingError logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -410,6 +411,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + # Flag to set when an input fails to process and the engine should run + # the next step without re-scheduling. + self._skip_scheduling_next_step = False + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1334,7 +1339,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. - if not self._has_remaining_steps(seq_group_metadata_list): + # The scheduler is also skipped if a single request caused the last + # engine step to fail, and the previous schedule needs to be rerun. + if not self._has_remaining_steps( + seq_group_metadata_list + ) and not self._skip_scheduling_next_step: # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc @@ -1388,8 +1397,23 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) + try: + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) + self._skip_scheduling_next_step = False + except InputProcessingError as e: + # The input for this request cannot be processed, so we must + # abort it. If there are remaining requests in the batch that + # have been scheduled, they will be retried on the next step. + invalid_request_id = e.request_id + self._abort_and_cache_schedule( + request_id=invalid_request_id, + virtual_engine=virtual_engine, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + allow_async_output_proc=allow_async_output_proc) + # Raise so the caller is notified that this request failed + raise # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. @@ -1464,6 +1488,38 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: return ctx.request_outputs + def _abort_and_cache_schedule( + self, request_id: str, virtual_engine: int, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + """Aborts a single request, and caches the scheduler outputs minus that + request. This allows the next step to continue processing the remaining + requests without having to re-run the scheduler.""" + + # Abort the request and remove its sequence group from the current + # schedule + self.abort_request(request_id) + for i, metadata in enumerate(seq_group_metadata_list): + if metadata.request_id == request_id: + del seq_group_metadata_list[i] + break + for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): + if group.seq_group.request_id == request_id: + del scheduler_outputs.scheduled_seq_groups[i] + break + + # If there are still other sequence groups left in the schedule, cache + # them and flag the engine to reuse the schedule. + if len(seq_group_metadata_list) > 0: + self._skip_scheduling_next_step = True + # Reuse multi-step caching logic + self._cache_scheduler_outputs_for_multi_step( + virtual_engine=virtual_engine, + scheduler_outputs=scheduler_outputs, + seq_group_metadata_list=seq_group_metadata_list, + allow_async_output_proc=allow_async_output_proc) + def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] ) -> bool: diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index ce24aa21514d..efea6ee2c69a 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -27,6 +27,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext +from vllm.worker.model_runner_base import InputProcessingError logger = init_logger(__name__) @@ -210,6 +211,14 @@ def engine_step(self) -> List[RequestOutput]: return self.engine.step() except SystemExit: raise + except InputProcessingError as e: + # Special case where we handle an error preparing the inputs for + # a single request in the batch + rpc_err = RPCError(request_id=e.request_id, + is_engine_errored=False, + exception=e.__cause__) + self._send_outputs(rpc_err) + return [] except BaseException as e: self._set_errored(e) rpc_err = RPCError(request_id=None, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86dcde234f86..a37a3168bbbc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -53,8 +53,8 @@ is_pin_memory_available, supports_dynamo, weak_ref_tensor) from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, - _add_attn_metadata_broadcastable_dict, + InputProcessingError, ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) @@ -1216,7 +1216,12 @@ def _prepare_model_input_tensors( """ self.builder.prepare(finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - self.builder.add_seq_group(seq_group_metadata) + try: + self.builder.add_seq_group(seq_group_metadata) + except Exception as e: + # Raise an exception that tracks the ID of the bad request + raise InputProcessingError(seq_group_metadata.request_id, + str(e)) from e self.builder.reset_cached_inter_data() diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index bae37cb7155f..935325cb2e1c 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -261,3 +261,21 @@ def __init__( def __getattr__(self, attr): return getattr(self.model_runner, attr) + + +class InputProcessingError(Exception): + """This exception is raised when an error occurs preparing the inputs for + a single sequence group. + This allows the engine to gracefully handle errors with a single sequence + group without having to fail the entire batch. + """ + + def __init__(self, request_id, message): + """request_id is the id of the offending sequence group""" + self.request_id = request_id + self.message = message + super().__init__(self.message) + + def __str__(self): + return "Failed to prepare inputs for sequence group with request id: " \ + f"{self.request_id}, Error: {self.message}" From afdf70210a34ee728be67849d2f22f068b82bb62 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 26 Feb 2025 20:56:50 +0800 Subject: [PATCH 247/317] [Bugfix] Update expected token counts for Ultravox tests (#13895) --- tests/entrypoints/openai/test_audio.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index fe7299a48e6f..7e08fdaf1ad9 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -83,7 +83,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=201, total_tokens=211) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message @@ -140,7 +140,7 @@ async def test_single_chat_session_audio_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=201, total_tokens=211) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message @@ -196,7 +196,7 @@ async def test_single_chat_session_input_audio( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=201, total_tokens=211) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message From b7a622d94fc7bce07b9b09a4ce5b20a2089ce57c Mon Sep 17 00:00:00 2001 From: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com> Date: Wed, 26 Feb 2025 05:18:54 -0800 Subject: [PATCH 248/317] [TPU] use torch2.6 with whl package (#13860) Signed-off-by: Chenyaaang --- requirements-tpu.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 1abde714af7c..8bfbb2dda194 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -17,7 +17,9 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.6.0.dev20241216+cpu +torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" From 20f3457378a0696d967a72508aef3694ae985af8 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Thu, 27 Feb 2025 00:31:53 +0800 Subject: [PATCH 249/317] [Misc] fixed qwen_vl_utils parameter error (#13906) --- examples/offline_inference/vision_language_multi_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 5dc6a936d1c1..872c9481a229 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -439,7 +439,7 @@ def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData: image_data = [fetch_image(url) for url in image_urls] else: image_data, _ = process_vision_info(messages, - return_video_sample_fps=False) + return_video_kwargs=False) return ModelRequestData( llm=llm, From 2462b654005f29bb305ba1bcff44c02ccb837d4f Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Wed, 26 Feb 2025 15:52:34 -0300 Subject: [PATCH 250/317] [Bugfix] Backend option to disable xgrammar any_whitespace (#12744) Signed-off-by: Wallas Santos Signed-off-by: Joe Runde Co-authored-by: Joe Runde --- tests/entrypoints/llm/test_guided_generate.py | 54 +++++++++++++++++++ vllm/engine/arg_utils.py | 1 + .../guided_decoding/xgrammar_decoding.py | 36 +++++++++++-- 3 files changed, 88 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 314dc59328cb..fce581c78288 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -6,6 +6,7 @@ import jsonschema import pytest +from pydantic import BaseModel from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM @@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str): # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) + + +@pytest.mark.skip_global_cleanup +def test_json_with_any_whitespace_disabled(llm): + + class ResponseSchema(BaseModel): + clarifying_question: str + cost_per_serving: str + calories: str + type_dish_ids: str + type_meal_ids: str + product_ids: list[str] + exclude_product_ids: list[str] + allergen_ids: list[str] + total_cooking_time: str + kitchen_ids: str + holiday_ids: str + + # Note: Without this setting, the response is sometimes full of `\n` + # for some models. This option prevents that. + guided_decoding_backend = 'xgrammar:disable-any-whitespace' + + schema = ResponseSchema.model_json_schema() + guided_params = GuidedDecodingParams(json=schema, + backend=\ + guided_decoding_backend) + sampling_params = SamplingParams(max_tokens=2000, + frequency_penalty=0, + presence_penalty=-1.1, + repetition_penalty=1.3, + guided_decoding=guided_params) + + prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You" + "are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a " + "quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n") + outputs = llm.generate(prompts=prompt, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + generated_text = output.outputs[0].text + assert generated_text is not None + assert "\n" not in generated_text + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 663ea1ef8afd..26d4a84b841c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -385,6 +385,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Backend-specific options can be supplied in a comma-separated ' 'list following a colon after the backend name. Valid backends and ' 'all available options are: [xgrammar:no-fallback, ' + 'xgrammar:disable-any-whitespace, ' 'outlines:no-fallback, lm-format-enforcer:no-fallback]') parser.add_argument( '--logits-processor-pattern', diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index e6ba7f5ecc6e..eb9d83acb286 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -19,6 +19,7 @@ xgr_installed = False pass +from vllm.logger import init_logger from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, grammar_is_likely_lark) from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer @@ -29,6 +30,8 @@ from vllm.config import ModelConfig from vllm.sampling_params import GuidedDecodingParams +logger = init_logger(__name__) + # TODO: passing batch size to max threads here def get_local_xgrammar_guided_decoding_logits_processor( @@ -161,6 +164,7 @@ class GrammarConfig: json_str: str | None = None grammar_str: str | None = None json_object: bool | None = None + any_whitespace: bool = True max_threads: int = 8 tokenizer_data: TokenizerData | None = None @@ -180,11 +184,33 @@ def from_guided_params(cls, else: json_str = guided_params.json + any_whitespace = 'disable-any-whitespace' not in \ + guided_params.backend_options() + + # Check and log if model with xgrammar and whitespace have history + # of runaway generation of whitespaces. + # References: + # https://github.com/vllm-project/vllm/pull/12744 + # https://github.com/mlc-ai/xgrammar/issues/212 + model_with_warn = None + + if 'Mistral' in model_config.model: + model_with_warn = 'Mistral' + elif 'Qwen' in model_config.model: + model_with_warn = 'Qwen' + + if model_with_warn is not None and any_whitespace: + msg = (f"{model_with_warn} " + f"model detected, consider set " + f"`guided_backend=xgrammar:disable-any-whitespace` " + f"to prevent runaway generation of whitespaces.") + logger.info_once(msg) # Validate the schema and raise ValueError here if it is invalid. # This is to avoid exceptions in model execution, which will crash # the engine worker process. try: - xgr.Grammar.from_json_schema(json_str) + xgr.Grammar.from_json_schema(json_str, + any_whitespace=any_whitespace) except RuntimeError as err: raise ValueError(str(err)) from err @@ -192,7 +218,8 @@ def from_guided_params(cls, vocab_size=model_config.hf_text_config.vocab_size, tokenizer_hash=tokenizer_hash, max_threads=max_threads, - tokenizer_data=tokenizer_data) + tokenizer_data=tokenizer_data, + any_whitespace=any_whitespace) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -290,7 +317,10 @@ def _ensure_ctx(self): if self.ctx is None: compiler = GrammarCompilerCache.get_compiler(self.config) if self.config.json_str is not None: - self.ctx = compiler.compile_json_schema(self.config.json_str) + any_whitespace = self.config.any_whitespace + self.ctx = compiler\ + .compile_json_schema(self.config.json_str, + any_whitespace=any_whitespace) elif self.config.grammar_str is not None: self.ctx = compiler.compile_grammar(self.config.grammar_str) elif self.config.json_object: From 5eb0d6398eb290426f672ef6eb3eb6968c57ccf0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Feb 2025 13:48:55 -0800 Subject: [PATCH 251/317] [BugFix] Make FP8 Linear compatible with torch.compile (#13918) Signed-off-by: Woosuk Kwon --- .../model_executor/layers/quantization/fp8.py | 5 +---- .../layers/quantization/utils/fp8_utils.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 76a7d4df8a36..a705f63be4ac 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -369,12 +369,9 @@ def apply(self, size_k=layer.input_size_per_partition, bias=bias) - # Note: lazy import to avoid triton import error. - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear) if self.block_quant: assert self.quant_config.weight_block_size is not None - return apply_w8a8_block_fp8_linear( + return torch.ops.vllm.apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 61706f485f46..7d91d2cf1c6e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear) from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear( return output.to(dtype=input.dtype).view(*output_shape) +def apply_w8a8_block_fp8_linear_fake( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + output_shape = [*input.shape[:-1], weight.shape[0]] + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +direct_register_custom_op( + op_name="apply_w8a8_block_fp8_linear", + op_func=apply_w8a8_block_fp8_linear, + mutates_args=[], + fake_impl=apply_w8a8_block_fp8_linear_fake, +) + + # Unify the interface between `apply_w8a8_block_fp8_linear` and # `apply_fp8_linear` # NOTE(lucas): this is quite messy, we should think through this more formally From b6ce76233734bb2eb489ad93a9bd69039c4a32d6 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 26 Feb 2025 21:35:08 -0500 Subject: [PATCH 252/317] [Kernel] FlashMLA integration (#13747) Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson --- CMakeLists.txt | 77 +----- cmake/external_projects/flashmla.cmake | 66 +++++ cmake/external_projects/vllm_flash_attn.cmake | 67 +++++ setup.py | 6 + tests/kernels/test_flashmla.py | 132 ++++++++++ vllm/_custom_ops.py | 64 +++++ vllm/attention/backends/flashmla.py | 239 ++++++++++++++++++ vllm/attention/backends/mla/common.py | 28 +- vllm/attention/ops/flashmla.py | 115 +++++++++ vllm/platforms/cuda.py | 24 ++ vllm/platforms/interface.py | 1 + 11 files changed, 733 insertions(+), 86 deletions(-) create mode 100644 cmake/external_projects/flashmla.cmake create mode 100644 cmake/external_projects/vllm_flash_attn.cmake create mode 100644 tests/kernels/test_flashmla.py create mode 100644 vllm/attention/backends/flashmla.py create mode 100644 vllm/attention/ops/flashmla.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 02a60c0e3520..0dd350c93ed5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -575,77 +575,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP") WITH_SOABI) endif() -# vllm-flash-attn currently only supported on CUDA -if (NOT VLLM_GPU_LANG STREQUAL "CUDA") - return() +# For CUDA we also build and ship some external projects. +if (VLLM_GPU_LANG STREQUAL "CUDA") + include(cmake/external_projects/flashmla.cmake) + include(cmake/external_projects/vllm_flash_attn.cmake) endif () - -# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target -# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the -# arches in the CUDA case (and instead set the gencodes on a per file basis) -# we need to manually set VLLM_GPU_ARCHES here. -if(VLLM_GPU_LANG STREQUAL "CUDA") - foreach(_ARCH ${CUDA_ARCHS}) - string(REPLACE "." "" _ARCH "${_ARCH}") - list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") - endforeach() -endif() - -# -# Build vLLM flash attention from source -# -# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. -# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. -# They should be identical but if they aren't, this is a massive footgun. -# -# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. -# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). -# If no component is specified, vllm-flash-attn is still installed. - -# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. -# This is to enable local development of vllm-flash-attn within vLLM. -# It can be set as an environment variable or passed as a cmake argument. -# The environment variable takes precedence. -if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) - set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) -endif() - -if(VLLM_FLASH_ATTN_SRC_DIR) - FetchContent_Declare( - vllm-flash-attn SOURCE_DIR - ${VLLM_FLASH_ATTN_SRC_DIR} - BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn - ) -else() - FetchContent_Declare( - vllm-flash-attn - GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade - GIT_PROGRESS TRUE - # Don't share the vllm-flash-attn build between build types - BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn - ) -endif() - - -# Fetch the vllm-flash-attn library -FetchContent_MakeAvailable(vllm-flash-attn) -message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") - -# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in -# case only one is built, in the case both are built redundant work is done) -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm_flash_attn - COMPONENT _vllm_fa2_C - FILES_MATCHING PATTERN "*.py" -) - -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm_flash_attn - COMPONENT _vllm_fa3_C - FILES_MATCHING PATTERN "*.py" -) - -# Nothing after vllm-flash-attn, see comment about macros above diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake new file mode 100644 index 000000000000..6291475164ba --- /dev/null +++ b/cmake/external_projects/flashmla.cmake @@ -0,0 +1,66 @@ +include(FetchContent) + +# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory +# instead of downloading. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{FLASH_MLA_SRC_DIR}) + set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR}) +endif() + +if(FLASH_MLA_SRC_DIR) + FetchContent_Declare( + flashmla + SOURCE_DIR ${FLASH_MLA_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + flashmla + GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git + GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +endif() + + +FetchContent_MakeAvailable(flashmla) +message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") + +# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. +# Only build FlashMLA kernels if we are building for something compatible with +# sm90a +cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) + set(FlashMLA_SOURCES + ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu) + + set(FlashMLA_INCLUDES + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/include) + + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + + define_gpu_extension_target( + _flashmla_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FlashMLA_SOURCES} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} + USE_SABI 3 + WITH_SOABI) +else() + # Create an empty target for setup.py when not targeting sm90a systems + add_custom_target(_flashmla_C) +endif() + diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake new file mode 100644 index 000000000000..ef6261fa6d9b --- /dev/null +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -0,0 +1,67 @@ +# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target +# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the +# arches in the CUDA case (and instead set the gencodes on a per file basis) +# we need to manually set VLLM_GPU_ARCHES here. +if(VLLM_GPU_LANG STREQUAL "CUDA") + foreach(_ARCH ${CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") + endforeach() +endif() + +# +# Build vLLM flash attention from source +# +# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. +# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. +# They should be identical but if they aren't, this is a massive footgun. +# +# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). +# If no component is specified, vllm-flash-attn is still installed. + +# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. +# This is to enable local development of vllm-flash-attn within vLLM. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) + set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) +endif() + +if(VLLM_FLASH_ATTN_SRC_DIR) + FetchContent_Declare( + vllm-flash-attn SOURCE_DIR + ${VLLM_FLASH_ATTN_SRC_DIR} + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +else() + FetchContent_Declare( + vllm-flash-attn + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade + GIT_PROGRESS TRUE + # Don't share the vllm-flash-attn build between build types + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +endif() + + +# Fetch the vllm-flash-attn library +FetchContent_MakeAvailable(vllm-flash-attn) +message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") + +# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in +# case only one is built, in the case both are built redundant work is done) +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa2_C + FILES_MATCHING PATTERN "*.py" +) + +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa3_C + FILES_MATCHING PATTERN "*.py" +) \ No newline at end of file diff --git a/setup.py b/setup.py index d8a336c2d426..a636d266cfbd 100755 --- a/setup.py +++ b/setup.py @@ -328,6 +328,7 @@ def run(self) -> None: files_to_copy = [ "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", + "vllm/_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", @@ -612,6 +613,11 @@ def _read_requirements(filename: str) -> List[str]: # FA3 requires CUDA 12.0 or later ext_modules.append( CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): + # Optional since this doesn't get built (produce an .so file) when + # not targeting a hopper system + ext_modules.append( + CMakeExtension(name="vllm._flashmla_C", optional=True)) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py new file mode 100644 index 000000000000..21c1079fc8eb --- /dev/null +++ b/tests/kernels/test_flashmla.py @@ -0,0 +1,132 @@ +# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py +# SPDX-License-Identifier: Apache-2.0 +import math +import random + +import pytest +import torch +import triton + +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) + + +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + x, y = x.double(), y.double() + cos_diff = 1 - 2 * (x * y).sum().item() / max( + (x * x + y * y).sum().item(), 1e-12) + assert cos_diff < 1e-5 + +FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ + if not is_flashmla_supported()[0] else "FlashMLA is supported" + + +@pytest.mark.skipif(not is_flashmla_supported()[0], + reason=FLASH_MLA_UNSUPPORTED_REASON) +@pytest.mark.parametrize("b", [128]) +@pytest.mark.parametrize("s_q", [1, 2]) +@pytest.mark.parametrize("mean_sk", [4096, 8192]) +@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) +@pytest.mark.parametrize("h_kv", [1]) +@pytest.mark.parametrize("d", [576]) +@pytest.mark.parametrize("dv", [512]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("varlen", [False, True]) +@torch.inference_mode() +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, + varlen): + # TODO: parametrize using pytest + dtype = torch.bfloat16 + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}") + + cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), + s_q) + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + + q = torch.randn(b, s_q, h_q, d) + block_table = torch.arange(b * max_seqlen_pad // block_size, + dtype=torch.int32).view( + b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, + d)[i, cache_seqlens[i].item():] = float("nan") + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + ) + + def scaled_dot_product_attention(query, key, value, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, + dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + ref_O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + is_causal=causal, + ) + out[i] = ref_O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_flash, lse_flash = flash_mla() + out_torch, lse_torch = ref_mla() + cal_diff(out_flash, out_torch, "out") + cal_diff(lse_flash, lse_torch, "lse") + + t = triton.testing.do_bench(flash_mla, fast_flush=False) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " + f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3306610ad800..0e83bcaead94 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1163,3 +1163,67 @@ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: def register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + + +def get_flash_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._C.get_flash_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py new file mode 100644 index 000000000000..273c69b63ec6 --- /dev/null +++ b/vllm/attention/backends/flashmla.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashMLAState"]: + return FlashMLAState + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + # TODO: cache assignment? + if decode_metadata is not None: + decode_metadata.decode_tile_scheduler_metadata=\ + self.decode_tile_scheduler_metadata + decode_metadata.decode_num_splits=\ + self.decode_num_splits + return decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + raise NotImplementedError( + "advance_step is not implemented for FlashMLA") + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + m = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + + if m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + get_mla_metadata( + m.seq_lens_tensor[m.num_prefills:], + self.num_q_heads, + 1, # MQA for the decode path + ) + + return m + + +class FlashMLAState(MLACommonState[FlashMLAMetadata]): + + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + @contextmanager + def graph_capture(self, max_batch_size: int): + # Run a dummy `get_mla_metadata` so we can get the right shapes + self._graph_decoder_tile_scheduler_metadata, \ + self._graph_decode_num_splits = get_mla_metadata( + torch.ones( + max_batch_size, dtype=torch.int32, device=self.runner.device), + self.num_q_heads, + 1, # MQA for the decode path + ) + + with super().graph_capture(max_batch_size): + yield + + del self._graph_decoder_tile_scheduler_metadata + del self._graph_decode_num_splits + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + assert metadata.num_decode_tokens > 0 + + decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + self._graph_seq_lens[:batch_size], + self.num_q_heads, + 1, # MQA for the decode path + ) + + self._graph_decoder_tile_scheduler_metadata.copy_( + decoder_tile_scheduler_metadata) + self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) + + metadata.decode_tile_scheduler_metadata=\ + self._graph_decoder_tile_scheduler_metadata + metadata.decode_num_splits=\ + self._graph_decode_num_splits[:batch_size + 1] + + return metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers["decode_tile_scheduler_metadata"] = \ + attn_metadata.decode_metadata.decode_tile_scheduler_metadata + input_buffers["decode_num_splits"] = \ + attn_metadata.decode_metadata.decode_num_splits + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + input_buffers["decode_tile_scheduler_metadata"].copy_( + attn_metadata.decode_metadata.decode_tile_scheduler_metadata) + input_buffers["decode_num_splits"].copy_( + attn_metadata.decode_metadata.decode_num_splits) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 FlashMLA not yet supported") + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, + num_splits=decode_meta.decode_num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 225fee8d2a0d..1befcb6b45df 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -293,7 +293,10 @@ def get_supported_head_sizes() -> List[int]: return [576] -class MLACommonState(AttentionState): +T = TypeVar("T", bound="MLACommonMetadata") + + +class MLACommonState(AttentionState, Generic[T]): def __init__(self, runner): self.runner = runner @@ -355,7 +358,9 @@ def graph_clone(self, batch_size: int): return self.__class__(self.runner) def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: assert self._is_graph_capturing attn_metadata = self.runner.attn_backend.make_metadata( @@ -507,8 +512,8 @@ class MLACommonMetadata(AttentionMetadata): # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None - _cached_prefill_metadata: Optional["MLACommonMetadata"] = None - _cached_decode_metadata: Optional["MLACommonMetadata"] = None + _cached_prefill_metadata: Optional[Any] = None + _cached_decode_metadata: Optional[Any] = None num_prefill_tokens: int @@ -537,7 +542,7 @@ def __post_init__(self): f" received {self.head_dim}.") @property - def prefill_metadata(self) -> Optional["MLACommonMetadata"]: + def prefill_metadata(self): if self.num_prefills == 0: return None @@ -565,7 +570,7 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: input_positions = (None if self.input_positions is None else self.input_positions[:self.num_prefill_tokens]) - self._cached_prefill_metadata = MLACommonMetadata( + self._cached_prefill_metadata = self.__class__( # Required by ModelRunner use_cuda_graph=False, # Not Attention Related # Required by Attention Metadata @@ -599,7 +604,7 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional["MLACommonMetadata"]: + def decode_metadata(self): if self.num_decode_tokens == 0: return None @@ -617,7 +622,7 @@ def decode_metadata(self) -> Optional["MLACommonMetadata"]: input_positions = (None if self.input_positions is None else self.input_positions[self.num_prefill_tokens:]) - self._cached_decode_metadata = MLACommonMetadata( + self._cached_decode_metadata = self.__class__( # Required by ModelRunner use_cuda_graph=self.use_cuda_graph, # Not Attention Related # Required by Attention Metadata @@ -723,10 +728,7 @@ def advance_step(self, block_tables=self.block_tables) -T = TypeVar("T", bound=MLACommonMetadata) - - -class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): +class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -959,7 +961,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], assert max(context_chunk_seq_tot) <= \ self.chunked_prefill_workspace_size - return MLACommonMetadata( + return self.runner.attn_backend.make_metadata( # Required by ModelRunner use_cuda_graph=use_captured_graph, # Not Attention Related # Required by Attention Metadata diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py new file mode 100644 index 000000000000..18b69a6b3ddf --- /dev/null +++ b/vllm/attention/ops/flashmla.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py +from typing import Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +if current_platform.is_cuda(): + try: + import vllm._flashmla_C # noqa: F401 + _flashmla_C_AVAILABLE = True + except ImportError: + _flashmla_C_AVAILABLE = False +else: + _flashmla_C_AVAILABLE = False + + +def is_flashmla_supported() -> Tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + if not current_platform.is_cuda(): + return False, "FlashMLA is only supported on CUDA devices." + if current_platform.get_device_capability()[0] != 9: + return False, "FlashMLA is only supported on Hopper devices." + if not _flashmla_C_AVAILABLE: + return False, "vllm._flashmla_C is not available, likely was not "\ + "compiled due to insufficient nvcc version or a supported arch "\ + "(only sm90a currently) was not in the list of target arches to "\ + "compile for." + return True, None + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse + + +# +# TODO: Add fake functions +# +# @register_fake("_flashmla_C::get_mla_metadata") +# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# +# @register_fake("_flashmla_C::fwd_kvcache_mla") +# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bf425b89132e..c6f3ccf0a3c4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -141,6 +141,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 + # TODO(lucas): handle this more gracefully + if envs.VLLM_ATTENTION_BACKEND is not None \ + and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \ + and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "FlashMLA: Forcing kv cache block size to 64 since this" + " is currently the only block size supported by the kernel.") @classmethod def get_current_memory_usage(cls, @@ -157,6 +165,22 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" if use_mla: + if selected_backend == _Backend.FLASHMLA: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + if not is_flashmla_supported()[0]: + logger.warning( + "FlashMLA backend is not supported due to %s", + is_flashmla_supported()[1]) + elif block_size != 64: + logger.warning( + "FlashMLA backend is not supported for block size %d" + " (currently only supports block size 64).", + block_size) + else: + logger.info("Using FlashMLA backend.") + return "vllm.attention.backends.flashmla.FlashMLABackend" + logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" if selected_backend == _Backend.FLASHINFER: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d764000c363c..e3ef7c4ac7c5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -35,6 +35,7 @@ class _Backend(enum.Enum): OPENVINO = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() + FLASHMLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() From 375e5708d97a3e68fe10ea70cf946371380d8a5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=84=8D=F0=9D=95=A0=F0=9D=95=9D=F0=9D=95=9D=F0=9D=95=A0?= =?UTF-8?q?=F0=9D=95=A8=20=F0=9D=95=84=F0=9D=95=92=F0=9D=95=9F?= Date: Thu, 27 Feb 2025 04:39:10 +0200 Subject: [PATCH 253/317] [ROCm][Quantization][Kernel] Use FP8 FNUZ when OCP flag is 0 or undefined (#13851) Signed-off-by: Hollow Man --- csrc/quantization/fp8/amd/quant_utils.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index b2196b8ed516..b812b28b607e 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -24,12 +24,12 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, return x; } - #if HIP_FP8_TYPE_FNUZ -using fp8_type = __hip_fp8_e4m3_fnuz; -using fp8x2_type = __hip_fp8x2_e4m3_fnuz; - #elif HIP_FP8_TYPE_OCP + #if HIP_FP8_TYPE_OCP using fp8_type = __hip_fp8_e4m3; using fp8x2_type = __hip_fp8x2_e4m3; + #else +using fp8_type = __hip_fp8_e4m3_fnuz; +using fp8x2_type = __hip_fp8x2_e4m3_fnuz; #endif // fp8 -> half From 258b598136d053c7cde91f71bc5abcac79506016 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 26 Feb 2025 22:06:37 -0500 Subject: [PATCH 254/317] Use CUDA 12.4 as default for release and nightly wheels (#12098) --- .buildkite/release-pipeline.yaml | 13 ++++++++++++- .buildkite/upload-wheels.sh | 10 ++++++++-- .../getting_started/installation/gpu/cuda.inc.md | 4 ++-- setup.py | 7 +++---- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 829414bf8a3b..37cdab9e01ec 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,4 +1,15 @@ steps: + - label: "Build wheel - CUDA 12.4" + agents: + queue: cpu_queue_postmerge + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/upload-wheels.sh" + env: + DOCKER_BUILDKIT: "1" + - label: "Build wheel - CUDA 12.1" agents: queue: cpu_queue_postmerge @@ -37,7 +48,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - label: "Build and publish TPU release image" diff --git a/.buildkite/upload-wheels.sh b/.buildkite/upload-wheels.sh index 3c756659a715..a681f8927060 100644 --- a/.buildkite/upload-wheels.sh +++ b/.buildkite/upload-wheels.sh @@ -50,8 +50,11 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" if [[ $normal_wheel == *"cu118"* ]]; then # if $normal_wheel matches cu118, do not upload the index.html echo "Skipping index files for cu118 wheels" +elif [[ $normal_wheel == *"cu121"* ]]; then + # if $normal_wheel matches cu121, do not upload the index.html + echo "Skipping index files for cu121 wheels" else - # only upload index.html for cu12 wheels (default wheels) + # only upload index.html for cu124 wheels (default wheels) aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" fi @@ -63,8 +66,11 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" if [[ $normal_wheel == *"cu118"* ]]; then # if $normal_wheel matches cu118, do not upload the index.html echo "Skipping index files for cu118 wheels" +elif [[ $normal_wheel == *"cu121"* ]]; then + # if $normal_wheel matches cu121, do not upload the index.html + echo "Skipping index files for cu121 wheels" else - # only upload index.html for cu12 wheels (default wheels) + # only upload index.html for cu124 wheels (default wheels) aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" fi diff --git a/docs/source/getting_started/installation/gpu/cuda.inc.md b/docs/source/getting_started/installation/gpu/cuda.inc.md index 948bdbffbeb7..2477c3e4c93f 100644 --- a/docs/source/getting_started/installation/gpu/cuda.inc.md +++ b/docs/source/getting_started/installation/gpu/cuda.inc.md @@ -23,12 +23,12 @@ Therefore, it is recommended to install vLLM with a **fresh new** environment. I You can install vLLM using either `pip` or `uv pip`: ```console -# Install vLLM with CUDA 12.1. +# Install vLLM with CUDA 12.4. pip install vllm # If you are using pip. uv pip install vllm # If you are using uv. ``` -As of now, vLLM's binaries are compiled with CUDA 12.1 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 11.8 and public PyTorch release versions: +As of now, vLLM's binaries are compiled with CUDA 12.4 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.1, 11.8, and public PyTorch release versions: ```console # Install vLLM with CUDA 11.8. diff --git a/setup.py b/setup.py index a636d266cfbd..6fe433517a05 100755 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def load_module_from_path(module_name, path): # fallback to cpu VLLM_TARGET_DEVICE = "cpu" -MAIN_CUDA_VERSION = "12.1" +MAIN_CUDA_VERSION = "12.4" def is_sccache_available() -> bool: @@ -571,9 +571,8 @@ def _read_requirements(filename: str) -> List[str]: cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: - if ("vllm-flash-attn" in req - and not (cuda_major == "12" and cuda_minor == "1")): - # vllm-flash-attn is built only for CUDA 12.1. + if ("vllm-flash-attn" in req and cuda_major != "12"): + # vllm-flash-attn is built only for CUDA 12.x. # Skip for other versions. continue modified_requirements.append(req) From f81c37f25dd9f4c1e67e72047bbb8085e37d7f33 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:03:28 -0800 Subject: [PATCH 255/317] [misc] Rename Ray ADAG to Compiled Graph (#13928) --- tests/basic_correctness/test_basic_correctness.py | 2 +- tests/basic_correctness/test_chunked_prefill.py | 2 +- tests/distributed/test_pipeline_parallel.py | 11 ++++++----- vllm/envs.py | 9 +++++---- vllm/executor/ray_distributed_executor.py | 10 +++++----- vllm/executor/ray_utils.py | 8 ++++---- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index d2fc0916bc55..0cb3b739b724 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -117,7 +117,7 @@ def test_models_distributed( pytest.skip(f"Skip test for {test_suite}") if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa - # test ray adag + # test Ray Compiled Graph os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index a500ba9dfe02..fd4a804183bf 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -93,7 +93,7 @@ def test_models_distributed( if (model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray"): - # test ray adag + # test Ray Compiled Graph os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 9677ccd2ea82..390ed91c2605 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -324,8 +324,8 @@ def _compare_tp( specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill if distributed_backend == "ray" and (vllm_major_version == "1" or specific_case): - # For V1, test Ray ADAG for all the tests - # For V0, test Ray ADAG for a subset of the tests + # For V1, test Ray Compiled Graph for all the tests + # For V0, test Ray Compiled Graph for a subset of the tests pp_env = { "VLLM_USE_V1": vllm_major_version, "VLLM_USE_RAY_COMPILED_DAG": "1", @@ -333,7 +333,7 @@ def _compare_tp( "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", } # Temporary. Currently when zeromq + SPMD is used, it does not properly - # terminate because of aDAG issue. + # terminate because of a Ray Compiled Graph issue. common_args.append("--disable-frontend-multiprocessing") else: pp_env = None @@ -367,8 +367,9 @@ def _compare_tp( if pp_env is None: raise else: - # Ray ADAG tests are flaky, so we don't want to fail the test - logger.exception("Ray ADAG tests failed") + # Ray Compiled Graph tests are flaky, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") @pytest.mark.parametrize( diff --git a/vllm/envs.py b/vllm/envs.py index 84426cb5bb22..048d63bfec0f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -371,21 +371,22 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_RAY_SPMD_WORKER": lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), - # If the env var is set, it uses the Ray's compiled DAG API - # which optimizes the control plane overhead. + # If the env var is set, it uses the Ray's Compiled Graph + # (previously known as ADAG) API which optimizes the + # control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. "VLLM_USE_RAY_COMPILED_DAG": lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), # If the env var is set, it uses NCCL for communication in - # Ray's compiled DAG. This flag is ignored if + # Ray's Compiled Graph. This flag is ignored if # VLLM_USE_RAY_COMPILED_DAG is not set. "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1")) ), # If the env var is set, it enables GPU communication overlap - # (experimental feature) in Ray's compiled DAG. This flag is ignored if + # (experimental feature) in Ray's Compiled Graph. This flag is ignored if # VLLM_USE_RAY_COMPILED_DAG is not set. "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 2908fefc8e7e..c3b41d1c1134 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -491,7 +491,7 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _check_ray_adag_installation(self): + def _check_ray_cgraph_installation(self): import pkg_resources from packaging import version @@ -503,10 +503,10 @@ def _check_ray_adag_installation(self): f"required, but found {current_version}") import importlib.util - adag_spec = importlib.util.find_spec( + cgraph_spec = importlib.util.find_spec( "ray.experimental.compiled_dag_ref") - if adag_spec is None: - raise ValueError("Ray accelerated DAG is not installed. " + if cgraph_spec is None: + raise ValueError("Ray Compiled Graph is not installed. " "Run `pip install ray[adag]` to install it.") cupy_spec = importlib.util.find_spec("cupy") @@ -518,7 +518,7 @@ def _check_ray_adag_installation(self): def _compiled_ray_dag(self, enable_asyncio: bool): assert self.parallel_config.use_ray - self._check_ray_adag_installation() + self._check_ray_cgraph_installation() from ray.dag import InputNode, MultiOutputNode from ray.experimental.channel.torch_tensor_type import TorchTensorType diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index a9661fe0ef16..6067f9a3c13b 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -83,9 +83,9 @@ def execute_model_spmd( execute_model_req = self.input_decoder.decode(serialized_req) - # TODO(swang): This is needed right now because Ray aDAG executes - # on a background thread, so we need to reset torch's current - # device. + # TODO(swang): This is needed right now because Ray Compiled Graph + # executes on a background thread, so we need to reset torch's + # current device. import torch if not self.compiled_dag_cuda_device_set: torch.cuda.set_device(self.worker.device) @@ -119,7 +119,7 @@ def execute_model_ray( "IntermediateTensors"]], ) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput", "IntermediateTensors"]]: - # this method is used to compile ray CG, + # This method is used by Ray Compiled Graph to execute the model, # and it needs a special logic of self.setup_device_if_necessary() self.setup_device_if_necessary() assert self.worker is not None, "Worker is not initialized" From 3f108620f6ebea45e5be32d71534fc7dad44b60f Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 26 Feb 2025 20:04:12 -0800 Subject: [PATCH 256/317] [ROCm][V1] Update reshape_and_cache to properly work with CUDA graph padding (#13922) --- csrc/cache_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index a6f8602a0588..d06eac2b3d4f 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -375,7 +375,7 @@ void reshape_and_cache( torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale) { - int num_tokens = key.size(0); + int num_tokens = slot_mapping.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(3); From c9095e0d4a74cc376ea45e094ff9c496a291aee7 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Thu, 27 Feb 2025 04:04:59 +0000 Subject: [PATCH 257/317] [V1][Metrics] Handle preemptions (#13169) --- tests/entrypoints/openai/test_metrics.py | 1 + vllm/v1/core/scheduler.py | 10 +++++++- vllm/v1/engine/__init__.py | 1 + vllm/v1/metrics/loggers.py | 24 +++++++++++------- vllm/v1/metrics/stats.py | 31 +++++++++++++++++------- 5 files changed, 48 insertions(+), 19 deletions(-) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index e0323abe2525..5aa259a4f318 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -227,6 +227,7 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "vllm:gpu_cache_usage_perc", "vllm:gpu_prefix_cache_queries", "vllm:gpu_prefix_cache_hits", + "vllm:num_preemptions_total", "vllm:prompt_tokens_total", "vllm:generation_tokens_total", "vllm:iteration_tokens_total", diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 535aa644c53c..87c9c0cd12b7 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -164,6 +164,7 @@ def schedule(self) -> "SchedulerOutput": self.kv_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 + self.request_preempted(preempted_req, scheduled_timestamp) self.waiting.appendleft(preempted_req) preempted_reqs.append(preempted_req) @@ -281,9 +282,9 @@ def schedule(self) -> "SchedulerOutput": self.waiting.popleft() self.running.append(request) self.scheduled_req_ids.add(request.request_id) + self.request_scheduled(request, scheduled_timestamp) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) - self.request_scheduled(request, scheduled_timestamp) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: @@ -675,6 +676,13 @@ def request_scheduled(self, request: Request, timestamp: float): EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED, timestamp)) + def request_preempted(self, request: Request, timestamp: float): + if not self.log_stats: + return + request.events.append( + EngineCoreEvent.new_event(EngineCoreEventType.PREEMPTED, + timestamp)) + def make_stats(self) -> Optional[SchedulerStats]: if not self.log_stats: return None diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 7420dde1f7e4..32fb3c5bd62e 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -65,6 +65,7 @@ class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" QUEUED = 1 SCHEDULED = 2 + PREEMPTED = 3 class EngineCoreEvent(msgspec.Struct): diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 2c17da0ebc83..40dfc5661672 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -132,6 +132,11 @@ def __init__(self, vllm_config: VllmConfig): "GPU prefix cache hits, in terms of number of cached blocks.", labelnames=labelnames).labels(*labelvalues) + self.counter_num_preempted_reqs = prometheus_client.Counter( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames).labels(*labelvalues) + self.counter_prompt_tokens = prometheus_client.Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -282,6 +287,7 @@ def log(self, scheduler_stats: SchedulerStats, self.counter_gpu_prefix_cache_hits.inc( scheduler_stats.prefix_cache_stats.hits) + self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs) self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens.inc( iteration_stats.num_generation_tokens) @@ -289,10 +295,19 @@ def log(self, scheduler_stats: SchedulerStats, iteration_stats.num_prompt_tokens + \ iteration_stats.num_generation_tokens) + for ttft in iteration_stats.time_to_first_tokens_iter: + self.histogram_time_to_first_token.observe(ttft) + for tpot in iteration_stats.time_per_output_tokens_iter: + self.histogram_time_per_output_token.observe(tpot) + for finished_request in iteration_stats.finished_requests: self.counter_request_success[finished_request.finish_reason].inc() self.histogram_e2e_time_request.observe( finished_request.e2e_latency) + self.histogram_queue_time_request.observe( + finished_request.queued_time) + self.histogram_prefill_time_request.observe( + finished_request.prefill_time) self.histogram_inference_time_request.observe( finished_request.inference_time) self.histogram_decode_time_request.observe( @@ -302,15 +317,6 @@ def log(self, scheduler_stats: SchedulerStats, self.histogram_num_generation_tokens_request.observe( finished_request.num_generation_tokens) - for ttft in iteration_stats.time_to_first_tokens_iter: - self.histogram_time_to_first_token.observe(ttft) - for tpot in iteration_stats.time_per_output_tokens_iter: - self.histogram_time_per_output_token.observe(tpot) - for queue_time in iteration_stats.queue_times_iter: - self.histogram_queue_time_request.observe(queue_time) - for prefill_time in iteration_stats.prefill_times_iter: - self.histogram_prefill_time_request.observe(prefill_time) - if self.gauge_lora_info is not None: running_lora_adapters = \ ",".join(iteration_stats.running_lora_adapters.keys()) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 74d4a1bc4fba..30f460e5a691 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -67,6 +67,8 @@ class FinishedRequestStats: e2e_latency: float = 0.0 num_prompt_tokens: int = 0 num_generation_tokens: int = 0 + queued_time: float = 0.0 + prefill_time: float = 0.0 inference_time: float = 0.0 decode_time: float = 0.0 @@ -78,11 +80,10 @@ def __init__(self): self.iteration_timestamp = time.time() self.num_generation_tokens = 0 self.num_prompt_tokens = 0 + self.num_preempted_reqs = 0 self.finished_requests: List[FinishedRequestStats] = [] self.time_to_first_tokens_iter: List[float] = [] self.time_per_output_tokens_iter: List[float] = [] - self.queue_times_iter: List[float] = [] - self.prefill_times_iter: List[float] = [] self.waiting_lora_adapters: Dict[str, int] = {} self.running_lora_adapters: Dict[str, int] = {} @@ -122,9 +123,6 @@ def update_from_output(self, output: "EngineCoreOutput", if is_prefilling: # TODO: re-enable no-output-for-partial-prefills invariant as above if num_new_generation_tokens > 0: - prefill_interval = \ - engine_core_timestamp - req_stats.scheduled_ts - self.prefill_times_iter.append(prefill_interval) req_stats.first_token_ts = engine_core_timestamp else: tpot = engine_core_timestamp - req_stats.last_token_ts @@ -145,24 +143,39 @@ def update_from_events(self, req_id: str, events: List["EngineCoreEvent"], if lora_stats is not None: lora_stats.waiting_requests.add(req_id) elif event.type == EngineCoreEventType.SCHEDULED: - queued_interval = event.timestamp - req_stats.queued_ts - self.queue_times_iter.append(queued_interval) - req_stats.scheduled_ts = event.timestamp + if req_stats.scheduled_ts == 0.0: # ignore preemptions + req_stats.scheduled_ts = event.timestamp LoRARequestStates.scheduled_request(lora_stats, req_id) + elif event.type == EngineCoreEventType.PREEMPTED: + self.num_preempted_reqs += 1 def update_from_finished_request(self, finish_reason: "FinishReason", request_output: "RequestOutput", req_stats: RequestStateStats): e2e_latency = self._time_since(req_stats.arrival_time) - inference_time = req_stats.last_token_ts - req_stats.scheduled_ts + # Queued interval is from first QUEUED event to first SCHEDULED + queued_time = req_stats.scheduled_ts - req_stats.queued_ts + + # Prefill interval is from first SCHEDULED to first NEW_TOKEN + # Any preemptions during prefill is included in the interval + prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts + + # Decode interval is from first NEW_TOKEN to last NEW_TOKEN + # Any preemptions during decode are included decode_time = req_stats.last_token_ts - req_stats.first_token_ts + # Inference interval is from first SCHEDULED to last NEW_TOKEN + # Any preemptions during prefill or decode are included + inference_time = req_stats.last_token_ts - req_stats.scheduled_ts + finished_req = \ FinishedRequestStats(finish_reason=finish_reason, e2e_latency=e2e_latency, num_prompt_tokens=len(request_output.prompt_token_ids), num_generation_tokens=req_stats.num_generation_tokens, + queued_time=queued_time, + prefill_time=prefill_time, inference_time=inference_time, decode_time=decode_time) self.finished_requests.append(finished_req) From 96faffa24a192e8d18cfc4b4043bd93c35b11ce2 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Thu, 27 Feb 2025 03:24:11 -0500 Subject: [PATCH 258/317] [CI/Build] Add examples/ directory to be labelled by `mergify` (#13944) Signed-off-by: Brayden Zhong --- .github/mergify.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/mergify.yml b/.github/mergify.yml index 43bc5ce623d3..e41107ae0a01 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -5,6 +5,7 @@ pull_request_rules: - or: - files~=^[^/]+\.md$ - files~=^docs/ + - files~=^examples/ actions: label: add: From 20b147abc7ee7fe2ede51c685f10f8914749ad7b Mon Sep 17 00:00:00 2001 From: Chauncey Date: Thu, 27 Feb 2025 17:06:49 +0800 Subject: [PATCH 259/317] [Misc] fixed 'required' is an invalid argument for positionals (#13948) Signed-off-by: chaunceyjiang --- .../openai_chat_embedding_client_for_multimodal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py index e410620378a5..2c63c5ec370e 100644 --- a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py @@ -102,7 +102,7 @@ def dse_qwen2_vl(inp: dict): parser = argparse.ArgumentParser( "Script to call a specified VLM through the API. Make sure to serve " "the model with --task embed before running this.") - parser.add_argument("model", + parser.add_argument("--model", type=str, choices=["vlm2vec", "dse_qwen2_vl"], required=True, From f0a2f15a5ab177bc90919e7f1ba8593fab7cb271 Mon Sep 17 00:00:00 2001 From: Yang Zheng <50227060+zhengy001@users.noreply.github.com> Date: Thu, 27 Feb 2025 17:47:29 +0800 Subject: [PATCH 260/317] [PP] Correct cache size check (#13873) Signed-off-by: Yang Zheng --- vllm/worker/hpu_worker.py | 13 +++++++------ vllm/worker/worker.py | 13 +++++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index a1f31bead729..ccb175d88fd3 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -258,9 +258,10 @@ def initialize_cache(self, num_gpu_blocks: int, This also warms up the model, which may record CUDA graphs. """ - raise_if_cache_size_invalid(num_gpu_blocks, - self.cache_config.block_size, - self.model_config.max_model_len) + raise_if_cache_size_invalid( + num_gpu_blocks, self.cache_config.block_size, + self.model_config.max_model_len, + self.parallel_config.pipeline_parallel_size) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -442,13 +443,13 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size) -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, - max_model_len) -> None: +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len, + pipeline_parallel_size) -> None: if num_gpu_blocks <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") - max_seq_len = block_size * num_gpu_blocks + max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size) if max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({max_model_len}) " diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5d548bdb59f7..ad94a6a4db7a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -288,10 +288,11 @@ def initialize_cache(self, num_gpu_blocks: int, This also warms up the model, which may record CUDA graphs. """ - raise_if_cache_size_invalid(num_gpu_blocks, - self.cache_config.block_size, - self.cache_config.is_attention_free, - self.model_config.max_model_len) + raise_if_cache_size_invalid( + num_gpu_blocks, self.cache_config.block_size, + self.cache_config.is_attention_free, + self.model_config.max_model_len, + self.parallel_config.pipeline_parallel_size) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -530,7 +531,7 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, - max_model_len) -> None: + max_model_len, pipeline_parallel_size) -> None: if is_attention_free and num_gpu_blocks != 0: raise ValueError("No memory should be allocated for the cache blocks " f"for an attention-free model, but {num_gpu_blocks} " @@ -539,7 +540,7 @@ def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") - max_seq_len = block_size * num_gpu_blocks + max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size) if not is_attention_free and max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({max_model_len}) " From c3128cb0365991b3aca575ff9e2e6729067fa2f1 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 27 Feb 2025 05:00:00 -0500 Subject: [PATCH 261/317] Fix test_block_fp8.py test for MoE (#13915) Signed-off-by: mgoin --- tests/kernels/test_block_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 20eff1c20723..6206cbd5f76f 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -30,8 +30,8 @@ N_moe = [4608] # [128, 4608, 13824] K_moe = [7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] -E = [256] # [8, 24, 128, 256] -TOP_KS = [1] # [1, 2, 6] +E = [8, 24] # [8, 24, 128, 256] +TOP_KS = [2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] From 93a4a5005c012c17bfb8b636d59e18f0552a9b9f Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 27 Feb 2025 18:06:41 +0800 Subject: [PATCH 262/317] [VLM] Support multimodal inputs for Florence-2 models (#13320) --- docs/source/models/supported_models.md | 7 + .../offline_inference/florence2_inference.py | 39 +- examples/offline_inference/vision_language.py | 17 + tests/conftest.py | 6 +- .../audio_language/test_ultravox.py | 4 +- .../vision_language/test_florence2.py | 139 ++- .../multimodal/processing/test_common.py | 5 +- tests/models/registry.py | 10 +- vllm/model_executor/models/bart.py | 27 +- vllm/model_executor/models/florence2.py | 913 +++++++++++++++++- vllm/model_executor/models/registry.py | 2 +- vllm/multimodal/processing.py | 20 +- vllm/multimodal/profiling.py | 6 +- 13 files changed, 1078 insertions(+), 117 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 9959f7233e86..4b1f3e180ed5 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Florence2ForConditionalGeneration` + * Florence-2 + * T + I + * `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. + * + * + * - * `FuyuForCausalLM` * Fuyu * T + I diff --git a/examples/offline_inference/florence2_inference.py b/examples/offline_inference/florence2_inference.py index 58610b0fd2a5..27aceee43cbf 100644 --- a/examples/offline_inference/florence2_inference.py +++ b/examples/offline_inference/florence2_inference.py @@ -1,34 +1,45 @@ # SPDX-License-Identifier: Apache-2.0 -''' +""" Demonstrate prompting of text-to-text encoder/decoder models, specifically Florence-2 -''' +""" # TODO(Isotr0py): # Move to offline_inference/vision_language.py # after porting vision backbone from vllm import LLM, SamplingParams - -dtype = "float" +from vllm.assets.image import ImageAsset # Create a Florence-2 encoder/decoder model instance llm = LLM( - model="microsoft/Florence-2-base", - tokenizer="facebook/bart-base", - dtype=dtype, + model="microsoft/Florence-2-large", + tokenizer="facebook/bart-large", + max_num_seqs=8, trust_remote_code=True, ) prompts = [ - "", "", "", - "", "", "", - "", "", "" + { # implicit prompt with task token + "prompt": "", + "multi_modal_data": { + "image": ImageAsset("stop_sign").pil_image + }, + }, + { # explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "Describe in detail what is shown in the image.", + "multi_modal_data": { + "image": ImageAsset("cherry_blossom").pil_image + }, + }, + "decoder_prompt": "", + }, ] # Create a sampling params object. sampling_params = SamplingParams( temperature=0, top_p=1.0, min_tokens=0, - max_tokens=20, + max_tokens=128, ) # Generate output tokens from the prompts. The output is a list of @@ -38,9 +49,5 @@ # Print the outputs. for output in outputs: - prompt = output.prompt - encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text - print(f"Encoder prompt: {encoder_prompt!r}, " - f"Decoder prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print(f"Generated text: {generated_text!r}") diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 5f05389faf80..e2ec36211b86 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str): return llm, prompt, stop_token_ids +# Florence2 +def run_florence2(question: str, modality: str): + assert modality == "image" + + llm = LLM(model="microsoft/Florence-2-large", + tokenizer="facebook/bart-large", + max_num_seqs=8, + trust_remote_code=True, + dtype="bfloat16", + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + + prompt = "" + stop_token_ids = None + return llm, prompt, stop_token_ids + + # Fuyu def run_fuyu(question: str, modality: str): assert modality == "image" @@ -571,6 +587,7 @@ def run_qwen2_5_vl(question: str, modality: str): "blip-2": run_blip2, "chameleon": run_chameleon, "deepseek_vl_v2": run_deepseek_vl2, + "florence2": run_florence2, "fuyu": run_fuyu, "glm4v": run_glm4v, "h2ovl_chat": run_h2ovl, diff --git a/tests/conftest.py b/tests/conftest.py index dd339030e5e4..871f0b62c532 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -600,8 +600,8 @@ def generate_encoder_decoder_greedy_logprobs_limit( if images is not None and images[i] is not None: processor_kwargs["images"] = images[i] - encoder_input_ids = self.wrap_device( - self.processor(**processor_kwargs).input_ids, + encoder_inputs = self.wrap_device( + self.processor(**processor_kwargs), device=self.model.device.type, ) @@ -615,13 +615,13 @@ def generate_encoder_decoder_greedy_logprobs_limit( ) output = self.model.generate( - encoder_input_ids, decoder_input_ids=decoder_input_ids, use_cache=True, do_sample=False, max_new_tokens=max_tokens, output_hidden_states=True, return_dict_in_generate=True, + **encoder_inputs, **kwargs, ) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index d1f643a8fdb7..0ea17247028f 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" +MODEL_NAME = "fixie-ai/ultravox-v0_4" AudioTuple = Tuple[np.ndarray, int] @@ -187,7 +187,7 @@ def run_multi_audio_test( @pytest.mark.core_model -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("vllm_kwargs", [ diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index a1d15679918b..de18deab11f6 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -1,52 +1,59 @@ # SPDX-License-Identifier: Apache-2.0 -from functools import partial -from typing import List, Optional, Tuple, Type +from typing import Optional, Type import pytest from PIL import Image -from vllm.inputs.data import ExplicitEncoderDecoderPrompt +from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt +from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, VllmRunner +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ...utils import check_logprobs_close -Florence2Prompt = partial(ExplicitEncoderDecoderPrompt, - decoder_prompt=None, - mm_processor_kwargs=None) - MODELS = ["microsoft/Florence-2-base"] # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model TOKENIZER = "facebook/bart-base" -PROMPTS = [ - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), -] +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "", # special task token + "cherry_blossom": + "Describe in detail what is shown in the image.", +}) + +def get_hf_images_prompts( + prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]], +) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]: + prompts, images = [], [] + for prompt in prompts_: + encoder_prompt = prompt["encoder_prompt"] + prompts.append( + ExplicitEncoderDecoderPrompt( + encoder_prompt=encoder_prompt["prompt"], + decoder_prompt=None, + )) + images.append(encoder_prompt["multi_modal_data"]["image"]) + return prompts, images -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, - Optional[SampleLogprobs]], ): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - hf_output_str = "" + output_str + "" +def hf_to_vllm_output(hf_output: tuple[list[int], str, + Optional[SampleLogprobs]]): + """Sanitize hf output to be comparable with vllm output.""" + output_ids, output_str, out_logprobs = hf_output - return output_ids, hf_output_str, out_logprobs + output_str = output_str.replace("", "").replace("", "") + output_ids = [ids for ids in output_ids if ids not in [0, 2]] + + return output_ids, output_str, out_logprobs def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - prompts: List[ExplicitEncoderDecoderPrompt], + inputs: list[list[ExplicitEncoderDecoderPrompt]], model: str, *, dtype: str, @@ -56,46 +63,76 @@ def run_test( distributed_executor_backend: Optional[str] = None, ) -> None: with vllm_runner(model, + max_num_seqs=8, tokenizer_name=TOKENIZER, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) + vllm_outputs_per_case = [ + vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs=num_logprobs) + for prompts in inputs + ] + + hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs] - # Florence-2 processors require image inputs - dummy_image = Image.new(mode="RGB", size=(2, 2)) with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model: hf_model.model.get_output_embeddings = lambda: \ hf_model.model.language_model.lm_head - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs, - images=[dummy_image] * len(prompts), - )) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - + hf_outputs_per_case = [ + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, max_tokens, num_logprobs=num_logprobs, images=images) + for prompts, images in hf_inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs], + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, - num_logprobs) -> None: +def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, model: str, + size_factors: list[int], dtype: str, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [[ + ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=prompt, + multi_modal_data={"image": rescale_image_size(image, factor)}), + decoder_prompt=None, + ) for factor in size_factors + ] for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + run_test( hf_runner, vllm_runner, - PROMPTS, + inputs_per_image, model, dtype=dtype, max_tokens=max_tokens, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a84999cfbf4f..7534f0c97798 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -29,8 +29,8 @@ def _test_processing_correctness( model_config = ModelConfig( model_id, task="auto", - tokenizer=model_id, - tokenizer_mode="auto", + tokenizer=model_info.tokenizer or model_id, + tokenizer_mode=model_info.tokenizer_mode, trust_remote_code=model_info.trust_remote_code, seed=0, dtype="float16", @@ -151,6 +151,7 @@ def _test_processing_correctness( "Salesforce/blip2-opt-2.7b", "facebook/chameleon-7b", "deepseek-ai/deepseek-vl2-tiny", + "microsoft/Florence-2-base", "adept/fuyu-8b", "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", diff --git a/tests/models/registry.py b/tests/models/registry.py index 8614baf18f3b..95bda0293498 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -193,11 +193,6 @@ def check_available_online( # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), - # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer - # Therefore, we borrow the BartTokenizer from the original Bart model - "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 - tokenizer="facebook/bart-base", - trust_remote_code=True), # noqa: E501 } _EMBEDDING_EXAMPLE_MODELS = { @@ -288,6 +283,11 @@ def check_available_online( extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501 trust_remote_code=True), # [Encoder-decoder] + # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer + # Therefore, we borrow the BartTokenizer from the original Bart model + "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 + tokenizer="facebook/bart-base", + trust_remote_code=True), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 } diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 5d2a8cdcb97d..93452696dca5 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -588,8 +588,12 @@ def __init__(self, self.layernorm_embedding = nn.LayerNorm(embed_dim) - def forward(self, input_ids: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: r""" Args: input_ids @@ -602,7 +606,8 @@ def forward(self, input_ids: torch.Tensor, Decoder output torch.Tensor """ # retrieve input_ids and inputs_embeds - inputs_embeds = self.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(positions) embed_pos = embed_pos.to(inputs_embeds.device) @@ -661,9 +666,13 @@ def __init__( self.layernorm_embedding = nn.LayerNorm(config.d_model) - def forward(self, decoder_input_ids: torch.Tensor, - decoder_positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor: + def forward( + self, + decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: r""" Args: decoder_input_ids @@ -677,8 +686,10 @@ def forward(self, decoder_input_ids: torch.Tensor, Returns: Decoder output torch.Tensor """ - - inputs_embeds = self.embed_tokens(decoder_input_ids) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(decoder_input_ids) + else: + decoder_positions = inputs_embeds[:, -1] # embed positions embed_pos = self.embed_positions(decoder_positions) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 06912bcfdc8a..b71d0de8d707 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Iterable, Optional, Set, Tuple +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict, + Set, Tuple, TypedDict, Union) import torch import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BatchFeature, PretrainedConfig from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -14,11 +19,567 @@ BartParallelLMHead, BartScaledWordEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement, + PromptReplacementDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from .utils import AutoWeightsLoader +from .interfaces import SupportsMultiModal +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings +class Florence2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channel, height, width)""" + + +# ViT implementation are all copied from +# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py +class LearnedAbsolutePositionEmbedding2D(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256, num_pos=50): + super().__init__() + self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) + self.column_embeddings = nn.Embedding( + num_pos, embedding_dim - (embedding_dim // 2)) + + def forward(self, pixel_values): + """ + pixel_values: (batch_size, height, width, num_channels) + returns: (batch_size, height, width, embedding_dim * 2) + """ + if len(pixel_values.shape) != 4: + raise ValueError('pixel_values must be a 4D tensor') + height, width = pixel_values.shape[1:3] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + # (height, width, embedding_dim * 2) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(height, 1, 1), + y_emb.unsqueeze(1).repeat(1, width, 1) + ], + dim=-1) + # (embedding_dim * 2, height, width) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + # (batch_size, embedding_dim * 2, height, width) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + # (batch_size, height, width, embedding_dim * 2) + pos = pos.permute(0, 2, 3, 1) + return pos + + +class PositionalEmbeddingCosine1D(nn.Module): + """ + This class implements a very simple positional encoding. It follows closely + the encoder from the link below: + https://pytorch.org/tutorials/beginner/translation_transformer.html + Args: + embed_dim: The dimension of the embeddings. + dropout_prob: The dropout probability. + max_seq_len: The maximum length to precompute the positional encodings. + """ + + def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: + super().__init__() + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + # Generate the sinusoidal arrays. + factor = math.log(10000) + denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / + self.embed_dim) + # Matrix where rows correspond to a positional embedding as a function + # of the position index (i.e., the row index). + frequencies = \ + torch.arange(0, self.max_seq_len) \ + .reshape(self.max_seq_len, 1) * denominator + pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) + # Populate uneven entries. + pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) + pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) + # Save the positional embeddings in a constant buffer. + # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) + self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed, + requires_grad=False) + + def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: + """ + Args: + seq_embeds: The sequence embeddings in order. Allowed size: + 1. [T, D], where T is the length of the sequence, and D is the + frame embedding dimension. + 2. [B, T, D], where B is the batch size and T and D are the + same as above. + Returns a tensor of with the same dimensions as the input: i.e., + [1, T, D] or [T, D]. + """ + shape_len = len(seq_embeds.shape) + assert 2 <= shape_len <= 3 + len_seq = seq_embeds.size(-2) + assert len_seq <= self.max_seq_len + pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] + # Adapt pre-computed positional embeddings to the input. + if shape_len == 3: + pos_embeds = pos_embeds.view( + (1, pos_embeds.size(0), pos_embeds.size(1))) + return pos_embeds + + +class MySequential(nn.Sequential): + + def forward(self, *inputs): + for module in self._modules.values(): + if isinstance(inputs, tuple): + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class PreNorm(nn.Module): + + def __init__(self, norm, fn): + super().__init__() + self.norm = norm + self.fn = fn + + def forward(self, x, *args, **kwargs): + shortcut = x + if self.norm is not None: + x, size = self.fn(self.norm(x), *args, **kwargs) + else: + x, size = self.fn(x, *args, **kwargs) + + x = shortcut + x + + return x, size + + +class Mlp(nn.Module): + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.net = nn.Sequential( + OrderedDict([("fc1", nn.Linear(in_features, hidden_features)), + ("act", act_layer()), + ("fc2", nn.Linear(hidden_features, out_features))])) + + def forward(self, x, size): + return self.net(x), size + + +class DepthWiseConv2d(nn.Module): + + def __init__( + self, + dim_in, + kernel_size, + padding, + stride, + bias=True, + ): + super().__init__() + self.dw = nn.Conv2d(dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + groups=dim_in, + stride=stride, + bias=bias) + + def forward(self, x, size): + B, N, C = x.shape + H, W = size + assert N == H * W + + x = self.dw(x.transpose(1, 2).view(B, C, H, W)) + size = (x.size(-2), x.size(-1)) + x = x.flatten(2).transpose(1, 2) + return x, size + + +class ConvEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, + patch_size=7, + in_chans=3, + embed_dim=64, + stride=4, + padding=2, + norm_layer=None, + pre_norm=True): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding) + + dim_norm = in_chans if pre_norm else embed_dim + self.norm = norm_layer(dim_norm) if norm_layer else None + + self.pre_norm = pre_norm + + def forward(self, x, size): + H, W = size + if len(x.size()) == 3: + if self.norm and self.pre_norm: + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) + + x = self.proj(x) + + _, _, H, W = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if self.norm and not self.pre_norm: + x = self.norm(x) + + return x, (H, W) + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, groups=8, qkv_bias=True): + super().__init__() + + self.groups = groups + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, size): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.groups, + C // self.groups).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * (float(N)**-0.5) + attention = q.transpose(-1, -2) @ k + attention = attention.softmax(dim=-1) + x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x, size + + +class ChannelBlock(nn.Module): + + def __init__(self, + dim, + groups, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True): + super().__init__() + + self.conv1 = PreNorm(None, DepthWiseConv2d( + dim, 3, 1, 1)) if conv_at_attn else None + self.channel_attn = PreNorm( + norm_layer(dim), + ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, + 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer), + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.channel_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + + return x, size + + +def window_partition(x, window_size: int): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): + B = batch_size + + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + + def __init__(self, dim, num_heads, window_size, qkv_bias=True): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = float(head_dim)**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, size): + + H, W = size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x = window_partition(x, self.window_size) + x = x.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # attn_windows = self.attn(x_windows) + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + + # merge windows + x = x.view(-1, self.window_size, self.window_size, C) + x = window_reverse(x, B, self.window_size, Hp, Wp) + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + return x, size + + +class SpatialBlock(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True): + super().__init__() + + self.conv1 = PreNorm(None, DepthWiseConv2d( + dim, 3, 1, 1)) if conv_at_attn else None + self.window_attn = PreNorm( + norm_layer(dim), + WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, + 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer), + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.window_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + return x, size + + +class DaViT(nn.Module): + + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=(1, 1, 3, 1), + patch_size=(7, 2, 2, 2), + patch_stride=(4, 2, 2, 2), + patch_padding=(3, 0, 0, 0), + patch_prenorm=(False, False, False, False), + embed_dims=(64, 128, 192, 256), + num_heads=(3, 6, 12, 24), + num_groups=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + enable_checkpoint=False, + conv_at_attn=True, + conv_at_ffn=True, + ): + super().__init__() + + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_groups = num_groups + self.num_stages = len(self.embed_dims) + self.enable_checkpoint = enable_checkpoint + assert self.num_stages == len(self.num_heads) == len(self.num_groups) + + num_stages = len(embed_dims) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths) * 2) + ] + + depth_offset = 0 + convs = [] + blocks = [] + for i in range(num_stages): + conv_embed = ConvEmbed( + patch_size=patch_size[i], + stride=patch_stride[i], + padding=patch_padding[i], + in_chans=in_chans if i == 0 else self.embed_dims[i - 1], + embed_dim=self.embed_dims[i], + norm_layer=norm_layer, + pre_norm=patch_prenorm[i]) + convs.append(conv_embed) + + block = MySequential(*[ + MySequential( + OrderedDict([('spatial_block', + SpatialBlock( + embed_dims[i], + num_heads[i], + window_size, + drop_path_rate=dpr[depth_offset + j * 2], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + )), + ('channel_block', + ChannelBlock( + embed_dims[i], + num_groups[i], + drop_path_rate=dpr[depth_offset + j * 2 + + 1], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ))])) for j in range(depths[i]) + ]) + blocks.append(block) + depth_offset += depths[i] * 2 + + self.convs = nn.ModuleList(convs) + self.blocks = nn.ModuleList(blocks) + + self.avgpool = nn.AdaptiveAvgPool1d(1) + + @property + def dim_out(self): + return self.embed_dims[-1] + + def forward_features_unpool(self, x): + """ + forward until avg pooling + Args: + x (_type_): input image tensor + """ + input_size = (x.size(2), x.size(3)) + for conv, block in zip(self.convs, self.blocks): + x, input_size = conv(x, input_size) + x, input_size = block(x, input_size) + return x + + def forward_features(self, x): + x = self.forward_features_unpool(x) + + # (batch_size, num_tokens, token_dim) + x = self.avgpool(x.transpose(1, 2)) + # (batch_size, 1, num_tokens) + x = torch.flatten(x, 1) + x = self.norms(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + @classmethod + def from_config(cls, config): + return cls( + depths=config.depths, + embed_dims=config.dim_embed, + num_heads=config.num_heads, + num_groups=config.num_groups, + patch_size=config.patch_size, + patch_stride=config.patch_stride, + patch_padding=config.patch_padding, + patch_prenorm=config.patch_prenorm, + drop_path_rate=config.drop_path_rate, + window_size=config.window_size, + ) + + +# Language backbone and processor implementation class Florence2LanguageModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -47,9 +608,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.encoder.embed_tokens.weight = self.shared.weight self.decoder.embed_tokens.weight = self.shared.weight - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: r""" Args: input_ids @@ -68,11 +634,12 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_hidden_states = None - if encoder_input_ids.numel() > 0: + if inputs_embeds is not None or encoder_input_ids.numel() > 0: # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions) + positions=encoder_positions, + inputs_embeds=inputs_embeds) # decoder outputs consists of # (dec_features, past_key_value, dec_hidden, dec_attn) @@ -112,6 +679,7 @@ def forward( positions: torch.Tensor, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: r""" @@ -127,8 +695,15 @@ def forward( Returns: Output torch.Tensor """ - return self.model(input_ids, positions, encoder_input_ids, - encoder_positions) + + return self.model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + inputs_embeds=inputs_embeds) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.encoder.embed_tokens(input_ids) def compute_logits( self, @@ -177,21 +752,312 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class Florence2ForConditionalGeneration(nn.Module): +class Florence2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_max_image_tokens(self) -> int: + processor_config = self.ctx.get_hf_image_processor_config() + return processor_config["image_seq_length"] + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + +class Florence2DummyInputsBuilder( + BaseDummyInputsBuilder[Florence2ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width = target_height = self.info.get_hf_config().projection_dim + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class Florence2MultiModalProcessor( + EncDecMultiModalProcessor[Florence2ProcessingInfo]): + + def _hf_processor_applies_repl( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return [self.info.get_hf_config().eos_token_id] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs) + else: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + prompt = hf_processor._construct_prompts([prompt])[0] + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + pad_token_id = hf_config.pad_token_id + bos_token_id = hf_config.bos_token_id + num_image_tokens = self.info.get_max_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[bos_token_id], + replacement=PromptReplacementDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ), + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Florence2MultiModalProcessor, + info=Florence2ProcessingInfo, + dummy_inputs=Florence2DummyInputsBuilder) +class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config - # TODO(Isotr0py): Add vision backbone + self.config = config + self.vision_config = config.vision_config + self.processor_config = processor_config + assert config.vision_config.model_type == 'davit', ( + 'only DaViT is supported for now') + self.vision_tower = DaViT.from_config(config=config.vision_config) + self._build_image_projection_layers(config) self.language_model = Florence2LanguageForConditionalGeneration( vllm_config=vllm_config.with_hf_config(config.text_config), prefix=f"{prefix}.language_model", ) + self.pad_token_id = config.pad_token_id - @property + def _build_image_projection_layers(self, config: PretrainedConfig): + image_dim_out = config.vision_config.dim_embed[-1] + dim_projection = config.vision_config.projection_dim + self.image_projection = nn.Parameter( + torch.empty(image_dim_out, dim_projection)) + self.image_proj_norm = nn.LayerNorm(dim_projection) + image_pos_embed_config = config.vision_config.image_pos_embed + if image_pos_embed_config['type'] == 'learned_abs_2d': + self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( + embedding_dim=image_dim_out, + num_pos=image_pos_embed_config['max_pos_embeddings']) + else: + raise NotImplementedError("Florence2 only supports learned_abs_2d " + "as image position embedding.") + + self.image_feature_source = config.vision_config.image_feature_source + + # temporal embedding + visual_temporal_embedding_config = ( + self.vision_config.visual_temporal_embedding) + if visual_temporal_embedding_config['type'] == 'COSINE': + self.visual_temporal_embed = PositionalEmbeddingCosine1D( + embed_dim=image_dim_out, + max_seq_len=visual_temporal_embedding_config[ + 'max_temporal_embeddings']) + else: + raise NotImplementedError( + 'Florence2 only supports COSINE as temporal embedding.') + + @cached_property def sampler(self): - return self.language_model.sampler + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + return get_sampler() + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + size = self.processor_config["size"] + h, w = size["height"], size["width"] + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = tuple(*map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + return Florence2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: + dtype = next(self.vision_tower.parameters()).dtype + pixel_values = pixel_values.to(dtype) + + batch_size, T = pixel_values.size(0), 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, ( + 'only support square feature maps for now') + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h * w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed( + x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, + x.shape[-1]) + visual_temporal_embed.view( + 1, T, 1, x.shape[-1]) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, + x.shape[-1]).mean(dim=1) + x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict['last_frame'] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError('invalid image feature source: {}'.format( + _image_feature_source)) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + return x + + def _process_image_input( + self, image_input: Florence2ImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + return self._encode_image(pixel_values) + + def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.pad_token_id) + return inputs_embeds def forward( self, @@ -216,8 +1082,19 @@ def forward( Returns: Output torch.Tensor """ - return self.language_model(input_ids, positions, encoder_input_ids, - encoder_positions) + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + if encoder_input_ids.numel() > 0 or vision_embeddings is not None: + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + else: + inputs_embeds = None + + hidden_states = self.language_model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + inputs_embeds=inputs_embeds) + return hidden_states def compute_logits( self, @@ -236,9 +1113,5 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - skip_prefixes = [ - 'image_projection', "vision_tower", "image_proj_norm", - "image_pos_embed", "visual_temporal_embed" - ] - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 58155905a7b7..75e31d557dd1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -105,7 +105,6 @@ # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), - "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 } _EMBEDDING_MODELS = { @@ -182,6 +181,7 @@ "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] + "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 93756364dea1..60b000e2b34f 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1303,6 +1303,14 @@ def create_encoder_prompt( """ raise NotImplementedError + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + """Create input prompt for the decoder.""" + return prompt + def apply( self, prompt: Union[str, list[int]], @@ -1323,17 +1331,15 @@ def apply( hf_processor_mm_kwargs, ) - # We assumed the decoder prompt text is copied from - # the original encoder prompt without extra process tokenizer = self.info.get_tokenizer() - if isinstance(prompt, str): - decoder_prompt = prompt + decoder_prompt = self.create_decoder_prompt(prompt, mm_data) + if isinstance(decoder_prompt, str): decoder_prompt_ids = encode_tokens(tokenizer, - prompt, + decoder_prompt, add_special_tokens=False) else: - decoder_prompt = decode_tokens(tokenizer, prompt) - decoder_prompt_ids = prompt + decoder_prompt_ids = decoder_prompt + decoder_prompt = decode_tokens(tokenizer, decoder_prompt) mm_inputs = MultiModalEncDecInputs( encoder_prompt=encoder_inputs["prompt"], diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 093f8b7a8179..3178b0f8c3e6 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -204,9 +204,11 @@ def get_dummy_data( "and/or reduce `mm_counts`.", seq_len, total_len, total_placeholders_by_modality) + num_tokens_to_pad = max(total_len, seq_len) - total_len + prompt_token_ids.extend([0] * num_tokens_to_pad) + return DummyData( - seq_data=SequenceData.from_prompt_token_counts( - (0, max(seq_len, total_len))), + seq_data=SequenceData.from_seqs(prompt_token_ids), multi_modal_data=None, multi_modal_placeholders=None, ) From feaa8ce8a0eb655f240e8fd24fbb3184cc07c74d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Thu, 27 Feb 2025 11:08:35 +0100 Subject: [PATCH 263/317] [Model] Deepseek GGUF support (#13167) --- docs/source/features/quantization/gguf.md | 7 + vllm/config.py | 9 +- vllm/engine/arg_utils.py | 8 ++ vllm/model_executor/layers/fused_moe/layer.py | 22 ++- vllm/model_executor/layers/linear.py | 15 ++- .../layers/quantization/gguf.py | 127 +++++++++++++++++- vllm/model_executor/model_loader/loader.py | 19 ++- .../model_loader/weight_utils.py | 1 - 8 files changed, 198 insertions(+), 10 deletions(-) diff --git a/docs/source/features/quantization/gguf.md b/docs/source/features/quantization/gguf.md index 65c181900f9b..4b1ff4a22a23 100644 --- a/docs/source/features/quantization/gguf.md +++ b/docs/source/features/quantization/gguf.md @@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size. ::: +GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path + +```console +# If you model is not supported by huggingface you can manually provide a huggingface compatible config path +vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0 +``` + You can also use the GGUF model directly through the LLM entrypoint: ```python diff --git a/vllm/config.py b/vllm/config.py index a5d8ee9303d0..d1384c6375f3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -229,6 +229,7 @@ def __init__( trust_remote_code: bool, dtype: Union[str, torch.dtype], seed: int, + hf_config_path: Optional[str] = None, allowed_local_media_path: str = "", revision: Optional[str] = None, code_revision: Optional[str] = None, @@ -259,6 +260,7 @@ def __init__( model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, ) -> None: self.model = model + self.hf_config_path = hf_config_path self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code @@ -321,8 +323,9 @@ def __init__( if self.enable_sleep_mode and not current_platform.is_cuda(): raise ValueError("Sleep mode is only supported on CUDA devices.") - hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, config_format) + hf_config = get_config(self.hf_config_path or self.model, + trust_remote_code, revision, code_revision, + config_format) if hf_overrides_kw: logger.info("Overriding HF config with %s", hf_overrides_kw) @@ -947,7 +950,7 @@ def get_multimodal_config(self) -> "MultiModalConfig": def try_get_generation_config(self) -> Dict[str, Any]: if self.generation_config is None or self.generation_config == "auto": config = try_get_generation_config( - self.model, + self.hf_config_path or self.model, trust_remote_code=self.trust_remote_code, revision=self.revision, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 26d4a84b841c..1a2f794c9151 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -93,6 +93,7 @@ class EngineArgs: model: str = 'facebook/opt-125m' served_model_name: Optional[Union[str, List[str]]] = None tokenizer: Optional[str] = None + hf_config_path: Optional[str] = None task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' @@ -262,6 +263,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use. ' 'If unspecified, model name or path will be used.') + parser.add_argument( + "--hf-config-path", + type=nullable_str, + default=EngineArgs.hf_config_path, + help='Name or path of the huggingface config to use. ' + 'If unspecified, model name or path will be used.') parser.add_argument( '--skip-tokenizer-init', action='store_true', @@ -1076,6 +1083,7 @@ def create_model_config(self) -> ModelConfig: return ModelConfig( model=self.model, + hf_config_path=self.hf_config_path, task=self.task, # We know this is not None because we set it in __post_init__ tokenizer=cast(str, self.tokenizer), diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 42554b61f67a..28a88571dab4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple import torch +from torch.nn.parameter import UninitializedParameter import vllm.envs as envs from vllm.distributed import (get_tensor_model_parallel_rank, @@ -514,7 +515,12 @@ def weight_loader(self, param: torch.nn.Parameter, # dimension intermediate_size_per_partition is used. SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} - expert_data = param.data[expert_id] + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + param.data.copy_(loaded_weight) + return # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -524,6 +530,20 @@ def weight_loader(self, param: torch.nn.Parameter, if is_transposed: shard_dim = int(not shard_dim) + full_load = len(loaded_weight.shape) == 3 + if full_load: + shard_dim += 1 + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + final_shape = list(loaded_weight.shape) + if shard_id in ["w1", "w3"]: + final_shape[1] *= 2 + final_shape[shard_dim] = final_shape[ + shard_dim] // get_tensor_model_parallel_world_size() + param.materialize(final_shape, dtype=loaded_weight.dtype) + + expert_data = param.data if full_load else param.data[expert_id] # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # this is needed for compressed-tensors only diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 521724765beb..b9c85aaf50b5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -235,10 +235,23 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # If the weight on disk does not have a shape, give it one # (such scales for AutoFp8). + # Special case for GGUF + + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert param.size() == loaded_weight.size() + assert param.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter of size {param.size()}") param.data.copy_(loaded_weight) def forward(self, diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index b1fecb32f4d8..ba176e4a567c 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import gguf import torch @@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -29,7 +32,7 @@ def get_name(self) -> str: return "gguf" def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half, torch.bfloat16] + return [torch.half] @classmethod def get_min_capability(cls) -> int: @@ -49,6 +52,8 @@ def get_quant_method(self, layer: torch.nn.Module, return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): return GGUFEmbeddingMethod(self) + elif isinstance(layer, FusedMoE): + return GGUFMoEMethod(self) return None @@ -184,6 +189,124 @@ def apply(self, return out +class GGUFMoEMethod(FusedMoEMethodBase): + """MoE method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def __init__(self, quant_config: GGUFConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + tensor_shape = (num_experts, 2 * intermediate_size_per_partition, + hidden_size) + #gate up proj + w13_qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + w13_qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + }) + set_weight_attrs(w13_qweight, extra_weight_attrs) + layer.register_parameter("w13_qweight", w13_qweight) + + w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), + requires_grad=False) + set_weight_attrs(w13_qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "ignore_warning": True + }) + set_weight_attrs(w13_qweight_type, extra_weight_attrs) + layer.register_parameter("w13_qweight_type", w13_qweight_type) + + tensor_shape = (num_experts, intermediate_size_per_partition, + hidden_size) + #gate down proj + w2_qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + w2_qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + }) + set_weight_attrs(w2_qweight, extra_weight_attrs) + layer.register_parameter("w2_qweight", w2_qweight) + + w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), + requires_grad=False) + set_weight_attrs(w2_qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "ignore_warning": True + }) + + set_weight_attrs(w2_qweight_type, extra_weight_attrs) + layer.register_parameter("w2_qweight_type", w2_qweight_type) + self.act = SiluAndMul() + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ): + assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + final_hidden_states = torch.empty_like(x) + for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): + inp = x[tok].reshape((1, ) + x.shape[1:]) + current_hidden_state = None + for ww, ii in zip(w, idx): + expert_up = layer.w13_qweight[ii] + + out = _fuse_mul_mat(inp, expert_up, + layer.w13_qweight_type.weight_type) + out = self.act(out) + + expert_down = layer.w2_qweight[ii] + current_state = _fuse_mul_mat( + out, expert_down, + layer.w2_qweight_type.weight_type).mul_(ww) + if current_hidden_state is None: + current_hidden_state = current_state + else: + current_hidden_state.add_(current_state) + final_hidden_states[tok] = current_hidden_state + return final_hidden_states + + class GGUFEmbeddingMethod(GGUFLinearMethod): """Embedding method for GGUF. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 4e8ef49235ed..46247eaf2a60 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1245,9 +1245,24 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): """ config = model_config.hf_config model_type = config.model_type + gguf_to_hf_name_map = {} # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" + if model_type in ("deepseek_v3", "deepseek_v2"): + model_type = "deepseek2" + # GGUF layer map assumes that we will have a merged expert weights + # so we need to map them manually + for idx in range(config.num_hidden_layers): + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): if value == model_type: @@ -1258,10 +1273,10 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): num_layers = config.num_hidden_layers name_map = gguf.get_tensor_name_map(arch, num_layers) with torch.device("meta"): - dummy_model = AutoModelForCausalLM.from_config(config) + dummy_model = AutoModelForCausalLM.from_config( + config, trust_remote_code=model_config.trust_remote_code) state_dict = dummy_model.state_dict() - gguf_to_hf_name_map = {} for hf_name in state_dict: name, suffix = hf_name.rsplit(".", 1) gguf_name = name_map.get_name(name) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 18f6f40b32f0..245c199f75b1 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -496,7 +496,6 @@ def gguf_quant_weights_iterator( weight = tensor.data weight_type = tensor.tensor_type name = gguf_to_hf_name_map[tensor.name] - if weight_type.name != "F32": name = name.replace("weight", "qweight") param = torch.tensor(weight) From 954a82a7bc0c255316a39449f9082a2d22e30fc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=8D=9A=E4=BC=9F?= Date: Fri, 28 Feb 2025 00:05:11 +0800 Subject: [PATCH 264/317] Update quickstart.md (#13958) --- docs/source/getting_started/quickstart.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index f51856d6eaeb..452bee2385fe 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -24,6 +24,12 @@ source myenv/bin/activate uv pip install vllm ``` +Another delightful way is to use `uv run` with `--with [dependency]` option, which allows you to run commands such as `vllm serve` without creating an environment: + +```console +uv run --with vllm vllm --help +``` + You can also use [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html) to create and manage Python environments. ```console From e2ec2ebc99f359eae49efa1598040fa994a159fb Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 27 Feb 2025 16:27:47 +0000 Subject: [PATCH 265/317] Deduplicate `.pre-commit-config.yaml`'s `exclude` (#13967) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .pre-commit-config.yaml | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20d1981c9a05..23a38d49638f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,7 @@ default_stages: - pre-commit # Run locally - manual # Run in CI +exclude: 'vllm/third_party/.*' repos: - repo: https://github.com/google/yapf rev: v0.43.0 @@ -8,13 +9,11 @@ repos: - id: yapf args: [--in-place, --verbose] additional_dependencies: [toml] # TODO: Remove when yapf is upgraded - exclude: 'vllm/third_party/.*' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.3 hooks: - id: ruff args: [--output-format, github, --fix] - exclude: 'vllm/third_party/.*' - repo: https://github.com/codespell-project/codespell rev: v2.4.0 hooks: @@ -25,7 +24,6 @@ repos: rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0 hooks: - id: isort - exclude: 'vllm/third_party/.*' - repo: https://github.com/pre-commit/mirrors-clang-format rev: v19.1.7 hooks: @@ -38,12 +36,10 @@ repos: hooks: - id: pymarkdown args: [fix] - exclude: 'vllm/third_party/.*' - repo: https://github.com/rhysd/actionlint rev: v1.7.7 hooks: - id: actionlint - exclude: 'vllm/third_party/.*' - repo: https://github.com/astral-sh/uv-pre-commit rev: 0.6.2 hooks: @@ -59,7 +55,6 @@ repos: types: [python] additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests] stages: [pre-commit] # Don't run in CI - exclude: 'vllm/third_party/.*' - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 entry: tools/mypy.sh 1 "3.9" @@ -67,7 +62,6 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI - exclude: 'vllm/third_party/.*' - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 entry: tools/mypy.sh 1 "3.10" @@ -75,7 +69,6 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI - exclude: 'vllm/third_party/.*' - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.11 entry: tools/mypy.sh 1 "3.11" @@ -83,7 +76,6 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI - exclude: 'vllm/third_party/.*' - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.12 entry: tools/mypy.sh 1 "3.12" @@ -91,19 +83,16 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI - exclude: 'vllm/third_party/.*' - id: shellcheck name: Lint shell scripts entry: tools/shellcheck.sh language: script types: [shell] - exclude: 'vllm/third_party/.*' - id: png-lint name: Lint PNG exports from excalidraw entry: tools/png-lint.sh language: script types: [png] - exclude: 'vllm/third_party/.*' - id: signoff-commit name: Sign-off Commit entry: bash @@ -116,13 +105,11 @@ repos: language: system verbose: true stages: [commit-msg] - exclude: 'vllm/third_party/.*' - id: check-spdx-header name: Check SPDX headers entry: python tools/check_spdx_header.py language: python types: [python] - exclude: 'vllm/third_party/.*' - id: check-filenames name: Check for spaces in all filenames entry: bash @@ -132,7 +119,6 @@ repos: language: system always_run: true pass_filenames: false - exclude: 'vllm/third_party/.*' # Keep `suggestion` last - id: suggestion name: Suggestion @@ -140,5 +126,4 @@ repos: language: system verbose: true pass_filenames: false - exclude: 'vllm/third_party/.*' # Insert new entries above the `suggestion` entry From 4ef625ac7166e20eae832b7c88ed59d65130d0f6 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Thu, 27 Feb 2025 09:01:21 -0800 Subject: [PATCH 266/317] [bugfix] Fix profiling for RayDistributedExecutor (#13945) Signed-off-by: Rui Qiao --- vllm/executor/ray_distributed_executor.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index c3b41d1c1134..2accb9e17f3d 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -309,19 +309,24 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): ",".join(map(str, node_gpus[node_id])), } for (node_id, _) in worker_node_and_gpu_ids] + # Environment variables to copy from driver to workers + env_vars_to_copy = [ + "VLLM_ATTENTION_BACKEND", "TPU_CHIPS_PER_HOST_BOUNDS", + "TPU_HOST_BOUNDS", "VLLM_USE_V1", "VLLM_TRACE_FUNCTION", + "VLLM_TORCH_PROFILER_DIR", "VLLM_TEST_ENABLE_EP" + ] + + # Copy existing env vars to each worker's args for args in all_args_to_update_environment_variables: - # some carry-over env vars from the driver # TODO: refactor platform-specific env vars - for name in [ - "VLLM_ATTENTION_BACKEND", - "TPU_CHIPS_PER_HOST_BOUNDS", - "TPU_HOST_BOUNDS", - "VLLM_USE_V1", - "VLLM_TRACE_FUNCTION", - ]: + for name in env_vars_to_copy: if name in os.environ: args[name] = os.environ[name] + logger.info( + "Copying the following environment variables to workers: %s", + [v for v in env_vars_to_copy if v in os.environ]) + self._env_vars_for_all_workers = ( all_args_to_update_environment_variables) From d805e8d7ff86da7c661c35e8f1caca741f69c662 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Thu, 27 Feb 2025 19:16:12 +0200 Subject: [PATCH 267/317] =?UTF-8?q?Update=20LMFE=20version=20to=20v0.10.11?= =?UTF-8?q?=20to=20support=20new=20versions=20of=20transforme=E2=80=A6=20(?= =?UTF-8?q?#13930)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 942c3e039eaf..fb84d6d9e7b6 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -17,7 +17,7 @@ prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer >= 0.10.9, < 0.11 +lm-format-enforcer >= 0.10.11, < 0.11 outlines == 0.1.11 lark == 1.2.2 xgrammar == 0.1.11; platform_machine == "x86_64" From 41ea5429894c9e2945262ac07eb6f97898f9ac93 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 28 Feb 2025 01:30:39 +0800 Subject: [PATCH 268/317] [Bugfix] Fix qwen2.5-vl overflow issue (#13968) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/minicpmo.py | 11 +++-------- vllm/model_executor/models/qwen2_5_vl.py | 7 ++++++- vllm/model_executor/models/utils.py | 10 ++++++++++ vllm/model_executor/models/whisper.py | 9 +++------ 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index e354e5323327..e6111f46143d 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -47,7 +47,7 @@ MiniCPMVMultiModalDataParser, MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, _minicpmv_field_config) -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -469,13 +469,8 @@ def forward( training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) + if hidden_states.dtype == torch.float16: + hidden_states = cast_overflow_tensors(hidden_states) outputs = (hidden_states, ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 858cf28d2b87..0dbff665b5d3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -63,7 +63,7 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) -from .utils import (AutoWeightsLoader, WeightsMapper, +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) from .vision import get_vit_attn_backend @@ -641,6 +641,11 @@ def forward( cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb) + # For Qwen2.5-VL-3B, float16 will overflow at last block + # for long visual tokens sequences. + if hidden_states.dtype == torch.float16: + hidden_states = cast_overflow_tensors(hidden_states) + # adapter hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index fff4be34ddbe..f9aa5da39a5f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -641,3 +641,13 @@ def extract_layer_index(layer_name: str) -> int: assert len(int_vals) == 1, (f"layer name {layer_name} should" " only contain one integer") return int_vals[0] + + +def cast_overflow_tensors( + tensors: torch.Tensor, + offset: float = 1000, +) -> torch.Tensor: + if tensors.isinf().any() or tensors.isnan().any(): + clamp_value = torch.finfo(tensors.dtype).max - offset + tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) + return tensors diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index e5f77e08c403..a2eefbc6d899 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -35,7 +35,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .interfaces import SupportsMultiModal, SupportsTranscription -from .utils import AutoWeightsLoader, WeightsMapper, make_layers +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, + make_layers) logger = init_logger(__name__) @@ -285,11 +286,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - if hidden_states.isinf().any() or hidden_states.isnan().any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) + hidden_states = cast_overflow_tensors(hidden_states) return hidden_states From af836bee011e57d9eaa39bcb48f8480059ca8cc7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 28 Feb 2025 01:44:25 +0800 Subject: [PATCH 269/317] [VLM] Generalized prompt updates for multi-modal processor (#13964) Signed-off-by: DarkLight1337 --- docs/source/contributing/model/multimodal.md | 26 +- docs/source/design/mm_processing.md | 23 +- tests/multimodal/test_processing.py | 210 ++++---- vllm/model_executor/models/aria.py | 12 +- vllm/model_executor/models/blip2.py | 24 +- vllm/model_executor/models/chameleon.py | 14 +- vllm/model_executor/models/deepseek_vl2.py | 10 +- vllm/model_executor/models/florence2.py | 24 +- vllm/model_executor/models/fuyu.py | 12 +- vllm/model_executor/models/glm4v.py | 11 +- vllm/model_executor/models/h2ovl.py | 11 +- vllm/model_executor/models/idefics3.py | 10 +- vllm/model_executor/models/internvl.py | 13 +- vllm/model_executor/models/llava.py | 23 +- .../model_executor/models/llava_next_video.py | 11 +- vllm/model_executor/models/llava_onevision.py | 20 +- vllm/model_executor/models/minicpmo.py | 9 +- vllm/model_executor/models/minicpmv.py | 11 +- vllm/model_executor/models/mllama.py | 10 +- vllm/model_executor/models/molmo.py | 35 +- vllm/model_executor/models/nvlm_d.py | 13 +- vllm/model_executor/models/phi3v.py | 21 +- .../models/prithvi_geospatial_mae.py | 24 +- vllm/model_executor/models/qwen2_audio.py | 12 +- vllm/model_executor/models/qwen2_vl.py | 16 +- vllm/model_executor/models/qwen_vl.py | 15 +- vllm/model_executor/models/ultravox.py | 11 +- vllm/model_executor/models/whisper.py | 10 +- vllm/multimodal/processing.py | 486 +++++++++++------- 29 files changed, 635 insertions(+), 492 deletions(-) diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 990eac82d516..c8046d248506 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -720,13 +720,13 @@ def _get_mm_fields_config( ::::: -### Prompt replacements +### Prompt updates -Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to -return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances. +Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to +return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances. -Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace -operation performed by the HF processor. +Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation +(e.g.: insertion, replacement) performed by the HF processor. ::::{tab-set} :::{tab-item} Basic example: LLaVA @@ -743,15 +743,15 @@ for sample in text: ``` It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`). -Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows: +Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows: ```python -def _get_prompt_replacements( +def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, -) -> list[PromptReplacement]: +) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -859,7 +859,7 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( ) ``` -To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails` +To accommodate this, instead of a string you can return an instance of `PromptUpdateDetails` with different `full` and `feature` attributes: ```python @@ -878,7 +878,7 @@ def get_replacement_fuyu(item_idx: int): image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptReplacementDetails( + return PromptUpdateDetails( full=image_tokens + [bos_token_id], features=image_tokens, ) @@ -888,12 +888,12 @@ Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the we can search for it to conduct the replacement at the start of the string: ```python -def _get_prompt_replacements( +def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, -) -> list[PromptReplacement]: +) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id assert isinstance(bos_token_id, int) @@ -913,7 +913,7 @@ def _get_prompt_replacements( image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptReplacementDetails( + return PromptUpdateDetails( full=image_tokens + [bos_token_id], features=image_tokens, ) diff --git a/docs/source/design/mm_processing.md b/docs/source/design/mm_processing.md index a0d01205e638..2a4dac786d4b 100644 --- a/docs/source/design/mm_processing.md +++ b/docs/source/design/mm_processing.md @@ -6,11 +6,16 @@ To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefi Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`: -## Prompt Replacement Detection +## Prompt Update Detection -One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `` for a single image) with feature placeholder tokens (e.g. `...`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs. +One of the main responsibilies of HF processor is to update the prompt with placeholder tokens. For example: -In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt. +- Insert feature placeholder tokens (e.g. `...`, the number of which equals to the feature size) at the start of the string. +- Replace existing input placeholder tokens (e.g. `` for a single image) with feature placeholder tokens (e.g. `...`, the number of which equals to the feature size). + +The information about which tokens have been updated is key to finding the correspondence between placeholder feature tokens and multi-modal inputs. + +In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptUpdate` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens. ## Tokenized Prompt Inputs @@ -22,7 +27,7 @@ Consider that HF processors follow these main steps: 1. Tokenize the text 2. Process multi-modal inputs -3. Perform prompt replacement +3. Perform prompt updates And we require that: @@ -44,16 +49,16 @@ Moreover, since the tokenized text has not passed through the HF processor, we h We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. -(mm-automatic-prompt-replacement)= +(mm-automatic-prompt-updating)= -### Automatic prompt replacement +### Automatic prompt updating We address the second issue by implementing model-agnostic code in -{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. +{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates` to automatically update the prompt with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. ### Summary -With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`. +With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`. ## Processor Output Caching @@ -61,4 +66,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238) When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache. -Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other. +Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#mm-automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other. diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index c2fbe83abc83..878b15925006 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -14,12 +14,12 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (PlaceholderFeaturesInfo, - PromptReplacement, + PromptInsertion, PromptReplacement, + apply_text_matches, + apply_token_matches, find_mm_placeholders, find_text_matches, find_token_matches, - iter_token_matches, - replace_text_matches, - replace_token_matches) + iter_token_matches) # yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import (AnyTokenizer, @@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected): { "pattern_1": [], "pattern_2": [], - } + }, ), ( [32000, 32000, 32000, 32000], @@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected): ), ], ) +@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) # yapf: enable -def test_find_token_matches(prompt, target_by_key, expected_by_key): +def test_find_token_matches( + prompt, + target_by_key, + expected_by_key, + update_type, +): # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - prompt_repls = [ - PromptReplacement(key, target, []).bind(mock_tokenizer) + prompt_updates = [ + update_type(key, target, []).bind(mock_tokenizer) for key, target in target_by_key.items() ] - result = find_token_matches(prompt, prompt_repls) + result = find_token_matches(prompt, prompt_updates) # Only displayed on error print("result:", result) @@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): ), ], ) +@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) # yapf: enable -def test_find_text_matches(prompt, target_by_key, expected_by_key): +def test_find_text_matches( + prompt, + target_by_key, + expected_by_key, + update_type, +): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - prompt_repls = [ - PromptReplacement(key, target, []).bind(mock_tokenizer) + prompt_updates = [ + update_type(key, target, []).bind(mock_tokenizer) for key, target in target_by_key.items() ] - result = find_text_matches(prompt, prompt_repls) + result = find_text_matches(prompt, prompt_updates) # Only displayed on error print("result:", result) @@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # yapf: disable @pytest.mark.parametrize( - ("prompt", "target_by_key", "repl_by_key"), + ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ ( "Image:Image:!", @@ -300,58 +312,66 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Test dynamic replacement (beyond the form of `unit * count`) "pattern_3": "?!?", }, + { + PromptInsertion: { + 0: "Image:Image:!", + 1: "Image:Image:!?!?", + 2: "Image:Image:!?!??!?", # noqa: E501 + }, + PromptReplacement: { + 0: "Image:Image:!", + 1: "Image:?!?", + 2: "?!?", + }, + }, ), ] ) -@pytest.mark.parametrize( - ("mm_count", "expected"), - [ - (0, "Image:Image:!"), - (1, "Image:?!?"), - (2, "?!?"), - ] -) # yapf: enable -def test_find_replace_text( +def test_find_update_text( prompt, target_by_key, repl_by_key, - mm_count, - expected, + expected_by_update_type_mm_count, ): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - mm_prompt_repls = { - key: [ - PromptReplacement(key, target, - repl_by_key[key]).bind(mock_tokenizer) - ] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_text_matches(prompt, prompt_repls) - for key, prompt_repls in mm_prompt_repls.items() - } - - result = replace_text_matches( - prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, - ) - - # Only displayed on error - print("mm_matches:", mm_matches) - print("result:", result) - - # Manually constructed results - assert result == expected + for ( + update_type, + expected_by_mm_count, + ) in expected_by_update_type_mm_count.items(): + mm_prompt_updates = { + key: + [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] + for key, target in target_by_key.items() + } + mm_matches = { + key: find_text_matches(prompt, updates) + for key, updates in mm_prompt_updates.items() + } + + for mm_count, expected in expected_by_mm_count.items(): + result = apply_text_matches( + prompt, + mm_matches, + {key: mm_count + for key in repl_by_key}, + ) + + # Only displayed on error + print("update_type:", update_type) + print("mm_count:", mm_count) + print("mm_matches:", mm_matches) + print("result:", result) + + # Manually constructed results + assert result == expected # yapf: disable @pytest.mark.parametrize( - ("prompt", "target_by_key", "repl_by_key"), + ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ # Tokenized test cases of `test_find_replace_text` # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf @@ -372,53 +392,61 @@ def test_find_replace_text( # Test dynamic replacement (beyond the form of `unit * count`) "pattern_3": [1550, 918, 1550], }, + { + PromptInsertion: { + 0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + 1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501 + 2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501 + }, + PromptReplacement: { + 0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + 1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], # noqa: E501 + 2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], + }, + }, ), ] ) -@pytest.mark.parametrize( - ("mm_count", "expected"), - [ - (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), - (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]), - (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]), - ] -) # yapf: enable -def test_find_replace_tokens( +def test_find_update_tokens( prompt, target_by_key, repl_by_key, - mm_count, - expected, + expected_by_update_type_mm_count, ): # Should not be used since there is nothing to convert to tokens mock_tokenizer = cast(AnyTokenizer, object()) - mm_prompt_repls = { - key: [ - PromptReplacement(key, target, - repl_by_key[key]).bind(mock_tokenizer) - ] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_token_matches(prompt, prompt_repls) - for key, prompt_repls in mm_prompt_repls.items() - } - - result = replace_token_matches( - prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, - ) - - # Only displayed on error - print("mm_matches:", mm_matches) - print("result:", result) - - # Manually constructed results - assert result == expected + for ( + update_type, + expected_by_mm_count, + ) in expected_by_update_type_mm_count.items(): + mm_prompt_updates = { + key: + [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] + for key, target in target_by_key.items() + } + mm_matches = { + key: find_token_matches(prompt, updates) + for key, updates in mm_prompt_updates.items() + } + + for mm_count, expected in expected_by_mm_count.items(): + result = apply_token_matches( + prompt, + mm_matches, + {key: mm_count + for key in repl_by_key}, + ) + + # Only displayed on error + print("update_type:", update_type) + print("mm_count:", mm_count) + print("mm_matches:", mm_matches) + print("result:", result) + + # Manually constructed results + assert result == expected # yapf: disable @@ -524,22 +552,24 @@ def test_find_replace_tokens( ), ] ) +@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) # yapf: enable def test_find_mm_placeholders( repl_by_key, prompt, expected, + update_type, ): # Should not be used since there is nothing to convert to tokens mock_tokenizer = cast(AnyTokenizer, object()) - mm_prompt_repls = { - key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)] + mm_prompt_updates = { + key: [update_type(key, [], repl).bind(mock_tokenizer)] for key, repl in repl_by_key.items() } result = find_mm_placeholders( - mm_prompt_repls, + mm_prompt_updates, prompt, # Effectively match all occurrences in the prompt {key: 3 diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 656e9b037d96..061a9a5bd2bc 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -26,7 +25,8 @@ NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -457,12 +457,12 @@ def _get_mm_fields_config( pixel_mask=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 23bb3cd07f1d..61f2f8974d91 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -19,8 +19,8 @@ NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + BaseProcessingInfo, PromptInsertion, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -474,30 +474,24 @@ def _get_mm_fields_config( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() - bos_token_id = tokenizer.bos_token_id - assert isinstance(bos_token_id, int) - image_token_id = vocab[""] num_image_tokens = self.info.get_num_image_tokens() image_tokens = [image_token_id] * num_image_tokens return [ - PromptReplacement( + PromptInsertion( modality="image", - target=[bos_token_id], - replacement=PromptReplacementDetails( - full=image_tokens + [bos_token_id], - features=image_tokens, - ), + target="", + insertion=image_tokens, ) ] diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index e91399b2674d..9d597e240951 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -35,7 +35,7 @@ from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -141,12 +141,12 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values=MultiModalFieldConfig.batched("image")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -162,7 +162,7 @@ def _get_prompt_replacements( PromptReplacement( modality="image", target=[image_token_id], - replacement=PromptReplacementDetails( + replacement=PromptUpdateDetails( full=([image_start_id] + image_tokens + [image_end_id]), features=image_tokens, ), @@ -371,7 +371,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is None: residual = hidden_states diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index ea217e244404..3d2e452bb50e 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -3,9 +3,9 @@ # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py """Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" import math +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -26,7 +26,7 @@ ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, @@ -281,12 +281,12 @@ def _get_mm_fields_config( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = hf_processor.image_token_id diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index b71d0de8d707..c51fcf3d438b 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections import OrderedDict +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict, - Set, Tuple, TypedDict, Union) +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -24,8 +25,7 @@ from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, - PromptReplacement, - PromptReplacementDetails) + PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -803,7 +803,7 @@ def get_dummy_processor_inputs( class Florence2MultiModalProcessor( EncDecMultiModalProcessor[Florence2ProcessingInfo]): - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, @@ -850,26 +850,22 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values=MultiModalFieldConfig.batched("image")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() pad_token_id = hf_config.pad_token_id - bos_token_id = hf_config.bos_token_id num_image_tokens = self.info.get_max_image_tokens() image_tokens = [pad_token_id] * num_image_tokens return [ - PromptReplacement( + PromptInsertion( modality="image", - target=[bos_token_id], - replacement=PromptReplacementDetails( - full=image_tokens + [bos_token_id], - features=image_tokens, - ), + target="", + insertion=image_tokens, ) ] diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7e4cc6bac5e6..581ec54b2cab 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -17,8 +17,8 @@ # limitations under the License. """ PyTorch Fuyu model.""" import math -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict) +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Literal, Optional, Set, Tuple, TypedDict import torch import torch.nn as nn @@ -37,7 +37,7 @@ MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -203,12 +203,12 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict(image_patches=MultiModalFieldConfig.batched("image")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id assert isinstance(bos_token_id, int) @@ -228,7 +228,7 @@ def get_replacement_fuyu(item_idx: int): image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptReplacementDetails( + return PromptUpdateDetails( full=image_tokens + [bos_token_id], features=image_tokens, ) diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 48543c5642ea..ca34c4f8d53f 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -4,7 +4,8 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from typing import Literal, Mapping, Optional, TypedDict, Union +from collections.abc import Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union import torch from torch import nn @@ -32,7 +33,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, BatchFeature, MultiModalFieldConfig, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig @@ -480,7 +481,7 @@ def get_dummy_processor_inputs( class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, @@ -495,12 +496,12 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values=MultiModalFieldConfig.batched("image")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() boi_token_id = hf_config.boi_token_id diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index bab9c256b9aa..d336d7521a27 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -7,7 +7,8 @@ # Copyright (c) 2024 H2O.AI # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- -from typing import Mapping, Optional +from collections.abc import Mapping, Sequence +from typing import Optional import torch from PIL import Image @@ -20,7 +21,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -487,12 +488,12 @@ def __init__(self, f"{type(self).__name__} does not support processing cache with " "multi-image support enabled.") - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if "image_num_patches" in out_mm_kwargs: @@ -527,7 +528,7 @@ def get_replacement_internvl(item_idx: int): if num_patches is not None: assert isinstance(num_patches, int) - return PromptReplacementDetails( + return PromptUpdateDetails( full=hf_processor.get_image_repl_full(feature_size, num_patches), features=hf_processor.get_image_repl_features( diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 0a8763cf910c..286a75339d20 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -16,8 +16,8 @@ """Inference-only Idefics3 model compatible with HuggingFace weights.""" import math -from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.utils.checkpoint @@ -41,7 +41,7 @@ BaseProcessingInfo, MultiModalDataItems, MultiModalFieldConfig, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -274,12 +274,12 @@ def _get_mm_fields_config( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token.content diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 52ddb279cca3..48c2eb8c9f6e 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -7,9 +7,10 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, TypeVar, Union) +from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar, + Union) import torch import torch.nn as nn @@ -31,7 +32,7 @@ ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -599,12 +600,12 @@ def _get_mm_fields_config( image_token_id=MultiModalFieldConfig.shared("image", num_images), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if "image_num_patches" in out_mm_kwargs: @@ -636,7 +637,7 @@ def get_replacement_internvl(item_idx: int): if num_patches is not None: assert isinstance(num_patches, int) - return PromptReplacementDetails( + return PromptUpdateDetails( full=hf_processor.get_image_repl_full(feature_size, num_patches), features=hf_processor.get_image_repl_features( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 72b1591306f2..8318a496e608 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, TypeVar, Union) +from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, + TypedDict, TypeVar, Union) import torch import torch.nn as nn @@ -31,7 +32,7 @@ ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -222,12 +223,12 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: raise NotImplementedError - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -328,12 +329,12 @@ def _get_mm_fields_config( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() tokenizer = self.info.get_tokenizer() @@ -789,7 +790,7 @@ def get_replacement_mantis(item_idx: int): ")", # 3 tokens ]) - mantis_mm_repls = self._bind_and_group_repls([ + mantis_mm_repls = self._bind_and_group_updates([ PromptReplacement( modality="image", target=[image_token_id] * num_image_tokens, @@ -797,18 +798,18 @@ def get_replacement_mantis(item_idx: int): ) ]) - prompt_ids, prompt, _ = self._apply_prompt_replacements( + prompt_ids, prompt, _ = self._apply_prompt_updates( result["prompt_token_ids"], mantis_mm_repls, mm_item_counts, ) - unbound_orig_repls = self._get_prompt_replacements( + unbound_orig_repls = self._get_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - orig_repls = self._bind_and_group_repls(unbound_orig_repls) + orig_repls = self._bind_and_group_updates(unbound_orig_repls) mm_placeholders = self._find_mm_placeholders( orig_repls, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 807d6977ed40..ca9406657df5 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -21,7 +21,8 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -183,12 +184,12 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values_videos=MultiModalFieldConfig.batched("video")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index e57eea4286e9..e87ef24ce2ca 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, Union) +from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -22,7 +23,7 @@ NestedTensors) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import PromptReplacement +from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -347,13 +348,13 @@ def _call_hf_processor( ) return BatchFeature(combined_outputs) - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], ) -> bool: - base_result = super()._hf_processor_applies_repl( + base_result = super()._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -361,13 +362,13 @@ def _hf_processor_applies_repl( return base_result and mm_items.get_count("video", strict=False) == 0 - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: - image_repls = super()._get_prompt_replacements( + ) -> Sequence[PromptUpdate]: + image_repls = super()._get_prompt_updates( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, out_mm_kwargs=out_mm_kwargs, @@ -392,7 +393,8 @@ def get_video_replacement(item_idx: int): return [video_token_id] * num_video_tokens - return image_repls + [ + return [ + *image_repls, PromptReplacement( modality="video", target=[video_token_id], diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index e6111f46143d..f35c230c0cea 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -22,9 +22,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Set, Tuple, TypedDict, Union) +from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, + TypedDict, Union) import torch from torch import nn @@ -356,10 +357,10 @@ def get_prompt_texts_by_modality(self, inputs: Dict[str, object], inputs["audio"]["audio_lens"][index]) return super().get_prompt_texts_by_modality(inputs, modality, index) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: + out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]: placeholder = { "image": self.info.image_pattern, "video": self.info.video_pattern, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2699958331f3..fb6ea53acf9e 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -25,9 +25,10 @@ import math import re from collections import Counter +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property, partial -from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Set, Tuple, TypedDict, Union) +from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, + TypedDict, Union) import numpy as np import torch @@ -732,7 +733,7 @@ def _call_hf_processor( } } - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, @@ -740,10 +741,10 @@ def _hf_processor_applies_repl( ) -> bool: return False - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: + out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]: placeholder = { "image": self.info.image_pattern, "video": self.info.video_pattern, diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 459928fe3fb0..36e653e41e1b 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -15,8 +15,8 @@ # limitations under the License. """PyTorch Mllama model.""" import math -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import numpy as np import torch @@ -59,7 +59,7 @@ MultiModalDataDict, MultiModalDataItems) from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .clip import CLIPMLP @@ -243,12 +243,12 @@ def create_encoder_prompt( image_token_id = self.info.get_hf_config().image_token_index return [image_token_id] * num_images - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: token_per_chunk = self.info.get_token_per_chunk_from_config() image_token_id = self.info.get_hf_config().image_token_index diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index cc4d38d8740b..60af103189f8 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union, cast) +from typing import List, Optional, Set, Tuple, TypedDict, Union, cast import numpy as np import torch @@ -46,8 +46,8 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + BaseProcessingInfo, PromptInsertion, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import JSONTree, json_map_leaves @@ -1190,6 +1190,8 @@ def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper: return MolmoProcessorWrapper(processor) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + # TODO: Investigate different `embed_is_patch` between cache/no-cache + # in multi-image case return {"image": 1} def get_mm_max_tokens_per_item( @@ -1328,25 +1330,18 @@ def _get_mm_fields_config( img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() image_token_length_w = processor.image_token_length_w image_token_length_h = processor.image_token_length_h pooling_size = processor.pooling_size - user_str = "User:" - if processor.always_start_with_space: - user_str = " " + user_str - - user_tokens = tokenizer.encode(user_str, add_special_tokens=False) - img_patch_id = processor.image_patch_id img_col_id = processor.im_col_id img_start_id = processor.im_start_id @@ -1356,7 +1351,7 @@ def _get_prompt_replacements( extra_joint = ([img_start_id] + extra_row * image_token_length_h + [img_end_id]) - def get_replacement_molmo(item_idx: int): + def get_insertion_molmo(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) @@ -1371,17 +1366,13 @@ def get_replacement_molmo(item_idx: int): ((nrows + 1) // pooling_size) + [img_end_id]) image_tokens = extra_joint + joint - - return PromptReplacementDetails( - full=image_tokens + user_tokens, - features=image_tokens, - ) + return image_tokens return [ - PromptReplacement( + PromptInsertion( modality="image", - target=user_str, - replacement=get_replacement_molmo, + target="<|endoftext|>", + insertion=get_insertion_molmo, ) ] diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 5de8eeb3fffe..1e1760491a97 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -6,7 +6,8 @@ # Copyright (c) 2024 NVIDIA # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- -from typing import Mapping, Optional +from collections.abc import Mapping, Sequence +from typing import Optional import torch import torch.nn as nn @@ -17,8 +18,8 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) -from vllm.multimodal.processing import (PromptReplacement, - PromptReplacementDetails) +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import ProcessorInputs from .intern_vit import InternVisionModel @@ -142,12 +143,12 @@ def get_dummy_processor_inputs( class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if "image_num_patches" in out_mm_kwargs: @@ -179,7 +180,7 @@ def get_replacement_nvlm(item_idx: int): if num_patches is not None: assert isinstance(num_patches, int) - return PromptReplacementDetails( + return PromptUpdateDetails( full=hf_processor.get_image_repl_full(feature_size, num_patches) + "\n", features=hf_processor.get_image_repl_features( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0f45f131065a..0fd4b3c70211 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -38,11 +38,10 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - BoundPromptReplacement, + BaseProcessingInfo, BoundPromptUpdate, PlaceholderFeaturesInfo, - PromptReplacement, - PromptReplacementDetails) + PromptReplacement, PromptUpdate, + PromptUpdateDetails) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -420,12 +419,12 @@ def _get_mm_fields_config( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_tokens: list[str] = hf_processor.img_tokens # type: ignore @@ -449,7 +448,7 @@ def get_replacement_phi3v(item_idx: int): image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens - return PromptReplacementDetails( + return PromptUpdateDetails( full=image_tokens + [bos_token_id], features=image_tokens, ) @@ -464,15 +463,15 @@ def get_replacement_phi3v(item_idx: int): ) for image_token in image_tokens[:num_images] ] - def _apply_prompt_replacements( + def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], mm_item_counts: Mapping[str, int], ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - token_ids, text, placeholders = super()._apply_prompt_replacements( + token_ids, text, placeholders = super()._apply_prompt_updates( token_ids=token_ids, - mm_prompt_repls=mm_prompt_repls, + mm_prompt_updates=mm_prompt_updates, mm_item_counts=mm_item_counts, ) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 3d95e949e71d..bfa90e42733d 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM/NASA Prithvi Geospatial model.""" -from typing import Iterable, Mapping, Optional, Set, Tuple, Union +from collections.abc import Iterable, Mapping, Sequence +from typing import Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -32,7 +33,7 @@ MultiModalInputs, MultiModalKwargs) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import (IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput) @@ -44,7 +45,7 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - pass + return {"image": 0} class PrithviGeoSpatialMAEInputBuilder( @@ -78,20 +79,13 @@ def _get_mm_fields_config( location_coords=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: - pass - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - pass + ) -> Sequence[PromptUpdate]: + return [] def apply( self, @@ -120,7 +114,7 @@ def apply( class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): """ Prithvi Masked Autoencoder""" - def _instantiate_model(self, config: dict) -> nn.Module | None: + def _instantiate_model(self, config: dict) -> Optional[nn.Module]: # We might be able/need to support different tasks with this same model if config["task_args"]["task"] == "SemanticSegmentationTask": @@ -158,7 +152,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): "by PrithviGeospatialMAE.") def _parse_and_validate_multimodal_data( - self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]: + self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: pixel_values = kwargs.pop("pixel_values", None) if not isinstance(pixel_values, torch.Tensor): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index f0dc8573ee14..1c3107e76eb6 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -21,9 +21,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from typing import Any, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -43,7 +43,7 @@ MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -188,12 +188,12 @@ def _get_mm_fields_config( feature_attention_mask=MultiModalFieldConfig.batched("audio"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -230,7 +230,7 @@ def get_replacement_qwen2_audio(item_idx: int): audio_tokens = [audio_token_id] * num_features - return PromptReplacementDetails( + return PromptUpdateDetails( full=[audio_bos_id] + audio_tokens + [audio_eos_id], features=audio_tokens, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 849ef7293bb7..cb92fcbe9fa1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -23,9 +23,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property, partial -from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set, - Tuple, Type, TypedDict, Union) +from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, + Union) import torch import torch.nn as nn @@ -61,7 +62,8 @@ ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors @@ -169,7 +171,7 @@ def __init__( self, in_features: int, hidden_features: int, - act_layer: Type[nn.Module] = QuickGELU, + act_layer: type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -383,7 +385,7 @@ def __init__( dim: int, num_heads: int, mlp_ratio: float, - act_layer: Type[nn.Module] = QuickGELU, + act_layer: type[nn.Module] = QuickGELU, norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -987,12 +989,12 @@ def _call_hf_processor( self.info._get_image_processor_kwargs(**mm_kwargs), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index e0d8bf2fa3d2..b8aaa7f1db1b 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -9,9 +9,10 @@ import math import re import unicodedata +from collections.abc import Collection, Mapping, Sequence +from collections.abc import Set as AbstractSet from functools import lru_cache, partial -from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping, - Optional, TypedDict, Union) +from typing import Callable, List, Literal, Optional, TypedDict, Union import torch from torch import nn @@ -36,7 +37,7 @@ from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -606,7 +607,7 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, @@ -624,12 +625,12 @@ def _get_mm_fields_config( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() special_tokens: dict[str, int] = tokenizer.special_tokens # type: ignore @@ -646,7 +647,7 @@ def _get_prompt_replacements( PromptReplacement( modality="image", target=[img_start_id, img_end_id], - replacement=PromptReplacementDetails( + replacement=PromptUpdateDetails( full=[img_start_id] + image_tokens + [img_end_id], features=image_tokens, ), diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b8d4aef252e5..d47f924ea19e 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,9 +3,9 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" import math +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.utils.checkpoint @@ -29,7 +29,8 @@ NestedTensors) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -197,12 +198,12 @@ def _get_mm_fields_config( audio_embeds=MultiModalFieldConfig.batched("audio"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index a2eefbc6d899..2da8c5c8b0e2 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Optional, Set, Tuple, TypedDict, Union import torch from torch import nn @@ -31,7 +31,7 @@ MultiModalDataParser) from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .interfaces import SupportsMultiModal, SupportsTranscription @@ -623,12 +623,12 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict(input_features=MultiModalFieldConfig.batched("audio")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: num_tokens = self.info.get_max_audio_tokens() return [ PromptReplacement( diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 60b000e2b34f..ac33af7c10c7 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -6,11 +6,14 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, Sequence) from dataclasses import dataclass, field +from enum import Enum from functools import lru_cache +from itertools import groupby from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, - TypeVar, Union) + TypeVar, Union, cast) from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from typing_extensions import assert_never import vllm.envs as envs from vllm.inputs import InputProcessingContext @@ -38,35 +41,129 @@ @dataclass -class PromptReplacementDetails: - """Details about the replacement token sequence or text.""" +class PromptUpdateDetails: + """Details about the token sequence or text that are part of the update.""" full: PromptSeq - """The full replacement.""" + """The full content.""" features: PromptSeq """ - The part of the replacement that corresponds to feature placeholders; + The part of the content that corresponds to feature placeholders; this will be replaced by the output of the vision encoder during model inference. """ @staticmethod - def from_seq(seq: PromptSeq) -> "PromptReplacementDetails": - return PromptReplacementDetails(full=seq, features=seq) + def from_seq(seq: PromptSeq) -> "PromptUpdateDetails": + return PromptUpdateDetails(full=seq, features=seq) -PromptRepl = Union[PromptSeq, PromptReplacementDetails] +PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] """ -The replacement token sequence or text. +The token sequence or text that are part of the update. -If only part of the replacement corresponds to feature placeholders, you can -use :class:`PromptReplacementDetails` to specify which part. +If only part of the content corresponds to feature placeholders, you can +use :class:`PromptUpdateDetails` to specify which part. """ +PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], + PromptUpdateInfo] +""" +Given the index of the processed item within :attr:`modality`, +output the corresponding token sequence (or text). + +For convenience, you can directly pass in the token sequence (or text) +instead of a function if it does not depend on the input. +""" + + +class UpdateMode(str, Enum): + INSERT = "insert" + REPLACE = "replace" + + +@dataclass +class PromptUpdate: + """ + Defines how to update a prompt with placeholder tokens. + """ + + modality: str + """The modality for which the update is made.""" + + target: PromptSeq + """The token sequence (or text) to update.""" + + @property + @abstractmethod + def content(self) -> PromptUpdateContent: + """The placeholder tokens that are part of the update.""" + raise NotImplementedError + + @property + @abstractmethod + def mode(self) -> UpdateMode: + """Defines how to update the prompt.""" + raise NotImplementedError + + def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": + return BoundPromptUpdate( + _origin=self, + tokenizer=tokenizer, + ) + @dataclass -class PromptReplacement: +class PromptInsertion(PromptUpdate): + """ + Defines how to insert placeholder tokens into a prompt. + + Example: + + For each image, insert a number of ```` feature placeholders + equal to the feature size of the vision encoder at the start of the + prompt: + + .. code-block:: python + + PromptInsertion( + modality="image", + target="", + insertion="" * image_feature_size, + ) + + As above, but insert after the ```` token: + + .. code-block:: python + + PromptInsertion( + modality="image", + target="", + insertion="" * image_feature_size, + ) + """ + + insertion: PromptUpdateContent = field(repr=False) + """ + Given the index of the processed item within :attr:`modality`, + output the token sequence (or text) to insert right after :attr:`target`. + + For convenience, you can directly pass in the token sequence (or text) + instead of a function if it does not depend on the input. + """ + + @property + def content(self) -> PromptUpdateContent: + return self.insertion + + @property + def mode(self) -> UpdateMode: + return UpdateMode.INSERT + + +@dataclass +class PromptReplacement(PromptUpdate): """ Defines how to replace portions of an input prompt with placeholder tokens. @@ -93,7 +190,7 @@ class PromptReplacement: PromptReplacement( modality="image", target="", - replacement=PromptReplacementDetails( + replacement=PromptUpdateDetails( full="".join([ "", "" * image_feature_size, @@ -111,7 +208,7 @@ class PromptReplacement: PromptReplacement( modality="image", target=[image_token_id], - replacement=PromptReplacementDetails( + replacement=PromptUpdateDetails( full=([image_bos_id] + [image_token_id] * image_feature_size + [image_eos_id]), features=[image_token_id] * image_feature_size, @@ -119,29 +216,22 @@ class PromptReplacement: ) """ - modality: str - """The modality for which the replacement is made.""" - - target: PromptSeq - """The token sequence (or text) to find and replace.""" - - replacement: Union[Callable[[int], PromptRepl], - PromptRepl] = field(repr=False) + replacement: PromptUpdateContent = field(repr=False) """ Given the index of the processed item within :attr:`modality`, - output the replacement token sequence (or text). + output the token sequence (or text) to replace :attr:`target`. - For convenience, you can directly pass in the replacement token sequence - (or text) instead of a function if it does not depend on the input. + For convenience, you can directly pass in the token sequence (or text) + instead of a function if it does not depend on the input. """ - def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement": - return BoundPromptReplacement( - tokenizer=tokenizer, - modality=self.modality, - _target=self.target, - _replacement=self.replacement, - ) + @property + def content(self) -> PromptUpdateContent: + return self.replacement + + @property + def mode(self) -> UpdateMode: + return UpdateMode.REPLACE @lru_cache(maxsize=2048) @@ -232,64 +322,73 @@ def token_ids(self) -> list[int]: @dataclass -class _BoundPromptReplacementGroup: +class _BoundPromptContent: full: _BoundPromptSequence features: _BoundPromptSequence @dataclass -class BoundPromptReplacement: +class BoundPromptUpdate: """ - A :class:`PromptReplacement` bound to a tokenizer to automatically - convert :attr:`target` and the result of :meth:`get_replacement` between + A :class:`PromptUpdate` bound to a tokenizer to automatically convert + :attr:`target` and the result of :meth:`get_content` between token sequence and text representations. """ + _origin: PromptUpdate tokenizer: AnyTokenizer = field(repr=False) - modality: str - - _target: PromptSeq - _replacement: Union[Callable[[int], PromptRepl], - PromptRepl] = field(repr=False) def __post_init__(self) -> None: - self._replacement_cache = dict[int, _BoundPromptReplacementGroup]() + self._content_cache = dict[int, _BoundPromptContent]() + + @property + def modality(self) -> str: + return self._origin.modality @property def target(self) -> _BoundPromptSequence: - """The token sequence (or text) to find and replace.""" - return _BoundPromptSequence.from_seq(self.tokenizer, self._target) + """The token sequence (or text) to update.""" + return _BoundPromptSequence.from_seq(self.tokenizer, + self._origin.target) - def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup: + @property + def content(self) -> PromptUpdateContent: + """The placeholder tokens that are part of the update.""" + return self._origin.content + + @property + def mode(self) -> UpdateMode: + """Defines how to update the prompt.""" + return self._origin.mode + + def get_content(self, item_idx: int) -> _BoundPromptContent: """ Given the index of the processed item within :attr:`modality`, - output the replacement token sequence (or text). + output the token sequence (or text) to update. """ - replacement = self._replacement - if callable(replacement): + content = self.content + if callable(content): cache_key = item_idx - if cache_key in self._replacement_cache: - return self._replacement_cache[cache_key] + if cache_key in self._content_cache: + return self._content_cache[cache_key] - replacement = replacement(item_idx) + content = content(item_idx) else: cache_key = None - if not isinstance(replacement, PromptReplacementDetails): - replacement = PromptReplacementDetails.from_seq(replacement) + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) bound_full = _BoundPromptSequence.from_seq(self.tokenizer, - replacement.full) + content.full) bound_features = _BoundPromptSequence.from_seq(self.tokenizer, - replacement.features) - bound_replacement = _BoundPromptReplacementGroup( - full=bound_full, - features=bound_features, - ) + content.features) + bound_content = _BoundPromptContent(full=bound_full, + features=bound_features) if cache_key is not None: - self._replacement_cache[cache_key] = bound_replacement + self._content_cache[cache_key] = bound_content - return bound_replacement + return bound_content class _TokenMatch(NamedTuple): @@ -326,12 +425,12 @@ def iter_token_matches( @dataclass(repr=False) -class _PromptReplacementMatch(ABC): - prompt_repl: BoundPromptReplacement +class _PromptTargetMatch(ABC): + _origin: BoundPromptUpdate @property def modality(self) -> str: - return self.prompt_repl.modality + return self._origin.modality @property @abstractmethod @@ -349,7 +448,7 @@ def __repr__(self) -> str: @dataclass(repr=False) -class _PromptReplacementTokenMatch(_PromptReplacementMatch): +class _PromptTargetTokenMatch(_PromptTargetMatch): match: _TokenMatch @property @@ -362,7 +461,7 @@ def end_idx(self) -> int: @dataclass(repr=False) -class _PromptReplacementTextMatch(_PromptReplacementMatch): +class _PromptTargetTextMatch(_PromptTargetMatch): match: re.Match[str] @property @@ -394,40 +493,37 @@ def to_range(self) -> PlaceholderRange: def find_token_matches( prompt: list[int], - prompt_repls: Sequence[BoundPromptReplacement], -) -> list[_PromptReplacementTokenMatch]: - """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" + prompt_updates: Sequence[BoundPromptUpdate], +) -> Sequence[_PromptTargetMatch]: + """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" return [ - _PromptReplacementTokenMatch(prompt_repl, match) - for prompt_repl in prompt_repls - for match in iter_token_matches(prompt, prompt_repl.target.token_ids) + _PromptTargetTokenMatch(update, match) for update in prompt_updates + for match in iter_token_matches(prompt, update.target.token_ids) ] def find_text_matches( prompt: str, - prompt_repls: Sequence[BoundPromptReplacement], -) -> list[_PromptReplacementTextMatch]: - """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" + prompt_updates: Sequence[BoundPromptUpdate], +) -> Sequence[_PromptTargetMatch]: + """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" return [ - _PromptReplacementTextMatch(prompt_repl, match) - for prompt_repl in prompt_repls - for match in re.finditer(re.escape(prompt_repl.target.text), prompt) + _PromptTargetTextMatch(update, match) for update in prompt_updates + for match in re.finditer(re.escape(update.target.text), prompt) ] def _resolve_matches( prompt: PromptSeq, - mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]], -) -> list[_PromptReplacementMatch]: + mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], +) -> list[_PromptTargetMatch]: """ Resolve :code:`mm_matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ matches = [m for matches in mm_matches.values() for m in matches] - seen_matches: list[Optional[_PromptReplacementMatch]] = [None - ] * len(prompt) + seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt) for match in matches: for idx in range(match.start_idx, match.end_idx): @@ -441,74 +537,91 @@ def _resolve_matches( return sorted(matches, key=lambda x: x.start_idx) -def _replace_matches( +def _apply_matches( prompt: _S, - mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]], + mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> list[_S]: - """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" - out_seqs = list[_S]() + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" + out_seqs = list[Union[str, list[int]]]() prev_end_idx = 0 next_idx_by_modality = defaultdict[str, int](lambda: 0) - for match in _resolve_matches(prompt, mm_matches): - modality = match.modality + for (start_idx, end_idx), group in groupby( + _resolve_matches(prompt, mm_matches), + key=lambda x: (x.start_idx, x.end_idx), + ): + matches = tuple(group) + assert len(matches) == 1 - item_idx = next_idx_by_modality[modality] - if item_idx >= mm_item_counts.get(modality, 0): - continue + for match in matches: + modality = match.modality - start_idx = match.start_idx - end_idx = match.end_idx + item_idx = next_idx_by_modality[modality] + if item_idx >= mm_item_counts.get(modality, 0): + continue - repl_info = match.prompt_repl - replacement = repl_info.get_replacement(item_idx) + origin = match._origin + content = origin.get_content(item_idx) + mode = origin.mode - if isinstance(prompt, str): - repl_seq = replacement.full.text - out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) - else: - repl_seq = replacement.full.token_ids - out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) + if mode == UpdateMode.INSERT: + out_seqs.append(prompt[prev_end_idx:end_idx]) + num_inserts = mm_item_counts.get(modality, 0) + elif mode == UpdateMode.REPLACE: + out_seqs.append(prompt[prev_end_idx:start_idx]) + num_inserts = 1 + else: + assert_never(mode) - prev_end_idx = end_idx - next_idx_by_modality[modality] += 1 + for _ in range(num_inserts): + if item_idx >= mm_item_counts.get(modality, 0): + continue + + if isinstance(prompt, str): + out_seqs.append(content.full.text) + else: + out_seqs.append(content.full.token_ids) + + next_idx_by_modality[modality] += 1 + + prev_end_idx = end_idx out_seqs.append(prompt[prev_end_idx:]) - return out_seqs + return cast(list[_S], out_seqs) -def replace_token_matches( +def apply_token_matches( prompt: list[int], - mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]], + mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> list[int]: - """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" if not mm_matches: return prompt - token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts) + token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) return flatten_2d_lists(token_id_seqs) -def replace_text_matches( +def apply_text_matches( prompt: str, - mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]], + mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> str: - """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" if not mm_matches: return prompt - texts = _replace_matches(prompt, mm_matches, mm_item_counts) + texts = _apply_matches(prompt, mm_matches, mm_item_counts) return "".join(texts) def _iter_placeholders( - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], mm_item_counts: Mapping[str, int], ) -> Iterable[PlaceholderFeaturesInfo]: @@ -517,7 +630,7 @@ def _iter_placeholders( Matches are exclusive even when multiple modalities share the same placeholder tokens. In that case, the modality that - appears earlier in `mm_prompt_repls` takes priority. + appears earlier in `mm_prompt_updates` takes priority. Note that empty matches are ignored. """ @@ -528,37 +641,37 @@ def _iter_placeholders( while start_idx < prompt_len: found = False - for modality, modality_repls in mm_prompt_repls.items(): + for modality, modality_updates in mm_prompt_updates.items(): item_idx = item_idx_by_modality[modality] if item_idx >= mm_item_counts.get(modality, 0): continue - for repl_info in modality_repls: - replacement = repl_info.get_replacement(item_idx) - repl_tokens_full = replacement.full.token_ids - repl_len_full = len(repl_tokens_full) - end_idx_full = start_idx + repl_len_full + for update_info in modality_updates: + content = update_info.get_content(item_idx) + content_tokens_full = content.full.token_ids + content_len_full = len(content_tokens_full) + end_idx_full = start_idx + content_len_full - if repl_len_full == 0 or end_idx_full > prompt_len: + if content_len_full == 0 or end_idx_full > prompt_len: continue - if prompt[start_idx:end_idx_full] == repl_tokens_full: - repl_tokens_feat = replacement.features.token_ids + if prompt[start_idx:end_idx_full] == content_tokens_full: + content_tokens_feat = content.features.token_ids try: match = next( - iter_token_matches(repl_tokens_full, - repl_tokens_feat)) + iter_token_matches(content_tokens_full, + content_tokens_feat)) yield PlaceholderFeaturesInfo( modality=modality, item_idx=item_idx, start_idx=start_idx + match.start_idx, - tokens=repl_tokens_feat, + tokens=content_tokens_feat, ) except StopIteration: raise AssertionError( - f"{repl_tokens_feat=} should be a " - f"subsequence of {repl_tokens_full=}") from None + f"{content_tokens_feat=} should be a " + f"subsequence of {content_tokens_full=}") from None # Exclude overlapping matches start_idx = end_idx_full @@ -574,11 +687,11 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], mm_item_counts: Mapping[str, int], ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) + it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) return dict(full_groupby_modality(it)) @@ -712,6 +825,12 @@ def __init__(self, *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True) -> None: + if get_repls := getattr(self, "_get_prompt_replacements", None): + logger.warning_once("`_get_prompt_replacements` has been renamed " + "to `_get_prompt_updates`. The old name will " + "be removed in an upcoming release.") + self._get_prompt_updates = get_repls # type: ignore[method-assign] + super().__init__() self.info = info @@ -770,34 +889,34 @@ def _get_mm_fields_config( raise NotImplementedError @abstractmethod - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> list[PromptUpdate]: """ Given the original multi-modal items for this modality - and HF-processed data, output the replacements to perform. + and HF-processed data, output the updates to perform. Notes: - You should not assume that HF processor always performs prompt - replacement: in :meth:`_apply_hf_processor_missing`, this method + updates: in :meth:`_apply_hf_processor_missing`, this method is called on text-only and multimodal-only inputs separately, instead of passing them in the same call. - - The replacement information returned by this method is also used - to determine the placeholder token positions for each multi-modal + - The update information returned by this method is also used to + determine the placeholder token positions for each multi-modal item. """ raise NotImplementedError def _find_mm_placeholders( self, - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], mm_item_counts: Mapping[str, int], ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(mm_prompt_repls, new_token_ids, + return find_mm_placeholders(mm_prompt_updates, new_token_ids, mm_item_counts) def _get_hf_mm_data( @@ -831,14 +950,14 @@ def _call_hf_processor( mm_kwargs, ) - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], ) -> bool: """ - Return whether the HF processor applies prompt replacements. + Return whether the HF processor applies prompt updates. For most HF processors, this should be :code:`True` when multi-modal data items are passed, but :code:`False` when multi-modal embeddings @@ -858,7 +977,7 @@ def _apply_hf_processor_text_mm( Apply the HF processor on the prompt text and multi-modal data together. - In addition, return whether prompt replacements have been applied. + In addition, return whether prompt updates have been applied. """ processor_data, passthrough_data = self._get_hf_mm_data(mm_items) @@ -876,13 +995,13 @@ def _apply_hf_processor_text_mm( self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), ) - is_repl_applied = self._hf_processor_applies_repl( + is_update_applied = self._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, ) - return prompt_ids, mm_kwargs, is_repl_applied + return prompt_ids, mm_kwargs, is_update_applied def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: """ @@ -948,21 +1067,21 @@ def _apply_hf_processor_main( mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], *, - enable_hf_prompt_replacement: bool, + enable_hf_prompt_update: bool, ) -> tuple[list[int], MultiModalKwargs, bool]: """ Apply the HF processor on the prompt text and multi-modal data. - In addition, return whether prompt replacements have been applied + In addition, return whether prompt updates have been applied (for most HF processors, this should be :code:`True`). Note: - If :code:`enable_hf_prompt_replacement=False`, we use HF processor - to perform prompt replacement if available; HF processor requires + If :code:`enable_hf_prompt_update=False`, we use HF processor + to perform prompt updates if available; HF processor requires that the prompt corresponds to multi-modal items. """ if isinstance(prompt, str): - if enable_hf_prompt_replacement: + if enable_hf_prompt_update: return self._apply_hf_processor_text_mm( prompt_text=prompt, mm_items=mm_items, @@ -999,7 +1118,7 @@ def _cached_apply_hf_processor( prompt=prompt, mm_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, - enable_hf_prompt_replacement=True, + enable_hf_prompt_update=True, ) mm_maybe_cached_kw_items = { @@ -1022,17 +1141,17 @@ def _cached_apply_hf_processor( mm_missing_data_items = self._to_mm_items(mm_missing_data) # NOTE: `prompt` does not correspond to `mm_missing_data_items`, - # so we can't apply prompt replacements until the new multimodal + # so we can't apply prompt updates until the new multimodal # items are combined with the cached multimodal items ( prompt_ids, mm_missing_kwargs, - is_repl_applied, + is_update_applied, ) = self._apply_hf_processor_main( prompt=prompt, mm_items=mm_missing_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, - enable_hf_prompt_replacement=False, + enable_hf_prompt_update=False, ) mm_missing_next_idx = { @@ -1071,28 +1190,28 @@ def _cached_apply_hf_processor( mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) - return prompt_ids, mm_kwargs, is_repl_applied + return prompt_ids, mm_kwargs, is_update_applied - def _bind_and_group_repls( + def _bind_and_group_updates( self, - prompt_repls: list[PromptReplacement], - ) -> dict[str, list[BoundPromptReplacement]]: + prompt_updates: list[PromptUpdate], + ) -> dict[str, list[BoundPromptUpdate]]: tokenizer = self.info.get_tokenizer() - it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) + it = (update.bind(tokenizer) for update in prompt_updates) return dict(full_groupby_modality(it)) - def _apply_prompt_replacements( + def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], mm_item_counts: Mapping[str, int], ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() mm_token_matches = { - modality: find_token_matches(token_ids, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() + modality: find_token_matches(token_ids, updates) + for modality, updates in mm_prompt_updates.items() } mm_match_counts = { modality: len(matches) @@ -1107,31 +1226,31 @@ def _apply_prompt_replacements( # up a token, then the token ID of "foo" will not appear at all # ---- # Since it is inefficient to search for all possible tokenizations - # of the search text in the prompt, we instead perform string - # replacement on the decoded token IDs, then encode them back. + # of the search text in the prompt, we instead perform string-based + # updates on the decoded token IDs, then encode them back. if all( mm_match_counts.get(modality, 0) >= item_count for modality, item_count in mm_item_counts.items() ): # yapf: disable - token_ids = replace_token_matches( + token_ids = apply_token_matches( token_ids, mm_token_matches, mm_item_counts, ) text = decode_tokens(tokenizer, token_ids) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] + matched_updates = { + modality: [match._origin for match in token_matches] for modality, token_matches in mm_token_matches.items() } else: text = decode_tokens(tokenizer, token_ids) mm_text_matches = { - modality: find_text_matches(text, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() + modality: find_text_matches(text, updates) + for modality, updates in mm_prompt_updates.items() } - text = replace_text_matches( + text = apply_text_matches( text, mm_text_matches, mm_item_counts, @@ -1140,13 +1259,13 @@ def _apply_prompt_replacements( token_ids = encode_tokens(tokenizer, text, add_special_tokens=False) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] + matched_updates = { + modality: [match._origin for match in token_matches] for modality, token_matches in mm_text_matches.items() } placeholders = self._find_mm_placeholders( - matched_repls, + matched_updates, token_ids, mm_item_counts, ) @@ -1184,14 +1303,14 @@ def _validate_mm_placeholders( if len(placeholders) != item_count: raise RuntimeError( - f"Expected there to be {item_count} prompt replacements " + f"Expected there to be {item_count} prompt updates " f"corresponding to {item_count} {modality} items, but " - f"instead found {len(placeholders)} prompt replacements! " + f"instead found {len(placeholders)} prompt updates! " "Either the prompt text has missing/incorrect tokens for " "multi-modal inputs, or there is a problem with your " "implementation of merged multi-modal processor for this " "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_prompt_replacements`).") + "`_call_hf_processor` and `_get_prompt_updates`).") def apply( self, @@ -1206,7 +1325,7 @@ def apply( 1. Apply HF Processor on prompt text and multi-modal data together, outputting token IDs and processed tensors. - 2. Find and replace sequences in the token IDs with placeholder tokens. + 2. Find and update sequences in the token IDs with placeholder tokens. The number of placeholder tokens equals the feature size of the multi-modal data outputted by the multi-modal encoder. 3. Extract information about the placeholder tokens from the @@ -1235,26 +1354,27 @@ def apply( ( prompt_ids, mm_kwargs, - is_repl_applied, + is_update_applied, ) = self._cached_apply_hf_processor( prompt, mm_items, hf_processor_mm_kwargs, ) - unbound_prompt_repls = self._get_prompt_replacements( + unbound_prompt_updates = self._get_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - if is_repl_applied: + if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_repls, + mm_prompt_updates, prompt_ids, mm_item_counts, ) @@ -1267,9 +1387,9 @@ def apply( prompt_ids, prompt, mm_placeholders, - ) = self._apply_prompt_replacements( + ) = self._apply_prompt_updates( prompt_ids, - mm_prompt_repls, + mm_prompt_updates, mm_item_counts, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) From c8452982536d634c48ce50537bd4b825e4f06c64 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Thu, 27 Feb 2025 10:14:17 -0800 Subject: [PATCH 270/317] [Attention] MLA support for V1 (#13789) Signed-off-by: Yang Chen --- vllm/attention/layer.py | 35 +- vllm/model_executor/models/deepseek_v2.py | 13 +- vllm/platforms/cuda.py | 9 +- vllm/platforms/interface.py | 1 + vllm/v1/attention/backends/flash_attn.py | 69 +- vllm/v1/attention/backends/mla/__init__.py | 0 vllm/v1/attention/backends/mla/common.py | 1022 ++++++++++++++++++++ vllm/v1/attention/backends/triton_mla.py | 110 +++ vllm/v1/worker/gpu_input_batch.py | 64 +- vllm/v1/worker/gpu_model_runner.py | 76 +- 10 files changed, 1340 insertions(+), 59 deletions(-) create mode 100644 vllm/v1/attention/backends/mla/__init__.py create mode 100644 vllm/v1/attention/backends/mla/common.py create mode 100644 vllm/v1/attention/backends/triton_mla.py diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index c45c83a0707f..58a3b4ee43ce 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -89,6 +89,7 @@ def __init__( self._k_scale_float = 1.0 self._v_scale_float = 1.0 + self.use_mla = use_mla self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads @@ -158,6 +159,10 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. + output_shape: Optional[torch.Size] = None, ) -> torch.Tensor: """ The KV cache is stored inside this class and is accessed via @@ -173,17 +178,25 @@ def forward( if attn_metadata.enable_kv_scales_calculation: self.calc_kv_scales(key, value) if self.use_output: - output = torch.empty_like(query) - hidden_size = query.size(-1) - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the - # CPU overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size) - if key is not None: - key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size) + output_shape = (output_shape + if output_shape is not None else query.shape) + output = torch.empty(output_shape, + dtype=query.dtype, + device=query.device) + hidden_size = output_shape[-1] + # We skip reshaping query, key and value tensors for the MLA + # backend since these tensors have different semantics and are + # processed differently. + if not self.use_mla: + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6ff3ef129a74..b5409c7fe1b7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -420,9 +420,15 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size self.mla_attn = Attention( num_heads=self.num_local_heads, - head_size=self.kv_lora_rank, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, scale=self.scaling, num_kv_heads=1, cache_config=cache_config, @@ -458,7 +464,10 @@ def forward( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=hidden_states.shape) class DeepseekV2DecoderLayer(nn.Module): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c6f3ccf0a3c4..0209c7236278 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -162,8 +162,13 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_v1: - logger.info("Using Flash Attention backend on V1 engine.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + if use_mla: + logger.info("Using Triton MLA backend on V1 engine.") + return "vllm.v1.attention.backends.triton_mla.TritonMLABackend" + else: + logger.info("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends.flash_attn." + "FlashAttentionBackend") if use_mla: if selected_backend == _Backend.FLASHMLA: from vllm.attention.backends.flashmla import ( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index e3ef7c4ac7c5..5f988e1479c5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -35,6 +35,7 @@ class _Backend(enum.Enum): OPENVINO = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() + TRITON_MLA_VLLM_V1 = enum.auto() FLASHMLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1922a3bf2724..353bf46d503e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import numpy as np import torch @@ -14,6 +14,11 @@ from vllm.platforms import current_platform from vllm.utils import cdiv +if TYPE_CHECKING: + from vllm.v1.core.scheduler_output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + if current_platform.is_cuda(): from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -40,6 +45,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -85,6 +94,62 @@ class FlashAttentionMetadata: num_input_tokens: int = 0 # Number of tokens including padding. +class FlashAttentionMetadataBuilder: + + def __init__(self, runner: "GPUModelRunner"): + self.runner = runner + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput"): + pass + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + max_seq_len = self.runner.seq_lens_np[:num_reqs].max() + query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( + self.runner.device, non_blocking=True) + seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, + non_blocking=True) + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + self.runner.device, non_blocking=True).long() + + use_cascade = common_prefix_len > 0 + if use_cascade: + # TODO: Optimize. + cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=self.runner.device) + prefix_kv_lens = torch.tensor([common_prefix_len], + dtype=torch.int32, + device=self.runner.device) + suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - + common_prefix_len) + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( + self.runner.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + ) + return attn_metadata + + class FlashAttentionImpl(AttentionImpl): def __init__( @@ -371,4 +436,4 @@ def cascade_attention( # Merge prefix and suffix outputs, and store the result in output. merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) \ No newline at end of file + suffix_lse) diff --git a/vllm/v1/attention/backends/mla/__init__.py b/vllm/v1/attention/backends/mla/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py new file mode 100644 index 000000000000..2a742f5ce524 --- /dev/null +++ b/vllm/v1/attention/backends/mla/common.py @@ -0,0 +1,1022 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N * P] +W_KR project h_t to k_pe shape [H, N * R] +W_UV project kv_c to v shape [Lkv, N * V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK).view(Skv, N, P) +v = (kv_c @ W_UV).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatnated per head + `q_b_proj` is [W_UQ; W_QR] concatnated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Ahead of time, compute: + +% this projects from q_c to [Sq, N * Lkv] +W_UQ_UK = einsum("qnp,knp -> qnk" + W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P) + ).view(Lkv, N * Lkv) +% this projects from attn output [Sq, N * Lkv] to [Sq, H] +W_UV_O = einsum("knv,nvh -> nkh" + W_UV.view(Lkv, N, V), W_O.view(N, V, H) + ).view(N * Lkv, H) + +Runtime +q_c = h_t @ W_DQ +q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([q_latent, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) +return spda_o.reshape(-1, N * Lkv) @ W_UV_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) +new_v = (new_kv_c @ W_UV).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type, TypeVar) + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + MLAAttentionImpl) +from vllm.attention.backends.utils import get_flash_attn_version +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsW8A8Fp8) +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_quantize) +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm.utils import cdiv, round_down + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +except ImportError: + # For rocm use upstream flash attention + from flash_attn import flash_attn_varlen_func + +if TYPE_CHECKING: + from vllm.v1.core.scheduler_output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +logger = init_logger(__name__) + + +class MLACommonBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "TRITON_MLA_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +@dataclass +class MLACommonMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # New for MLA (compared to FlashAttention) + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + # The dimension of the attention heads + head_dim: Optional[int] = None + + # New for MLA (compared to FlashAttention) + # For chunked prefill + num_decodes: Optional[int] = None + num_decode_tokens: Optional[int] = None + num_prefills: Optional[int] = None + has_context: bool = False + context_chunk_cu_seq_lens: Optional[torch.Tensor] = None + context_chunk_starts: Optional[torch.Tensor] = None + context_chunk_seq_tot: Optional[List[int]] = None + context_chunk_max_seq_lens: Optional[List[int]] = None + chunked_prefill_workspace: Optional[torch.Tensor] = None + + def __post_init__(self): + supported_head_sizes = MLACommonBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + +T = TypeVar("T", bound=MLACommonMetadata) + + +class MLACommonMetadataBuilder: + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, runner: "GPUModelRunner"): + self.runner = runner + scheduler_config = runner.scheduler_config + model_config = runner.model_config + cache_config = runner.cache_config + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max( + 8 * model_config.max_model_len, 4 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + model_config.get_head_size()), + dtype=model_config.dtype, + device=runner.device, + ) + self.page_size = self.runner.block_size + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput"): + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the TritonMLA._forward_decode only supports + # num_tokens = 1 + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_prefills = num_prefills + self._num_decode_tokens = num_decode_tokens + self._num_prefill_tokens = num_prefill_tokens + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + device = self.runner.device + max_seq_len = self.runner.seq_lens_np[:num_reqs].max() + query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( + device, non_blocking=True) + seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device, + non_blocking=True) + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + device, non_blocking=True).long() + input_positions = self.runner.positions_cpu[:num_actual_tokens].to( + device, non_blocking=True).long() + + context_chunk_cu_seq_lens = None + context_chunk_starts = None + context_chunk_seq_tot = None + context_chunk_max_seq_lens = None + + num_computed_tokens_cpu_tensor = \ + self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] + context_lens_tensor = \ + num_computed_tokens_cpu_tensor.to(device, non_blocking=True) + + if self.chunked_prefill_enabled and self._num_prefills > 0 \ + and context_lens_tensor[self._num_decodes:].max() > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section in + # the comment at the top of the file before trying to understand + # the following code + + self.has_context = True + + num_prefills_with_context = \ + (context_lens_tensor[self._num_decodes:] > 0).sum().item() + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + self.chunked_prefill_workspace_size // num_prefills_with_context + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + context_chunk_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, self._num_prefills) \ + * max_context_chunk + chunk_ends = torch.min(context_lens_tensor[self._num_decodes:] \ + .unsqueeze(0), context_chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) + _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) + zero = torch.zeros(num_chunks, dtype=torch.int32, device=device) \ + .unsqueeze(-1) + context_chunk_cu_seq_lens = \ + torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) + context_chunk_max_seq_lens = \ + chunk_seq_lens.max(dim=1).values.tolist() + context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() + assert max(context_chunk_seq_tot) <= \ + self.chunked_prefill_workspace_size + + return MLACommonMetadata( + input_positions=input_positions, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=slot_mapping, + head_dim=self.runner.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + num_decodes=self._num_decodes, + num_decode_tokens=self._num_decode_tokens, + num_prefills=self._num_prefills, + context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, + context_chunk_starts=context_chunk_starts, + context_chunk_seq_tot=context_chunk_seq_tot, + context_chunk_max_seq_lens=context_chunk_max_seq_lens, + ) + + +class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + rotary_emb: RotaryEmbedding, + # q_proj should be q_b_proj if q_lora_rank is not None, but from an + # attention backend perspective we rely on the layer to pass in the + # correct matrix + q_proj: ColumnParallelLinear, + kv_b_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + self.rotary_emb = rotary_emb + self.use_yarn_rope = isinstance(rotary_emb, + DeepseekScalingRotaryEmbedding) + self.q_proj = q_proj + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + self.vllm_flash_attn_version = get_flash_attn_version() + + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + def _v_up_proj_and_o_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_UV_O): + output_parallel = apply_fp8_linear_generic( + x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape) + else: + output_parallel = torch.matmul(x.flatten(start_dim=1), + self.W_UV_O) + if self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + return output + else: + x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) + return self.o_proj(x.reshape(-1, + self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_Q_UK): + return apply_fp8_linear_generic( + x, self.W_Q_UK, self.W_Q_UK_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape).view( + -1, self.num_heads, self.kv_lora_rank) + return torch.matmul(x, self.W_Q_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + else: + x = torch.matmul(x, self.W_Q)\ + .view(-1, self.num_heads, self.qk_nope_head_dim) + return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + # TODO(lucas) This is very gross, we need a more wide scale refactor of + # all the FP8 code with a more standard way of + # defining schemes/group-shapes, we should also potentially force + # quant_methods to support a decompress function + # + # returns input_group_shape, weight_group_shape + def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ + Tuple[Tuple[int, int], Tuple[int, int]]: + if isinstance(layer.quant_method, Fp8LinearMethod): + if layer.quant_method.block_quant: + weight_block_size = \ + layer.quant_method.quant_config.weight_block_size + # per-token-group (1, X), block-quantized (X, Y) + return (1, weight_block_size[-1]), weight_block_size + else: + return (-1, -1), (-1, -1) # per-tensor, per-tensor + elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ + and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): + # this is hacky but we always assume the for + # CompressedTensorsW8A8Fp8 the input is dynamic per-token + # we ignore if it is static-per-tensor since we are going to + # requantize after later anyways + strategy = layer.scheme.strategy + if strategy == QuantizationStrategy.TENSOR: + return (1, -1), (-1, -1) # per-token, per-tensor + elif strategy == QuantizationStrategy.CHANNEL: + return (1, -1), (-1, 1) # per-token, per-channel + else: + raise NotImplementedError( + f"QuantizationStrategy.{strategy} is not supported for " + "fp8 MLA, please run with VLLM_MLA_DISABLE=1") + else: + raise NotImplementedError( + "Can't determine scale group shapes for " + f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" + ) + + def get_layer_weight(layer): + if hasattr(layer, "weight"): + return layer.weight + elif hasattr(layer, "qweight"): + return layer.qweight + else: + raise AttributeError( + f"Layer '{layer}' has neither weight nor qweight") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + weight_dtype = get_layer_weight(self.kv_b_proj).dtype + assert get_layer_weight(self.o_proj).dtype == weight_dtype + assert get_layer_weight(self.q_proj).dtype == weight_dtype + + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ + .view(-1, self.num_heads, self.qk_head_dim) + + # can be W_Q or W_UQ depending q_lora_rank, the former if + # q_lora_rank is None, the latter otherwise. From the Attention backend + # perspective though we call these both W_Q and rely on the layer + # to pass in the correct matrix + W_Q = q_proj_weight[..., :self.qk_nope_head_dim] + self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ + .flatten(start_dim=1).contiguous() + + # W_QR is small so for simplicity we dont bother requantizing it + self.W_QR = self.W_QR.to(act_dtype) + + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION + if is_fp8(weight_dtype) and requantization_enabled: + # This assumes it wise to requantize using the same group shapes + # (i.e. strategy, per-tensor, per-channel, block etc.) that the + # weights were originally quantized + requant_input_group_shape, requant_weight_group_shape = \ + get_scale_group_shapes_for_fp8(self.q_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.kv_b_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.o_proj) + self.reqaunt_input_group_shape = requant_input_group_shape + self.reqaunt_weight_group_shape = requant_weight_group_shape + + # + # Perform matrix-absorption following + # https://github.com/flashinfer-ai/flashinfer/pull/551 + # for decode, as a result we end up with absorbed weights for decode + # and another copy of raw weights for prefill. + # + self.W_UK, self.W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK + # depending q_lora_rank, the former if q_lora_rank is None, the + # latter otherwise + # basically if q_lora_rank is none we are absorbing into q_proj + # instead of UQ + W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ + .flatten(start_dim=1).contiguous() + + if is_fp8(weight_dtype) and requantization_enabled: + W_Q_UK, W_Q_UK_scales = scaled_quantize( + W_Q_UK, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_Q_UK = W_Q_UK.T.contiguous() + self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() + else: + self.W_Q_UK = W_Q_UK.to(act_dtype) + + W_O = get_and_maybe_dequant_weights(self.o_proj)\ + .view(-1, self.num_heads, self.v_head_dim) + W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ + .flatten(start_dim=0, end_dim=1).contiguous() + + if is_fp8(weight_dtype) and requantization_enabled: + W_UV_O, W_UV_O_scales = scaled_quantize( + W_UV_O, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_UV_O = W_UV_O.T.contiguous() + self.W_UV_O_scales = W_UV_O_scales.T.contiguous() + else: + self.W_UV_O = W_UV_O.to(act_dtype) + + self.tp_size = get_tensor_model_parallel_world_size() + else: + if is_fp8(weight_dtype): + raise NotImplementedError( + "Currently fp8 requires matrix absorption") + + self.W_UV = W_UV + self.W_UK = W_UK + self.W_Q = W_Q.flatten(start_dim=1) + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ): + assert attn_metadata.num_prefills is not None + assert attn_metadata.context_chunk_seq_tot is not None + assert attn_metadata.context_chunk_cu_seq_lens is not None + assert attn_metadata.context_chunk_starts is not None + assert attn_metadata.context_chunk_max_seq_lens is not None + + output = None + iters = len(attn_metadata.context_chunk_seq_tot) + + assert attn_metadata.chunked_prefill_workspace is not None + workspace = attn_metadata.chunked_prefill_workspace + + for i in range(iters): + toks = attn_metadata.context_chunk_seq_tot[i] + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=attn_metadata.block_table, + cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=attn_metadata.context_chunk_starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank].unsqueeze(1) + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, + [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + has_context = attn_metadata.has_context + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) + + if has_context: + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # slice by `:v.shape[-1]` in order to remove v headdim padding + output = output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + return self.o_proj(output)[0] + + @abstractmethod + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) + assert hasattr(attn_metadata, "input_positions") + + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + decode_k_pe = k_pe[:num_decode_tokens] + decode_input_positions = \ + attn_metadata.input_positions[:num_decode_tokens] + + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_input_positions = \ + attn_metadata.input_positions[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + if has_decode: + decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ + .view(-1, self.num_heads, self.qk_rope_head_dim) + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + decode_input_positions, decode_q_pe, decode_k_pe) + + if has_prefill: + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + prefill_input_positions, prefill_q_pe, prefill_k_pe) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if has_prefill: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) + + if has_decode: + output[:num_decode_tokens] = self._forward_decode( + decode_q_nope, decode_q_pe, kv_cache, attn_metadata) + + return output_padded diff --git a/vllm/v1/attention/backends/triton_mla.py b/vllm/v1/attention/backends/triton_mla.py new file mode 100644 index 000000000000..7747509f1a4b --- /dev/null +++ b/vllm/v1/attention/backends/triton_mla.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) + +logger = init_logger(__name__) + + +class TritonMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["TritonMLAImpl"]: + return TritonMLAImpl + + +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 Triton MLA not yet supported") + + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + num_kv_splits = 4 # TODO: heuristic + + # TODO(lucas) Allocate ahead of time + attn_logits = torch.empty( + ( + B, + self.num_heads, + num_kv_splits, + # NOTE(lucas) idk why the +1 is here but sglang has it so we + # just mirror that + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device=q.device, + ) + + # Add a head dim of 1 + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + PAGE_SIZE = kv_c_and_k_pe_cache.size(1) + + # Run MQA + decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + attn_metadata.block_table, attn_metadata.seq_lens, + attn_logits, num_kv_splits, self.scale, PAGE_SIZE) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e4e6b88245d0..1b6ea559a7b7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -80,7 +80,14 @@ def __init__( self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu_tensor = torch.zeros( + (max_num_reqs, ), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.num_computed_tokens_cpu = \ + self.num_computed_tokens_cpu_tensor.numpy() # Block table. self.block_table = BlockTable( @@ -356,6 +363,61 @@ def remove_request(self, req_id: str) -> Optional[int]: self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) return req_index + def swap_states(self, i1: int, i2: int) -> None: + old_id_i1 = self._req_ids[i1] + old_id_i2 = self._req_ids[i2] + self._req_ids[i1], self._req_ids[i2] =\ + self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ + self.req_output_token_ids[i2], self.req_output_token_ids[i1] + assert old_id_i1 is not None and old_id_i2 is not None + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ + self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] + self.num_tokens[i1], self.num_tokens[i2] =\ + self.num_tokens[i2], self.num_tokens[i1] + self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ + self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ + self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ + self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ + self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] + self.temperature_cpu[i1], self.temperature_cpu[i2] =\ + self.temperature_cpu[i2], self.temperature_cpu[i1] + self.top_p_cpu[i1], self.top_p_cpu[i2] =\ + self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] =\ + self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + self.min_p_cpu[i1], self.min_p_cpu[i2] =\ + self.min_p_cpu[i2], self.min_p_cpu[i1] + + g1 = self.generators.get(i1) + g2 = self.generators.get(i2) + if g1 is not None: + self.generators[i2] = g1 + if g2 is not None: + self.generators[i1] = g2 + + t1 = self.min_tokens.get(i1) + t2 = self.min_tokens.get(i2) + if t1 is not None: + self.min_tokens[i2] = t1 + if t2 is not None: + self.min_tokens[i1] = t2 + + self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ + self.request_lora_mapping[i2], self.request_lora_mapping[i1] + self.logit_bias[i1], self.logit_bias[i2] =\ + self.logit_bias[i2], self.logit_bias[i1] + self.block_table.swap_row(i1, i2) + def condense(self, empty_req_indices: List[int]) -> None: num_reqs = self.num_reqs if num_reqs == 0: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4d0ae9a205a1..c9212d993f2b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,6 +2,7 @@ import gc import time +import weakref from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np @@ -9,7 +10,7 @@ import torch.distributed import torch.nn as nn -from vllm.attention.backends.abstract import AttentionType +from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import get_pp_group, graph_capture @@ -24,8 +25,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) -from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionMetadata) +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -92,6 +92,27 @@ def __init__( self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() + self.attn_backend = get_attn_backend( + self.head_size, + self.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + if self.attn_backend is None: + error_msg = ( + f"Error with get_att_backend: {self.head_size=}, " + f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{self.model_config.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 GPUModelRunner.") + + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self)) + # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY @@ -433,6 +454,12 @@ def _prepare_inputs( num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + # Some attention backends (namely MLA) may want to separate requests + # based on if the attention computation will be compute-bound or + # memory-bound. This gives them a hook to do that. + self.attn_metadata_builder.reorder_batch(self.input_batch, + scheduler_output) + # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) @@ -515,7 +542,6 @@ def _prepare_inputs( self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) - max_seq_len = self.seq_lens_np[:num_reqs].max() # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( @@ -530,49 +556,17 @@ def _prepare_inputs( self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( - self.device, non_blocking=True) - seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, - non_blocking=True) - slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( - self.device, non_blocking=True).long() # Prepare for cascade attention if needed. common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, scheduler_output.num_common_prefix_blocks, ) - use_cascade = common_prefix_len > 0 - if use_cascade: - # TODO: Optimize. - cu_prefix_query_lens = torch.tensor( - [0, total_num_scheduled_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len) - suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device) - else: - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - - attn_metadata = FlashAttentionMetadata( + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=( - self.input_batch.block_table.get_device_tensor()[:num_reqs]), - slot_mapping=slot_mapping, - use_cascade=use_cascade, common_prefix_len=common_prefix_len, - cu_prefix_query_lens=cu_prefix_query_lens, - prefix_kv_lens=prefix_kv_lens, - suffix_kv_lens=suffix_kv_lens, ) use_spec_decode = len( @@ -586,7 +580,7 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 + logits_indices = attn_metadata.query_start_loc[1:] - 1 # Hot-Swap lora model if self.lora_config: @@ -667,7 +661,7 @@ def _compute_cascade_attn_prefix_len( # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // self.block_size * self.block_size) - use_cascade = FlashAttentionBackend.use_cascade_attention( + use_cascade = self.attn_backend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, @@ -1379,7 +1373,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: assert tensor_config.size % layer_spec.page_size_bytes == 0 num_blocks = tensor_config.size // layer_spec.page_size_bytes if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, layer_spec.head_size) dtype = layer_spec.dtype From 284a899c3005f29a1cba8406d7ee20a7517ae653 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 27 Feb 2025 19:04:10 +0000 Subject: [PATCH 271/317] Bump azure/setup-helm from 4.2.0 to 4.3.0 (#13742) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/lint-and-deploy.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index a4e9acc414d4..b199d0867a64 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -12,7 +12,7 @@ jobs: fetch-depth: 0 - name: Set up Helm - uses: azure/setup-helm@fe7b79cd5ee1e45176fcad797de68ecaf3ca4814 # v4.2.0 + uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0 with: version: v3.14.4 From 3e575f6c35d735a833ebaab79b1e333cf4b50b59 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 28 Feb 2025 03:14:55 +0800 Subject: [PATCH 272/317] [VLM] Deprecate legacy input mapper for OOT multimodal models (#13979) Signed-off-by: DarkLight1337 --- vllm/config.py | 45 ++++++++++++++++++++------------------- vllm/inputs/preprocess.py | 14 +++++++----- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d1384c6375f3..cb683d19386b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -400,7 +400,7 @@ def __init__( else: self.override_neuron_config = None - supported_tasks, task = self._resolve_task(task, self.hf_config) + supported_tasks, task = self._resolve_task(task) self.supported_tasks = supported_tasks self.task: Final = task if self.task in ("draft", "generate"): @@ -418,6 +418,14 @@ def __init__( self._verify_cuda_graph() self._verify_bnb_config() + @property + def registry(self): + return ModelRegistry + + @property + def architectures(self) -> list[str]: + return getattr(self.hf_config, "architectures", []) + def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: """ @@ -446,8 +454,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] ) -> Optional["MultiModalConfig"]: - architectures = getattr(self.hf_config, "architectures", []) - if ModelRegistry.is_multimodal_model(architectures): + if self.registry.is_multimodal_model(self.architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) if limit_mm_per_prompt: @@ -480,16 +487,13 @@ def _init_pooler_config( return None def _init_attention_free(self) -> bool: - architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.is_attention_free_model(architectures) + return self.registry.is_attention_free_model(self.architectures) def _init_is_hybrid(self) -> bool: - architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.is_hybrid_model(architectures) + return self.registry.is_hybrid_model(self.architectures) def _init_has_inner_state(self) -> bool: - architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.model_has_inner_state(architectures) + return self.registry.model_has_inner_state(self.architectures) def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() @@ -507,9 +511,9 @@ def _get_preferred_task( model_id = self.model if get_pooling_config(model_id, self.revision): return "embed" - if ModelRegistry.is_cross_encoder_model(architectures): + if self.registry.is_cross_encoder_model(architectures): return "score" - if ModelRegistry.is_transcription_model(architectures): + if self.registry.is_transcription_model(architectures): return "transcription" suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ @@ -522,7 +526,7 @@ def _get_preferred_task( ("EmbeddingModel", "embed"), ("RewardModel", "reward"), ] - _, arch = ModelRegistry.inspect_model_cls(architectures) + _, arch = self.registry.inspect_model_cls(architectures) for suffix, pref_task in suffix_to_preferred_task: if arch.endswith(suffix) and pref_task in supported_tasks: @@ -533,20 +537,19 @@ def _get_preferred_task( def _resolve_task( self, task_option: Union[TaskOption, Literal["draft"]], - hf_config: PretrainedConfig, ) -> Tuple[Set[_ResolvedTask], _ResolvedTask]: if task_option == "draft": return {"draft"}, "draft" - architectures = getattr(hf_config, "architectures", []) + registry = self.registry + architectures = self.architectures runner_support: Dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them - "transcription": - ModelRegistry.is_transcription_model(architectures), - "generate": ModelRegistry.is_text_generation_model(architectures), - "pooling": ModelRegistry.is_pooling_model(architectures), + "transcription": registry.is_transcription_model(architectures), + "generate": registry.is_text_generation_model(architectures), + "pooling": registry.is_pooling_model(architectures), } supported_runner_types_lst: List[RunnerType] = [ runner_type @@ -755,8 +758,7 @@ def verify_with_parallel_config( pipeline_parallel_size = parallel_config.pipeline_parallel_size if pipeline_parallel_size > 1: - architectures = getattr(self.hf_config, "architectures", []) - if not ModelRegistry.is_pp_supported_model(architectures): + if not self.registry.is_pp_supported_model(self.architectures): raise NotImplementedError( "Pipeline parallelism is not supported for this model. " "Supported models implement the `SupportsPP` interface.") @@ -1023,8 +1025,7 @@ def is_multimodal_model(self) -> bool: @property def is_cross_encoder(self) -> bool: - architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.is_cross_encoder_model(architectures) + return self.registry.is_cross_encoder_model(self.architectures) @property def use_mla(self) -> bool: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index bc5856990da6..206a76e52b7a 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -236,11 +236,15 @@ def _can_process_multimodal(self) -> bool: # updated to use the new multi-modal processor can_process_multimodal = self.mm_registry.has_processor(model_config) if not can_process_multimodal: - logger.info_once( - "Your model uses the legacy input pipeline instead of the new " - "multi-modal processor. Please note that the legacy pipeline " - "will be removed in a future release. For more details, see: " - "https://github.com/vllm-project/vllm/issues/10114") + from vllm.model_executor.models.registry import _VLLM_MODELS + if not any(arch in _VLLM_MODELS + for arch in model_config.architectures): + logger.warning_once( + "Your model uses the legacy input pipeline, which will be " + "removed in an upcoming release. " + "Please upgrade to the new multi-modal processing pipeline " + "(https://docs.vllm.ai/en/latest/design/mm_processing.html)" + ) return can_process_multimodal From 5d1129225694d3266e8037c9a599adbfaff918fe Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 27 Feb 2025 12:31:47 -0800 Subject: [PATCH 273/317] [ROCm] Fix the Kernels, Core, and Prefix Caching AMD CI groups (#13970) Signed-off-by: Sage Moore --- .buildkite/run-amd-test.sh | 4 +++- .../core/block/e2e/test_correctness_sliding_window.py | 10 ++++++++++ tests/prefix_caching/test_prefix_caching.py | 10 ++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index f8bf1c87603f..35d2ba1f8bab 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -92,7 +92,9 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_moe.py \ --ignore=kernels/test_prefix_prefill.py \ --ignore=kernels/test_rand.py \ - --ignore=kernels/test_sampler.py" + --ignore=kernels/test_sampler.py \ + --ignore=kernels/test_cascade_flash_attn.py \ + --ignore=kernels/test_mamba_mixer2.py" fi #ignore certain Entrypoints tests diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index c874608e40a2..a7dafcf8be87 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -7,6 +7,7 @@ from tests.kernels.utils import override_backend_env_variable from vllm import LLM, SamplingParams +from vllm.platforms import current_platform from .conftest import get_text_from_llm_generator @@ -42,6 +43,11 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, Additionally, we compare the results of the v1 and v2 managers. """ + if backend == "FLASHINFER" and current_platform.is_rocm(): + pytest.skip("Flashinfer does not support ROCm/HIP.") + if backend == "XFORMERS" and current_platform.is_rocm(): + pytest.skip("Xformers does not support ROCm/HIP.") + override_backend_env_variable(monkeypatch, backend) sampling_params = SamplingParams( @@ -101,6 +107,10 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, The results with and without chunked prefill are not the same due to numerical instabilities. """ + if backend == "FLASHINFER" and current_platform.is_rocm(): + pytest.skip("Flashinfer does not support ROCm/HIP.") + if backend == "XFORMERS" and current_platform.is_rocm(): + pytest.skip("Xformers does not support ROCm/HIP.") override_backend_env_variable(monkeypatch, backend) sampling_params = SamplingParams( diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 2773d27a6813..d7d84bdcf382 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -12,6 +12,7 @@ from vllm import SamplingParams, TokensPrompt from vllm.core.scheduler import Scheduler from vllm.engine.llm_engine import LLMEngine +from vllm.platforms import current_platform from ..models.utils import check_outputs_equal @@ -53,6 +54,10 @@ def test_mixed_requests( and the others don't. The cached position determines where the sequence is at among the batch of prefills. """ + if backend == "FLASHINFER" and current_platform.is_rocm(): + pytest.skip("Flashinfer does not support ROCm/HIP.") + if backend == "XFORMERS" and current_platform.is_rocm(): + pytest.skip("Xformers does not support ROCm/HIP.") override_backend_env_variable(monkeypatch, backend) with hf_runner(model, dtype=dtype) as hf_model: @@ -103,6 +108,11 @@ def test_unstable_prompt_sequence( backend: str, monkeypatch, ) -> None: + + if backend == "FLASHINFER" and current_platform.is_rocm(): + pytest.skip("Flashinfer does not support ROCm/HIP.") + if backend == "XFORMERS" and current_platform.is_rocm(): + pytest.skip("Xformers does not support ROCm/HIP.") override_backend_env_variable(monkeypatch, backend) with vllm_runner( From 810e7c5349207a3f2d63288cb12ae8443ffdb834 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 27 Feb 2025 13:11:40 -0800 Subject: [PATCH 274/317] [V1][Minor] Minor cleanup for GPU Model Runner (#13983) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c9212d993f2b..2730e6770dc3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1187,8 +1187,9 @@ def profile_run(self) -> None: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 - self.model_config) + max_tokens_by_modality_dict = ( + MULTIMODAL_REGISTRY. + get_max_tokens_per_item_by_nonzero_modality(self.model_config)) dummy_data_modality, max_tokens_per_mm_item = max( max_tokens_by_modality_dict.items(), key=lambda item: item[1]) @@ -1275,15 +1276,15 @@ def profile_run(self) -> None: # maximum num_tokens. num_reqs = self.scheduler_config.max_num_seqs num_tokens = self.max_num_tokens - min_tokens_per_req: int = num_tokens // num_reqs + min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, - dtype=np.int32) + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) logit_indices = np.cumsum(num_scheduled_tokens) - 1 with self.maybe_profile_with_lora(self.lora_config, From 9d019efda049d1977b8758add25560c3e5ed782a Mon Sep 17 00:00:00 2001 From: qli88 Date: Thu, 27 Feb 2025 16:14:30 -0600 Subject: [PATCH 275/317] [core] Perf improvement for DSv3 on AMD GPUs (#13718) Signed-off-by: qli88 --- vllm/attention/backends/mla/common.py | 92 ++++++++++--- vllm/attention/ops/triton_decode_attention.py | 15 +- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 128 ++++++++++++++++++ 3 files changed, 210 insertions(+), 25 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 1befcb6b45df..f240074f252d 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -237,14 +237,20 @@ try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func + is_vllm_fa = False + +from vllm.attention.ops.triton_flash_attention import triton_attention if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +is_hip = current_platform.is_rocm() + class MLACommonBackend(AttentionBackend): @@ -1046,12 +1052,13 @@ def __init__( self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj - self.vllm_flash_attn_version = get_flash_attn_version() + self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = \ functools.partial(flash_attn_varlen_func, @@ -1315,18 +1322,48 @@ def _compute_prefill_context( [0, q.shape[-1] - v.shape[-1]], value=0) - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: + attn_output, attn_softmax_lse = self.triton_fa_func( + q, + k, + v_padded, + None, + prefill_metadata.query_start_loc, + prefill_metadata.context_chunk_cu_seq_lens[i], + prefill_metadata.max_query_len, + prefill_metadata.context_chunk_max_seq_lens[i], + False, # causal + self.scale, + None, # attn_mask is None unless applying ALiBi mask + ) + elif is_vllm_fa: + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata. + context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + else: + attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata. + context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_attn_probs=True, + ) if output is None: output = attn_output @@ -1374,11 +1411,24 @@ def _forward_prefill( v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - if has_context: - if not current_platform.is_cuda(): - raise NotImplementedError( - "Chunked Prefill for MLA is not currently supported on" - "non-cuda platforms") + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: + output = self.triton_fa_func( + q, + k, + v_padded, + None, + prefill_metadata.query_start_loc, + prefill_metadata.query_start_loc, + prefill_metadata.max_prefill_seq_len, + prefill_metadata.max_prefill_seq_len, + True, # causal + self.scale, + None, # attn_mask is None unless applying ALiBi mask + ) + ## triton flash attention always return 2 objects + if not has_context: + output = output[0] + elif is_vllm_fa: output = self.flash_attn_varlen_func( q=q, k=k, @@ -1389,7 +1439,7 @@ def _forward_prefill( max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, - return_softmax_lse=True, + return_softmax_lse=has_context, ) else: output = self.flash_attn_varlen_func( @@ -1402,10 +1452,12 @@ def _forward_prefill( max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, + return_attn_probs=has_context, ) if has_context: - suffix_output, suffix_lse = output + # ROCm flash_attn_varlen_func will return 3 objects instead of 2 + suffix_output, suffix_lse, *rest = output context_output, context_lse = self._compute_prefill_context( \ q, kv_c_and_k_pe_cache, attn_metadata) diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index 057fccb5e598..40daec3ec124 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -178,7 +178,8 @@ def _decode_att_m_fwd( page_size, logit_cap, ): - BLOCK = 64 + BLOCK = 64 if not is_hip_ else 8 + NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] @@ -188,7 +189,9 @@ def _decode_att_m_fwd( grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[-2] - num_warps = 4 if kv_group_num == 1 else 2 + num_warps = 4 + if kv_group_num != 1: + num_warps = 1 if is_hip_ else 2 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) @@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd( ) extra_kargs = {} + num_stages = 2 if is_hip_: - # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = { - "waves_per_eu": 4, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2 } + num_stages = 1 _fwd_grouped_kernel_stage1[grid]( q, @@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd( PAGE_SIZE=page_size, logit_cap=logit_cap, num_warps=4, - num_stages=2, + num_stages=num_stages, Lk=Lk, Lv=Lv, **extra_kargs, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..2b1167fc71e2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,128 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} From 4962680f34caf4e2dda0bfd0da8d366a2efd325e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 27 Feb 2025 18:03:41 -0500 Subject: [PATCH 276/317] [Attention] Flash MLA for V1 (#13867) Signed-off-by: Yang Chen Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson Co-authored-by: Yang Chen --- vllm/platforms/cuda.py | 35 +++-- vllm/platforms/interface.py | 5 +- vllm/v1/attention/backends/mla/common.py | 11 +- vllm/v1/attention/backends/mla/flashmla.py | 139 ++++++++++++++++++ .../backends/{ => mla}/triton_mla.py | 0 5 files changed, 170 insertions(+), 20 deletions(-) create mode 100644 vllm/v1/attention/backends/mla/flashmla.py rename vllm/v1/attention/backends/{ => mla}/triton_mla.py (100%) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0209c7236278..2a4cac46c066 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -161,15 +161,9 @@ def get_current_memory_usage(cls, def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: - if use_v1: - if use_mla: - logger.info("Using Triton MLA backend on V1 engine.") - return "vllm.v1.attention.backends.triton_mla.TritonMLABackend" - else: - logger.info("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends.flash_attn." - "FlashAttentionBackend") if use_mla: + # TODO(lucas): refactor to be more concise + # we should probably consider factoring out V1 here if selected_backend == _Backend.FLASHMLA: from vllm.attention.backends.flashmla import ( is_flashmla_supported) @@ -183,11 +177,26 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, " (currently only supports block size 64).", block_size) else: - logger.info("Using FlashMLA backend.") - return "vllm.attention.backends.flashmla.FlashMLABackend" - - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" + if use_v1: + logger.info("Using FlashMLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashmla.FlashMLABackend") + else: + logger.info("Using FlashMLA backend.") + return ("vllm.attention.backends." + "flashmla.FlashMLABackend") + + if use_v1: + logger.info("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" + if use_v1: + logger.info("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends.flash_attn." + "FlashAttentionBackend") if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 5f988e1479c5..6e80a1ff269a 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -34,9 +34,8 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() OPENVINO = enum.auto() FLASHINFER = enum.auto() - TRITON_MLA = enum.auto() - TRITON_MLA_VLLM_V1 = enum.auto() - FLASHMLA = enum.auto() + TRITON_MLA = enum.auto() # Supported by V1 + FLASHMLA = enum.auto() # Supported by V1 HPU_ATTN = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2a742f5ce524..30bce5cc8b68 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -333,13 +333,16 @@ def __post_init__(self): T = TypeVar("T", bound=MLACommonMetadata) -class MLACommonMetadataBuilder: +class MLACommonMetadataBuilder(Generic[T]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ - def __init__(self, runner: "GPUModelRunner"): + def __init__(self, + runner: "GPUModelRunner", + cls: Optional[type[T]] = None): + self.cls = cls if cls is not None else MLACommonMetadata self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config @@ -431,7 +434,7 @@ def reorder_batch(self, input_batch: "InputBatch", self._num_prefill_tokens = num_prefill_tokens def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int) -> T: device = self.runner.device max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( @@ -502,7 +505,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, assert max(context_chunk_seq_tot) <= \ self.chunked_prefill_workspace_size - return MLACommonMetadata( + return self.cls( input_positions=input_positions, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py new file mode 100644 index 000000000000..8a7b7b974e36 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) + +logger = init_logger(__name__) + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner, cls=FlashMLAMetadata) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + m = super().build(num_reqs, num_actual_tokens, max_query_len, + common_prefix_len) + + if m.num_decode_tokens is not None and m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + get_mla_metadata( + m.seq_lens[:m.num_decode_tokens], + self.num_q_heads, + 1, # MQA for the decode path + ) + + return m + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 FlashMLA not yet supported") + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=attn_metadata.block_table[:attn_metadata.num_decodes, + ...], + cache_seqlens=attn_metadata.seq_lens[:attn_metadata. + num_decode_tokens], + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=attn_metadata. + decode_tile_scheduler_metadata, + num_splits=attn_metadata.decode_num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/v1/attention/backends/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py similarity index 100% rename from vllm/v1/attention/backends/triton_mla.py rename to vllm/v1/attention/backends/mla/triton_mla.py From b5ae7d9df15fda23a50cbc14c342e89f8540e519 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 27 Feb 2025 18:28:08 -0500 Subject: [PATCH 277/317] [Model][Speculative Decoding] Expand DeepSeek MTP code to support k > n_predict (#13626) Signed-off-by: Benjamin Chislett --- vllm/config.py | 11 +++++------ vllm/model_executor/models/deepseek_mtp.py | 14 +++++++++----- vllm/spec_decode/draft_model_runner.py | 11 ++++------- vllm/spec_decode/multi_step_worker.py | 17 +++++++++++++++++ vllm/spec_decode/spec_decode_worker.py | 6 +++--- vllm/worker/model_runner.py | 12 +++++++++++- 6 files changed, 49 insertions(+), 22 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index cb683d19386b..c3f9932ab8b3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1978,13 +1978,12 @@ def maybe_create_spec_config( if num_speculative_tokens is None: # Default to max value defined in draft model config. num_speculative_tokens = n_predict - elif num_speculative_tokens > n_predict: - # Verify provided value doesn't exceed the maximum - # supported by the draft model. + elif num_speculative_tokens > n_predict and \ + num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. raise ValueError( - "This speculative model supports a maximum of " - f"num_speculative_tokens={n_predict}, but " - f"{num_speculative_tokens=} was provided.") + f"{num_speculative_tokens=} must be divisible by " + f"{n_predict=}") speculative_draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index cac1b2b3b11c..e7fde76cd0ba 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -87,7 +87,7 @@ def forward( hidden_states=hidden_states, residual=None) hidden_states = residual + hidden_states - return self.shared_head(hidden_states) + return hidden_states class DeepSeekMultiTokenPredictor(nn.Module): @@ -121,12 +121,13 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, previous_hidden_states, inputs_embeds, - spec_step_idx, + current_step_idx, ) def compute_logits( @@ -135,9 +136,12 @@ def compute_logits( sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: - mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, - hidden_states, sampling_metadata) + mtp_layer.shared_head(hidden_states), + sampling_metadata) return logits diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index c54e6abe18d7..bc1b3e2319d0 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -50,12 +50,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): """ def __init__(self, model_runner: ModelRunnerBase): - if hasattr( - model_runner, - "return_hidden_states") and model_runner.return_hidden_states: - raise ValueError( - "return_hidden_states is not supported for TP1DraftModelRunner." - ) super().__init__(model_runner) self.indices_of_seq_with_bonus_tokens = None @@ -153,7 +147,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add support for other attn backends - if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): + if self.attn_backend.get_name() not in ("FLASH_ATTN", ): return False # TODO: Add support for LORA @@ -307,6 +301,9 @@ def execute_model( ) outputs.append(output) + if self.return_hidden_states and is_fallback: + output.hidden_states = hidden_states + if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index c28d413efe74..d8d54918fa98 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -96,12 +96,16 @@ def sampler_output( # TODO: Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) + if expanded_request.previous_hidden_states is not None: + self.worker.model_runner.return_hidden_states = True for _ in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] + self._maybe_update_previous_hidden_states( + model_output, expanded_request) self._append_new_tokens( model_output, expanded_request.seq_group_metadata_list, @@ -115,6 +119,19 @@ def sampler_output( model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True + @staticmethod + def _maybe_update_previous_hidden_states( + model_output: SamplerOutput, + expanded_request: ExecuteModelRequest) -> None: + """ + Updates the previous hidden states in an expanded request + in-place with the hidden states from the model output. + """ + if expanded_request.previous_hidden_states is not None: + expanded_request.previous_hidden_states = HiddenStates( + model_output.hidden_states, + expanded_request.seq_group_metadata_list) + @staticmethod def _expand_execute_model_request( execute_model_req: ExecuteModelRequest, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 871a3aee6306..8909a41bc99f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -184,8 +184,7 @@ def create_worker( elif draft_model_config.hf_config.model_type == "medusa": proposer_worker = MedusaWorker(**draft_worker_kwargs) else: - if draft_tp == 1 or draft_model_config.hf_config.model_type ==\ - "deepseek_mtp": + if draft_tp == 1: if current_platform.is_cuda_alike(): draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner @@ -203,7 +202,8 @@ def create_worker( proposer_worker = MultiStepWorker(**draft_worker_kwargs) if draft_model_config.hf_config.model_type == "deepseek_mtp": - num_spec_prefill_steps = num_speculative_tokens + num_spec_prefill_steps = \ + draft_model_config.hf_config.n_predict proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a37a3168bbbc..bb2228165b52 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1685,11 +1685,22 @@ def execute_model( # TODO(andoorve): We can remove this once all # virtual engines share the same kv cache. virtual_engine = model_input.virtual_engine + previous_hidden_states = kwargs.get("previous_hidden_states") if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][ graph_batch_size] + if previous_hidden_states is not None: + previous_hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) else: model_executable = self.model @@ -1716,7 +1727,6 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} - previous_hidden_states = kwargs.get("previous_hidden_states") model_kwargs = {} if previous_hidden_states is not None: model_kwargs["previous_hidden_states"] = previous_hidden_states From 76f457e67c74fe36ddff14781b95c99f3bf47da0 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 28 Feb 2025 07:53:13 +0800 Subject: [PATCH 278/317] [Misc] Print FusedMoE detail info (#13974) --- vllm/model_executor/layers/fused_moe/layer.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 28a88571dab4..052d4d54601f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -737,3 +737,23 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, # If we are in the row parallel case (down_proj) else: param_data[expert_id] = loaded_weight + + def extra_repr(self) -> str: + + s = ( + f"global_num_experts={self.global_num_experts}, " + f"local_num_experts={self.local_num_experts}, " + f"top_k={self.top_k}, " + f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501 + f"tp_size={self.tp_size},\n" + f"ep_size={self.ep_size}, " + f"reduce_results={self.reduce_results}, " + f"renormalize={self.renormalize}, " + f"use_grouped_topk={self.use_grouped_topk}") + + if self.use_grouped_topk: + s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 + + s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 + + return s From 8d02f59b3d9af3076eeff549325c7cc6d9300af6 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 27 Feb 2025 17:02:15 -0800 Subject: [PATCH 279/317] [V1]`SupportsV0Only` protocol for model definitions (#13959) Signed-off-by: Roger Wang --- vllm/config.py | 5 ++++ vllm/model_executor/models/__init__.py | 7 +++-- vllm/model_executor/models/bamba.py | 5 ++-- vllm/model_executor/models/bart.py | 3 ++- vllm/model_executor/models/bert.py | 4 +-- vllm/model_executor/models/florence2.py | 4 +-- vllm/model_executor/models/gritlm.py | 4 ++- vllm/model_executor/models/interfaces.py | 26 +++++++++++++++++++ vllm/model_executor/models/jamba.py | 5 ++-- vllm/model_executor/models/mamba.py | 6 +++-- vllm/model_executor/models/mamba2.py | 6 +++-- vllm/model_executor/models/minicpmv.py | 6 +++-- vllm/model_executor/models/mllama.py | 5 ++-- vllm/model_executor/models/paligemma.py | 4 +-- .../models/prithvi_geospatial_mae.py | 6 +++-- vllm/model_executor/models/qwen2_rm.py | 5 ++-- vllm/model_executor/models/registry.py | 14 ++++++++-- vllm/model_executor/models/roberta.py | 5 ++-- vllm/model_executor/models/whisper.py | 5 ++-- 19 files changed, 93 insertions(+), 32 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c3f9932ab8b3..78d02b017350 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1039,6 +1039,11 @@ def supported_runner_types(self) -> Set[RunnerType]: def runner_type(self) -> RunnerType: return _TASK_RUNNER[self.task] + @property + def is_v1_compatible(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_v1_compatible(architectures) + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 6be4a8341306..3580c4fa5252 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, - SupportsPP, has_inner_state, supports_lora, - supports_multimodal, supports_pp) + SupportsPP, SupportsV0Only, has_inner_state, + supports_lora, supports_multimodal, supports_pp, + supports_v0_only) from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, is_pooling_model, is_text_generation_model) from .registry import ModelRegistry @@ -21,4 +22,6 @@ "supports_multimodal", "SupportsPP", "supports_pp", + "SupportsV0Only", + "supports_v0_only", ] diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 69da05884ded..ec62e41d59f0 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -32,7 +32,8 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsV0Only) from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -366,7 +367,7 @@ def forward( class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): + IsHybrid, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 93452696dca5..82684dfa730e 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -43,6 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsV0Only from .utils import maybe_prefix logger = logging.get_logger(__name__) @@ -776,7 +777,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, return decoder_outputs -class BartForConditionalGeneration(nn.Module): +class BartForConditionalGeneration(nn.Module, SupportsV0Only): base_model_prefix = "model" def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 4ff69527653d..77b2ef0fce5f 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -26,7 +26,7 @@ from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .interfaces import SupportsCrossEncoding +from .interfaces import SupportsCrossEncoding, SupportsV0Only from .utils import WeightsMapper, maybe_prefix @@ -385,7 +385,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module): +class BertEmbeddingModel(nn.Module, SupportsV0Only): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index c51fcf3d438b..6fa1bb80995d 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -29,7 +29,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsV0Only from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings @@ -651,7 +651,7 @@ def forward( return decoder_outputs -class Florence2LanguageForConditionalGeneration(nn.Module): +class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 16223953ff83..2984f2241286 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -19,6 +19,8 @@ PoolingSequenceGroupOutput) from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from .interfaces import SupportsV0Only + logger = init_logger(__name__) @@ -177,7 +179,7 @@ def forward( return PoolerOutput(outputs=pooled_outputs) -class GritLM(LlamaForCausalLM): +class GritLM(LlamaForCausalLM, SupportsV0Only): """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. The class inherits from LlamaForCausalLM and provides a custom pooling diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 47bd05f140c8..fb3ceb005295 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -498,3 +498,29 @@ def supports_transcription( return isinstance(model, SupportsTranscription) return isinstance(model, SupportsTranscription) + + +@runtime_checkable +class SupportsV0Only(Protocol): + """Models with this interface are not compatible with V1 vLLM.""" + + supports_v0_only: ClassVar[Literal[True]] = True + + +@overload +def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]: + ... + + +@overload +def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: + ... + + +def supports_v0_only( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]: + if isinstance(model, type): + return isinstance(model, SupportsV0Only) + + return isinstance(model, SupportsV0Only) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 14e56df6cadf..58eccd6a6b87 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -30,7 +30,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import LayerBlockType -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsV0Only) from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -353,7 +354,7 @@ def forward( class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): + IsHybrid, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9f1cd8c29a5a..46b9182f2d79 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -19,7 +19,8 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree, SupportsPP) + IsAttentionFree, SupportsPP, + SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -155,7 +156,8 @@ def forward( return hidden_states -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, + SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 266cdc243ac4..da5cbddbcbc5 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -22,7 +22,8 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree) + IsAttentionFree, + SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -174,7 +175,8 @@ def forward( return hidden_states -class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): +class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, + SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index fb6ea53acf9e..1816bf5d008d 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -63,7 +63,8 @@ from vllm.sequence import IntermediateTensors from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, + SupportsV0Only) from .utils import AutoWeightsLoader, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -804,7 +805,8 @@ def apply( return result -class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): +class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, + SupportsV0Only): """ The abstract class of MiniCPMV can only be inherited, but cannot be instantiated. diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 36e653e41e1b..7122fea2b3a8 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -63,7 +63,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .clip import CLIPMLP -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsV0Only from .llama import LlamaDecoderLayer, LlamaMLP from .utils import maybe_prefix @@ -1128,7 +1128,8 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, info=MllamaProcessingInfo, dummy_inputs=MllamaDummyInputsBuilder) -class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"] diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 02d1861b8027..9a1398c28dbc 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -18,7 +18,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import (AutoWeightsLoader, init_vllm_registered_model, @@ -136,7 +136,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + SupportsPP, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index bfa90e42733d..d922329b3a49 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -25,7 +25,8 @@ from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModal) + SupportsMultiModal, + SupportsV0Only) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -111,7 +112,8 @@ def apply( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, + SupportsV0Only): """ Prithvi Masked Autoencoder""" def _instantiate_model(self, config: dict) -> Optional[nn.Module]: diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 21cc9e8ed1c6..90f799e6734e 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -17,7 +17,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsV0Only from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix @@ -33,7 +33,8 @@ def forward(self, input): return self.activation(input) -class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP, + SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 75e31d557dd1..028658b52644 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,7 +22,7 @@ from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, - supports_pp, supports_transcription) + supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -228,6 +228,7 @@ class _ModelInfo: is_attention_free: bool is_hybrid: bool supports_transcription: bool + supports_v0_only: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -241,7 +242,9 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), - supports_transcription=supports_transcription(model)) + supports_transcription=supports_transcription(model), + supports_v0_only=supports_v0_only(model), + ) class _BaseRegisteredModel(ABC): @@ -504,6 +507,13 @@ def is_transcription_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_transcription + def is_v1_compatible( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return not model_cls.supports_v0_only + ModelRegistry = _ModelRegistry({ model_arch: diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index f86fa268072d..ba92eef12707 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -19,7 +19,7 @@ from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .interfaces import SupportsCrossEncoding +from .interfaces import SupportsCrossEncoding, SupportsV0Only def roberta_task_weights_filter( @@ -191,7 +191,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): assert len(loaded), "Unable to load RobertaEmbeddingModel" -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsV0Only): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 2da8c5c8b0e2..656e5fc6dcf3 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -34,7 +34,8 @@ PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs -from .interfaces import SupportsMultiModal, SupportsTranscription +from .interfaces import (SupportsMultiModal, SupportsTranscription, + SupportsV0Only) from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, make_layers) @@ -643,7 +644,7 @@ def _get_prompt_updates( info=WhisperProcessingInfo, dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal): + SupportsMultiModal, SupportsV0Only): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", From 0e21ae344d8dc296cd695ccb520552763d1d65df Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Thu, 27 Feb 2025 21:00:45 -0700 Subject: [PATCH 280/317] [Bugfix] Check that number of images matches number of <|image|> tokens with mllama (#13911) Signed-off-by: Travis Johnson --- .../vision_language/test_mllama.py | 5 ++-- vllm/model_executor/models/mllama.py | 24 ++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 202516f4c209..4fee04fdb7b6 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -479,8 +479,9 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, # Regression tests for https://github.com/vllm-project/vllm/issues/10648 - # Number of image tags is greater than the number of images provided - prompt = "<|begin_of_text|><|image|><|image|> Compare the two images" # noqa: E501 + # Number of groups of image tokens is greater than the number of images + # provided (the whitespace between the tags is necessary) + prompt = "<|begin_of_text|><|image|> <|image|> Compare the two images" # noqa: E501 image = stop_sign with pytest.raises(ValueError): vllm_model.generate_greedy_logprobs([prompt], diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 7122fea2b3a8..2a829bf0e61e 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -54,7 +54,8 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalEncDecInputs, + MultiModalFieldConfig, MultiModalKwargs) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataDict, MultiModalDataItems) from vllm.multimodal.processing import (BaseProcessingInfo, @@ -169,6 +170,27 @@ def get_dummy_processor_inputs( class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ): + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalEncDecInputs: + mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs) + + # Check that the number of image tokens in the decoder prompt matches + # the number of images provided in mm_data + num_image_tokens = mm_inputs['prompt_token_ids'].count( + self.info.get_hf_config().image_token_index) + image_data = mm_data.get("image", []) + num_images = 1 if isinstance(image_data, Image) else len(image_data) + if num_image_tokens != num_images: + raise ValueError( + f"The number of image tokens ({num_image_tokens}) must be" + f" the same as the number of images ({num_images})") + + return mm_inputs + def _call_hf_processor( self, prompt: str, From f00446b345222c6fbae7c62d824928e8878bfc05 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 28 Feb 2025 15:12:04 +0800 Subject: [PATCH 281/317] [Doc] Move multimodal Embedding API example to Online Serving page (#14017) Signed-off-by: DarkLight1337 --- docs/source/serving/multimodal_inputs.md | 89 ++----------------- .../serving/openai_compatible_server.md | 80 ++++++++++++++++- vllm/model_executor/models/registry.py | 4 +- 3 files changed, 89 insertions(+), 84 deletions(-) diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index 5cec5548ba18..c540bff2cf30 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -16,7 +16,7 @@ To input multi-modal data, follow this schema in {class}`vllm.inputs.PromptType` - `prompt`: The prompt should follow the format that is documented on HuggingFace. - `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.inputs.MultiModalDataDict`. -### Image +### Image Inputs You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples: @@ -120,20 +120,20 @@ for o in outputs: print(generated_text) ``` -### Video +### Video Inputs You can pass a list of NumPy arrays directly to the `'video'` field of the multi-modal dictionary instead of using multi-image input. Full example: -### Audio +### Audio Inputs You can pass a tuple `(array, sampling_rate)` to the `'audio'` field of the multi-modal dictionary. Full example: -### Embedding +### Embedding Inputs To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. @@ -211,7 +211,7 @@ The chat template can be inferred based on the documentation on the model's Hugg For example, LLaVA-1.5 (`llava-hf/llava-1.5-7b-hf`) requires a chat template that can be found here: ::: -### Image +### Image Inputs Image input is supported according to [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). Here is a simple example using Phi-3.5-Vision. @@ -293,7 +293,7 @@ export VLLM_IMAGE_FETCH_TIMEOUT= ::: -### Video +### Video Inputs Instead of `image_url`, you can pass a video file via `video_url`. Here is a simple example using [LLaVA-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf). @@ -356,7 +356,7 @@ export VLLM_VIDEO_FETCH_TIMEOUT= ::: -### Audio +### Audio Inputs Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in). Here is a simple example using Ultravox-v0.5-1B. @@ -460,77 +460,6 @@ export VLLM_AUDIO_FETCH_TIMEOUT= ::: -### Embedding +### Embedding Inputs -vLLM's Embeddings API is a superset of OpenAI's [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings), -where a list of chat `messages` can be passed instead of batched `inputs`. This enables multi-modal inputs to be passed to embedding models. - -:::{tip} -The schema of `messages` is exactly the same as in Chat Completions API. -You can refer to the above tutorials for more details on how to pass each type of multi-modal data. -::: - -Usually, embedding models do not expect chat-based input, so we need to use a custom chat template to format the text and images. -Refer to the examples below for illustration. - -Here is an end-to-end example using VLM2Vec. To serve the model: - -```bash -vllm serve TIGER-Lab/VLM2Vec-Full --task embed \ - --trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja -``` - -:::{important} -Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass `--task embed` -to run this model in embedding mode instead of text generation mode. - -The custom chat template is completely different from the original one for this model, -and can be found here: -::: - -Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library: - -```python -import requests - -image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" - -response = requests.post( - "http://localhost:8000/v1/embeddings", - json={ - "model": "TIGER-Lab/VLM2Vec-Full", - "messages": [{ - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": "Represent the given image."}, - ], - }], - "encoding_format": "float", - }, -) -response.raise_for_status() -response_json = response.json() -print("Embedding output:", response_json["data"][0]["embedding"]) -``` - -Below is another example, this time using the `MrLight/dse-qwen2-2b-mrl-v1` model. - -```bash -vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embed \ - --trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja -``` - -:::{important} -Like with VLM2Vec, we have to explicitly pass `--task embed`. - -Additionally, `MrLight/dse-qwen2-2b-mrl-v1` requires an EOS token for embeddings, which is handled -by a custom chat template: -::: - -:::{important} -Also important, `MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code -example below for details. -::: - -Full example: +TBD diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 9b9242abf1e2..5ab46da90ea6 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -266,11 +266,85 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai If the model has a [chat template](#chat-template), you can replace `inputs` with a list of `messages` (same schema as [Chat API](#chat-api)) which will be treated as a single prompt to the model. -:::{tip} -This enables multi-modal inputs to be passed to embedding models, see [this page](#multimodal-inputs) for details. +Code example: + +#### Multi-modal inputs + +You can pass multi-modal inputs to embedding models by defining a custom chat template for the server +and passing a list of `messages` in the request. Refer to the examples below for illustration. + +:::::{tab-set} +::::{tab-item} VLM2Vec + +To serve the model: + +```bash +vllm serve TIGER-Lab/VLM2Vec-Full --task embed \ + --trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja +``` + +:::{important} +Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass `--task embed` +to run this model in embedding mode instead of text generation mode. + +The custom chat template is completely different from the original one for this model, +and can be found here: ::: -Code example: +Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library: + +```python +import requests + +image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + +response = requests.post( + "http://localhost:8000/v1/embeddings", + json={ + "model": "TIGER-Lab/VLM2Vec-Full", + "messages": [{ + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Represent the given image."}, + ], + }], + "encoding_format": "float", + }, +) +response.raise_for_status() +response_json = response.json() +print("Embedding output:", response_json["data"][0]["embedding"]) +``` + +:::: + +::::{tab-item} DSE-Qwen2-MRL + +To serve the model: + +```bash +vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embed \ + --trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja +``` + +:::{important} +Like with VLM2Vec, we have to explicitly pass `--task embed`. + +Additionally, `MrLight/dse-qwen2-2b-mrl-v1` requires an EOS token for embeddings, which is handled +by a custom chat template: +::: + +:::{important} +`MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code +example below for details. +::: + +:::: + +::::: + +Full example: #### Extra parameters diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 028658b52644..4551d81e8a5d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -19,6 +19,7 @@ import torch.nn as nn from vllm.logger import init_logger +from vllm.utils import is_in_doc_build from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, @@ -368,7 +369,8 @@ def register_model( raise ValueError(msg) model = _LazyRegisteredModel(*split_str) - elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module): + elif isinstance(model_cls, type) and (is_in_doc_build() or issubclass( + model_cls, nn.Module)): model = _RegisteredModel.from_model_cls(model_cls) else: msg = ("`model_cls` should be a string or PyTorch model class, " From 8f1c66366513eee60809acc0306e773378fd7a4c Mon Sep 17 00:00:00 2001 From: Mathis Felardos Date: Fri, 28 Feb 2025 08:53:45 +0100 Subject: [PATCH 282/317] [Bugfix][Disaggregated] patch the inflight batching on the decode node in SimpleConnector to avoid hangs in SimpleBuffer (nccl based) (#13987) Signed-off-by: Mathis Felardos --- .../kv_connector/simple_connector.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 2033e9762ac0..8e2fbf36b4de 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -214,6 +214,7 @@ def recv_kv_caches_and_hidden_states( input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() hidden_or_intermediate_states_for_one_req = [] @@ -225,9 +226,21 @@ def recv_kv_caches_and_hidden_states( # enumerate different requests # FIXME(Kuntai): This impl assumes that all requests are prefill. for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + current_tokens = input_tokens_tensor[start_pos:end_pos] num_tokens = slen @@ -288,7 +301,7 @@ def recv_kv_caches_and_hidden_states( # Here we will fall back to normal model forwarding # But optionally you can adjust model_input so that you only do # prefilling on those tokens that are missing KV caches. - logger.debug( + logger.warning( "[rank%d]: Failed to receive all KVs and hidden " "states, redo model forwarding.", torch.distributed.get_rank()) hidden_or_intermediate_states = None From 377095a314eddfa795379125beb1f9dc4cbbc898 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 28 Feb 2025 08:50:43 +0000 Subject: [PATCH 283/317] Use smaller embedding model when not testing model specifically (#13891) --- tests/entrypoints/llm/test_encode.py | 2 +- tests/entrypoints/openai/test_embedding.py | 2 +- tests/entrypoints/openai/test_metrics.py | 4 ++-- tests/entrypoints/openai/test_run_batch.py | 10 +++++----- tests/model_executor/test_model_load_with_params.py | 4 ++-- tests/models/embedding/language/test_embedding.py | 2 +- tests/models/registry.py | 2 +- tests/test_config.py | 2 +- vllm/test_utils.py | 2 +- 9 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index ebec8baba38d..a65235ccdf19 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -8,7 +8,7 @@ from vllm import LLM, PoolingParams, PoolingRequestOutput from vllm.distributed import cleanup_dist_env_and_memory -MODEL_NAME = "intfloat/e5-mistral-7b-instruct" +MODEL_NAME = "intfloat/multilingual-e5-small" PROMPTS = [ "Hello, my name is", diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index e86ea87dd661..8d00564351c5 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -13,7 +13,7 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "intfloat/e5-mistral-7b-instruct" +MODEL_NAME = "intfloat/multilingual-e5-small" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 5aa259a4f318..39ce4ba23548 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -282,7 +282,7 @@ async def test_metrics_exist(server: RemoteOpenAIServer, def test_metrics_exist_run_batch(use_v1: bool): if use_v1: pytest.skip("Skipping test on vllm V1") - input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}""" # noqa: E501 + input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501 base_url = "0.0.0.0" port = "8001" @@ -302,7 +302,7 @@ def test_metrics_exist_run_batch(use_v1: bool): "-o", output_file.name, "--model", - "intfloat/e5-mistral-7b-instruct", + "intfloat/multilingual-e5-small", "--enable-metrics", "--url", base_url, diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index db049ee2bfd8..643d0d06abcb 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -18,10 +18,10 @@ INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" -INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}} +INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are an unhelpful assistant."}} -{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "Hello world!"}} +{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}} {"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}""" INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} @@ -37,7 +37,7 @@ def test_empty_file(): proc = subprocess.Popen([ sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", input_file.name, "-o", output_file.name, "--model", - "intfloat/e5-mistral-7b-instruct" + "intfloat/multilingual-e5-small" ], ) proc.communicate() proc.wait() @@ -97,7 +97,7 @@ def test_embeddings(): proc = subprocess.Popen([ sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", input_file.name, "-o", output_file.name, "--model", - "intfloat/e5-mistral-7b-instruct" + "intfloat/multilingual-e5-small" ], ) proc.communicate() proc.wait() diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 760a11993523..f8efa2eff857 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -14,7 +14,7 @@ REVISION = os.environ.get("REVISION", "main") MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME", - "intfloat/multilingual-e5-large") + "intfloat/multilingual-e5-small") REVISION_ROBERTA = os.environ.get("REVISION", "main") @@ -83,7 +83,7 @@ def test_roberta_model_loading_with_params(vllm_runner): assert model_config.pooler_config.pooling_norm # asserts on the tokenizer loaded - assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large" + assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-small" assert not model_tokenizer.tokenizer_config["do_lower_case"] def check_model(model): diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index ad6385376dc8..4b9926860f24 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -17,7 +17,7 @@ pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model, pytest.mark.cpu_model]), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), - pytest.param("intfloat/multilingual-e5-large"), + pytest.param("intfloat/multilingual-e5-small"), # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", marks=[pytest.mark.core_model]), diff --git a/tests/models/registry.py b/tests/models/registry.py index 95bda0293498..78a65b93870e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -211,7 +211,7 @@ def check_available_online( "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 - "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"), + "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", diff --git a/tests/test_config.py b/tests/test_config.py index 8927a14d79ac..709d60b83670 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,7 +13,7 @@ ("model_id", "expected_runner_type", "expected_task"), [ ("distilbert/distilgpt2", "generate", "generate"), - ("intfloat/e5-mistral-7b-instruct", "pooling", "embed"), + ("intfloat/multilingual-e5-small", "pooling", "embed"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), diff --git a/vllm/test_utils.py b/vllm/test_utils.py index eb9a4d80a2c2..8611a25922bb 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -28,7 +28,7 @@ "HuggingFaceM4/Idefics3-8B-Llama3", "internlm/internlm2-1_8b-reward", "intfloat/e5-mistral-7b-instruct", - "intfloat/multilingual-e5-large", + "intfloat/multilingual-e5-small", "jason9693/Qwen2.5-1.5B-apeach", "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", From 24a38656e53dc75f85a87dc3810a3b6b27d1b99d Mon Sep 17 00:00:00 2001 From: Kacper Pietkun Date: Fri, 28 Feb 2025 09:51:49 +0100 Subject: [PATCH 284/317] [Hardware][Intel-Gaudi] Regional compilation support (#13213) --- vllm/worker/hpu_model_runner.py | 43 ++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d6eaf84e40f6..4ac547ae326d 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -39,7 +39,10 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs) @@ -311,10 +314,38 @@ def __init__(self, model, vllm_config): self.block_size = vllm_config.cache_config.block_size self.dtype = vllm_config.model_config.dtype enforce_eager = vllm_config.model_config.enforce_eager + if not htorch.utils.internal.is_lazy() and not enforce_eager: - self.model = torch.compile(self.model, - backend='hpu_backend', - dynamic=False) + if os.getenv('VLLM_REGIONAL_COMPILATION', + 'true').lower() == 'true': + self.regional_compilation_layers_list = [ + RMSNorm, VocabParallelEmbedding + ] + self._regional_compilation(self.model) + else: + self.model = torch.compile(self.model, + backend='hpu_backend', + dynamic=False) + + def _regional_compilation(self, + module, + parent_module=None, + module_name=None): + if isinstance(module, torch.nn.ModuleList): + for children_name, children_module in module.named_children(): + self._compile_region(module, children_name, children_module) + elif any( + isinstance(module, layer) + for layer in self.regional_compilation_layers_list): + self._compile_region(parent_module, module_name, module) + else: + for children_name, children_module in module.named_children(): + self._regional_compilation(children_module, module, + children_name) + + def _compile_region(self, model, name, module): + module = torch.compile(module, backend='hpu_backend', dynamic=False) + setattr(model, name, module) def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): @@ -1575,9 +1606,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: list(sorted(self.bucketing_global_state.decode_buckets))) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: - cache_size_limit = len( - self.bucketing_global_state.prompt_buckets) + len( - self.bucketing_global_state.decode_buckets) + 1 + cache_size_limit = 1 + 3 * ( + len(self.bucketing_global_state.prompt_buckets) + + len(self.bucketing_global_state.decode_buckets)) torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between From 181d4bf39660417abc82aa096de5b56d2f27e5dd Mon Sep 17 00:00:00 2001 From: Thibault Schueller <1625198+Ryp@users.noreply.github.com> Date: Fri, 28 Feb 2025 09:52:25 +0100 Subject: [PATCH 285/317] [V1][Minor] Restore V1 compatibility with LLMEngine class (#13090) --- vllm/engine/llm_engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3dee4dab4c47..9c83ea75ead7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2084,3 +2084,8 @@ def _build_logits_processors( sampling_params.logits_processors.extend(logits_processors) return sampling_params + + +# TODO(v1): Remove this class proxy when V1 goes default. +if envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine # type: ignore From d65f74bc002e9a3a1e1bff49eceecdbb20ea7e84 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 28 Feb 2025 15:20:29 +0000 Subject: [PATCH 286/317] Update AutoAWQ docs (#14042) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/source/features/quantization/auto_awq.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/features/quantization/auto_awq.md b/docs/source/features/quantization/auto_awq.md index fa0bebeb8ba1..7001ec91467f 100644 --- a/docs/source/features/quantization/auto_awq.md +++ b/docs/source/features/quantization/auto_awq.md @@ -6,13 +6,13 @@ To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%. The main benefits are lower latency and memory usage. -You can quantize your own models by installing AutoAWQ or picking one of the [400+ models on Huggingface](https://huggingface.co/models?sort=trending&search=awq). +You can quantize your own models by installing AutoAWQ or picking one of the [6500+ models on Huggingface](https://huggingface.co/models?sort=trending&search=awq). ```console pip install autoawq ``` -After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`: +After installing AutoAWQ, you are ready to quantize a model. Please refer to the `AutoAWQ documentation `_ for further details. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`: ```python from awq import AutoAWQForCausalLM From 367db4b1c902e3b361e6414034a3cd67fc0183aa Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 28 Feb 2025 23:22:42 +0800 Subject: [PATCH 287/317] [Bugfix] Fix MoeWNA16Method activation (#14024) --- vllm/model_executor/layers/quantization/moe_wna16.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index a3adac1bb129..41b75c9be05a 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -293,9 +293,10 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts - + assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, From c3eca7b66aa6ff88316b9676d67f78d337ab2ce4 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 28 Feb 2025 23:35:55 +0800 Subject: [PATCH 288/317] [VLM][Bugfix] Enable specifying prompt target via index (#14038) --- tests/multimodal/test_processing.py | 258 +++++++++++++++++++++++- vllm/model_executor/models/blip2.py | 6 +- vllm/model_executor/models/florence2.py | 5 +- vllm/model_executor/models/molmo.py | 6 +- vllm/multimodal/processing.py | 216 +++++++++++++++----- 5 files changed, 432 insertions(+), 59 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 878b15925006..ba3df86f715a 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -14,8 +14,8 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (PlaceholderFeaturesInfo, - PromptInsertion, PromptReplacement, - apply_text_matches, + PromptIndexTargets, PromptInsertion, + PromptReplacement, apply_text_matches, apply_token_matches, find_mm_placeholders, find_text_matches, find_token_matches, @@ -98,10 +98,20 @@ def test_iter_token_matches(token_ids, match_ids, expected): { "pattern_1": [], "pattern_2": [32000], + "pattern_3": PromptIndexTargets.start(), + "pattern_4": PromptIndexTargets.prefix([32000]), + "pattern_5": PromptIndexTargets.end(), }, { "pattern_1": [], "pattern_2": [], + "pattern_3": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_4": [], + "pattern_5": [ + { "start_idx": 0, "end_idx": 0 }, + ], }, ), ( @@ -110,6 +120,9 @@ def test_iter_token_matches(token_ids, match_ids, expected): "pattern_1": [32000], "pattern_2": [32000, 32000], "pattern_3": [32000, 32000, 32000], + "pattern_4": PromptIndexTargets.start(), + "pattern_5": PromptIndexTargets.prefix([32000]), + "pattern_6": PromptIndexTargets.end(), }, { "pattern_1": [ @@ -125,6 +138,15 @@ def test_iter_token_matches(token_ids, match_ids, expected): "pattern_3": [ { "start_idx": 0, "end_idx": 3 }, ], + "pattern_4": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_5": [ + { "start_idx": 1, "end_idx": 1 }, + ], + "pattern_6": [ + { "start_idx": 4, "end_idx": 4 }, + ], }, ), ( @@ -133,6 +155,9 @@ def test_iter_token_matches(token_ids, match_ids, expected): "pattern_1": [28747, 32000], "pattern_2": [28747, 32000, 32000, 32000], "pattern_3": [28747, 0, 32000], + "pattern_4": PromptIndexTargets.start(), + "pattern_5": PromptIndexTargets.prefix([28747, 32000]), + "pattern_6": PromptIndexTargets.end(), }, { "pattern_1": [ @@ -143,6 +168,13 @@ def test_iter_token_matches(token_ids, match_ids, expected): { "start_idx": 1, "end_idx": 5 }, ], "pattern_3": [], + "pattern_4": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_5": [], + "pattern_6": [ + { "start_idx": 10, "end_idx": 10 }, + ], }, ), ], @@ -189,10 +221,20 @@ def test_find_token_matches( { "pattern_1": "", "pattern_2": "", + "pattern_3": PromptIndexTargets.start(), + "pattern_4": PromptIndexTargets.prefix(""), + "pattern_5": PromptIndexTargets.end(), }, { "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], "pattern_2": [], + "pattern_3": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_4": [], + "pattern_5": [ + { "start_idx": 0, "end_idx": 0 }, + ], } ), ( @@ -201,6 +243,9 @@ def test_find_token_matches( "pattern_1": "", "pattern_2": "", "pattern_3": "", + "pattern_4": PromptIndexTargets.start(), + "pattern_5": PromptIndexTargets.prefix(""), + "pattern_6": PromptIndexTargets.end(), }, { "pattern_1": [ @@ -216,6 +261,15 @@ def test_find_token_matches( "pattern_3": [ { "start_idx": 0, "end_idx": 21 }, ], + "pattern_4": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_5": [ + { "start_idx": 7, "end_idx": 7 }, + ], + "pattern_6": [ + { "start_idx": 28, "end_idx": 28 }, + ], }, ), ( @@ -224,6 +278,9 @@ def test_find_token_matches( "pattern_1": "Image:", "pattern_2": "Image:", "pattern_3": "Image:", + "pattern_4": PromptIndexTargets.start(), + "pattern_5": PromptIndexTargets.prefix("Image:"), + "pattern_6": PromptIndexTargets.end(), }, { "pattern_1": [ @@ -234,6 +291,15 @@ def test_find_token_matches( { "start_idx": 0, "end_idx": 27 }, ], "pattern_3": [], + "pattern_4": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_5": [ + { "start_idx": 13, "end_idx": 13 }, + ], + "pattern_6": [ + { "start_idx": 48, "end_idx": 48 }, + ], }, ), # Test regex escape @@ -325,6 +391,100 @@ def test_find_text_matches( }, }, ), + # Test index targets + ( + "", + { + "pattern_1": PromptIndexTargets.start(), + "pattern_2": PromptIndexTargets.prefix(""), + "pattern_3": PromptIndexTargets.end(), + }, + { + "pattern_1": "1", + "pattern_2": "2", + "pattern_3": "3", + }, + { + PromptInsertion: { + 0: "", + 1: "13", + 2: "1133", + }, + PromptReplacement: { + 0: "", + 1: "13", + 2: "1133", + }, + }, + ), + ( + "", + { + "pattern_1": PromptIndexTargets.start(), + "pattern_2": PromptIndexTargets.prefix(""), + "pattern_3": PromptIndexTargets.end(), + }, + { + "pattern_1": "1", + "pattern_2": "2", + "pattern_3": "3", + }, + { + PromptInsertion: { + 0: "", + 1: "123", + 2: "112233", + }, + PromptReplacement: { + 0: "", + 1: "123", + 2: "112233", + }, + }, + ), + # Test different replacement per item + ( + "", + { + "pattern_1": "", + }, + { + "pattern_1": lambda idx: str(idx + 1), + }, + { + PromptInsertion: { + 0: "", + 1: "1", + 2: "12", + }, + PromptReplacement: { + 0: "", + 1: "1", + 2: "12", + }, + }, + ), + ( + "", + { + "pattern_1": PromptIndexTargets.prefix(""), + }, + { + "pattern_1": lambda idx: str(idx + 1), + }, + { + PromptInsertion: { + 0: "", + 1: "1", + 2: "12", + }, + PromptReplacement: { + 0: "", + 1: "1", + 2: "12", + }, + }, + ), ] ) # yapf: enable @@ -405,6 +565,100 @@ def test_find_update_text( }, }, ), + # Test index targets + ( + [], + { + "pattern_1": PromptIndexTargets.start(), + "pattern_2": PromptIndexTargets.prefix([32000]), + "pattern_3": PromptIndexTargets.end(), + }, + { + "pattern_1": [-1], + "pattern_2": [-2], + "pattern_3": [-3], + }, + { + PromptInsertion: { + 0: [], + 1: [-1, -3], + 2: [-1, -1, -3, -3], + }, + PromptReplacement: { + 0: [], + 1: [-1, -3], + 2: [-1, -1, -3, -3], + }, + }, + ), + ( + [32000], + { + "pattern_1": PromptIndexTargets.start(), + "pattern_2": PromptIndexTargets.prefix([32000]), + "pattern_3": PromptIndexTargets.end(), + }, + { + "pattern_1": [-1], + "pattern_2": [-2], + "pattern_3": [-3], + }, + { + PromptInsertion: { + 0: [32000], + 1: [-1, 32000, -2, -3], + 2: [-1, -1, 32000, -2, -2, -3, -3], + }, + PromptReplacement: { + 0: [32000], + 1: [-1, 32000, -2, -3], + 2: [-1, -1, 32000, -2, -2, -3, -3], + }, + }, + ), + # Test different replacement per item + ( + [32000, 32000, 32000], + { + "pattern_1": [32000], + }, + { + "pattern_1": lambda idx: [-(idx + 1)], + }, + { + PromptInsertion: { + 0: [32000, 32000, 32000], + 1: [32000, -1, 32000, 32000], + 2: [32000, -1, -2, 32000, 32000], + }, + PromptReplacement: { + 0: [32000, 32000, 32000], + 1: [-1, 32000, 32000], + 2: [-1, -2, 32000], + }, + }, + ), + ( + [32000, 32000, 32000], + { + "pattern_1": PromptIndexTargets.prefix([32000]), + }, + { + "pattern_1": lambda idx: [-(idx + 1)], + }, + { + PromptInsertion: { + 0: [32000, 32000, 32000], + 1: [32000, -1, 32000, 32000], + 2: [32000, -1, -2, 32000, 32000], + }, + PromptReplacement: { + 0: [32000, 32000, 32000], + 1: [32000, -1, 32000, 32000], + 2: [32000, -1, -2, 32000, 32000], + }, + }, + ), ] ) # yapf: enable diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 61f2f8974d91..8457f6294460 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -19,8 +19,8 @@ NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptInsertion, - PromptUpdate) + BaseProcessingInfo, PromptIndexTargets, + PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -490,7 +490,7 @@ def _get_prompt_updates( return [ PromptInsertion( modality="image", - target="", + target=PromptIndexTargets.start(), insertion=image_tokens, ) ] diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 6fa1bb80995d..7a8510379455 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -25,7 +25,8 @@ from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, - PromptInsertion, PromptUpdate) + PromptIndexTargets, PromptInsertion, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -864,7 +865,7 @@ def _get_prompt_updates( return [ PromptInsertion( modality="image", - target="", + target=PromptIndexTargets.start(), insertion=image_tokens, ) ] diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 60af103189f8..21158f7e5802 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -46,8 +46,8 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptInsertion, - PromptUpdate) + BaseProcessingInfo, PromptIndexTargets, + PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import JSONTree, json_map_leaves @@ -1371,7 +1371,7 @@ def get_insertion_molmo(item_idx: int): return [ PromptInsertion( modality="image", - target="<|endoftext|>", + target=PromptIndexTargets.prefix("<|endoftext|>"), insertion=get_insertion_molmo, ) ] diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index ac33af7c10c7..7232df074f84 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -8,7 +8,6 @@ from dataclasses import dataclass, field from enum import Enum from functools import lru_cache -from itertools import groupby from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, TypeVar, Union, cast) @@ -40,6 +39,65 @@ """A token sequence (list of token IDs) or text.""" +@dataclass +class PromptIndex: + """Resolves to an index in the prompt.""" + get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] + + +class PromptIndexTargets: + + @staticmethod + def start() -> PromptIndex: + """ + Resolves to the start of the prompt (before the first token). + + This results in a match even if the prompt is empty. + """ + return PromptIndex(lambda tok, prompt: 0) + + @staticmethod + def prefix(seq: PromptSeq) -> PromptIndex: + """ + Resolves to a location in the prompt after the given prefix. + """ + + def get_match_index( + tokenizer: AnyTokenizer, + prompt: PromptSeq, + ) -> Optional[int]: + prefix = seq + + if isinstance(prompt, str): + if not isinstance(prefix, str): + # Make both `str` + prefix = decode_tokens(tokenizer, prefix) + else: + if isinstance(prefix, str): + # Make both `list[int]` + prefix = encode_tokens(tokenizer, prefix) + + match_idx = len(prefix) + return match_idx if prompt[:match_idx] == prefix else None + + return PromptIndex(get_match_index) + + @staticmethod + def end() -> PromptIndex: + """ + Resolves to the end of the prompt (after the last token). + + This results in a match even if the prompt is empty. + """ + return PromptIndex(lambda tok, prompt: len(prompt)) + + +PromptTarget = Union[PromptSeq, PromptIndex] +""" +The token sequence or text to update. +""" + + @dataclass class PromptUpdateDetails: """Details about the token sequence or text that are part of the update.""" @@ -84,7 +142,7 @@ class UpdateMode(str, Enum): @dataclass -class PromptUpdate: +class PromptUpdate(ABC): """ Defines how to update a prompt with placeholder tokens. """ @@ -92,7 +150,7 @@ class PromptUpdate: modality: str """The modality for which the update is made.""" - target: PromptSeq + target: PromptTarget """The token sequence (or text) to update.""" @property @@ -122,24 +180,43 @@ class PromptInsertion(PromptUpdate): Example: For each image, insert a number of ```` feature placeholders - equal to the feature size of the vision encoder at the start of the - prompt: + equal to the feature size of the vision encoder after the ```` token: .. code-block:: python PromptInsertion( modality="image", - target="", + target="", insertion="" * image_feature_size, ) - As above, but insert after the ```` token: + Insert these tokens at the start of the prompt: .. code-block:: python PromptInsertion( modality="image", - target="", + target=PromptIndexTargets.start(), + insertion="" * image_feature_size, + ) + + Insert these tokens after a prefix ``Images:``: + + .. code-block:: python + + PromptInsertion( + modality="image", + target=PromptIndexTargets.prefix("Images:"), + insertion="" * image_feature_size, + ) + + Insert these tokens at the end of the prompt: + + .. code-block:: python + + PromptInsertion( + modality="image", + target=PromptIndexTargets.end(), insertion="" * image_feature_size, ) """ @@ -345,10 +422,14 @@ def modality(self) -> str: return self._origin.modality @property - def target(self) -> _BoundPromptSequence: + def target(self) -> Union[_BoundPromptSequence, PromptIndex]: """The token sequence (or text) to update.""" - return _BoundPromptSequence.from_seq(self.tokenizer, - self._origin.target) + target = self._origin.target + + if isinstance(target, PromptIndex): + return target + + return _BoundPromptSequence.from_seq(self.tokenizer, target) @property def content(self) -> PromptUpdateContent: @@ -447,6 +528,19 @@ def __repr__(self) -> str: f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") +@dataclass(repr=False) +class _PromptTargetIndexMatch(_PromptTargetMatch): + match_idx: int + + @property + def start_idx(self) -> int: + return self.match_idx + + @property + def end_idx(self) -> int: + return self.match_idx + + @dataclass(repr=False) class _PromptTargetTokenMatch(_PromptTargetMatch): match: _TokenMatch @@ -496,9 +590,24 @@ def find_token_matches( prompt_updates: Sequence[BoundPromptUpdate], ) -> Sequence[_PromptTargetMatch]: """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" + + def get_matches(update: BoundPromptUpdate): + target = update.target + + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(update.tokenizer, prompt) + if match_idx is None: + return [] + + return [_PromptTargetIndexMatch(update, match_idx)] + + return [ + _PromptTargetTokenMatch(update, match) + for match in iter_token_matches(prompt, target.token_ids) + ] + return [ - _PromptTargetTokenMatch(update, match) for update in prompt_updates - for match in iter_token_matches(prompt, update.target.token_ids) + match for update in prompt_updates for match in get_matches(update) ] @@ -507,9 +616,24 @@ def find_text_matches( prompt_updates: Sequence[BoundPromptUpdate], ) -> Sequence[_PromptTargetMatch]: """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" + + def get_matches(update: BoundPromptUpdate): + target = update.target + + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(update.tokenizer, prompt) + if match_idx is None: + return [] + + return [_PromptTargetIndexMatch(update, match_idx)] + + return [ + _PromptTargetTextMatch(update, match) + for match in re.finditer(re.escape(target.text), prompt) + ] + return [ - _PromptTargetTextMatch(update, match) for update in prompt_updates - for match in re.finditer(re.escape(update.target.text), prompt) + match for update in prompt_updates for match in get_matches(update) ] @@ -547,45 +671,39 @@ def _apply_matches( prev_end_idx = 0 next_idx_by_modality = defaultdict[str, int](lambda: 0) - for (start_idx, end_idx), group in groupby( - _resolve_matches(prompt, mm_matches), - key=lambda x: (x.start_idx, x.end_idx), - ): - matches = tuple(group) - assert len(matches) == 1 - - for match in matches: - modality = match.modality + for match in _resolve_matches(prompt, mm_matches): + modality = match.modality + + item_start_idx = next_idx_by_modality[modality] + max_item_count = mm_item_counts.get(modality, 0) + if item_start_idx >= max_item_count: + continue + + start_idx = match.start_idx + end_idx = match.end_idx + origin = match._origin + mode = origin.mode + + if mode == UpdateMode.INSERT: + out_seqs.append(prompt[prev_end_idx:end_idx]) + num_inserts = max_item_count + elif mode == UpdateMode.REPLACE: + out_seqs.append(prompt[prev_end_idx:start_idx]) + num_inserts = max_item_count if start_idx == end_idx else 1 + else: + assert_never(mode) - item_idx = next_idx_by_modality[modality] - if item_idx >= mm_item_counts.get(modality, 0): - continue + item_end_idx = min(item_start_idx + num_inserts, max_item_count) - origin = match._origin + for item_idx in range(item_start_idx, item_end_idx): content = origin.get_content(item_idx) - mode = origin.mode - - if mode == UpdateMode.INSERT: - out_seqs.append(prompt[prev_end_idx:end_idx]) - num_inserts = mm_item_counts.get(modality, 0) - elif mode == UpdateMode.REPLACE: - out_seqs.append(prompt[prev_end_idx:start_idx]) - num_inserts = 1 - else: - assert_never(mode) - - for _ in range(num_inserts): - if item_idx >= mm_item_counts.get(modality, 0): - continue - - if isinstance(prompt, str): - out_seqs.append(content.full.text) - else: - out_seqs.append(content.full.token_ids) + insert_seq = (content.full.text if isinstance(prompt, str) else + content.full.token_ids) - next_idx_by_modality[modality] += 1 + out_seqs.append(insert_seq) - prev_end_idx = end_idx + prev_end_idx = end_idx + next_idx_by_modality[modality] += item_end_idx - item_start_idx out_seqs.append(prompt[prev_end_idx:]) From 50bf059ee85f43973dba3aa7096b00e185ed8b64 Mon Sep 17 00:00:00 2001 From: Yang Liu <651636074@qq.com> Date: Fri, 28 Feb 2025 23:36:08 +0800 Subject: [PATCH 289/317] [Bugfix] Initialize attention bias on the same device as Query/Key/Value for QwenVL Series (#14031) --- vllm/model_executor/models/qwen2_5_vl.py | 3 ++- vllm/model_executor/models/qwen2_vl.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 0dbff665b5d3..ef3d28c8087d 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -323,7 +323,8 @@ def forward( seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None) + kv_seqlen=None, + device=q.device) context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index cb92fcbe9fa1..523b53d5ee41 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -367,7 +367,8 @@ def forward( seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None) + kv_seqlen=None, + device=q.device) context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) From e6134469d215325b5a5f0f8892fd6d10968a8fde Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Fri, 28 Feb 2025 11:42:07 -0500 Subject: [PATCH 290/317] [Doc] Fix ROCm documentation (#14041) Signed-off-by: Brayden Zhong --- docs/source/getting_started/installation/gpu/rocm.inc.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/installation/gpu/rocm.inc.md b/docs/source/getting_started/installation/gpu/rocm.inc.md index 7004313c90f3..84e7f6507de8 100644 --- a/docs/source/getting_started/installation/gpu/rocm.inc.md +++ b/docs/source/getting_started/installation/gpu/rocm.inc.md @@ -53,9 +53,9 @@ Currently, there are no pre-built ROCm wheels. If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. ::: -2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile) +2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention) - Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support) + Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention#amd-rocm-support) Alternatively, wheels intended for vLLM use can be accessed under the releases. For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. From f0c3ab60dec802e411772f675f986fd5e0778349 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 28 Feb 2025 16:56:44 +0000 Subject: [PATCH 291/317] Fix entrypoint tests for embedding models (#14052) --- tests/entrypoints/openai/test_embedding.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 8d00564351c5..a37169f51b05 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -27,7 +27,7 @@ def server(): "bfloat16", "--enforce-eager", "--max-model-len", - "8192", + "512", "--chat-template", DUMMY_CHAT_TEMPLATE, ] @@ -60,10 +60,10 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.id is not None assert len(embeddings.data) == 1 - assert len(embeddings.data[0].embedding) == 4096 + assert len(embeddings.data[0].embedding) == 384 assert embeddings.usage.completion_tokens == 0 - assert embeddings.usage.prompt_tokens == 9 - assert embeddings.usage.total_tokens == 9 + assert embeddings.usage.prompt_tokens == 11 + assert embeddings.usage.total_tokens == 11 # test using token IDs input_tokens = [1, 1, 1, 1, 1] @@ -77,7 +77,7 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.id is not None assert len(embeddings.data) == 1 - assert len(embeddings.data[0].embedding) == 4096 + assert len(embeddings.data[0].embedding) == 384 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 5 assert embeddings.usage.total_tokens == 5 @@ -101,10 +101,10 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.id is not None assert len(embeddings.data) == 3 - assert len(embeddings.data[0].embedding) == 4096 + assert len(embeddings.data[0].embedding) == 384 assert embeddings.usage.completion_tokens == 0 - assert embeddings.usage.prompt_tokens == 32 - assert embeddings.usage.total_tokens == 32 + assert embeddings.usage.prompt_tokens == 33 + assert embeddings.usage.total_tokens == 33 # test List[List[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], @@ -119,7 +119,7 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.id is not None assert len(embeddings.data) == 4 - assert len(embeddings.data[0].embedding) == 4096 + assert len(embeddings.data[0].embedding) == 384 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 17 assert embeddings.usage.total_tokens == 17 @@ -234,7 +234,7 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI, assert embeddings.id is not None assert len(embeddings.data) == 1 - assert len(embeddings.data[0].embedding) == 4096 + assert len(embeddings.data[0].embedding) == 384 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 10 assert embeddings.usage.total_tokens == 10 @@ -252,7 +252,7 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI, assert embeddings.id is not None assert len(embeddings.data) == 1 - assert len(embeddings.data[0].embedding) == 4096 + assert len(embeddings.data[0].embedding) == 384 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 10 assert embeddings.usage.total_tokens == 10 From a99cf1dde75c3f91e154f2635514072aec8cc034 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 3 Mar 2025 18:07:44 +0000 Subject: [PATCH 292/317] [V1][TPU] Integrate the new ragged paged attention kernel with vLLM v1 on TPU (#13379) Signed-off-by: Xiongfei Wei Signed-off-by: mgoin Co-authored-by: mgoin --- requirements-tpu.txt | 11 +- vllm/v1/attention/backends/pallas.py | 280 ++------ vllm/v1/outputs.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 4 +- vllm/v1/worker/tpu_model_runner.py | 987 ++++++++------------------- vllm/v1/worker/tpu_worker.py | 6 +- 6 files changed, 354 insertions(+), 936 deletions(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 8bfbb2dda194..725b1a2e4a58 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -17,9 +17,8 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" + +torch==2.7.0.dev20250226+cpu +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 37bf33f6e3e9..a9f7b3fd4471 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -4,13 +4,16 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch -import torch_xla.experimental.custom_kernel # Required to register custom ops. +# Required to register custom ops. +import torch_xla.experimental.custom_kernel # noqa: F401 from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) + AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +NUM_QUERIES_PER_BLOCK = 16 +NUM_KV_PAGES_PER_BLOCK = 128 + class PallasAttentionBackend(AttentionBackend): @@ -47,47 +50,23 @@ def swap_blocks( ) -> None: raise RuntimeError("swap_blocks is not used for the TPU backend.") - @torch.compile(backend="openxla") - @staticmethod - def copy_blocks( - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - src_to_dists: Tuple[torch.Tensor, torch.Tensor], - ) -> None: - src_indices, dst_indices = src_to_dists - for k_cache, v_cache in kv_caches: - torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) - k_cache[:, dst_indices] = k_cache[:, src_indices] - torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) - v_cache[:, dst_indices] = v_cache[:, src_indices] - @dataclass -class PallasMetadata(AttentionMetadata): - - # Currently, input sequences can only contain all prefills - # or all decoding. - block_tables: Optional[torch.Tensor] = None - context_lens: Optional[torch.Tensor] = None - effective_query_lens: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self) -> Optional["PallasMetadata"]: - if self.num_prefills == 0: - return None - - assert self.num_decode_tokens == 0 - return self - - @property - def decode_metadata(self) -> Optional["PallasMetadata"]: - if self.num_decode_tokens == 0: - return None - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.block_tables is not None - assert self.context_lens is not None - return self +class PallasMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Used in the PallasAttentionBackendImpl + slot_mapping: torch.Tensor + block_tables: torch.Tensor + context_lens: torch.Tensor + query_start_loc: torch.Tensor + num_seqs: int class PallasAttentionBackendImpl(AttentionImpl): @@ -105,10 +84,13 @@ def __init__( logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + if blocksparse_params is not None: + raise ValueError("Paged attention Pallas kernel does " + "not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -126,25 +108,6 @@ def __init__( raise NotImplementedError( "Attention logits soft-capping is not supported.") - if torch_xla.tpu.version() < 4: - raise NotImplementedError("TPU version must be 4 or higher.") - - self.megacore_mode = None - tpu_env = torch_xla.tpu.get_tpu_env() - tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) - or tpu_env.get("TYPE", None) - or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) - assert tpu_type is not None - tpu_type = tpu_type.lower() - - if (("lite" not in tpu_type) and ("v6" not in tpu_type)): - if self.num_kv_heads % 2 == 0: - self.megacore_mode = "kv_head" - else: - # NOTE(woosuk): If the batch size is not a multiple of 2, the - # megacore mode will be None. - self.megacore_mode = "batch" - if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -164,135 +127,47 @@ def forward( """Forward pass with Pallas attention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] - kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] - NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor - with shape [0] for profiling run. + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = ([num_kv_heads, num_blocks, block_size, head_size], + [num_kv_heads, num_blocks, block_size, head_size]) attn_metadata: Metadata for attention. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - - if attn_metadata is None: + # For determine_available_memory case. + if kv_cache[0].numel() == 0: if output is None: output = torch.ones_like(query) return output assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - batch_size, seq_len, hidden_size = query.shape - query = query.view(batch_size, seq_len, self.num_heads, self.head_size) - key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) - value = value.view(batch_size, seq_len, self.num_kv_heads, - self.head_size) + num_tokens, hidden_size = query.shape + query = query.view(num_tokens, self.num_heads, self.head_size) + key = key.view(num_tokens, self.num_kv_heads, self.head_size) + value = value.view(num_tokens, self.num_kv_heads, self.head_size) + key_cache, value_cache = kv_cache if kv_cache[0].numel() > 0: slot_mapping = attn_metadata.slot_mapping - key_cache, value_cache = kv_cache write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale - if attn_metadata.num_prefills > 0: - if attn_metadata.block_tables is None: - # Prefill without paged KV cache. - assert seq_len % 16 == 0, ( - "Pallas FlashAttention kernel requires seq_len to be a " - f"multiple of 16 but got {seq_len}") - - # Handle GQA/MQA. - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, - dim=-2) - key = key.view(batch_size, seq_len, self.num_heads, - self.head_size) - value = value.repeat_interleave(self.num_queries_per_kv, - dim=-2) - value = value.view(batch_size, seq_len, self.num_heads, - self.head_size) - # FlashAttention kernel requires the input shape to be - # [batch_size, num_heads, seq_len, d_model] - # while the input is [batch_size, seq_len, num_heads, d_model]. - # Permute the input to match the required format. - output = torch.ops.xla.flash_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - True, - ) - output = output.permute(0, 2, 1, 3) - else: - # Prefill with paged KV cache. - # TODO(woosuk): Tune the below knobs. - num_kv_pages_per_compute_block = 16 - num_queries_per_compute_block = 16 - assert seq_len % num_queries_per_compute_block == 0 - output = torch.ops.xla.multi_queries_paged_attention( - query, - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - attn_metadata.effective_query_lens, - num_kv_pages_per_compute_block, - num_queries_per_compute_block, - use_kernel=True, - ) - else: - # Decoding run. - assert kv_cache[0].numel() > 0 - query = query.squeeze(dim=1) - pages_per_compute_block = 16 # TODO(woosuk): Tune this value. - - assert attn_metadata.block_tables is not None - assert attn_metadata.context_lens is not None - # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire - # block table in SMEM. Therefore, if the block table is too large, - # the kernel compilation will fail. To avoid this, we split the - # batch dimension into smaller chunks and run the kernel multiple - # times. - MAX_SMEM_USAGE = 512 * 1024 - size_per_seq = 4 * attn_metadata.block_tables.shape[1] - max_num_seq = MAX_SMEM_USAGE // size_per_seq - - if batch_size <= max_num_seq: - output = paged_attention( - query, - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - pages_per_compute_block, - self.megacore_mode, - ) - else: - chunk_size = max_num_seq - # Make sure the chunk size is a multiple of 2. - chunk_size = chunk_size // 2 * 2 - num_chunks = (batch_size + chunk_size - 1) // chunk_size - - output = torch.empty_like(query) - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * chunk_size - chunk_end = chunk_start + chunk_size - # NOTE(woosuk): We skip this line because it causes Dynamo - # compilation error. Instead, we rely on the slice operation - # to handle the out-of-bound case. - # chunk_end = min(chunk_end, batch_size) - chunk_output = paged_attention( - query[chunk_start:chunk_end], - key_cache, - value_cache, - attn_metadata.context_lens[chunk_start:chunk_end], - attn_metadata.block_tables[chunk_start:chunk_end], - pages_per_compute_block, - self.megacore_mode, - ) - output[chunk_start:chunk_end] = chunk_output + output = torch.ops.xla.ragged_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.query_start_loc, + attn_metadata.num_seqs, + num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, + num_queries_per_block=NUM_QUERIES_PER_BLOCK, + use_kernel=False, + ) - # Reshape the output tensor. - return output.reshape(batch_size, seq_len, hidden_size) + return output.reshape(num_tokens, hidden_size) def write_to_kv_cache( @@ -302,52 +177,21 @@ def write_to_kv_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: + """ Write the key and values to the KV cache. + + Args: + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + k_cache = [num_kv_heads, num_blocks, block_size, head_size] + v_cache = [num_kv_heads, num_blocks, block_size, head_size] + + """ torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) - key = key.flatten(0, 2) - value = value.flatten(0, 2) + key = key.flatten(0, 1) + value = value.flatten(0, 1) key_cache = key_cache.flatten(0, 2) value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - pages_per_compute_block: int, - megacore_mode: Optional[str], -) -> torch.Tensor: - batch_size = query.shape[0] - if megacore_mode == "batch" and batch_size % 2 != 0: - megacore_mode = None - else: - megacore_mode = megacore_mode - - # NOTE(woosuk): A temporary workaround to avoid the error: - # "xla::paged_attention() Expected a value of type 'str' for - # argument 'megacore_mode' but instead found type 'NoneType'." - if megacore_mode is not None: - output = torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - megacore_mode=megacore_mode, - ) - else: - output = torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - ) - return output diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 0c8eca38ade7..f461d52cc984 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -79,4 +79,4 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs] # [prompt_len] - prompt_logprobs_dict: Dict[str, LogprobsTensors] + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2730e6770dc3..e255becbefbf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1071,12 +1071,12 @@ def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, scheduler_output: "SchedulerOutput", - ) -> Dict[str, LogprobsTensors]: + ) -> Dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} - prompt_logprobs_dict: Dict[str, LogprobsTensors] = {} + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} # Since prompt logprobs are a rare feature, prioritize simple, # maintainable loop over optimal performance. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b68c1ac9d71b..ae9fadb2944b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from copy import deepcopy -import enum import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from unittest.mock import patch import numpy as np @@ -22,7 +19,9 @@ from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, +from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, + NUM_QUERIES_PER_BLOCK, + PallasAttentionBackend, PallasMetadata) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) @@ -39,36 +38,7 @@ # Here we utilize the behavior that out-of-bound index is ignored. # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 - - -class ExecutionMode(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - PREFIX_PREFILL = enum.auto() - - def is_prefill(self) -> bool: - return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) - - -@dataclass -class PromptDecodeInfo: - prompt_req_ids: List[str] - decode_req_ids: List[str] - prompt_scheduled_tokens: List[int] - - -@dataclass -class PromptData: - input_tokens: torch.Tensor - input_positions: torch.Tensor - attn_metadata: PallasMetadata - - -@dataclass -class DecodeData: - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - attn_metadata: Optional[PallasMetadata] = None +INVALID_TOKEN_ID = -1 class TPUModelRunner(LoRAModelRunnerMixin): @@ -115,8 +85,6 @@ def __init__( self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() - self.model: Optional[nn.Module] = None - # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, @@ -136,50 +104,48 @@ def __init__( # KV caches for forward pass self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - # Cached torch/numpy tensors - self.num_swaps = 2 - self.cur_swap_id = 0 - self.input_ids_cpu = [] - self.input_ids_np = [] - self.input_positions_cpu = [] - self.input_positions_np = [] - self.slot_mapping_cpu = [] - self.slot_mapping_np = [] - self.prompt_context_lens_cpu = [] - self.prompt_effective_query_lens_cpu = [] - self.decode_context_lens_cpu = [] - self.decode_context_lens_np = [] - for _ in range(self.num_swaps): - self.input_ids_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.input_ids_np.append(self.input_ids_cpu[-1].numpy()) - - self.input_positions_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.input_positions_np.append( - self.input_positions_cpu[-1].numpy()) - - self.slot_mapping_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int64, - device="cpu")) - self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy()) - - self.prompt_context_lens_cpu.append( - torch.empty((1), dtype=torch.int32, device="cpu")) - self.prompt_effective_query_lens_cpu.append( - torch.empty((1), dtype=torch.int32, device="cpu")) - - self.decode_context_lens_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.decode_context_lens_np.append( - self.decode_context_lens_cpu[-1].numpy()) + # Cached torch/numpy tensor + # The pytorch tensor and numpy array share the same buffer. + # Sometimes the numpy op is faster so we create both. + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.input_ids_np = self.input_ids_cpu.numpy() + + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.positions_np = self.positions_cpu.numpy() + + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu") + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + + # self.input_batch.block_table has a shape of [max_num_reqs, + # max_num_blocks_per_req]. To reduce the number of recompilation, + # we want the block_table.shape[0] to be num_tokens. + # To make the block_table to be compatible with the paged attention + # kernel, we want the block_table[1] to be multiple of + # NUM_KV_PAGES_PER_BLOCK. + padded_max_num_blocks_per_req = _get_padded_number( + self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) + self.block_table_cpu = torch.zeros( + (self.max_num_tokens, padded_max_num_blocks_per_req), + dtype=self.input_batch.block_table.get_cpu_tensor().dtype, + device="cpu") + + self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + + self.seq_lens_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_lens_np = self.seq_lens_cpu.numpy() # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens @@ -193,7 +159,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: the input GPU tensors for the model. Returns: - True if there is a new/resumed/paused/finished request in the batch. + True if there is a new/resumed/paused/finished request. If False, we can skip copying SamplingMetadata to the GPU. """ # Remove finished requests from the cached states. @@ -305,9 +271,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.condense(removed_req_indices) return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 - def swap_step(self): - self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps - def get_model(self) -> nn.Module: assert self.model is not None return self.model @@ -347,258 +310,124 @@ def get_kv_cache_spec(self) -> KVCacheSpec: return kv_cache_spec - def _get_prompts_and_decodes( - self, - scheduler_output: "SchedulerOutput", - ) -> PromptDecodeInfo: + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Traverse decodes first - decode_req_ids = [] - for i in range(num_reqs): - req_id = self.input_batch.req_ids[i] - assert req_id is not None - - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] - num_prompt_tokens = self.input_batch.num_prompt_tokens[i] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - - if num_computed_tokens < num_prompt_tokens: - # This is prompt - break - - # This is decode - assert num_scheduled_tokens == 1 - decode_req_ids.append(req_id) - - # Traverse prompts - prompt_req_ids = [] - prompt_scheduled_tokens = [] - for i in range(len(decode_req_ids), num_reqs): - req_id = self.input_batch.req_ids[i] + # Get the number of scheduled tokens for each request. + num_scheduled_tokens_per_req = [] + max_num_scheduled_tokens_all_reqs = 0 + for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None - - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] - num_prompt_tokens = self.input_batch.num_prompt_tokens[i] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - - # Must be prompt - assert num_computed_tokens < num_prompt_tokens - - prompt_req_ids.append(req_id) - prompt_scheduled_tokens.append(num_scheduled_tokens) - - return PromptDecodeInfo(prompt_req_ids, decode_req_ids, - prompt_scheduled_tokens) - - def _get_input_batch_subset(self, req_idxs: List[int]) -> InputBatch: - req_idxs = set(req_idxs) - all_req_idxs = set(self.input_batch.req_id_to_index.values()) - - req_idxs_to_remove = all_req_idxs.difference(req_idxs) - - subset_batch = deepcopy(self.input_batch) - subset_batch.condense(list(req_idxs_to_remove)) - return subset_batch - - - def _prepare_prompt(self, req_index: int, - num_scheduled_tokens: int) -> PromptData: - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ - req_index] - num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index] - - # Must be prompt - assert num_computed_tokens < num_prompt_tokens - - # Prompt len - prompt_len = num_scheduled_tokens - padded_prompt_len = _get_padded_prompt_len(prompt_len) - assert padded_prompt_len <= self.max_model_len - - # Seq len - seq_len = num_computed_tokens + prompt_len - padded_seq_len = num_computed_tokens + padded_prompt_len - - # Input tokens - input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ - req_index, num_computed_tokens:padded_seq_len] - input_tokens_cpu[prompt_len:] = 0 - - # Input positions - input_positions_np = self.input_positions_np[ - self.cur_swap_id][:padded_prompt_len] - np.add(num_computed_tokens, - self.arange_np[:padded_prompt_len], - out=input_positions_np) - input_positions_np[prompt_len:] = 0 - - # Slot mapping - block_table_np = \ - self.input_batch.block_table.get_numpy_array() - block_numbers_np = block_table_np[req_index, input_positions_np // - self.block_size] - block_offsets_np = input_positions_np % self.block_size - - slot_mapping_np = self.slot_mapping_np[ - self.cur_swap_id][:padded_prompt_len] - np.add(block_numbers_np * self.block_size, - block_offsets_np, - out=slot_mapping_np) - slot_mapping_np[prompt_len:] = _PAD_SLOT_ID - - # Block table - block_table_cpu = None - if num_computed_tokens > 0: - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_table_cpu = block_table_cpu[req_index] - - # Context len - self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0 - if num_computed_tokens > 0: - self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len - - # Effective query len - self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len - - # Get final tensors - input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) - input_positions = self.input_positions_cpu[ - self.cur_swap_id][:padded_prompt_len].reshape(1, - -1).to(self.device) - slot_mapping = self.slot_mapping_cpu[ - self.cur_swap_id][:padded_prompt_len].reshape(1, - -1).to(self.device) - block_table = block_table_cpu.reshape(1, -1).to( - self.device) if block_table_cpu is not None else None - - context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to( - self.device) - effective_query_lens = self.prompt_effective_query_lens_cpu[ - self.cur_swap_id].to(self.device) - - if self.lora_config: - prompt_input_batch = self._get_input_batch_subset(req_idxs=[req_index]) - self.set_active_loras(prompt_input_batch, np.array([padded_prompt_len], dtype=np.int32)) - - self.swap_step() - - # Attn metadata - attn_metadata = PallasMetadata( - num_prefills=1, - num_prefill_tokens=0, # NOTE: This is not used. - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_table, - context_lens=context_lens, - effective_query_lens=effective_query_lens, - ) - - return PromptData(input_tokens, input_positions, attn_metadata) - - def _prepare_decode( - self, - decode_req_ids: List[str], - ) -> DecodeData: - # Batch size - batch_size = len(decode_req_ids) - padded_batch_size = _get_padded_batch_size(batch_size) - assert padded_batch_size <= self.max_model_len - - # Init [0 .. batch_size - 1] - req_indices_np = self.arange_np[:padded_batch_size] - - # Input positions - input_positions_np = self.input_positions_np[ - self.cur_swap_id][:padded_batch_size] - np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], - 0, - out=input_positions_np) - input_positions_np[batch_size:] = 0 - input_positions_cpu = self.input_positions_cpu[ - self.cur_swap_id][:padded_batch_size] - - # Input tokens - token_indices_np = ( - input_positions_np + - req_indices_np * self.input_batch.token_ids_cpu.shape[1]) - input_tokens_cpu = self.input_ids_cpu[ - self.cur_swap_id][:padded_batch_size] + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens_per_req.append(num_tokens) + max_num_scheduled_tokens_all_reqs = max( + max_num_scheduled_tokens_all_reqs, num_tokens) + num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, + dtype=np.int32) + assert max_num_scheduled_tokens_all_reqs > 0 + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + # For each scheduled token, what are the corresponding req index. + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens_per_req) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # For each scheduled token, what is its position in corresponding req. + arange = np.concatenate( + [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) + + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, - torch.from_numpy(token_indices_np), - out=input_tokens_cpu) - input_tokens_cpu[batch_size:] = 0 - - # Slot mapping - block_table_indices_np = ( - req_indices_np * self.max_num_blocks_per_req + - input_positions_np // self.block_size) - + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens_per_req, + out=self.query_start_loc_np[1:num_reqs + 1]) + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req) + + # Do the padding and copy the tensors to the TPU. + padded_total_num_scheduled_tokens = _get_padded_number( + total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK) + self.input_ids = self.input_ids_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.position_ids = self.positions_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID + slot_mapping = self.slot_mapping_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + padded_block_table = self.block_table_cpu[: + padded_total_num_scheduled_tokens] + padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = ( + self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) + padded_block_table = padded_block_table.to(self.device) + query_start_loc = self.query_start_loc_cpu[: + padded_total_num_scheduled_tokens + + 1].to(self.device) + seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to( + self.device) - block_numbers_np = block_table_cpu.flatten( - )[block_table_indices_np].numpy() - - block_offsets_np = input_positions_np % self.block_size - - slot_mapping_np = self.slot_mapping_np[ - self.cur_swap_id][:padded_batch_size] - np.add(block_numbers_np * self.block_size, - block_offsets_np, - out=slot_mapping_np) - slot_mapping_np[batch_size:] = _PAD_SLOT_ID - - block_table_cpu = block_table_cpu[:padded_batch_size] - - # Context lens - context_lens_np = self.decode_context_lens_np[ - self.cur_swap_id][:padded_batch_size] - np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], - 1, - out=context_lens_np) - context_lens_np[batch_size:] = 0 - - # Get final tensors - input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device) - input_positions = input_positions_cpu.reshape(-1, 1).to(self.device) - slot_mapping = self.slot_mapping_cpu[ - self.cur_swap_id][:padded_batch_size].reshape(-1, - 1).to(self.device) - block_table = block_table_cpu.to(self.device) - context_lens = self.decode_context_lens_cpu[ - self.cur_swap_id][:padded_batch_size].to(self.device) - - if self.lora_config: - req_idxs = list(map(self.input_batch.req_id_to_index.get, decode_req_ids)) - decode_input_batch = self._get_input_batch_subset(req_idxs) - self.set_active_loras(decode_input_batch, np.array([padded_batch_size], dtype=np.int32)) - - self.swap_step() - - # Attn metadata attn_metadata = PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=padded_batch_size, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_table, - context_lens=context_lens, - effective_query_lens=None, + block_tables=padded_block_table, + context_lens=seq_lens, + query_start_loc=query_start_loc, + num_seqs=num_reqs, ) - - return DecodeData(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + return attn_metadata, logits_indices @torch.no_grad() def execute_model( @@ -608,118 +437,81 @@ def execute_model( # Update cached state self._update_states(scheduler_output) - # If necessary, swap decodes/prompts to have all decodes on the start - ensure_decodes_first(self.input_batch) - - # Prepare prompts/decodes info - pd_info = self._get_prompts_and_decodes(scheduler_output) - - # Init - num_prompts = len(pd_info.prompt_req_ids) - num_decodes = len(pd_info.decode_req_ids) - decode_data = None - sampled_token_ids = [0] * self.input_batch.num_reqs - - # Run each prompt individually - is_first = True - for i in range(num_prompts): - req_id = pd_info.prompt_req_ids[i] - req_index = num_decodes + i - assert req_index == self.input_batch.req_id_to_index[ - req_id] # TODO: Remove - req_state = self.requests[req_id] - num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] - prompt_len = num_scheduled_tokens - seq_len = req_state.num_computed_tokens + num_scheduled_tokens - - # Prepare first prompt - if is_first: - prompt_data = self._prepare_prompt(req_index, - num_scheduled_tokens) - is_first = False - - # Run forward pass - with set_forward_context(prompt_data.attn_metadata, - self.vllm_config): - assert self.model is not None - selected_token_ids = self.model(prompt_data.input_tokens, - prompt_data.input_positions, - self.kv_caches) - - # In parallel to TPU execution, prepare the next iteration - if i < num_prompts - 1: - # There is next prompt => prepare it - prompt_data = self._prepare_prompt( - req_index + 1, pd_info.prompt_scheduled_tokens[i + 1]) - elif i == num_prompts - 1 and num_decodes > 0: - # There is next decode => prepare it - decode_data = self._prepare_decode(pd_info.decode_req_ids) - - # Update cached state (if prompt is fully done) - if seq_len >= len(req_state.prompt_token_ids): - # Transfer sampled tokens from TPU to CPU - selected_token_ids_cpu = selected_token_ids.cpu() - - # Get output token - token_id = selected_token_ids_cpu[prompt_len - 1].item() - sampled_token_ids[req_index] = token_id - - # Add output token to the request - self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - self.input_batch.num_tokens[req_index] += 1 - req_state.output_token_ids.append(token_id) + # Prepare inputs + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - # Run decodes (a single batch) - if num_decodes > 0: - - # Prepare decode (if was not yet prepared) - if decode_data is None: - decode_data = self._prepare_decode(pd_info.decode_req_ids) - - # Run forward pass - with set_forward_context(decode_data.attn_metadata, - self.vllm_config): - assert self.model is not None - selected_token_ids = self.model(decode_data.input_tokens, - decode_data.input_positions, - self.kv_caches) - - # Transfer sampled tokens from TPU to CPU - decode_token_ids_cpu = selected_token_ids.cpu() - # Convert to list - decode_token_ids_list = decode_token_ids_cpu.tolist() - - # Update cached state for each decode request - for i in range(num_decodes): - req_id = pd_info.decode_req_ids[i] - req_index = i - assert req_index == self.input_batch.req_id_to_index[ - req_id] # TODO: Remove - req_state = self.requests[req_id] - seq_len = req_state.num_computed_tokens + 1 - - token_id = decode_token_ids_list[i] - sampled_token_ids[req_index] = token_id - - self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - self.input_batch.num_tokens[req_index] += 1 - req_state.output_token_ids.append(token_id) + # Run the decoder + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + token_ids=self.input_ids, + position_ids=self.position_ids, + kv_caches=self.kv_caches, + ) + hidden_states = hidden_states[:total_num_scheduled_tokens] + num_reqs = self.input_batch.num_reqs + logits_indices = logits_indices[:num_reqs] + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, None) + selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + + # Then, let's update the cache state. + request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): + assert req_id is not None + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len >= req_state.num_tokens: + request_seq_lens.append((i, req_state, seq_len)) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) + + # num_reqs entries should be non-None + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - # Create output. - all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} - for req_id in all_req_ids: + for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None + max_gen_len = selected_token_ids.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = selected_token_ids.tolist() + for i, req_state, seq_len in request_seq_lens: + token_id = valid_sampled_token_ids[i][0] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + self.input_batch.num_tokens[i] += 1 + else: + valid_mask = selected_token_ids != INVALID_TOKEN_ID + gen_lens = valid_mask.sum(dim=1).tolist() + valid_sampled_token_ids = [ + seq.tolist() + for seq in selected_token_ids[valid_mask].split(gen_lens) + ] + self.input_batch.num_tokens[:num_reqs] += gen_lens + for i, req_state, seq_len in request_seq_lens: + target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) + self.input_batch.token_ids_cpu[ + i, target_slice] = valid_sampled_token_ids[i] + req_state.output_token_ids.extend(valid_sampled_token_ids[i]) + model_runner_output = ModelRunnerOutput( - req_ids=all_req_ids, + req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[[token_id] for token_id in sampled_token_ids], + sampled_token_ids=valid_sampled_token_ids, spec_token_ids=None, logprobs=None, - prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] + prompt_logprobs_dict=prompt_logprobs_dict, ) - return model_runner_output def load_model(self) -> None: @@ -759,194 +551,63 @@ def dummy_run( self, kv_caches, num_tokens: int, - seq_len: Optional[int] = None, - exec_mode: Optional[ExecutionMode] = None, ) -> None: - assert seq_len is not None - assert exec_mode is not None - - exec_mode = ExecutionMode(exec_mode) - if exec_mode.is_prefill(): - seq_len = (seq_len + 15) // 16 * 16 - token_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((num_tokens, seq_len), - dtype=torch.int64, - device=self.device) - if exec_mode == ExecutionMode.PREFILL: - attn_metadata = PallasMetadata( - num_prefills=num_tokens, - num_prefill_tokens=num_tokens * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=None, - context_lens=None, - effective_query_lens=None, - ) - - else: - context_lens = torch.ones((num_tokens, ), - dtype=torch.int32, - device=self.device) - - block_tables = torch.zeros( - (num_tokens, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - - effective_query_lens = torch.ones_like(context_lens) - - attn_metadata = PallasMetadata( - num_prefills=num_tokens, - num_prefill_tokens=num_tokens * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_tables, - context_lens=context_lens, - effective_query_lens=effective_query_lens, - ) - else: - assert seq_len == 1 - token_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((num_tokens, seq_len), - dtype=torch.int64, - device=self.device) - block_tables = torch.zeros( - (num_tokens, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - context_lens = torch.ones((num_tokens, ), - dtype=torch.int32, - device=self.device) - attn_metadata = PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=num_tokens * seq_len, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_tables, - context_lens=context_lens, - ) + input_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros(num_tokens, + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]), + dtype=torch.int32, + device=self.device) + query_lens = [1] * num_tokens + query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, + dtype=torch.int32), + dim=0, + dtype=torch.int32).to(self.device) + context_lens = torch.ones((num_tokens, ), + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + query_start_loc=query_start_loc, + num_seqs=num_tokens, + ) - # NOTE(woosuk): There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if exec_mode.is_prefill(): - # Prefll - if self.lora_config is not None: # TODO: Remove this condition - torch._dynamo.config.capture_dynamic_output_shape_ops = True - else: - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - # Decode - if self.lora_config is not None: # TODO: Remove this condition - torch._dynamo.config.capture_dynamic_output_shape_ops = True - else: - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model(token_ids, position_ids, kv_caches) + self.model(input_ids, position_ids, kv_caches) def capture_model(self) -> None: """Compile the model.""" - # Prefill - logger.info( - "Compiling the model with different input shapes for prefill:") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while seq_len <= self.model_config.max_model_len: - with self.maybe_profile_with_lora(self.lora_config, np.array([seq_len] * batch_size, dtype=np.int32)): - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFILL) - xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, - seq_len) - num_tokens = batch_size * seq_len - if num_tokens >= self.scheduler_config.max_num_batched_tokens: - break - seq_len = seq_len * 2 - - end = time.time() - logger.info(" -- Compilation for prefill done in %.2f [secs].", - end - start) - - # Prefix prefill - if self.scheduler_config.enable_chunked_prefill: - logger.info("Compiling the model with different input shapes for " - "prefix prefill:") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while seq_len <= self.model_config.max_model_len: - with self.maybe_profile_with_lora(self.lora_config, np.array([seq_len] * batch_size, dtype=np.int32)): - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFIX_PREFILL) - xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, - seq_len) - num_tokens = batch_size * seq_len - if (num_tokens - >= self.scheduler_config.max_num_batched_tokens): - break - seq_len = seq_len * 2 - end = time.time() - logger.info( - " -- Compilation for prefix prefill done in %.2f [secs].", - end - start) - - # Decode - logger.info( - "Compiling the model with different input shapes for decode:") - start = time.time() - seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() + logger.info("Compiling the model with different input shapes.") + + start = time.perf_counter() + num_tokens = 16 while True: - with self.maybe_profile_with_lora(self.lora_config, np.array([seq_len] * batch_size, dtype=np.int32)): - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.DECODE) - xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) - - if batch_size >= self.scheduler_config.max_num_seqs: + self.dummy_run(self.kv_caches, num_tokens) + logger.info(" -- num_tokens: %d", num_tokens) + xm.mark_step() + xm.wait_device_ops() + if num_tokens >= self.scheduler_config.max_num_batched_tokens: break - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 - - end = time.time() - logger.info(" -- Compilation for decode done in %.2f [secs].", - end - start) + num_tokens *= 2 + end = time.perf_counter() + logger.info("Compilation finished in in %.2f [secs].", end - start) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -1002,12 +663,8 @@ def forward( """Executes the forward pass of the model and samples the next token. Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - input_lens: The actual input lengths of shape [batch_size]. - t: The sampling temperature of shape [batch_size]. - p: The top-p probability of shape [batch_size]. - num_samples: Number of samples to draw from each logits vector. + token_ids: The input token IDs of shape [num_tokens]. + position_ids: The input position IDs of shape [num_tokens]. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. """ @@ -1019,6 +676,7 @@ def forward( # [num_kv_heads, num_blocks, block_size, head_size]. To make it # work, we need to flatten the first three dimensions and modify # the slot_mapping accordingly. + # kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape slot_mapping = attn_metadata.slot_mapping slot_mapping = slot_mapping.flatten() @@ -1034,103 +692,22 @@ def forward( attn_metadata.slot_mapping = slot_mapping assert self.model is not None - hidden_states = self.model(token_ids, position_ids) - - hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, None) - - # Greedy sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - return argmax_token_ids - - -def swap_positions(b: InputBatch, id_1, id_2): - assert id_1 != id_2 - req_id_1 = b.req_ids[id_1] - req_id_2 = b.req_ids[id_2] - assert req_id_1 is not None - assert req_id_2 is not None - assert id_1 == b.req_id_to_index[req_id_1] - assert id_2 == b.req_id_to_index[req_id_2] - - b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1] - b.req_id_to_index[req_id_1], b.req_id_to_index[ - req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1] - - ids = [id_1, id_2] - rev_ids = [id_2, id_1] - b.num_tokens[ids] = b.num_tokens[rev_ids] - b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids] - b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids] - b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids] - - b.block_table.swap_row(id_1, id_2) - - b.temperature_cpu[ids] = b.temperature_cpu[rev_ids] - b.top_p_cpu[ids] = b.top_p_cpu[rev_ids] - b.top_k_cpu[ids] = b.top_k_cpu[rev_ids] - b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids] - b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids] - b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids] - - b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ - id_1] - - gen_1 = b.generators.pop(id_1, None) - gen_2 = b.generators.pop(id_2, None) - if gen_1 is not None: - b.generators[id_2] = gen_1 - if gen_2 is not None: - b.generators[id_1] = gen_2 - - -def ensure_decodes_first(b: InputBatch): - num_reqs = b.num_reqs - while True: - # Find the first prompt index - first_prompt_index = None - for i in range(num_reqs): - if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]: - first_prompt_index = i - break - if first_prompt_index is None: - break - - # Find the last decode index - last_decode_index = None - for i in reversed(range(num_reqs)): - if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]: - last_decode_index = i - break - if last_decode_index is None: - break - - # Sanity - assert first_prompt_index != last_decode_index - - # Check if done - if first_prompt_index > last_decode_index: - break - - # Swap - swap_positions(b, first_prompt_index, last_decode_index) + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + ) + return hidden_states -def _get_padded_prompt_len(x: int) -> int: - # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence - # length to be a multiple of 16. We pad the prompt length to the nearest - # multiple of 16. This is also good for performance. - if x <= 16: - return 16 - return 1 << (x - 1).bit_length() + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata, + ) -> Optional[torch.Tensor]: + logits = self.model.compute_logits(hidden_states, sampling_metadata) + return logits -def _get_padded_batch_size(batch_size: int) -> int: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8. - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 +def _get_padded_number(n: int, multiple: int) -> int: + return ((n + multiple - 1) // multiple) * multiple diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ae124c819a90..5dd021890d9d 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -22,7 +22,7 @@ KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner +from vllm.v1.worker.tpu_model_runner import TPUModelRunner logger = init_logger(__name__) @@ -127,9 +127,7 @@ def determine_available_memory(self) -> int: self.model_runner.dummy_run( runner_kv_caches, - num_tokens=1, - seq_len=self.scheduler_config.max_num_batched_tokens, - exec_mode=ExecutionMode.PREFILL, + num_tokens=self.scheduler_config.max_num_batched_tokens, ) # Synchronize before measuring the memory usage. From 25a1bdf02b30d4cde46b70db52f596323df7ddae Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 1 Mar 2025 03:03:16 +0800 Subject: [PATCH 293/317] [v1] Cleanup the BlockTable in InputBatch (#13977) Signed-off-by: Chen Zhang --- tests/v1/worker/test_gpu_model_runner.py | 14 ++++++++++++++ vllm/v1/worker/block_table.py | 13 ++++++------- vllm/v1/worker/gpu_input_batch.py | 3 +-- vllm/v1/worker/gpu_model_runner.py | 6 ++---- vllm/v1/worker/tpu_model_runner.py | 6 ++---- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 973efcbf8e50..ff4058a3b923 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner, sampling_metadata_before) +def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: + req_index = model_runner.input_batch.req_id_to_index[req_id] + block_table = model_runner.input_batch.block_table + req_state = model_runner.requests[req_id] + if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): + return False + num_blocks = block_table.num_blocks_per_row[req_index] + return (block_table.block_table_np[req_index, :num_blocks] == + req_state.block_ids).all() + + def test_update_states_new_request(model_runner): req_id = "req_0" @@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner): assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_request_finished(model_runner): @@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner): assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_no_changes(model_runner): @@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner): assert not _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_request_unscheduled(model_runner): diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 669175f5d9c3..830cca104ddb 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -15,13 +15,11 @@ class BlockTable: def __init__( self, max_num_reqs: int, - max_model_len: int, max_num_blocks_per_req: int, pin_memory: bool, device: torch.device, ): self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req self.pin_memory = pin_memory self.device = device @@ -42,18 +40,19 @@ def __init__( def append_row( self, - row_idx: int, - start: int, block_ids: List[int], + row_idx: int, ) -> None: if not block_ids: return num_blocks = len(block_ids) + start = self.num_blocks_per_row[row_idx] + self.num_blocks_per_row[row_idx] += num_blocks self.block_table_np[row_idx, start:start + num_blocks] = block_ids - self.num_blocks_per_row[row_idx] = start + num_blocks - def add_row(self, row_idx: int, block_ids: List[int]) -> None: - self.append_row(row_idx, 0, block_ids) + def add_row(self, block_ids: List[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1b6ea559a7b7..788a35221fe4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -92,7 +92,6 @@ def __init__( # Block table. self.block_table = BlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len, max_num_blocks_per_req=max_num_blocks_per_req, pin_memory=pin_memory, device=device, @@ -249,7 +248,7 @@ def add_request( self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(req_index, request.block_ids) + self.block_table.add_row(request.block_ids, req_index) sampling_params = request.sampling_params if sampling_params.sampling_type == SamplingType.GREEDY: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e255becbefbf..0215b2735384 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -399,10 +399,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - start_index = (len(req_state.block_ids) - - len(req_data.new_block_ids)) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(req_data.new_token_ids) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ae9fadb2944b..ffa5e21ede87 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -248,10 +248,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) - start_index = len(req_state.block_ids) - len( - req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. From 8212e036a001ab0ce9fc9e55721709901dfd4c99 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Fri, 28 Feb 2025 20:25:50 +0000 Subject: [PATCH 294/317] Add RELEASE.md (#13926) Signed-off-by: atalman --- RELEASE.md | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 000000000000..7f5270715212 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,54 @@ +# Releasing vLLM + +vLLM releases offer a reliable version of the code base, packaged into a binary format that can be conveniently accessed via PyPI. These releases also serve as key milestones for the development team to communicate with the community about newly available features, improvements, and upcoming changes that could affect users, including potential breaking changes. + +## Release Versioning + +vLLM uses a “right-shifted” versioning scheme where a new patch release is out every 2 weeks. And patch releases contain features and bug fixes (as opposed to semver where patch release contains only backwards-compatible bug fixes). When critical fixes need to be made, special release post1 is released. + +* _major_ major architectural milestone and when incompatible API changes are made, similar to PyTorch 2.0. +* _minor_ major features +* _patch_ features and backwards-compatible bug fixes +* _post1_ or _patch-1_ backwards-compatible bug fixes, either explicit or implicit post release + +## Release Cadence + +Patch release is released on bi-weekly basis. Post release 1-3 days after patch release and uses same branch as patch release. +Following is the release cadence for year 2025. All future release dates below are tentative. Please note: Post releases are optional. + +| Release Date | Patch release versions | Post Release versions | +| --- | --- | --- | +| Jan 2025 | 0.7.0 | --- | +| Feb 2025 | 0.7.1, 0.7.2, 0.7.3 | --- | +| Mar 2025 | 0.7.4, 0.7.5 | --- | +| Apr 2025 | 0.7.6, 0.7.7 | --- | +| May 2025 | 0.7.8, 0.7.9 | --- | +| Jun 2025 | 0.7.10, 0.7.11 | --- | +| Jul 2025 | 0.7.12, 0.7.13 | --- | +| Aug 2025 | 0.7.14, 0.7.15 | --- | +| Sep 2025 | 0.7.16, 0.7.17 | --- | +| Oct 2025 | 0.7.18, 0.7.19 | --- | +| Nov 2025 | 0.7.20, 0.7.21 | --- | +| Dec 2025 | 0.7.22, 0.7.23 | --- | + +## Release branch + +Each release is built from a dedicated release branch. + +* For _major_, _minor_, _patch_ releases, the release branch cut is performed 1-2 days before release is live. +* For post releases, previously cut release branch is reused +* Release builds are triggered via push to RC tag like vX.Y.Z-rc1 . This enables us to build and test multiple RCs for each release. +* Final tag : vX.Y.Z does not trigger the build but used for Release notes and assets. +* After branch cut is created we monitor the main branch for any reverts and apply these reverts to a release branch. + +## Release Cherry-Pick Criteria + +After branch cut, we approach finalizing the release branch with clear criteria on what cherry picks are allowed in. Note: a cherry pick is a process to land a PR in the release branch after branch cut. These are typically limited to ensure that the team has sufficient time to complete a thorough round of testing on a stable code base. + +* Regression fixes - that address functional/performance regression against the most recent release (e.g. 0.7.0 for 0.7.1 release) +* Critical fixes - critical fixes for severe issue such as silent incorrectness, backwards compatibility, crashes, deadlocks, (large) memory leaks +* Fixes to new features introduced in the most recent release (e.g. 0.7.0 for 0.7.1 release) +* Documentation improvements +* Release branch specific changes (e.g. change version identifiers or CI fixes) + +Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes. From 0532919e9fa60db6b143b7292cd7c93f32965fad Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 1 Mar 2025 04:53:31 +0800 Subject: [PATCH 295/317] [v1] Move block pool operations to a separate class (#13973) Signed-off-by: Chen Zhang Co-authored-by: Cody Yu --- tests/v1/core/test_prefix_caching.py | 89 +++++---- vllm/v1/core/block_pool.py | 285 +++++++++++++++++++++++++++ vllm/v1/core/kv_cache_manager.py | 263 +++--------------------- 3 files changed, 360 insertions(+), 277 deletions(-) create mode 100644 vllm/v1/core/block_pool.py diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d598d12571f1..8956393c0bfb 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 """Compare the with and without prefix caching.""" +from typing import List + import pytest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + hash_block_tokens) def make_request(request_id, @@ -62,14 +66,14 @@ def test_prefill(): for block_id in (0, 1, 2): block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) block_hash = hash_block_tokens(parent_block_hash, block_tokens) - assert manager.block_pool[block_id].block_hash == block_hash - assert manager.block_pool[block_id].ref_cnt == 1 + assert manager.block_pool.blocks[block_id].block_hash == block_hash + assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial/preallocated block metadata for block_id in (3, 4): - assert manager.block_pool[block_id].block_hash is None - assert manager.block_pool[block_id].ref_cnt == 1 + assert manager.block_pool.blocks[block_id].block_hash is None + assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) @@ -86,20 +90,21 @@ def test_prefill(): assert block.ref_cnt == 2 # At this point, we should have 3 free blocks left. - assert manager.free_block_queue.num_free_blocks == 3 + assert manager.block_pool.free_block_queue.num_free_blocks == 3 manager.free(req0) manager.free(req1) # All blocks should be available. - assert manager.free_block_queue.num_free_blocks == 10 + assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be # [unallocated (7, 8, 9)] # [unique_req0 (4, 3)] # [unique_req1 (6, 5)] # [common (2, 1, 0)] assert [ - b.block_id for b in manager.free_block_queue.get_all_free_blocks() + b.block_id + for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0] # Cache hit in the common prefix when the original block is already free. @@ -116,12 +121,14 @@ def test_prefill(): # Although we only have 5 free blocks, we have 8 blocks in # the free block queue due to lazy removal. - assert manager.free_block_queue.num_free_blocks == 5 + assert manager.block_pool.free_block_queue.num_free_blocks == 5 assert all([ - b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks() + b.ref_cnt == 0 + for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) - assert len([b - for b in manager.free_block_queue.get_all_free_blocks()]) == 5 + assert len([ + b for b in manager.block_pool.free_block_queue.get_all_free_blocks() + ]) == 5 manager.free(req2) @@ -133,9 +140,9 @@ def test_prefill(): blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) # This block ID order also checks the eviction order. assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] - assert manager.free_block_queue.num_free_blocks == 0 - assert manager.free_block_queue.free_list_head is None - assert manager.free_block_queue.free_list_tail is None + assert manager.block_pool.free_block_queue.num_free_blocks == 0 + assert manager.block_pool.free_block_queue.free_list_head is None + assert manager.block_pool.free_block_queue.free_list_tail is None def test_decode(): @@ -219,13 +226,14 @@ def test_evict(): assert len(blocks) == 3 # 3 full blocks last_token_id += 3 * 16 - assert manager.free_block_queue.num_free_blocks == 0 + assert manager.block_pool.free_block_queue.num_free_blocks == 0 manager.free(req0) manager.free(req1) - assert manager.free_block_queue.num_free_blocks == 10 + assert manager.block_pool.free_block_queue.num_free_blocks == 10 assert [ - b.block_id for b in manager.free_block_queue.get_all_free_blocks() + b.block_id + for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7] # Touch the first 2 blocks. @@ -235,7 +243,7 @@ def test_evict(): assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) assert [b.block_id for b in blocks] == [6, 5] - assert manager.free_block_queue.num_free_blocks == 6 + assert manager.block_pool.free_block_queue.num_free_blocks == 6 def test_hash_block_correct_reuse(): @@ -274,7 +282,7 @@ def test_hash_block_correct_reuse(): blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) assert len(blocks) == 1 - assert manager.block_pool[blocks[0].block_id].block_hash is None + assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -413,13 +421,9 @@ def test_cache_blocks(): function of KVCacheManager. """ block_size = 4 - manager = KVCacheManager( - block_size=block_size, + block_pool = BlockPool( num_gpu_blocks=5, - max_model_len=8192, - sliding_window=None, enable_caching=True, - num_preallocate_tokens=0, ) # Req: # Block 0: [0, 1, 2, 3] @@ -430,26 +434,31 @@ def test_cache_blocks(): # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] + block_hashes: List[BlockHashType] = [] - manager._cache_full_blocks( + block_pool.cache_full_blocks( request=req, - blk_start_idx=0, - full_blocks=blocks, - prev_block=None, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=0, + num_full_blocks=2, + block_size=block_size, ) - assert len(manager.cached_block_hash_to_block) == 2 + assert len(block_pool.cached_block_hash_to_block) == 2 assert all([block.block_hash is not None for block in blocks]) # Test that blocks that don't start from the beginning are cached correctly. - blocks = [KVCacheBlock(block_id=2)] - manager._cache_full_blocks( + blocks += [KVCacheBlock(block_id=2)] + block_pool.cache_full_blocks( request=req, - blk_start_idx=2, - full_blocks=blocks, - prev_block=None, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=2, + num_full_blocks=3, + block_size=block_size, ) - assert len(manager.cached_block_hash_to_block) == 3 + assert len(block_pool.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None @@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. - assert manager.free_block_queue.num_free_blocks == 5 + assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks == block_part1 @@ -621,12 +630,12 @@ def test_reset_prefix_cache(): # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() - assert manager.cached_block_hash_to_block + assert manager.block_pool.cached_block_hash_to_block # Free the blocks. manager.free(req0) manager.free(req1) assert manager.reset_prefix_cache() - assert not manager.cached_block_hash_to_block - assert all([blk.block_hash is None for blk in manager.block_pool]) + assert not manager.block_pool.cached_block_hash_to_block + assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py new file mode 100644 index 000000000000..5ef495c7eed8 --- /dev/null +++ b/vllm/v1/core/block_pool.py @@ -0,0 +1,285 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections import defaultdict +from typing import Dict, Iterable, List, Optional + +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, + KVCacheBlock, + generate_block_hash_extra_keys, + hash_block_tokens) +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class BlockPool: + """BlockPool that manages KVCacheBlocks. + It provides methods to allocate, free and cache the kv cache blocks. The + free_block_queue stores the free blocks in eviction order to enable + allocation, free, and cache eviction. The cached_block_hash_to_block + maps between block hash and cached block to support finding cached blocks + by their block hash. + + Args: + num_gpu_blocks: The number of blocks in the pool. + enable_caching: Whether to enable prefix caching. + """ + + def __init__(self, num_gpu_blocks: int, enable_caching: bool): + self.num_gpu_blocks = num_gpu_blocks + self.enable_caching = enable_caching + # All kv-cache blocks. + self.blocks: List[KVCacheBlock] = [ + KVCacheBlock(idx) for idx in range(num_gpu_blocks) + ] + # Free block queue that constructs and manipulates a doubly linked + # list of free blocks (including eviction candidates when caching is + # enabled). + self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) + + # {block_hash: {block ID: block}}. A cached block is + # a full block with a block hash that can be used for prefix caching. + # The cached block may be used by running requests or in the + # free_block_queue that could potentially be evicted. + # NOTE: We currently don't de-duplicate the blocks in the cache, + # meaning that if a block becomes full and is cached, we don't check + # if there is already an identical block in the cache. This is because + # we want to make sure the allocated block IDs won't change so that + # block tables are append-only. + self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ + int, KVCacheBlock]] = defaultdict(dict) + + def get_cached_block(self, + block_hash: BlockHashType) -> Optional[KVCacheBlock]: + """Get a cached block by the block hash, or None if cache miss. + If there are duplicated blocks, we return the first block in the cache. + + Args: + block_hash: The hash value of the block. + + Returns: + The cached block if it exists, or None. + """ + if block_hash in self.cached_block_hash_to_block: + first_block_id = list( + self.cached_block_hash_to_block[block_hash].keys())[0] + return self.cached_block_hash_to_block[block_hash][first_block_id] + return None + + def cache_full_blocks( + self, + request: Request, + blocks: List[KVCacheBlock], + block_hashes: List[BlockHashType], + num_cached_blocks: int, + num_full_blocks: int, + block_size: int, + ) -> None: + """Cache a list of full blocks for prefix caching. + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `num_cached_blocks` to + `num_full_blocks`, updating the metadata for each block + and caching them in the `cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blocks: All blocks in the request. + block_hashes: Block hashes of the blocks in the request. Note that + this list may be shorter than the blocks list. In this case the + missed block hash will be computed in this function. + num_cached_blocks: The number of blocks that are already cached. + num_full_blocks: The number of blocks that are full and should + be cached after this function. + block_size: Number of tokens in each block. + """ + if num_cached_blocks == num_full_blocks: + return + new_full_blocks = blocks[num_cached_blocks:num_full_blocks] + assert len(block_hashes) >= num_cached_blocks + new_block_hashes = block_hashes[num_cached_blocks:] + + # Update the new blocks with the block hashes through the chain. + if num_cached_blocks == 0: + prev_block_hash_value = None + else: + prev_block = blocks[num_cached_blocks - 1] + assert prev_block.block_hash is not None + prev_block_hash_value = prev_block.block_hash.hash_value + + # Find the first uncached block. + # FIXME: num_cached_blocks should be corrected by the caller + # so this should never happen. + offset = 0 + for blk in new_full_blocks: + if blk.block_hash is None: + break + else: + prev_block_hash_value = blk.block_hash.hash_value + offset += 1 + else: + # All blocks are cached. + return + + for i, blk in enumerate(new_full_blocks[offset:]): + blk_idx = num_cached_blocks + offset + i + assert blk.block_hash is None + + if i + offset < len(new_block_hashes): + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = new_block_hashes[i + offset] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * block_size + end_token_idx = (blk_idx + 1) * block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == block_size, ( + f"Expected {block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, extra_keys) + block_hashes.append(block_hash) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash_value = block_hash.hash_value + + def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + """Get new blocks from the free block pool. + + Note that we do not check block cache in this function. + + Args: + num_blocks: The number of blocks to allocate. + + Returns: + A list of new block. + """ + if num_blocks > self.get_num_free_blocks(): + raise ValueError( + f"Cannot get {num_blocks} free blocks from the pool") + + ret: List[KVCacheBlock] = [] + idx = 0 + while idx < num_blocks: + # First allocate blocks. + curr_block = self.free_block_queue.popleft() + assert curr_block.ref_cnt == 0 + + # If the block is cached, evict it. + if self.enable_caching: + self._maybe_evict_cached_block(curr_block) + + curr_block.incr_ref() + ret.append(curr_block) + idx += 1 + + return ret + + def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: + """ + If a block is cached in `cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. + + Args: + block: The block to evict. + + Returns: + True if the block is evicted, False otherwise. + """ + block_hash = block.block_hash + if block_hash and block_hash in self.cached_block_hash_to_block: + block.reset_hash() + del self.cached_block_hash_to_block[block_hash][block.block_id] + + if len(self.cached_block_hash_to_block[block_hash]) == 0: + del self.cached_block_hash_to_block[block_hash] + + return True + return False + + def touch(self, blocks: List[KVCacheBlock]) -> None: + """Touch a block increases its reference count by 1, and may remove + the block from the free queue. This is used when a block is hit by + another request with the same prefix. + + Args: + blocks: A list of blocks to touch. + """ + for block in blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0: + self.free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + """Free a list of blocks. The blocks should be ordered by their + eviction priority, where the first block will be evicted first. + + Args: + ordered_blocks: A list of blocks to free ordered by their eviction + priority. + """ + for block in ordered_blocks: + block.decr_ref() + if block.ref_cnt == 0: + self.free_block_queue.append(block) + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks()) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Remove all hashes so that no new blocks will hit. + self.cached_block_hash_to_block = defaultdict(dict) + + # Remove all hashes from all blocks. + for block in self.blocks: + block.reset_hash() + + logger.info("Successfully reset prefix cache") + return True + + def get_num_free_blocks(self) -> int: + """Get the number of free blocks in the pool. + + Returns: + The number of free blocks. + """ + return self.free_block_queue.num_free_blocks + + def get_usage(self) -> float: + """Get the KV cache usage. + + Returns: + The KV cache usage (between 0.0 and 1.0). + """ + return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 017e625dcdba..fc7bfa0eff57 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -5,10 +5,8 @@ from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, - generate_block_hash_extra_keys, - hash_block_tokens, +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -49,26 +47,7 @@ def __init__( self.num_preallocate_tokens = num_preallocate_tokens self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) - # A Block pool of all kv-cache blocks. - self.block_pool: List[KVCacheBlock] = [ - KVCacheBlock(idx) for idx in range(num_gpu_blocks) - ] - # Free block queue that constructs and manipulates a doubly linked - # list of free blocks (including eviction candidates when caching is - # enabled). - self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) - - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. - # The cached block may be used by running requests or in the - # free_block_queue that could potentially be evicted. - # NOTE: We currently don't de-duplicate the blocks in the cache, - # meaning that if a block becomes full and is cached, we don't check - # if there is already an identical block in the cache. This is because - # we want to make sure the allocated block IDs won't change so that - # block tables are append-only. - self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ - int, KVCacheBlock]] = defaultdict(dict) + self.block_pool = BlockPool(num_gpu_blocks, enable_caching) # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request @@ -96,8 +75,7 @@ def usage(self) -> float: Returns: The KV cache usage (between 0.0 and 1.0). """ - return 1.0 - (self.free_block_queue.num_free_blocks / - self.num_gpu_blocks) + return self.block_pool.get_usage() def make_prefix_cache_stats(self) -> PrefixCacheStats: """Get (and reset) the prefix cache stats. @@ -139,7 +117,7 @@ def get_computed_blocks( # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self._get_cached_block(block_hash): + if cached_block := self.block_pool.get_cached_block(block_hash): computed_blocks.append(cached_block) else: break @@ -204,14 +182,14 @@ def allocate_slots( # when allocating this request. num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks if blk.ref_cnt == 0) - if (num_new_blocks > self.free_block_queue.num_free_blocks - + if (num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): # Cannot allocate new blocks return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self._touch(new_computed_blocks) + self.block_pool.touch(new_computed_blocks) else: assert not new_computed_blocks, ( "Computed blocks should be empty when " @@ -231,7 +209,7 @@ def allocate_slots( # preallocated blocks. num_new_blocks = min( num_new_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, + self.block_pool.get_num_free_blocks(), # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape # [..., max_num_blocks_per_req]. @@ -240,29 +218,30 @@ def allocate_slots( assert num_new_blocks > 0 # Concatenate the computed block IDs and the new block IDs. - new_blocks = self._get_new_blocks(num_new_blocks) + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) if not self.enable_caching: return new_blocks + # FIXME: `num_cached_blocks` is not correct when the prefix cache + # of a new request is hit. num_cached_blocks = self.num_cached_block[request.request_id] # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( request.spec_token_ids)) // self.block_size - new_full_blocks = req_blocks[ - num_cached_blocks:num_full_blocks_after_append] - - if new_full_blocks: - self._cache_full_blocks( - request=request, - blk_start_idx=num_cached_blocks, - # The new full blocks are the full blocks that are not computed. - full_blocks=new_full_blocks, - prev_block=(req_blocks[num_cached_blocks - - 1] if num_cached_blocks > 0 else None)) + + self.block_pool.cache_full_blocks( + request=request, + blocks=req_blocks, + block_hashes=self.req_to_block_hashes[request.request_id], + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks_after_append, + block_size=self.block_size, + ) + self.num_cached_block[ request.request_id] = num_full_blocks_after_append return new_blocks @@ -283,11 +262,7 @@ def free(self, request: Request) -> None: # freed first. ordered_blocks = reversed(blocks) - for block in ordered_blocks: - block.decr_ref() - if block.ref_cnt == 0: - self.free_block_queue.append(block) - + self.block_pool.free_blocks(ordered_blocks) self.num_cached_block.pop(request.request_id, None) def reset_prefix_cache(self) -> bool: @@ -299,25 +274,10 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ - num_used_blocks = (self.num_gpu_blocks - - self.free_block_queue.num_free_blocks) - if num_used_blocks > 0: - logger.warning( - "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks) - return False - - # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) - - # Remove all hashes from all blocks. - for block in self.block_pool: - block.reset_hash() - - self.prefix_cache_stats.reset = True - - logger.info("Successfully reset prefix cache") - return True + if self.block_pool.reset_prefix_cache(): + self.prefix_cache_stats.reset = True + return True + return False def get_num_common_prefix_blocks( self, @@ -367,177 +327,6 @@ def get_num_common_prefix_blocks( break return num_common_blocks - def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: - """Get new blocks from the free block pool. - - Note that we do not check block cache in this function. - - Args: - num_blocks: The number of blocks to allocate. - - Returns: - A list of new block. - """ - if num_blocks > self.free_block_queue.num_free_blocks: - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") - - ret: List[KVCacheBlock] = [] - idx = 0 - while idx < num_blocks: - # First allocate blocks. - curr_block = self.free_block_queue.popleft() - assert curr_block.ref_cnt == 0 - - # If the block is cached, evict it. - if self.enable_caching: - self._maybe_evict_cached_block(curr_block) - - curr_block.incr_ref() - ret.append(curr_block) - idx += 1 - - return ret - - def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: - """ - If a block is cached in `cached_block_hash_to_block`, we reset its hash - metadata and evict it from the cache. - - Args: - block: The block to evict. - - Returns: - True if the block is evicted, False otherwise. - """ - block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: - block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] - - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] - - return True - return False - - def _get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: - """Get a cached block by the block hash, or None if cache miss. - If there are duplicated blocks, we return the first block in the cache. - - Args: - block_hash: The hash value of the block. - - Returns: - The cached block if it exists, or None. - """ - if block_hash in self.cached_block_hash_to_block: - first_block_id = list( - self.cached_block_hash_to_block[block_hash].keys())[0] - return self.cached_block_hash_to_block[block_hash][first_block_id] - return None - - def _touch(self, blocks: List[KVCacheBlock]) -> None: - """Touch a block increases its reference count by 1, and may remove - the block from the free queue. This is used when a block is hit by - another request with the same prefix. - - Args: - blocks: A list of blocks to touch. - """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0: - self.free_block_queue.remove(block) - block.incr_ref() - - def _cache_full_blocks( - self, - request: Request, - blk_start_idx: int, - full_blocks: List[KVCacheBlock], - prev_block: Optional[KVCacheBlock], - ) -> None: - """Cache a list of full blocks for prefix caching. - - This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `blk_start_idx` to the end - of the request's full blocks, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. - - Args: - request: The request to cache the blocks. - blk_start_idx: The index of the first block in the request's blocks - to cache. - full_blocks: The list of blocks to update hash metadata. - prev_block: The previous block in the chain. - """ - block_hashes = self.req_to_block_hashes[request.request_id] - num_cached_block_hashes = len(block_hashes) - - # Update the new blocks with the block hashes through the chain. - prev_block_hash_value = None - if prev_block is not None: - # Previous block must have a block hash because it must be - # a full, cached block. - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.hash_value - - # Find the first uncached block. This case should only happen when - # speculative decoding is used. - offset = 0 - for blk in full_blocks: - if blk.block_hash is None: - break - else: - prev_block_hash_value = blk.block_hash.hash_value - offset += 1 - else: - # All blocks are cached. - return - - for i, blk in enumerate(full_blocks[offset:]): - blk_idx = blk_start_idx + offset + i - assert blk.block_hash is None - - if blk_idx < num_cached_block_hashes: - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption). In this case we simply - # reuse the block hash. - block_hash = block_hashes[blk_idx] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - start_token_idx = blk_idx * self.block_size - end_token_idx = (blk_idx + 1) * self.block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == self.block_size, ( - f"Expected {self.block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, - block_tokens, extra_keys) - block_hashes.append(block_hash) - - # Update and added the full block to the cache. - blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk - prev_block_hash_value = block_hash.hash_value - def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. From 728088fa50f36482976aae92003fea0bcb767ac1 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Fri, 28 Feb 2025 13:47:44 -0800 Subject: [PATCH 296/317] [core] Bump ray to 2.43 (#13994) Signed-off-by: Rui Qiao --- .github/dependabot.yml | 2 +- requirements-cuda.txt | 2 +- requirements-test.in | 2 +- requirements-test.txt | 2 +- vllm/executor/ray_distributed_executor.py | 10 ++++------ 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 683b70cd8998..a017d69be991 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -23,7 +23,7 @@ updates: - dependency-name: "lm-format-enforcer" - dependency-name: "gguf" - dependency-name: "compressed-tensors" - - dependency-name: "ray[adag]" + - dependency-name: "ray[cgraph]" # Ray Compiled Graph - dependency-name: "lm-eval" groups: minor-update: diff --git a/requirements-cuda.txt b/requirements-cuda.txt index bc670b8511fd..2de06668c3a4 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -2,7 +2,7 @@ -r requirements-common.txt # Dependencies for NVIDIA GPUs -ray[adag] == 2.40.0 # Required for pipeline parallelism in V1. +ray[cgraph] >= 2.43.0 # Ray Compiled Graph, required for pipeline parallelism in V1. torch == 2.5.1 torchaudio==2.5.1 # These must be updated alongside torch diff --git a/requirements-test.in b/requirements-test.in index 53c531360d87..de33f92b37b9 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -16,7 +16,7 @@ vector_quantize_pytorch # required for minicpmo_26 test vocos # required for minicpmo_26 test peft pqdm -ray[adag]==2.40.0 +ray[cgraph]>=2.43.0 # Ray Compiled Graph, required by pipeline parallelism tests sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests diff --git a/requirements-test.txt b/requirements-test.txt index 11f0e10969a6..f5722c82e201 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -472,7 +472,7 @@ pyyaml==6.0.2 # vocos rapidfuzz==3.12.1 # via jiwer -ray==2.40.0 +ray==2.43.0 # via -r requirements-test.in redis==5.2.0 # via tensorizer diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 2accb9e17f3d..108f606e2fb8 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -500,7 +500,7 @@ def _check_ray_cgraph_installation(self): import pkg_resources from packaging import version - required_version = version.parse("2.40") + required_version = version.parse("2.43.0") current_version = version.parse( pkg_resources.get_distribution("ray").version) if current_version < required_version: @@ -512,20 +512,19 @@ def _check_ray_cgraph_installation(self): "ray.experimental.compiled_dag_ref") if cgraph_spec is None: raise ValueError("Ray Compiled Graph is not installed. " - "Run `pip install ray[adag]` to install it.") + "Run `pip install ray[cgraph]` to install it.") cupy_spec = importlib.util.find_spec("cupy") if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: raise ValueError( "cupy is not installed but required since " "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set. " - "Run `pip install ray[adag]` and check cupy installation.") + "Run `pip install ray[cgraph]` and check cupy installation.") def _compiled_ray_dag(self, enable_asyncio: bool): assert self.parallel_config.use_ray self._check_ray_cgraph_installation() from ray.dag import InputNode, MultiOutputNode - from ray.experimental.channel.torch_tensor_type import TorchTensorType logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) @@ -574,8 +573,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool): if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \ else "auto" outputs = [ - output.with_type_hint( - TorchTensorType(transport=transport)) + output.with_tensor_transport(transport=transport) for output in outputs ] From 27aacf99052f00a572f45de2dc7189e0987e7cda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 28 Feb 2025 18:20:11 -0500 Subject: [PATCH 297/317] [torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902) Signed-off-by: luka --- tests/compile/backend.py | 13 +- tests/compile/test_functionalization.py | 8 +- tests/compile/test_fusion.py | 127 ++++++++-------- vllm/compilation/noop_elimination.py | 135 ++++++++++++++++++ vllm/compilation/pass_manager.py | 8 +- vllm/compilation/reshapes.py | 90 ------------ vllm/compilation/vllm_inductor_pass.py | 18 ++- vllm/config.py | 13 +- .../layers/quantization/utils/w8a8_utils.py | 7 +- 9 files changed, 249 insertions(+), 170 deletions(-) create mode 100644 vllm/compilation/noop_elimination.py delete mode 100644 vllm/compilation/reshapes.py diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 74bc58a2dd54..64416eb136cf 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -13,21 +13,26 @@ class TestBackend: This class provides a simple Inductor backend that can be used for testing. It takes a list of custom passes and runs them after Inductor's passes. It also saves the graph before and after the custom passes for inspection. + + Inductor config can be modified directly by editing the inductor_config + property. This can be helpful for adding passes like the + 'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'. """ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) from torch._inductor import config - self.current_config = config.shallow_copy_dict() - self.current_config['force_disable_caches'] = True - self.current_config['post_grad_custom_post_pass'] = self.post_pass + self.inductor_config = config.shallow_copy_dict() + self.inductor_config['force_disable_caches'] = True + self.inductor_config['post_grad_custom_post_pass'] = self.post_pass def __call__(self, graph: fx.GraphModule, example_inputs): + self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx return compile_fx(graph, example_inputs, - config_patches=self.current_config) + config_patches=self.inductor_config) def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 8f5040522692..9f9b2d06b227 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -9,7 +9,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func -from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig from .backend import TestBackend @@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, torch.set_default_device("cuda") config = CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_reshape=True) - reshape_pass = RedundantReshapesPass(config) + enable_noop=True) + noop_pass = NoOpEliminationPass(config) fusion_pass = FusionPass.instance(config) - passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass] + passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass] func_pass = FixFunctionalizationPass(config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index c14f0caab539..89abc001764b 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,23 +5,25 @@ from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs +import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, FusionPass, QuantKey) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe -from vllm.compilation.reshapes import RedundantReshapesPass -from vllm.config import CompilationConfig +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear) + CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity) from .backend import TestBackend class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, eps: float, static: bool, *args, - **kwargs): + def __init__(self, hidden_size: int, eps: float, static: bool, + cutlass_fp8_enabled: bool, *args, **kwargs): super().__init__(*args, **kwargs) + self.cutlass_fp8_enabled = cutlass_fp8_enabled self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] if static: @@ -41,7 +43,8 @@ def forward(self, x): self.w[0], self.wscale[0], self.scale[0], - use_per_token_if_dynamic=True) + use_per_token_if_dynamic=True, + cutlass_fp8_supported=self.cutlass_fp8_enabled) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) @@ -49,7 +52,8 @@ def forward(self, x): self.w[1], self.wscale[1], self.scale[1], - use_per_token_if_dynamic=True) + use_per_token_if_dynamic=True, + cutlass_fp8_supported=self.cutlass_fp8_enabled) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -59,60 +63,67 @@ def forward(self, x): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("cutlass_fp8_enabled", + [True, False] if CUTLASS_FP8_SUPPORTED else [False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, + cutlass_fp8_enabled): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) + maybe_create_device_identity() # needed for certain non-cutlass fp8 paths - # Reshape pass is needed for the fusion pass to work - config = CompilationConfig.PassConfig(enable_fusion=True, - enable_reshape=True) - reshape_pass = RedundantReshapesPass(config) - fusion_pass = FusionPass.instance(config) - - backend = TestBackend(reshape_pass, fusion_pass) - model = TestModel(hidden_size, eps, static) - - # First dimension dynamic - x = torch.rand(num_tokens, hidden_size) - torch._dynamo.mark_dynamic(x, 0) - - result = model(x) - - model2 = torch.compile(model, backend=backend) - result2 = model2(x) - - # Higher tol for dynamic, even higher for bfloat16 - if static: - ATOL, RTOL = (1e-3, 1e-3) - elif dtype == torch.float16: - ATOL, RTOL = (2e-3, 2e-3) - else: - ATOL, RTOL = (1e-2, 1e-2) - - torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) - - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes - - # static is per-tensor, dynamic is per-token - key = QuantKey(dtype=FP8_DTYPE, - static=static, - per_tensor=static, - symmetric=True) - rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] - add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] - fp8_quant = QUANT_OPS[key] - - # In pre-nodes, fp8 quant should be present and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, rms_quant) is None - assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None - find_auto_fn(pre_nodes, fp8_quant) - - # In post-nodes, fused kernels should be present and fp8 quant should not - find_auto_fn(post_nodes, rms_quant) - find_auto_fn(post_nodes, add_rms_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) + with vllm.config.set_current_vllm_config(vllm_config): + # Reshape pass is needed for the fusion pass to work + config = CompilationConfig.PassConfig(enable_fusion=True, + enable_noop=True) + noop_pass = NoOpEliminationPass(config) + fusion_pass = FusionPass.instance(config) + + backend = TestBackend(noop_pass, fusion_pass) + model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Higher tol for dynamic, even higher for bfloat16 + if static: + ATOL, RTOL = (1e-3, 1e-3) + elif dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes + + # static is per-tensor, dynamic is per-token + key = QuantKey(dtype=FP8_DTYPE, + static=static, + per_tensor=static, + symmetric=True) + rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] + add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] + fp8_quant = QUANT_OPS[key] + + # In pre-nodes, fp8 quant should be there and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, rms_quant) is None + assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None + find_auto_fn(pre_nodes, fp8_quant) + + # In post-nodes, fused kernels should be there and fp8 quant should not + find_auto_fn(post_nodes, rms_quant) + find_auto_fn(post_nodes, add_rms_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py new file mode 100644 index 000000000000..19127e933ec4 --- /dev/null +++ b/vllm/compilation/noop_elimination.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Union + +import torch.fx +from torch import SymInt + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class NoOpEliminationPass(VllmInductorPass): + """ + This is an inductor pass that removes redundant reshape/slice operations. + It is required for RMSNorm-quant fusion to work properly. + That's because apply_fp8_linear adds a reshape, which is redundant + in the 2D-case. Additionally, torch internal no-op elimination pass does + not handle certain slice variants. + + Example graph 1: + getitem_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) + at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Can be replaced with: + getitem_1: "f16[s0, 4096]" = ... + at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Example graph 2: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) + at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...) + out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0) + + Can be replaced with: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) + out: "f16[s0, 4096]" = at[1] + + TODO(luka): This is currently tested in test_fusion, + but separate tests could be good. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_noop_elimination") + count = 0 + # Remove no-op reshapes/views: + for node in graph.nodes: + if is_func(node, torch.ops.aten.reshape.default): + input, shape = node.args[:2] + input_shape = input.meta["val"].shape + if len(shape) != len(input_shape): + # Reshape changing rank, skip + continue + + if shape.count(-1) > 1: + # Invalid reshape args, skip + continue + + if self.all_dims_equivalent(shape, input_shape): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice.Tensor): + input, dim_index, start, end = node.args[:4] + input_shape = input.meta["val"].shape + i_dim = input_shape[dim_index] + + if start == 0 and self.dims_equivalent(end, i_dim): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice_scatter.default): + base, view, dim_index, start, end = node.args[:5] + base_shape = base.meta["val"].shape + view_shape = view.meta["val"].shape + + view_dim = view_shape[dim_index] + + # Check that view fully covers base and the full view is used + # (if the view fully covered the base after slicing but was not + # fully used, we could replace slice_scatter with a simple slice + # but that's a niche case). + if (base_shape == view_shape and start == 0 + and self.dims_equivalent(end, view_dim)): + node.replace_all_uses_with(view) + graph.erase_node(node) + count += 1 + + logger.debug("Removed %s no-op reshapes and slices", count) + self.dump_graph(graph, "after_noop_elimination") + self.end_and_log() + + def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], + i_dims: Iterable[Union[int, SymInt]]): + return all( + self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) + + def dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: + """ + This function checks if two dimensions are equivalent. + :param dim: The dimension arg to reshape/slice + :param i_dim: The corresponding dimension in the input tensor + :return: Are the dimensions equivalent? + + There are three cases in which the dimensions are equivalent: + 1. The dimensions are equal (both integers) + 2. The reshape dimension is -1 (i.e. inferred) + 3. The dimensions both correspond to the same SymInt + + While case 2 does not guarantee the dimensions are equal, + they are equal if all other dimensions are equal. + + In case 3, the reshape dimension is a torch.fx.Node, + and its value is a SymInt. That value is equal to the + input dimension. + + """ + # Case 1 and 2 + if dim == i_dim or dim == -1: + return True + # Case 3 + return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 52f8c3b1ec15..b012346c353e 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -11,7 +11,7 @@ from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass from .inductor_pass import InductorPass -from .reshapes import RedundantReshapesPass +from .noop_elimination import NoOpEliminationPass logger = init_logger(__name__) @@ -36,7 +36,7 @@ class PostGradPassManager(Parent): The order of the post-grad post-passes is: 1. passes (constructor parameter) - 2. default passes (RedundantReshapesPass, FusionPass) + 2. default passes (NoopEliminationPass, FusionPass) 3. config["post_grad_custom_post_pass"] (if it exists) 4. fix_functionalization This way, all passes operate on a functionalized graph. @@ -54,8 +54,8 @@ def __call__(self, graph: fx.Graph): def configure(self, pass_config: CompilationConfig.PassConfig): self.pass_config = pass_config - if pass_config.enable_reshape: - self.passes += [RedundantReshapesPass(pass_config)] + if pass_config.enable_noop: + self.passes += [NoOpEliminationPass(pass_config)] if pass_config.enable_fusion: self.passes += [FusionPass.instance(pass_config)] diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py deleted file mode 100644 index 292baae85282..000000000000 --- a/vllm/compilation/reshapes.py +++ /dev/null @@ -1,90 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -import torch.fx -from torch import SymInt - -from vllm.logger import init_logger - -from .fx_utils import is_func -from .vllm_inductor_pass import VllmInductorPass - -logger = init_logger(__name__) - - -class RedundantReshapesPass(VllmInductorPass): - """ - This is an inductor pass that removes redundant reshape operations. - It is required for RMSNorm-quant fusion to work properly. - That's because apply_fp8_linear adds a reshape, which is redundant - in the 2D-case. - - Example graph: - - getitem_1: "f16[s0, 4096]" = ... - view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) - at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) - out: "f8e4m3fn[s0, 4096]" = at[1] - - Can be replaced with: - getitem_1: "f16[s0, 4096]" = ... - at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) - out: "f8e4m3fn[s0, 4096]" = at[1] - """ - - def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_reshapes") - count = 0 - # Remove no-op reshapes/views: - for node in graph.nodes: - if is_func(node, torch.ops.aten.reshape.default): - input, shape = node.args[:2] - input_shape = input.meta["val"].shape - if len(shape) != len(input_shape): - # Reshape changing rank, skip - continue - - if shape.count(-1) > 1: - # Invalid reshape args, skip - continue - - if all( - self.dims_equivalent(s, i_s) - for s, i_s in zip(shape, input_shape)): - node.replace_all_uses_with(input) - graph.erase_node(node) - count += 1 - - logger.debug("Removed %s no-op reshapes", count) - - self.dump_graph(graph, "after_reshapes") - self.end_and_log() - - def dims_equivalent(self, dim: Union[int, torch.fx.Node], - i_dim: Union[int, SymInt]) -> bool: - """ - This function checks if two dimensions are equivalent. - :param dim: The dimension arg to reshape - :param i_dim: The corresponding dimension in the input tensor - :return: Are the dimensions equivalent? - - There are three cases in which the dimensions are equivalent: - 1. The dimensions are equal (both integers) - 2. The reshape dimension is -1 (i.e. inferred) - 3. The dimensions both correspond to the same SymInt - - While case 2 does not guarantee the dimensions are equal, - they are equal if all other dimensions are equal. - - In case 3, the reshape dimension is a torch.fx.Node, - and its value is a SymInt. That value is equal to the - input dimension. - - """ - # Case 1 and 2 - if dim == i_dim or dim == -1: - return True - # Case 3 - return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 1d2597e42711..98ed6f1472a4 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -28,8 +28,8 @@ def __init__(self, config: CompilationConfig.PassConfig): self.config = config self.pass_name = self.__class__.__name__ - def dump_graph(self, graph: torch.fx.Graph, stage: str): - if stage in self.config.dump_graph_stages: + def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): + if stage in self.config.dump_graph_stages or always: # Make sure filename includes rank in the distributed setting parallel = p_is_init() and get_tp_world_size() > 1 rank = f"-{get_tp_rank()}" if parallel else "" @@ -49,3 +49,17 @@ def end_and_log(self): self._end_time = time.perf_counter_ns() duration_ms = float(self._end_time - self._start_time) / 1.0e6 logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class PrinterInductorPass(VllmInductorPass): + + def __init__(self, + name: str, + config: CompilationConfig.PassConfig, + always=False): + super().__init__(config) + self.name = name + self.always = always + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, self.name, always=self.always) diff --git a/vllm/config.py b/vllm/config.py index 78d02b017350..c7108473442b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2993,13 +2993,13 @@ class PassConfig(BaseModel): Each pass defines its own stages (before, after, maybe in-between). - dump_graph_dir: directory to dump the graphs. Default is . - enable_fusion: whether to enable the custom fusion pass. - - enable_reshape: whether to enable the custom reshape elimination pass. - TODO better pass enabling system. + - enable_noop: whether to enable the custom no-op elimination pass. + TODO(luka) better pass enabling system. """ dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) enable_fusion: bool = True - enable_reshape: bool = True + enable_noop: bool = True def uuid(self): """ @@ -3008,13 +3008,12 @@ def uuid(self): Do not include dump_graph_* in the hash - they don't affect compilation. """ - dict_ = self.model_dump( - include={"enable_fusion", "enable_reshape"}) + dict_ = self.model_dump(include={"enable_fusion", "enable_noop"}) encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).digest() def model_post_init(self, __context: Any) -> None: - if not self.enable_reshape and self.enable_fusion: + if not self.enable_noop and self.enable_fusion: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "RMSNorm + quant (fp8) fusion might not work") @@ -3411,7 +3410,7 @@ def __post_init__(self): self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False - self.compilation_config.pass_config.enable_reshape = False + self.compilation_config.pass_config.enable_noop = False self.compilation_config.level = CompilationLevel.PIECEWISE self._set_cudagraph_sizes() diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 0f93b7f6c45b..8072f307763d 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm.config import CompilationLevel, get_current_vllm_config from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting @@ -161,10 +162,14 @@ def apply_fp8_linear( # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. + config = get_current_vllm_config().compilation_config + do_pad = config.level < CompilationLevel.PIECEWISE qinput, x_scale = ops.scaled_fp8_quant( input_2d, input_scale, - num_token_padding=17, + num_token_padding=17 if do_pad else None, use_per_token_if_dynamic=use_per_token_if_dynamic) per_tensor_weights = (weight_scale.numel() == 1) From c875893f8f29fad8541e1c920cac8183ccacdb2a Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Sat, 1 Mar 2025 00:43:54 -0500 Subject: [PATCH 298/317] [Docs] Add `pipeline_parallel_size` to optimization docs (#14059) Signed-off-by: Brayden Zhong --- docs/source/performance/optimization.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/performance/optimization.md b/docs/source/performance/optimization.md index 4fbc376e1aa3..5b0f8421a51e 100644 --- a/docs/source/performance/optimization.md +++ b/docs/source/performance/optimization.md @@ -18,6 +18,7 @@ If you frequently encounter preemptions from the vLLM engine, consider the follo - Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space. - Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space. - Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache. +- Increase `pipeline_parallel_size`. This approach distributes model layers across GPUs, reducing the memory needed for model weights on each GPU, which indirectly leaves more memory available for KV cache. You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False. From 1a94642bc34a9dbb73dbb0f8d7b29f440d8fbbc6 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 1 Mar 2025 14:10:28 +0800 Subject: [PATCH 299/317] [Bugfix] Add file lock for ModelScope download (#14060) Signed-off-by: Jee Jee Li --- benchmarks/backend_request_func.py | 15 ++++++++----- vllm/model_executor/model_loader/loader.py | 20 ++++++++++------- .../model_loader/weight_utils.py | 5 ++++- vllm/transformers_utils/tokenizer.py | 22 ++++++++++++------- 4 files changed, 40 insertions(+), 22 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 364b087b841d..e43549c13c8e 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from vllm.model_executor.model_loader.weight_utils import get_lock + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) @@ -430,12 +432,15 @@ def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': from modelscope import snapshot_download - model_path = snapshot_download( - model_id=pretrained_model_name_or_path, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(pretrained_model_name_or_path): + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) - return model_path + return model_path return pretrained_model_name_or_path diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 46247eaf2a60..6244241d1891 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -49,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - get_gguf_extra_tensor_names, gguf_quant_weights_iterator, + get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, runai_safetensors_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs @@ -235,13 +235,17 @@ def _maybe_download_from_modelscope( from modelscope.hub.snapshot_download import snapshot_download if not os.path.exists(model): - model_path = snapshot_download( - model_id=model, - cache_dir=self.load_config.download_dir, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - revision=revision, - ignore_file_pattern=self.load_config.ignore_patterns, - ) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, self.load_config.download_dir): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants. + HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) else: model_path = model return model_path diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 245c199f75b1..d184079fb25d 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -8,6 +8,7 @@ import tempfile import time from collections import defaultdict +from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import filelock @@ -67,8 +68,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, disable=True) -def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): +def get_lock(model_name_or_path: Union[str, Path], + cache_dir: Optional[str] = None): lock_dir = cache_dir or temp_dir + model_name_or_path = str(model_name_or_path) os.makedirs(os.path.dirname(lock_dir), exist_ok=True) model_name = model_name_or_path.replace("/", "-") hash_name = hashlib.sha256(model_name.encode()).hexdigest() diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index f0aa5fdcaa61..2c34f2f5d44d 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -150,16 +150,22 @@ def get_tokenizer( # pylint: disable=C. from modelscope.hub.snapshot_download import snapshot_download + # avoid circuit import + from vllm.model_executor.model_loader.weight_utils import get_lock + # Only set the tokenizer here, model will be downloaded on the workers. if not os.path.exists(tokenizer_name): - tokenizer_path = snapshot_download( - model_id=tokenizer_name, - cache_dir=download_dir, - revision=revision, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - # Ignore weights - we only need the tokenizer. - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) - tokenizer_name = tokenizer_path + # Use file lock to prevent multiple processes from + # downloading the same file at the same time. + with get_lock(tokenizer_name, download_dir): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + tokenizer_name = tokenizer_path if tokenizer_mode == "slow": if kwargs.get("use_fast", False): From d92cda2383a3363bc1a90957605db9fb975b99e6 Mon Sep 17 00:00:00 2001 From: YajieWang Date: Sat, 1 Mar 2025 14:30:59 +0800 Subject: [PATCH 300/317] [Misc][Kernel]: Add GPTQAllSpark Quantization (#12931) --- CMakeLists.txt | 16 + benchmarks/kernels/benchmark_marlin.py | 47 +- .../gptq_allspark/allspark_qgemm_w8a16.cu | 1008 +++++++++++++++++ .../gptq_allspark/allspark_repack.cu | 163 +++ .../gptq_allspark/allspark_utils.cuh | 408 +++++++ csrc/torch_bindings.cpp | 19 + tests/kernels/test_allspark_gemm.py | 100 ++ tests/quantization/test_compressed_tensors.py | 2 - vllm/_custom_ops.py | 77 ++ .../kernels/mixed_precision/__init__.py | 3 + .../kernels/mixed_precision/allspark.py | 115 ++ .../quantization/utils/allspark_utils.py | 51 + 12 files changed, 2005 insertions(+), 4 deletions(-) mode change 100755 => 100644 CMakeLists.txt create mode 100644 csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu create mode 100644 csrc/quantization/gptq_allspark/allspark_repack.cu create mode 100644 csrc/quantization/gptq_allspark/allspark_utils.cuh create mode 100644 tests/kernels/test_allspark_gemm.py create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py create mode 100644 vllm/model_executor/layers/quantization/utils/allspark_utils.py diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100755 new mode 100644 index 0dd350c93ed5..c5fc2f3c1aaf --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -317,6 +317,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() + # Only build AllSpark kernels if we are building for at least some compatible archs. + cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") + if (ALLSPARK_ARCHS) + set(ALLSPARK_SRCS + "csrc/quantization/gptq_allspark/allspark_repack.cu" + "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") + set_gencode_flags_for_srcs( + SRCS "${ALLSPARK_SRCS}" + CUDA_ARCHS "${ALLSPARK_ARCHS}") + list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") + message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") + else() + message(STATUS "Not building AllSpark kernels as no compatible archs found" + " in CUDA target architectures") + endif() + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index c22e66c0b0c9..21ef491294e3 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) @@ -18,12 +20,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, gptq_quantize_weights, sort_weights) + gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] @@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str, GPTQ_MARLIN_24_MAX_PARALLEL) marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) + # AllSpark W8A16 quant + as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 and not act_order and is_k_full) + if as_supported_case: + properties = torch.cuda.get_device_properties(b.device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + + supported_arch = (sm_version >= 80 and sm_version < 90) + as_supported_case = as_supported_case and supported_arch + if supported_arch: + has_zp = False + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, + has_zp) + qw = qw.to(torch.uint8) + + qw_reorder, s_reorder, zp_reorder = \ + ops.allspark_repack_weight( + qw, s, zp, has_zp) + CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD + globals = { # Gen params "quant_type": quant_type, @@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str, # GPTQ params "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, + # AllSpark W8A16 params + "qw_reorder": qw_reorder if as_supported_case else None, + "s_reorder": s_reorder if as_supported_case else None, + "zp_reorder": zp_reorder if as_supported_case else None, + "sm_count": sm_count if as_supported_case else None, + "sm_version": sm_version if as_supported_case else None, + "CUBLAS_M_THRESHOLD": + CUBLAS_M_THRESHOLD if as_supported_case else None, # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_repack": ops.gptq_marlin_repack, + "allspark_w8a16_gemm": ops.allspark_w8a16_gemm, } min_run_time = 1 @@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str, description="gptq_marlin_repack", ).blocked_autorange(min_run_time=min_run_time)) + if as_supported_case: + results.append( + benchmark.Timer( + stmt= + "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="allspark_w8a16_gemm_fp32", + ).blocked_autorange(min_run_time=min_run_time)) + def main(args): print("Benchmarking models:") diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu new file mode 100644 index 000000000000..c4ed98ca64f8 --- /dev/null +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -0,0 +1,1008 @@ +#include "allspark_utils.cuh" +#include +#include "core/registration.h" +#include + +at::Tensor as_g_workspace; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +torch::Tensor allspark_w8a16_gemm( + torch::Tensor const& a, torch::Tensor const& b_qweight, + torch::Tensor const& b_scales, c10::optional const& b_qzeros, + int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, + int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "allspark_w8a16_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else +namespace allspark { +/* + * GemmTile manage data movement from Global Memory to Shared Memory + * requiring N % 8 == 0, K % 16 == 0 by loading uint + * BN is obtained by padding the original N to a multiple of 32 + * weight B is rearranged as N32K16 order, + * i.e. a initial data block of size 32(n)x16(k) is reordered as n8k4n4k4, + * in order to put data loaded by the same thread of 32x16 data block together + * continuously (see + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type) + */ +template +struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { + // element num loaded by a LDG inst. + static constexpr int LDG_ELEMENT_CNT_A = 8; + static constexpr int LDG_ELEMENT_CNT_B = 16; + static constexpr int WARP_SIZE = 32; + static constexpr int M_SIZE_ONE_LOAD = (BLOCK * LDG_ELEMENT_CNT_A) / 32; + static constexpr int N_SIZE_ONE_LOAD = (BLOCK * LDG_ELEMENT_CNT_B) / 32; + + __device__ GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK( + const SM8x_GEMM_W8A16_Splitk_Params& k_params, + const uint32_t& A_smem_addr, const uint32_t& BQ_smem_addr, + const uint32_t& A_stage_stride, const uint32_t& BQ_stage_stride) + : params(k_params), + A_smem_base_addr(A_smem_addr), + BQ_smem_base_addr(BQ_smem_addr), + A_smem_stage_stride(A_stage_stride), + BQ_smem_stage_stride(BQ_stage_stride) { + this_block_A_base_ptr = params.A_ptr + blockIdx.x * Mtile * params.K + + blockIdx.z * params.SplitK; + // here B is rearranged as N32K16 order, i.e. 4 continuous N-direction + // 8(N)x16(K) size data blocks are packed together + this_block_B_base_ptr = params.B_ptr + blockIdx.y * Ntile * params.K + + blockIdx.z * params.SplitK * 4; + + const int lane_id = threadIdx.x % WARP_SIZE; + + // For matrix A, a block load/store Mtile(row) x 32(col) elements in + // multiple iters, 8x4 warp load/store 8(row) x 32(col) elements per iter + const int Aldg_row_base_idx = threadIdx.x / 4; + Aldg_col_idx = (threadIdx.x % 4) * LDG_ELEMENT_CNT_A; + const int Aldg_base_offset = Aldg_row_base_idx * params.K + Aldg_col_idx; + + // For matrix B, a block load/store elements of (Ntile / 4) row x 128 col + // elements of N32K16 packing in multiple iters, 4x8 warp load/store 4(row) + // * 128(col) per iter + Bldg_col_idx = (threadIdx.x % 8) * LDG_ELEMENT_CNT_B; + const int Bldg_row_base_idx = threadIdx.x / 8; + const int Bldg_base_offset = + Bldg_row_base_idx * params.K * 4 + Bldg_col_idx; + + this_block_A_base_ptr += Aldg_base_offset; + this_block_B_base_ptr += Bldg_base_offset; + + const int sts_a_base_offset = + (threadIdx.x / 4) * 32 + + ((lane_id % 4) ^ ((lane_id / 4) % 4) ^ ((lane_id / 4) / 4)) * + LDG_ELEMENT_CNT_A; + const int sts_bq_base_offset = + Bldg_row_base_idx * 32 * 4 + + ((threadIdx.x % 8) ^ (((threadIdx.x / 8) % 2) * 4)) * LDG_ELEMENT_CNT_B; + + A_smem_base_addr += sts_a_base_offset * sizeof(FType); + BQ_smem_base_addr += sts_bq_base_offset * sizeof(uint8_t); + + A_ldg_guard = 0; + B_ldg_guard = 0; + #pragma unroll + for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) { + int m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD; + if (m_idx < params.M) { + A_ldg_guard |= (1u << i); + } + } + + const int N_padded = (params.N + 31) / 32 * 32; + #pragma unroll + for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) { + int n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 + + i * N_SIZE_ONE_LOAD; + if (n_idx < N_padded) { + B_ldg_guard |= (1u << i); + } + } + } + + __device__ void ldgsts_first_ktiles(const int& first_k_tile, + const int& k_tiles) { + // load first k_tile + // load A + const int A_src_size = Aldg_col_idx < first_k_tile ? 16 : 0; + #pragma unroll + for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) { + cp_async<16>( + A_smem_base_addr + (i * M_SIZE_ONE_LOAD * 32) * sizeof(FType), + this_block_A_base_ptr + i * M_SIZE_ONE_LOAD * params.K, A_src_size, + (A_ldg_guard & (1u << i)) != 0); + } + + // load B + const int B_src_size = (Bldg_col_idx / 4) < first_k_tile ? 16 : 0; + #pragma unroll + for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) { + cp_async<16>( + BQ_smem_base_addr + (i * N_SIZE_ONE_LOAD * 32) * sizeof(uint8_t), + this_block_B_base_ptr + i * N_SIZE_ONE_LOAD * params.K, B_src_size, + (B_ldg_guard & (1u << i)) != 0); + } + + cp_async_commit_group(); + this_block_A_base_ptr += first_k_tile; + this_block_B_base_ptr += (first_k_tile * 4); + + // load second to (N-stage - 1) k_tiles + for (int stage_idx = 1; stage_idx < NStage - 1; ++stage_idx) { + if (stage_idx < k_tiles) { + #pragma unroll + for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; + ++i) { + cp_async<16>(A_smem_base_addr + stage_idx * A_smem_stage_stride + + (i * M_SIZE_ONE_LOAD * 32) * sizeof(FType), + this_block_A_base_ptr + i * M_SIZE_ONE_LOAD * params.K, + 16, (A_ldg_guard & (1u << i)) != 0); + } + + #pragma unroll + for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; + ++i) { + cp_async<16>(BQ_smem_base_addr + stage_idx * BQ_smem_stage_stride + + (i * N_SIZE_ONE_LOAD * 32) * sizeof(uint8_t), + this_block_B_base_ptr + i * N_SIZE_ONE_LOAD * params.K, + 16, (B_ldg_guard & (1u << i)) != 0); + } + + this_block_A_base_ptr += 32; + this_block_B_base_ptr += (32 * 4); + } + cp_async_commit_group(); + } + } + + __device__ void ldgsts(const int& sts_stage_idx) { + const int a_stage_offset = sts_stage_idx * A_smem_stage_stride; + const int bq_stage_offset = sts_stage_idx * BQ_smem_stage_stride; + #pragma unroll + for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) { + cp_async<16>(A_smem_base_addr + a_stage_offset + + (i * M_SIZE_ONE_LOAD * 32) * sizeof(FType), + this_block_A_base_ptr + i * M_SIZE_ONE_LOAD * params.K, 16, + (A_ldg_guard & (1u << i)) != 0); + } + + #pragma unroll + for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) { + cp_async<16>(BQ_smem_base_addr + bq_stage_offset + + (i * N_SIZE_ONE_LOAD * 32) * sizeof(uint8_t), + this_block_B_base_ptr + i * N_SIZE_ONE_LOAD * params.K, 16, + (B_ldg_guard & (1u << i)) != 0); + } + + cp_async_commit_group(); + this_block_A_base_ptr += 32; + this_block_B_base_ptr += (32 * 4); + } + + const FType* this_block_A_base_ptr = nullptr; + const QType* this_block_B_base_ptr = nullptr; + + int Aldg_col_idx; + int Bldg_col_idx; + + uint32_t A_ldg_guard; + uint32_t B_ldg_guard; + + uint32_t A_smem_base_addr, BQ_smem_base_addr; + const uint32_t A_smem_stage_stride, BQ_smem_stage_stride; + + const SM8x_GEMM_W8A16_Splitk_Params& params; +}; + +/* + * requiring N % 8 == 0 + */ +template +struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { + static constexpr int WARP_SIZE = 32; + static constexpr int WARP_CNT = BLOCK / WARP_SIZE; + static constexpr int WARP_NTILE = Ntile / WARP_CNT; + static constexpr int WARP_NITER = WARP_NTILE / 8; // hmma16816 + static_assert(WARP_NTILE == 32 or WARP_NTILE == 64, + "now only support WARP_NTILE = 32 or 64!"); + + __device__ ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK( + const SM8x_GEMM_W8A16_Splitk_Params& k_params, + const uint32_t& A_smem_addr, const uint32_t& BQ_smem_addr, + const uint32_t& A_stage_stride, const uint32_t& BQ_stage_stride) + : params(k_params), + A_smem_base_addr(A_smem_addr), + BQ_smem_base_addr(BQ_smem_addr), + A_smem_stage_stride(A_stage_stride), + BQ_smem_stage_stride(BQ_stage_stride) { + warp_id = threadIdx.x / WARP_SIZE; + lane_id = threadIdx.x % WARP_SIZE; + + load_a_base_offset[0] = + (lane_id % 16) * 32 + + ((lane_id / 16) ^ (lane_id % 4) ^ ((lane_id / 4) % 2)) * 8; + load_a_base_offset[1] = + (lane_id % 16) * 32 + + ((lane_id / 16 + 2) ^ (lane_id % 4) ^ ((lane_id / 4) % 2)) * 8; + + load_b_base_offset[0] = + (lane_id / 4 + warp_id * (WARP_NTILE / 4)) * 32 * 4 + + (lane_id % 4) * 16 + ((lane_id / 4) % 2) * 16 * 4; + load_b_base_offset[1] = + (lane_id / 4 + warp_id * (WARP_NTILE / 4)) * 32 * 4 + + (lane_id % 4) * 16 + (((lane_id / 4) % 2) ^ 1) * 16 * 4; + + sts_c_base_offset = warp_id * Mtile * WARP_NTILE + + (lane_id / 4) * WARP_NTILE + (lane_id % 4) * 2; + + if (EnableFuse) { + this_block_C_base_ptr = + params.C_ptr + blockIdx.x * Mtile * params.N + blockIdx.y * Ntile; + } else { + this_block_C_base_ptr = + params.C_split_ptr + blockIdx.z * params.M * params.N + + blockIdx.x * Mtile * params.N + blockIdx.y * Ntile; + } + int store_thds_in_row = WARP_NTILE / 8; + store_c_row_base_idx = lane_id / store_thds_in_row; + store_c_col_idx = warp_id * WARP_NTILE + (lane_id % store_thds_in_row) * 8; + store_c_base_offset = store_c_row_base_idx * params.N + store_c_col_idx; + + #pragma unroll + for (int i = 0; i < Mtile / 16; ++i) { + #pragma unroll + for (int j = 0; j < WARP_NITER; ++j) { + #pragma unroll + for (int k = 0; k < 4; ++k) { + C_frag[i][j][k] = 0.f; + } + } + } + params_n_idx = + blockIdx.y * Ntile + warp_id * WARP_NTILE + (lane_id / 4) * 4; + } + + __device__ void lds(const int& smem_stage_idx, const int& reg_buf_idx, + const int& k_phase_idx) { + uint32_t A_smem_addr = + A_smem_base_addr + A_smem_stage_stride * smem_stage_idx; + uint32_t B_smem_addr = + BQ_smem_base_addr + BQ_smem_stage_stride * smem_stage_idx; + + #pragma unroll + for (int i = 0; i < Mtile / 16; ++i) { + ldsm_4(A_frag[reg_buf_idx][i][0], A_frag[reg_buf_idx][i][1], + A_frag[reg_buf_idx][i][2], A_frag[reg_buf_idx][i][3], + A_smem_addr + (load_a_base_offset[k_phase_idx] + i * 16 * 32) * + sizeof(FType)); + } + #pragma unroll + for (int i = 0; i < WARP_NTILE / 32; ++i) { + lds128(BQ_frag[reg_buf_idx][4 * i + 0], BQ_frag[reg_buf_idx][4 * i + 1], + BQ_frag[reg_buf_idx][4 * i + 2], BQ_frag[reg_buf_idx][4 * i + 3], + B_smem_addr + (load_b_base_offset[k_phase_idx] + i * 32 * 32) * + sizeof(uint8_t)); + } + + // dequant B + #pragma unroll + for (int i = 0; i < WARP_NITER / 2; ++i) { + cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i], + BF_frag[reg_buf_idx][2 * i]); + if (has_zp) { + BF_frag[reg_buf_idx][2 * i][0] = + __hsub2(BF_frag[reg_buf_idx][2 * i][0], num2num2(B_zero[i].x)); + BF_frag[reg_buf_idx][2 * i][1] = + __hsub2(BF_frag[reg_buf_idx][2 * i][1], num2num2(B_zero[i].x)); + } + + BF_frag[reg_buf_idx][2 * i][0] = + __hmul2(BF_frag[reg_buf_idx][2 * i][0], num2num2(B_scale[i].x)); + BF_frag[reg_buf_idx][2 * i][1] = + __hmul2(BF_frag[reg_buf_idx][2 * i][1], num2num2(B_scale[i].x)); + + cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i + 1], + BF_frag[reg_buf_idx][2 * i + 1]); + if (has_zp) { + BF_frag[reg_buf_idx][2 * i + 1][0] = + __hsub2(BF_frag[reg_buf_idx][2 * i + 1][0], num2num2(B_zero[i].y)); + BF_frag[reg_buf_idx][2 * i + 1][1] = + __hsub2(BF_frag[reg_buf_idx][2 * i + 1][1], num2num2(B_zero[i].y)); + } + + BF_frag[reg_buf_idx][2 * i + 1][0] = + __hmul2(BF_frag[reg_buf_idx][2 * i + 1][0], num2num2(B_scale[i].y)); + BF_frag[reg_buf_idx][2 * i + 1][1] = + __hmul2(BF_frag[reg_buf_idx][2 * i + 1][1], num2num2(B_scale[i].y)); + } + } + + __device__ void ldg_params() { + const int N_padded = (params.N + 31) / 32 * 32; + // load B scale and zero_point + #pragma unroll + for (int i = 0; i < WARP_NTILE / 32; ++i) { + ldg64_ca(B_scale[2 * i + 0], B_scale[2 * i + 1], + params.B_scale_ptr + params_n_idx + i * 32, + (params_n_idx + i * 32) < N_padded); + if (has_zp) { + ldg64_ca(B_zero[2 * i + 0], B_zero[2 * i + 1], + params.B_zero_ptr + params_n_idx + i * 32, + (params_n_idx + i * 32) < N_padded); + } + } + } + + __device__ void mma(const int& reg_buf_idx) { + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + hmma16816_f32( + C_frag[m_idx][n_idx], A_frag[reg_buf_idx][m_idx], + reinterpret_cast(BF_frag[reg_buf_idx][n_idx])); + } + } + } + + __device__ void fused_splitk_reduce() { + // need splitk-reduce if enable splitk + if (gridDim.z > 1) { + int blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y; + // Wait for all previous blocks in the splitk direction to accumulate the + // results into C_tmp + if (threadIdx.x == 0) { + uint32_t* red_count_ptr = params.red_count_ptr + blk_red_idx; + uint32_t count; + do { + // make sure the ld.cg inside the do-wile loop + __threadfence_block(); + asm volatile("ld.global.cg.b32 %0, [%1];" + : "=r"(count) + : "l"(red_count_ptr)); + } while (count != blockIdx.z); + } + __syncthreads(); + + int C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4; + if (blockIdx.z != 0) { + // expecting that temporary register here reuses the previous A&B frag + // register + float temp_frag[Mtile / 16][WARP_NITER][4]; + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + int offset = + C_tmp_base_offset + (m_idx * WARP_NITER + n_idx) * BLOCK * 4; + *reinterpret_cast(temp_frag[m_idx][n_idx]) = + *reinterpret_cast(params.C_tmp_ptr + offset); + } + } + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + #pragma unroll + for (int idx = 0; idx < 4; ++idx) { + C_frag[m_idx][n_idx][idx] += temp_frag[m_idx][n_idx][idx]; + } + } + } + } + + // first splitk - 1 blocks need to write partial results into C_tmp + if (blockIdx.z != gridDim.z - 1) { + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + int offset = + C_tmp_base_offset + (m_idx * WARP_NITER + n_idx) * BLOCK * 4; + asm volatile( + "{st.global.cg.v4.b32 [%0], {%1, %2, %3, %4};}\n" + : + : "l"(params.C_tmp_ptr + offset), "f"(C_frag[m_idx][n_idx][0]), + "f"(C_frag[m_idx][n_idx][1]), "f"(C_frag[m_idx][n_idx][2]), + "f"(C_frag[m_idx][n_idx][3])); + } + } + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0) { + uint32_t* red_count_ptr = params.red_count_ptr + blk_red_idx; + atomicInc(red_count_ptr, gridDim.z); + } + } + } + } + + __device__ void stg(char* smem) { + if (EnableFuse) { + if (blockIdx.z != gridDim.z - 1) return; + } + uint32_t* C_sts_ptr = + reinterpret_cast(smem + sts_c_base_offset * sizeof(FType)); + // C_tile sts + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + #pragma unroll + for (int k_idx = 0; k_idx < 2; ++k_idx) { + FType low16 = static_cast(C_frag[m_idx][n_idx][k_idx * 2]); + FType high16 = + static_cast(C_frag[m_idx][n_idx][k_idx * 2 + 1]); + uint32_t tmp = (reinterpret_cast(low16) & 0xffff) | + (reinterpret_cast(high16) << 16); + int sts_offset = + m_idx * 16 * (WARP_NTILE / 2) + + (((lane_id / (32 / WARP_NITER)) + n_idx) % WARP_NITER) * (8 / 2) + + k_idx * 8 * (WARP_NTILE / 2); + C_sts_ptr[sts_offset] = tmp; + } + } + } + + __syncthreads(); + + FType* C_base_ptr = this_block_C_base_ptr + store_c_base_offset; + // C_tile lds and stg + int m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile; + bool n_guard = (store_c_col_idx + blockIdx.y * Ntile) < params.N; + if (WARP_NTILE == 32) { + int lds_c_base_offset = warp_id * Mtile * WARP_NTILE + + (lane_id / 4) * WARP_NTILE + + ((lane_id % 4 + lane_id / 8) % 4) * 8; + uint4* C_lds_ptr = + reinterpret_cast(smem + lds_c_base_offset * sizeof(FType)); + #pragma unroll + for (int i = 0; i < (Mtile / 16) * (WARP_NITER / 2); ++i) { + uint4 stg_reg = C_lds_ptr[i * 8 * 4]; + stg128(stg_reg.x, stg_reg.y, stg_reg.z, stg_reg.w, + C_base_ptr + i * 8 * params.N, + (m_base_idx + i * 8) < params.M && n_guard); + } + } else if (WARP_NTILE == 64) { + int lds_c_base_offset = + warp_id * Mtile * WARP_NTILE + (lane_id / 8) * WARP_NTILE; + #pragma unroll + for (int i = 0; i < (Mtile / 16) * (WARP_NITER / 2); ++i) { + int lds_c_offset = lds_c_base_offset + i * 4 * WARP_NTILE + + ((lane_id % 8 + lane_id / 8 + (i % 2) * 4) % 8) * 8; + uint4 stg_reg = + *reinterpret_cast(smem + lds_c_offset * sizeof(FType)); + stg128(stg_reg.x, stg_reg.y, stg_reg.z, stg_reg.w, + C_base_ptr + i * 4 * params.N, + (m_base_idx + i * 4) < params.M && n_guard); + } + } + } + + const SM8x_GEMM_W8A16_Splitk_Params& params; + + int load_a_base_offset[2]; + int load_b_base_offset[2]; + int sts_c_base_offset; + + int store_c_base_offset; + + int store_c_row_base_idx, store_c_col_idx; + FType* this_block_C_base_ptr = nullptr; + + int params_n_idx; + const uint32_t A_smem_base_addr, BQ_smem_base_addr; + const uint32_t A_smem_stage_stride, BQ_smem_stage_stride; + + int lane_id; + int warp_id; + // first 2 denotes double buffer, second dim denotes M direction + uint32_t A_frag[2][Mtile / 16][4]; + + typename HalfType::T2 B_scale[WARP_NITER / 2]; + typename HalfType::T2 B_zero[WARP_NITER / 2]; + uint32_t BQ_frag[2][WARP_NITER]; + // first 2 denotes double buffer, second dim denotes N direction, last 2 + // denotes K direction + typename HalfType::T2 BF_frag[2][WARP_NITER][2]; + // first dim denotes M direction, second dim denotes N direction + float C_frag[Mtile / 16][WARP_NITER][4]; +}; + +/* + * @brief W8A16 Perchannel Quantization GEMM, + * requires N % 8 == 0, K % 16 == 0 + * accumulator precision: FP32 + * @tparam FType: DataType for A, B_scale, B_zero, and C, supports half or + * nv_bfloat16 + * @tparam QType: DataType for B, support uint8(bias128) + * @tparam Mtile: M-dimensional size of the gemm block tile, supports 16, 32, + * 48 or 64 + * @tparam Ntile: N-dimensional size of the gemm block tile, supports 128 or + * 256 + * @tparam NStage: Num of stages for async copy + * @tparam BLOCK: BLOCK size + * @tparam EnableFuse: If true, use fused splitk-reduce, otherwise use + * non-fused splitk-reduce + * @tparam has_zp: whether to use zero_point + * + * @fparam params struct consists of following parameters: + * @param A_ptr: Matrix A value ptr, A = (M, K) + * @param B_ptr: Matrix B value ptr, B = (N32_align, K) (N32K16 special + * format), N32_align = (N + 32 - 1) / 32 * 32 + * @param B_scale_ptr: B_scale value ptr, B_scale = (N32_align,) (N32K16 + * special format) + * @param B_zero_ptr: B_zero value ptr, B_zero = (N32_align,) (N32K16 + * special format) + * @param C_ptr: Matrix C value ptr, C = (M, N) + * @param M: dimnesion m + * @param N: dimnesion n + * @param K: dimnesion k + * @param SplitK: split size along K-dimension + * @param C_split_ptr: Matrix C_split value ptr, used only in non-fused + * splitk-reduce + * @param C_tmp_ptr: Matrix C_tmp value ptr, used only in fused + * splitk-reduce + * @param red_count_ptr: 1-D red_count value ptr, used only in fused + * splitk-reduce + */ +template +__global__ void __launch_bounds__(BLOCK) + ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_hmma16816_multistage_AN_BTN32K16_CN_splitk_kernel( + const SM8x_GEMM_W8A16_Splitk_Params params) { + // A smem size = 64 * 32 * 2B/elem * 4(stage) = 16KB + // B smem size = 128 * 32 * 1B/elem * 4(stage) = 16KB + constexpr int smem_size_one_stage = Mtile * 32 * 2 + Ntile * 32; + __shared__ char smem[NStage * smem_size_one_stage]; + char* A_smem = smem; + char* BQ_smem = smem + Mtile * 32 * 2 * NStage; + + uint32_t A_smem_addr = smem_u32addr(A_smem); + uint32_t BQ_smem_addr = smem_u32addr(BQ_smem); + uint32_t A_smem_stage_stride = Mtile * 32 * 2; + uint32_t BQ_smem_stage_stride = Ntile * 32; + + // initialize the data move process from GM to SMEM for this block + GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK< + FType, QType, Mtile, Ntile, NStage, BLOCK> + gmem_tile(params, A_smem_addr, BQ_smem_addr, A_smem_stage_stride, + BQ_smem_stage_stride); + + int sts_stage_idx = 0; + int lds_stage_idx = 0; + + int tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K + ? params.SplitK + : params.K - blockIdx.z * params.SplitK; + int k_tiles = (tb_k_slice + 31) / 32; + int first_k_tile = tb_k_slice - (k_tiles - 1) * 32; + + // load first three tiles to shared memory + gmem_tile.ldgsts_first_ktiles(first_k_tile, k_tiles); + sts_stage_idx += (NStage - 2); + ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK< + FType, QType, Mtile, Ntile, BLOCK, EnableFuse, has_zp> + compute_tile(params, A_smem_addr, BQ_smem_addr, A_smem_stage_stride, + BQ_smem_stage_stride); + compute_tile.ldg_params(); + cp_asyc_wait_group(); + __syncthreads(); + + compute_tile.lds(lds_stage_idx, 0, 0); + int reg_buf_idx = 1; + + // main loop + for (; k_tiles > NStage - 1; --k_tiles) { + // load next A&B tile + sts_stage_idx = sts_stage_idx < NStage - 1 ? sts_stage_idx + 1 : 0; + gmem_tile.ldgsts(sts_stage_idx); + + #pragma unroll + for (int k_phase_idx = 0; k_phase_idx < 2; k_phase_idx++) { + // dequantize next B tile + if (k_phase_idx == 1) { + cp_asyc_wait_group(); + __syncthreads(); + lds_stage_idx = lds_stage_idx < NStage - 1 ? lds_stage_idx + 1 : 0; + } + + compute_tile.lds(lds_stage_idx, reg_buf_idx, (k_phase_idx + 1) % 2); + + compute_tile.mma(reg_buf_idx ^ 1); + reg_buf_idx ^= 1; + } + } + + // last NStage-1 tiles + for (; k_tiles > 0; --k_tiles) { + cp_async_commit_group(); + #pragma unroll + for (int k_phase_idx = 0; k_phase_idx < 2; k_phase_idx++) { + // dequantize next B tile + if (k_phase_idx == 1) { + cp_asyc_wait_group(); + __syncthreads(); + lds_stage_idx = lds_stage_idx < NStage - 1 ? lds_stage_idx + 1 : 0; + } + + compute_tile.lds(lds_stage_idx, reg_buf_idx, (k_phase_idx + 1) % 2); + + compute_tile.mma(reg_buf_idx ^ 1); + reg_buf_idx ^= 1; + } + } + + if (EnableFuse) { + compute_tile.fused_splitk_reduce(); + } + compute_tile.stg(smem); +} + + #define __CALL_IF(MTILE, NTILE, NUM_THREADS, ENABLE_FUSE, HAS_ZP) \ + else if (Mtile == MTILE && Ntile == NTILE && BLOCK == NUM_THREADS && \ + enable_fuse == ENABLE_FUSE && has_zp == HAS_ZP) { \ + ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_hmma16816_multistage_AN_BTN32K16_CN_splitk_kernel< \ + FType, QType, MTILE, NTILE, 4, NUM_THREADS, ENABLE_FUSE, HAS_ZP> \ + <<>>(params); \ + } + +template +void ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32K16_CN_splitk( + const FType* A, const QType* B, const FType* B_scale, const FType* B_zero, + FType* C, const int M, const int N, const int K, void* workspace, + const int sm_version, const BlockTileSplitkParams& fused_gemm_params, + cudaStream_t stream) { + int Mtile = fused_gemm_params.Mtile; + int grid_x = (M + Mtile - 1) / Mtile; + int Ntile = fused_gemm_params.Ntile; + int grid_y = (N + Ntile - 1) / Ntile; + int SplitK = fused_gemm_params.SplitK; + int grid_z = (K + SplitK - 1) / SplitK; + + int BLOCK = (Ntile == 256) ? 256 : 128; + + dim3 grid(grid_x, grid_y, grid_z); + dim3 block(BLOCK); + + bool enable_fuse = fused_gemm_params.EnableFuse; + bool has_zp = B_zero != nullptr; + if (enable_fuse) { + float* C_tmp = reinterpret_cast(workspace); + uint32_t* red_count = reinterpret_cast( + (char*)workspace + grid_x * Mtile * grid_y * Ntile * sizeof(float)); + CHECK_CUDA(cudaMemsetAsync(red_count, 0, grid_x * grid_y * sizeof(uint32_t), + stream)); + SM8x_GEMM_W8A16_Splitk_Params params{ + A, B, B_scale, B_zero, C, M, N, + K, SplitK, 0, -1, nullptr, C_tmp, red_count}; + + if (false) { + } + // Select the template parameters for kernel launch + // according to the above settings. Tuning is not supported. + __CALL_IF(16, 256, 256, true, false) + __CALL_IF(32, 256, 256, true, false) + __CALL_IF(48, 256, 256, true, false) + __CALL_IF(64, 128, 128, true, false) + __CALL_IF(64, 256, 256, true, false) + __CALL_IF(16, 256, 256, true, true) + __CALL_IF(32, 256, 256, true, true) + __CALL_IF(48, 256, 256, true, true) + __CALL_IF(64, 128, 128, true, true) + __CALL_IF(64, 256, 256, true, true) + } else { + FType* C_split = reinterpret_cast(workspace); + SM8x_GEMM_W8A16_Splitk_Params params{ + A, B, B_scale, B_zero, C, M, N, + K, SplitK, 0, -1, C_split, nullptr, nullptr}; + + if (false) { + } + // Select the template parameters for kernel launch + // according to the above settings. Tuning is not supported. + __CALL_IF(16, 256, 256, false, false) + __CALL_IF(32, 256, 256, false, false) + __CALL_IF(48, 256, 256, false, false) + __CALL_IF(64, 128, 128, false, false) + __CALL_IF(64, 256, 256, false, false) + __CALL_IF(16, 256, 256, false, true) + __CALL_IF(32, 256, 256, false, true) + __CALL_IF(48, 256, 256, false, true) + __CALL_IF(64, 128, 128, false, true) + __CALL_IF(64, 256, 256, false, true) + + // SplitK reduce + f16_gemm_splitk_reduce(C_split, C, M, N, grid_z, stream); + } +} + +size_t allspark_qgemm_w8a16_perc_n32k16_ampere_workspace_size( + int m, int n, int k, int sm_count, + BlockTileSplitkParams& fused_gemm_params) { + // Determine the block tile and splitk strategy + int m16_times = (m + 16 - 1) / 16; + int Mtile = m16_times <= 4 ? m16_times * 16 : 64; + int grid_x = (m + Mtile - 1) / Mtile; + int Ntile = + (float(grid_x * ((n + 127) / 128)) / sm_count > 10) || (Mtile < 64) ? 256 + : 128; + int grid_y = (n + Ntile - 1) / Ntile; + int grid_z; + + // split-k + const float SPLIT_THRESHOLD = 0.8; + int n_slice; + for (n_slice = 1; n_slice < k / 256; ++n_slice) { + int n_block = grid_x * grid_y * n_slice; + if (n_block >= sm_count * SPLIT_THRESHOLD && + (n_block % sm_count == 0 || n_block % sm_count >= sm_count * 0.5)) { + break; + } + } + + int k_slice = + (k / n_slice) % 32 == 0 ? k / n_slice : k / n_slice / 32 * 32 + 32; + grid_z = (k + k_slice - 1) / k_slice; + bool enable_fuse = float(grid_x * grid_y) / sm_count >= 0.5 ? 1 : 0; + + size_t ws_size; + if (enable_fuse) { + ws_size = grid_x * Mtile * grid_y * Ntile * sizeof(float) // For C_tmp + + grid_x * grid_y * sizeof(uint32_t); // For red_count + } else { + ws_size = grid_z * m * n * sizeof(__half); + } + + fused_gemm_params.Mtile = Mtile; + fused_gemm_params.Ntile = Ntile; + fused_gemm_params.SplitK = k_slice; + fused_gemm_params.EnableFuse = enable_fuse; + return ws_size; +} + +// restore from N32K16 order to original N-major order +// K % 16 == 0, N % 8 == 0 +// each block process 64(k) * 32(n) result elements +template +__global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel( + const QT* qdata, const FT* scales, const FT* zeros, FT* fdata, + const int N_32align, const int N, const int K) { + __shared__ FT smem[64 * 32]; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + const int src_row_idx = blockIdx.x * 8 + lane_id / 4; + const int src_col_idx = + blockIdx.y * 64 * 4 + warp_id * 16 * 4 + (lane_id % 4) * 16; + const int src_offset = src_row_idx * K * 4 + src_col_idx; + int params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4; + + QT qval_reg[16]; + const QT* pdata = qdata + src_offset; + if (src_col_idx < (K * 4)) { + *(reinterpret_cast(qval_reg)) = + *(reinterpret_cast(qdata + src_offset)); + } + FT scale_reg[4]; + *(reinterpret_cast(scale_reg)) = + *(reinterpret_cast(scales + params_nidx)); + FT zero_reg[4] = {0}; + if (zeros != nullptr) { + *(reinterpret_cast(zero_reg)) = + *(reinterpret_cast(zeros + params_nidx)); + } + FT fval_reg[16]; + + const int sts_base_offset = + (warp_id * 16 + (lane_id % 4) * 2) * 32 + lane_id / 4; + #pragma unroll + for (int ni = 0; ni < 4; ++ni) { + cvt_8bx4_to_16bx4_bias128( + *reinterpret_cast(&qval_reg[ni * 4]), + reinterpret_cast::T2*>(&(fval_reg[ni * 4]))); + #pragma unroll + for (int ki = 0; ki < 4; ++ki) { + fval_reg[ni * 4 + ki] = + (fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni]; + int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 + + ((ni + lane_id % 4) % 4) * 8; + smem[sts_offset] = fval_reg[ni * 4 + ki]; + } + } + __syncthreads(); + + const int lds_base_offset = + (threadIdx.x / 4) * 32 + ((threadIdx.x % 4 + threadIdx.x / 8) % 4) * 8; + #pragma unroll + for (int i = 0; i < 2; ++i) { + *reinterpret_cast(fval_reg + i * 8) = + *reinterpret_cast(smem + lds_base_offset + i * 32 * 32); + } + + const int dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4; + const int dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8; + #pragma unroll + for (int i = 0; i < 2; ++i) { + int dst_row_kidx = dst_row_base_kidx + i * 32; + int dst_offset = dst_row_kidx * N + dst_col_nidx; + if (dst_row_kidx < K && dst_col_nidx < N) { + *reinterpret_cast(fdata + dst_offset) = + *reinterpret_cast(fval_reg + i * 8); + } + } +} + +template +void restore_N32_K16_dequantize_rhs_w8a16(const QT* qdata, const FT* scales, + const FT* zeros, FT* fdata, + const int N_32align, const int N, + const int K, const int GroupSize, + cudaStream_t stream) { + TORCH_CHECK(N % 8 == 0 && K % 16 == 0 && N_32align % 32 == 0, + "Unsupported shape"); + if (GroupSize == -1) { + const int BLOCK = 128; + dim3 grid(N_32align / 32, ((K / 16) + 3) / 4); + restore_N32_K16_dequantize_rhs_w8a16_perc_kernel + <<>>(qdata, scales, zeros, fdata, N_32align, N, + K); + } + // TODO: Support SubChannel + else { + TORCH_CHECK(false, "Now only support PerChannel"); + } +} + +template +void w8a16_gemm_dq_cublas(const FT* in, const QT* rhs_qdata_ptr, + const FT* rhs_scales_ptr, const FT* rhs_zeros_ptr, + FT* out, void* workspace, const int M, + const int N_32align, const int N, const int K, + const int group_size, cudaStream_t stream, + cublasHandle_t handle) { + static_assert( + std::is_same::value || std::is_same::value, + "only float16 and bfloat16 is supported"); + // Dequant + FT* rhs_fdata_ptr = static_cast(workspace); + restore_N32_K16_dequantize_rhs_w8a16(rhs_qdata_ptr, rhs_scales_ptr, + rhs_zeros_ptr, rhs_fdata_ptr, N_32align, + N, K, group_size, stream); + // cuBLAS GEMM + int lda = K; + int ldb = N; + int ldc = N; + const float alpha = 1.0f; + const float beta = 0.0f; + cudaDataType_t cuda_type; + if (std::is_same::value) { + cuda_type = CUDA_R_16F; + } else { + cuda_type = CUDA_R_16BF; + } + CHECK_CUBLAS(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, + rhs_fdata_ptr, cuda_type, ldb, in, cuda_type, lda, + &beta, out, cuda_type, ldc, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} + +template +void allspark_qgemm_w8a16_perc_ampere( + const FType* A, const QType* B, const FType* B_scale, const FType* B_zero, + FType* C, const int M, const int N_32align, const int N, const int K, + void* workspace, const BlockTileSplitkParams& fused_gemm_params, + const int group_size, int CUBLAS_M_THRESHOLD, const int sm_version, + cudaStream_t stream, cublasHandle_t handle) { + if (M > CUBLAS_M_THRESHOLD) { + w8a16_gemm_dq_cublas(A, B, B_scale, B_zero, C, workspace, M, + N_32align, N, K, group_size, stream, + handle); + } else { + ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32K16_CN_splitk< + FType, QType>(A, B, B_scale, B_zero, C, M, N, K, workspace, sm_version, + fused_gemm_params, stream); + } +} + +} // namespace allspark + +torch::Tensor allspark_w8a16_gemm( + torch::Tensor const& a, torch::Tensor const& b_qweight, + torch::Tensor const& b_scales, c10::optional const& b_qzeros, + int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, + int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); + TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + if (has_zp) { + TORCH_CHECK(b_qzeros.value().device().is_cuda(), "b_qzeros is not on GPU"); + TORCH_CHECK(b_qzeros.value().is_contiguous(), "b_qzeros is not contiguous"); + } + + int m = a.size(0); + int n_32align = (n + 32 - 1) / 32 * 32; + int k = a.size(1); + + // Verify shape + TORCH_CHECK(b_qweight.size(0) == n_32align, + "Shape mismatch: b_qweight.size(0) = ", b_qweight.size(0), + ", n_32align = ", n_32align); + TORCH_CHECK(b_qweight.size(1) == k, + "Shape mismatch: b_qweight.size(1) = ", b_qweight.size(1), + ", k = ", k); + + TORCH_CHECK(group_size == -1, "Currently only supports group_size = -1"); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + const void* a_ptr = reinterpret_cast(a.data_ptr()); + const uint8_t* b_ptr = reinterpret_cast(b_qweight.data_ptr()); + const void* b_scale_ptr = reinterpret_cast(b_scales.data_ptr()); + const void* b_zero_ptr = nullptr; + if (b_qzeros.has_value()) { + b_zero_ptr = reinterpret_cast(b_qzeros.value().data_ptr()); + } + + auto c_options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({m, n}, c_options); + void* c_ptr = reinterpret_cast(c.data_ptr()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + + allspark::BlockTileSplitkParams fused_gemm_params; + + size_t ws_size = 0; + if (m > CUBLAS_M_THRESHOLD) { + ws_size = k * n * 2; // sizeof(f16)==2 + } else { + ws_size = allspark::allspark_qgemm_w8a16_perc_n32k16_ampere_workspace_size( + m, n, k, sm_count, fused_gemm_params); + } + + auto ws_options = torch::TensorOptions().dtype(at::kChar).device(a.device()); + if (as_g_workspace.numel() < + ws_size) { // ws_options: kChar, so numel() is bytes + as_g_workspace = torch::empty({long(ws_size)}, ws_options); + } + void* ws = reinterpret_cast(as_g_workspace.data_ptr()); + + if (a.dtype() == at::ScalarType::Half) { + allspark::allspark_qgemm_w8a16_perc_ampere<__half, uint8_t>( + reinterpret_cast(a_ptr), b_ptr, + reinterpret_cast(b_scale_ptr), + reinterpret_cast(b_zero_ptr), + reinterpret_cast<__half*>(c_ptr), m, n_32align, n, k, ws, + fused_gemm_params, group_size, CUBLAS_M_THRESHOLD, sm_version, stream, + handle); + } else if (a.dtype() == at::ScalarType::BFloat16) { + allspark::allspark_qgemm_w8a16_perc_ampere<__nv_bfloat16, uint8_t>( + reinterpret_cast(a_ptr), b_ptr, + reinterpret_cast(b_scale_ptr), + reinterpret_cast(b_zero_ptr), + reinterpret_cast<__nv_bfloat16*>(c_ptr), m, n_32align, n, k, ws, + fused_gemm_params, group_size, CUBLAS_M_THRESHOLD, sm_version, stream, + handle); + } + + return c; +} + +#endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm); +} \ No newline at end of file diff --git a/csrc/quantization/gptq_allspark/allspark_repack.cu b/csrc/quantization/gptq_allspark/allspark_repack.cu new file mode 100644 index 000000000000..82929c94ad8b --- /dev/null +++ b/csrc/quantization/gptq_allspark/allspark_repack.cu @@ -0,0 +1,163 @@ +#include "allspark_utils.cuh" +#include +#include "core/registration.h" + +namespace allspark { + +// Rearrange B to facilitate Ampere Tensor Core load data +// reorder B from (K, N) to (N_32align / 4, K * 4) +// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0 +template +__global__ void __launch_bounds__(128) + rearrange_kn_weight_as_n32k16_order_ldg16_kernel( + const uint8_t* B, const FType* B_scale, const FType* B_zero, + uint8_t* B_result, FType* B_scale_result, FType* B_zero_result, + const int K, const int N, const int N_32align) { + const int lane_id = threadIdx.x % 32; + const int warp_id = threadIdx.x / 32; + + if (blockIdx.x != gridDim.x - 1) { + // Load B + // per block process 64(k) * 128(n) B elements + // per warp process 16(k) * 128 B elements + const int src_row_base_idx = + blockIdx.x * 64 + warp_id * 16 + ((lane_id % 8) / 2) * 2; + const int src_col_idx = + blockIdx.y * 128 + (lane_id / 8) * 32 + (lane_id % 2) * 16; + uint8_t B_frag[4][16]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + int src_row_idx = src_row_base_idx + (i / 2) * 8 + (i % 2); + int src_offset = src_row_idx * N + src_col_idx; + bool guard = src_row_idx < K && src_col_idx < N; + ldg128_cg_0(*reinterpret_cast(B_frag[i]), + *(reinterpret_cast(B_frag[i]) + 1), + *(reinterpret_cast(B_frag[i]) + 2), + *(reinterpret_cast(B_frag[i]) + 3), B + src_offset, + guard); + } + + // reorder B + uint8_t B_reorder_frag[8][8]; +#pragma unroll + for (int i = 0; i < 4; ++i) { +#pragma unroll + for (int j = 0; j < 16; ++j) { + int dst_i = j % 8; + int dst_j = i + (j / 8) * 4; + B_reorder_frag[dst_i][dst_j] = B_frag[i][j]; + } + } + + // Store B + const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8; + const int dst_col_idx = + blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8; + for (int i = 0; i < 8; ++i) { + int dst_row_idx = dst_row_base_idx + i; + int dst_offset = dst_row_idx * K * 4 + dst_col_idx; + bool guard = (dst_row_base_idx < N_32align / 4) && (dst_col_idx < K * 4); + if (guard) { + *reinterpret_cast(B_result + dst_offset) = + *reinterpret_cast(B_reorder_frag[i]); + } + } + } else { + // Load B_scale and B_zero + FType b_scale_reg, b_zero_reg; + int src_offset = blockIdx.y * 128 + threadIdx.x; + ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N); + if (B_zero != nullptr) + ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N); + int dst_offset = + blockIdx.y * 128 + warp_id * 32 + (lane_id % 8) * 4 + lane_id / 8; + if (dst_offset < N_32align) { + B_scale_result[dst_offset] = b_scale_reg; + if (B_zero != nullptr) B_zero_result[dst_offset] = b_zero_reg; + } + } +} + +template +void rearrange_kn_weight_as_n32k16_order_ldg16( + const uint8_t* B, const FType* B_scale, const FType* B_zero, + uint8_t* B_result, FType* B_scale_result, FType* B_zero_result, + const int64_t K, const int64_t N, const int64_t N_32align, + cudaStream_t stream) { + if (N % 16 != 0 || K % 16 != 0) { + std::cerr << "Now only support N and K is multiples of 16" << std::endl; + } + const int BLOCK = 128; + int grid_x = (K + 64 - 1) / 64 + 1; + int grid_y = (N + 128 - 1) / 128; + dim3 grid(grid_x, grid_y); + + rearrange_kn_weight_as_n32k16_order_ldg16_kernel + <<>>(B, B_scale, B_zero, B_result, B_scale_result, + B_zero_result, K, N, N_32align); +} +} // namespace allspark + +void rearrange_kn_weight_as_n32k16_order( + torch::Tensor const& b_qweight, torch::Tensor const& b_scales, + c10::optional const& b_zeros, bool has_zp, + torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder, + c10::optional const& b_zeros_reorder, const int64_t K, + const int64_t N, const int64_t N_32align) { + // Verify device and strides + TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); + TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + TORCH_CHECK(b_qweight_reorder.device().is_cuda(), + "b_qweight_reorder is not on GPU"); + TORCH_CHECK(b_qweight_reorder.is_contiguous(), + "b_qweight_reorder is not contiguous"); + + TORCH_CHECK(b_scales_reorder.device().is_cuda(), + "b_scales_reorder is not on GPU"); + TORCH_CHECK(b_scales_reorder.is_contiguous(), + "b_scales_reorder is not contiguous"); + + if (has_zp) { + TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous"); + + TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(), + "b_zeros_reorder is not on GPU"); + TORCH_CHECK(b_zeros_reorder.value().is_contiguous(), + "b_zeros_reorder is not contiguous"); + } + + const uint8_t* matB = reinterpret_cast(b_qweight.data_ptr()); + const void* b_scale = b_scales.data_ptr(); + const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr; + + uint8_t* matB_reorder = + reinterpret_cast(b_qweight_reorder.data_ptr()); + void* b_scale_reorder = b_scales_reorder.data_ptr(); + void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (b_scales.dtype() == at::ScalarType::Half) { + allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>( + matB, reinterpret_cast(b_scale), + reinterpret_cast(b_zero), matB_reorder, + reinterpret_cast<__half*>(b_scale_reorder), + reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream); + } else if (b_scales.dtype() == at::ScalarType::BFloat16) { + allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>( + matB, reinterpret_cast(b_scale), + reinterpret_cast(b_zero), matB_reorder, + reinterpret_cast<__nv_bfloat16*>(b_scale_reorder), + reinterpret_cast<__nv_bfloat16*>(b_zero_reorder), K, N, N_32align, + stream); + } +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("rearrange_kn_weight_as_n32k16_order", + &rearrange_kn_weight_as_n32k16_order); +} diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/quantization/gptq_allspark/allspark_utils.cuh new file mode 100644 index 000000000000..7aded9a17280 --- /dev/null +++ b/csrc/quantization/gptq_allspark/allspark_utils.cuh @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace allspark { + +#define CHECK_CUDA(cmd) \ + do { \ + cudaError_t cuda_status = cmd; \ + if (cuda_status != cudaSuccess) { \ + std::string err_str = cudaGetErrorString(cuda_status); \ + std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \ + << err_str; \ + exit(-1); \ + } \ + } while (0) + +#define CHECK_CUBLAS(cmd) \ + do { \ + cublasStatus_t cublas_status = cmd; \ + if (cublas_status != CUBLAS_STATUS_SUCCESS) { \ + std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \ + << cublas_status << std::endl; \ + exit(-1); \ + } \ + } while (0) + +template +struct SM8x_GEMM_W8A16_Splitk_Params { + const FType* A_ptr; + const QType* B_ptr; + const FType* B_scale_ptr; + const FType* B_zero_ptr; + FType* C_ptr; + int M; + int N; + int K; + int SplitK; + int GroupCnt; + int GroupSize; + FType* C_split_ptr; // for non-fused splitk reduce + float* C_tmp_ptr; // for fused splitk reduce + uint32_t* red_count_ptr; // for fused splitk reduce +}; + +struct alignas(16) BlockTileSplitkParams { + int Mtile; + int Ntile; + int SplitK; + bool EnableFuse; +}; + +template +__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C, + uint32_t n, uint32_t n_matrix, + uint32_t matrix_size) { + int idx = blockIdx.x * BLOCK + threadIdx.x; + + if (idx >= matrix_size) { + return; + } + + FType sum(0); + + int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; + for (int i = 0; i < n_mat; ++i) { + sum += C_split[idx + i * matrix_size]; + } + + C[idx] = sum; +} + +template +void f16_gemm_splitk_reduce(const FType* C_split, FType* C, const uint32_t m, + const uint32_t n, const uint32_t n_matrix, + cudaStream_t stream) { + const int BLOCK = 128; + uint32_t matrix_size = m * n; + int grid = (matrix_size + BLOCK - 1) / BLOCK; + + void (*kernel)(const FType*, FType*, uint32_t, uint32_t, uint32_t) = nullptr; + + switch (n_matrix) { + case 4: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 5: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 6: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 7: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 8: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 9: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 10: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 11: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 12: + kernel = f16_gemm_splitk_reduce_kernel; + break; + default: + kernel = f16_gemm_splitk_reduce_kernel; + break; + } + + kernel<<>>(C_split, C, n, n_matrix, matrix_size); +} + +template +struct HalfType; +template <> +struct HalfType { + using T1 = __half; + using T2 = __half2; +}; +template <> +struct HalfType<__nv_bfloat16> { + using T1 = __nv_bfloat16; + using T2 = __nv_bfloat162; +}; + +// convert 64-bit pointer to 32-bit smem addr +__device__ __forceinline__ uint32_t smem_u32addr(const void* smem_ptr) { + uint32_t addr; + asm("{.reg .u64 u64addr;\n" + " cvta.to.shared.u64 u64addr, %1;\n" + " cvt.u32.u64 %0, u64addr;}\n" + : "=r"(addr) + : "l"(smem_ptr)); + + return addr; +} + +template +__device__ __forceinline__ void ldg16_cg_0(T& r0, const void* ptr, bool guard) { + static_assert(sizeof(T) == 2, "ldg16_cg_0: invalid T"); + + asm volatile( + "{.reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " @!p mov.b16 %0, 0;\n" +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \ + __CUDA_ARCH__ >= 750 + " @p ld.global.cg.L2::128B.b16 {%0}, [%1];}\n" +#else + " @p ld.global.ca.b16 {%0}, [%1];}\n" +#endif + : "=h"(reinterpret_cast(r0)) + : "l"(ptr), "r"((int)guard)); +} + +template +__device__ __forceinline__ void ldg64_ca(T& r0, T& r1, const void* ptr, + bool guard) { + static_assert(sizeof(T) == 4, "ldg64_ca: invalid T"); + + asm volatile( + "{.reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \ + __CUDA_ARCH__ >= 750 + " @p ld.global.ca.L2::128B.v2.b32 {%0, %1}, [%2];}\n" +#else + " @p ld.global.ca.v2.b32 {%0, %1}, [%2];}\n" +#endif + : "=r"(reinterpret_cast(r0)), + "=r"(reinterpret_cast(r1)) + : "l"(ptr), "r"((int)guard)); +} + +template +__device__ __forceinline__ void ldg128_cg_0(T& r0, T& r1, T& r2, T& r3, + const void* ptr, bool guard) { + static_assert(sizeof(T) == 4, "ldg128_cg_0: invalid T"); + + asm volatile( + "{.reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " @!p mov.b32 %0, 0;\n" + " @!p mov.b32 %1, 0;\n" + " @!p mov.b32 %2, 0;\n" + " @!p mov.b32 %3, 0;\n" +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \ + __CUDA_ARCH__ >= 750 + " @p ld.global.cg.L2::128B.v4.b32 {%0, %1, %2, %3}, [%4];}\n" +#else + " @p ld.global.cg.v4.b32 {%0, %1, %2, %3}, [%4];}\n" +#endif + : "=r"(reinterpret_cast(r0)), + "=r"(reinterpret_cast(r1)), + "=r"(reinterpret_cast(r2)), + "=r"(reinterpret_cast(r3)) + : "l"(ptr), "r"((int)guard)); +} + +template +__device__ __forceinline__ void lds128(T& reg0, T& reg1, T& reg2, T& reg3, + const uint32_t addr) { + static_assert(sizeof(T) == 4, "lds128: invalid T"); + + asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(reinterpret_cast(reg0)), + "=r"(reinterpret_cast(reg1)), + "=r"(reinterpret_cast(reg2)), + "=r"(reinterpret_cast(reg3)) + : "r"(addr)); +} + +template +__device__ __forceinline__ void stg128(const T& r0, const T& r1, const T& r2, + const T& r3, const void* ptr, + bool guard) { + static_assert(sizeof(T) == 4, "stg128: invalid T"); + + asm volatile( + "{.reg .pred p;\n" + " setp.ne.b32 p, %1, 0;\n" + " @p st.global.v4.b32 [%0], {%2, %3, %4, %5};}\n" + : + : "l"(ptr), "r"((int)guard), "r"(reinterpret_cast(r0)), + "r"(reinterpret_cast(r1)), + "r"(reinterpret_cast(r2)), + "r"(reinterpret_cast(r3))); +} + +template +__device__ __forceinline__ void ldsm_4(T& r0, T& r1, T& r2, T& r3, + const uint32_t& addr) { + static_assert(sizeof(T) == 4, "ldsm_4: invalid T"); +#if (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 11) + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(reinterpret_cast(r0)), + "=r"(reinterpret_cast(r1)), + "=r"(reinterpret_cast(r2)), + "=r"(reinterpret_cast(r3)) + : "r"(addr)); +#endif +} + +template +__device__ __forceinline__ void hmma16816_f32(float (&d)[4], + const uint32_t (&a)[4], + const uint32_t (&b)[2]); + +template <> +__device__ __forceinline__ void hmma16816_f32<__half>(float (&d)[4], + const uint32_t (&a)[4], + const uint32_t (&b)[2]) { +#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n" + : "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); +#endif +} + +template <> +__device__ __forceinline__ void hmma16816_f32<__nv_bfloat16>( + float (&d)[4], const uint32_t (&a)[4], const uint32_t (&b)[2]) { +#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n" + : "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); +#endif +} + +template +__device__ __forceinline__ void cp_async(const uint32_t smem_addr, + const void* gmem_ptr, + const int src_in_bytes, bool guard) { + static_assert( + (SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16), + "Size is not supported"); +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile( + "{.reg.pred p;\n" + " setp.ne.b32 p, %4, 0;\n" + #if __CUDACC_VER_MINOR__ >= 4 + " @p cp.async.cg.shared.global.L2::256B [%0], [%1], %2, %3;}\n" + #else + " @p cp.async.cg.shared.global [%0], [%1], %2, %3;}\n" + #endif + ::"r"(smem_addr), + "l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard)); +#endif +} + +template +__device__ __forceinline__ void cp_async_ca(const uint32_t smem_addr, + const void* gmem_ptr, + const int src_in_bytes, + bool guard) { + static_assert( + (SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16), + "Size is not supported"); +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile( + "{.reg.pred p;\n" + " setp.ne.b32 p, %4, 0;\n" + #if __CUDACC_VER_MINOR__ >= 4 + " @p cp.async.ca.shared.global.L2::256B [%0], [%1], %2, %3;}\n" + #else + " @p cp.async.ca.shared.global [%0], [%1], %2, %3;}\n" + #endif + ::"r"(smem_addr), + "l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard)); +#endif +} + +__device__ __forceinline__ void cp_async_commit_group() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n"); +#endif +} + +template +__device__ __forceinline__ void cp_asyc_wait_group() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" : : "n"(N)); +#endif +} + +template +__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128(const uint32_t& idata, + T* fdata); + +template <> +// fast conversion: 4xuint8 to 4xhalf, subtracting bias = 128 +__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__half2>( + const uint32_t& idata, __half2* fdata) { + uint32_t i10, i32; + asm volatile( + "prmt.b32 %0, %2, 0x64, 0x4140;" + "prmt.b32 %1, %2, 0x64, 0x4342;" + : "=r"(i10), "=r"(i32) + : "r"(idata)); + + static constexpr uint32_t MAGIC_NUM = 0x64806480; + fdata[0] = __hsub2(reinterpret_cast(i10), + reinterpret_cast(MAGIC_NUM)); + fdata[1] = __hsub2(reinterpret_cast(i32), + reinterpret_cast(MAGIC_NUM)); +} + +template <> +// fast conversion: 4xuint8 to 4xbfloat16, subtracting bias = 128 +// reference from marlin fast implementation +__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__nv_bfloat162>( + const uint32_t& idata, __nv_bfloat162* fdata) { + float fp32_imd[4]; + uint32_t* fp32_imd_casted = reinterpret_cast(fp32_imd); + asm volatile( + "prmt.b32 %0, %4, 0x4B000000, 0x7650;" + "prmt.b32 %1, %4, 0x4B000000, 0x7651;" + "prmt.b32 %2, %4, 0x4B000000, 0x7652;" + "prmt.b32 %3, %4, 0x4B000000, 0x7653;" + : "=r"(fp32_imd_casted[0]), "=r"(fp32_imd_casted[1]), + "=r"(fp32_imd_casted[2]), "=r"(fp32_imd_casted[3]) + : "r"(idata)); + + fp32_imd[0] -= 8388736.f; + fp32_imd[1] -= 8388736.f; + fp32_imd[2] -= 8388736.f; + fp32_imd[3] -= 8388736.f; + + uint32_t* bf16_res = reinterpret_cast(fdata); + asm volatile( + "prmt.b32 %0, %2, %3, 0x7632;" + "prmt.b32 %1, %4, %5, 0x7632;" + : "=r"(bf16_res[0]), "=r"(bf16_res[1]) + : "r"(fp32_imd_casted[0]), "r"(fp32_imd_casted[1]), + "r"(fp32_imd_casted[2]), "r"(fp32_imd_casted[3])); +} + +static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat162bfloat162(x); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); +} + +} // namespace allspark \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 72de2035d0c1..0b0334f84efe 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); + +#ifndef USE_ROCM + // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel + ops.def( + "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " + "Tensor? b_zeros, " + "bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, " + "Tensor!? b_zeros_reorder, " + "int K, int N, int N_32align) -> ()"); + // conditionally compiled so impl in source file + + // AllSpark quantization ops + ops.def( + "allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, " + "Tensor? b_qzeros, " + "SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt " + "CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"); + // conditionally compiled so impl in source file +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/tests/kernels/test_allspark_gemm.py b/tests/kernels/test_allspark_gemm.py new file mode 100644 index 000000000000..896e0265738b --- /dev/null +++ b/tests/kernels/test_allspark_gemm.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_AMPERE_N_ALIGN) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + + +def is_gptq_allspark_supported(min_capability: int, + max_capability: int) -> bool: + if not current_platform.is_cuda(): + return False + + capability = current_platform.get_device_capability() + assert capability is not None + + return capability.to_int() >= min_capability \ + and capability.to_int() <= max_capability + + +MNK_FACTORS = [ + (1, 4, 8), + (13, 17, 67), + (26, 37, 13), + (48, 16, 24), + (67, 13, 88), + (257, 13, 11), + (658, 13, 11), + (1033, 9, 17), +] + +DTYPES = [torch.float16, torch.bfloat16] +HAS_ZP_OPTS = [False, True] + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +def rand_data(shape, dtype=torch.float16): + return torch.randn(shape, dtype=dtype, device="cuda") + + +@pytest.mark.skipif( + not is_gptq_allspark_supported(80, 89), + reason="AllSpark Ampere kernel is not supported on this GPU type.") +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("group_size", [-1]) +@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): + m_factor, n_factor, k_factor = mnk_factors + m = m_factor + n = n_factor * ALLSPARK_AMPERE_N_ALIGN + k = k_factor * ALLSPARK_AMPERE_K_ALIGN + + input = rand_data((m, k), dtype=dtype) + weight = rand_data((k, n), dtype=dtype) + + # Quantize (and apply act_order if provided) + w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128, + group_size, has_zp) + + qw = qw.to(torch.uint8) + if has_zp: + zp = zp.to(dtype) + properties = torch.cuda.get_device_properties(qw.device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + + n_32align = (n + 32 - 1) // 32 * 32 + + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( + qw, s, zp, has_zp) + opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order, + (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, + n_32align)) + + opcheck(torch.ops._C.allspark_w8a16_gemm, + (input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count, + sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True), + test_utils=DEFAULT_OPCHECK_TEST_UTILS) + output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder, + n, group_size, sm_count, sm_version, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp, True) + + output_ref = torch.matmul(input, w_ref) + torch.cuda.synchronize() + max_diff = compute_max_diff(output, output_ref) + + assert max_diff < 0.04 diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index c187b4c7ed99..b9b2b634e0bb 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -215,8 +215,6 @@ def check_model(model): assert qkv_proj.scheme.group_size == (-1 if group is None else group) - assert qkv_proj.weight_packed.dtype is torch.int32 - assert qkv_proj.weight_scale.dtype is torch.float16 assert qkv_proj.scheme.pack_factor == pack_factor llm.apply_model(check_model) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0e83bcaead94..373f92a52a19 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -404,6 +404,22 @@ def machete_prepack_B_fake( memory_format=torch.contiguous_format) +if hasattr(torch.ops._C, "allspark_w8a16_gemm"): + + @register_fake("_C::allspark_w8a16_gemm") + def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + n: torch.SymInt, group_size: torch.SymInt, + sm_count: torch.SymInt, + sm_version: torch.SymInt, + CUBLAS_M_THRESHOLD: torch.SymInt, + has_zp: bool, + n32k16_reorder: bool) -> torch.Tensor: + m = a.size(0) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + if hasattr(torch.ops._C, "ggml_dequantize"): @register_fake("_C::ggml_dequantize") @@ -881,6 +897,67 @@ def scaled_fp8_quant( return output, scale +# gptq allspark +def allspark_repack_weight( + qweight: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, + has_zp: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format + for Ampere W8A16 Fused Gemm kernel + + Args: + qweight: uint8 weight tensor, original k x n format. + scale: fp16/bf16 weight scale tensor, 1 x n format. + zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. + Must be provided for asymmetric quantization. + has_zp: if use symmetric quantization, has_zp = False. + if use asymmetric quantization, has_zp = True. + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : + rearranged weight, scale, and optionally zero_point. + """ + K = qweight.shape[0] + N = qweight.shape[1] + N_32align = (N + 32 - 1) // 32 * 32 + + qweight_reorder = torch.empty((N_32align, K), + device=qweight.device, + dtype=qweight.dtype) + scale_reorder = torch.empty((1, N_32align), + device=scale.device, + dtype=scale.dtype) + zero_point_reorder = None + if has_zp: + assert zero_point is not None, ( + "zero_point must be provided for asymmetric quantization.") + zero_point_reorder = torch.empty((1, N_32align), + device=zero_point.device, + dtype=zero_point.dtype) + + torch.ops._C.rearrange_kn_weight_as_n32k16_order( + qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder, + zero_point_reorder, K, N, N_32align) + + return qweight_reorder, scale_reorder, zero_point_reorder + + +def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], n: int, + group_size: int, sm_count: int, sm_version: int, + CUBLAS_M_THRESHOLD: int, has_zp: bool, + n32k16_reorder: bool) -> torch.Tensor: + + return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros, + n, group_size, sm_count, + sm_version, CUBLAS_M_THRESHOLD, + has_zp, n32k16_reorder) + + # int8 def scaled_int8_quant( input: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index bcfdb1677716..520e1bc96721 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -3,6 +3,8 @@ from typing import List, Optional, Type import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 + AllSparkLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 @@ -16,6 +18,7 @@ # in priority/performance order (when available) _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, + AllSparkLinearKernel, MarlinLinearKernel, ExllamaLinearKernel, ] diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py new file mode 100644 index 000000000000..56fdd6a18e0d --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class AllSparkLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx: + return False, "Act reordering currently not supported by AllSpark" + + if c.zero_points: + return False, "Zero points currently not supported by AllSpark" + + return check_allspark_supported_dtype_shape( + c.partition_weight_shape[0], # in_features + c.partition_weight_shape[1], # out_features + c.group_size, + c.weight_type, + c.act_type) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + # prepare the parameters required for the kernel + properties = torch.cuda.get_device_properties(device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + gemm_args = {} + gemm_args['sm_count'] = sm_count + gemm_args['sm_version'] = sm_version + + self.gemm_args = gemm_args + + # transform param weight, scale + old_weight_param = getattr(layer, self.w_q_name) + old_scale_param = getattr(layer, self.w_s_name) + + assert isinstance(old_weight_param, BasevLLMParameter) + permute_param_layout_(old_weight_param, + input_dim=0, + output_dim=1, + packed_dim=0) + + assert isinstance(old_scale_param, BasevLLMParameter) + permute_param_layout_(old_scale_param, input_dim=0, output_dim=1) + + # unpack weight from K / 4 x N int32 to K x N uint8 + new_weight_param = torch.nn.Parameter(old_weight_param.data, + requires_grad=False) + new_weight_param.data = new_weight_param.data.t().contiguous().view( + dtype=torch.uint8) + new_weight_param.data = new_weight_param.data.t().contiguous() + + new_scale_param = torch.nn.Parameter(old_scale_param.data, + requires_grad=False) + + # reorder K x N weight as N32K16 format for Ampere W8A16 + new_weight_param.data, new_scale_param.data, _ = \ + ops.allspark_repack_weight( + new_weight_param.data, new_scale_param.data, None, + c.zero_points) + + replace_parameter(layer, self.w_q_name, new_weight_param.data) + replace_parameter(layer, self.w_s_name, new_scale_param.data) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + gemm_args = self.gemm_args + w_q, w_s, _, _ = self._get_weight_params(layer) + + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + output = ops.allspark_w8a16_gemm( + a=reshaped_x, + b_qweight=w_q, + b_scales=w_s, + b_qzeros=None, + n=c.partition_weight_shape[1], + group_size=c.group_size, + sm_count=gemm_args['sm_count'], + sm_version=gemm_args['sm_version'], + CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp=c.zero_points, + n32k16_reorder=True) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/allspark_utils.py b/vllm/model_executor/layers/quantization/utils/allspark_utils.py new file mode 100644 index 000000000000..97860765a9e1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/allspark_utils.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024 +ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128] +ALLSPARK_AMPERE_N_ALIGN = 16 +ALLSPARK_AMPERE_K_ALIGN = 16 + + +def check_allspark_supported_dtype_shape(input_size_per_partition: int, + output_size_per_partition: int, + group_size: int, + weight_dtype: ScalarType, + act_dtype: torch.dtype): + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + # For Ampere GPU + if device_capability >= 80 and device_capability < 90: + if group_size != -1: + return False, \ + "For Ampere GPU, AllSpark does not support group_size "\ + f"= {group_size}. Only group_size = -1 are supported." + + if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES: + return False, "For Ampere GPU, AllSpark does not support "\ + f"quant type ({weight_dtype}). Only quant type "\ + f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported." + + if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \ + or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0: + return False, \ + "AllSpark needs input_size_per_partition % "\ + f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\ + f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\ + "for Ampere GPU optimized kernels." + + if act_dtype != torch.float16 and act_dtype != torch.bfloat16: + return False, \ + "AllSpark only supports act_dtype = float16 or bfloat16,"\ + f"for Ampere GPU, but got act_dtype = {act_dtype}." + else: + return False, "AllSpark currently does not support "\ + f"device_capability = {device_capability}." + + return True, None From 39a2024dc13f6cf05a1e152f4e805ec9c35c9608 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Sat, 1 Mar 2025 14:31:01 +0800 Subject: [PATCH 301/317] [Bugfix][V1][Minor] Fix shutting_down flag checking in V1 MultiprocExecutor (#14053) --- vllm/v1/executor/multiproc_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index d4582122fa6d..25b5c1c1c2fc 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -170,7 +170,7 @@ def _cleanup_sockets(self): def shutdown(self): """Properly shut down the executor and its workers""" - if getattr(self, 'shutting_down', False): + if not getattr(self, 'shutting_down', False): self.shutting_down = True for w in self.workers: w.worker_response_mq = None From 0daae74245ae01c6b604c8be297d8a11faaa7bbc Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 28 Feb 2025 22:44:24 -0800 Subject: [PATCH 302/317] [Documentation] Add more deployment guide for Kubernetes deployment (#13841) Signed-off-by: KuntaiDu Signed-off-by: Kuntai Du --- docs/source/deployment/integrations/index.md | 1 + .../integrations/production-stack.md | 154 ++++++++++++++++++ docs/source/deployment/k8s.md | 18 +- 3 files changed, 166 insertions(+), 7 deletions(-) create mode 100644 docs/source/deployment/integrations/production-stack.md diff --git a/docs/source/deployment/integrations/index.md b/docs/source/deployment/integrations/index.md index a557456c086d..410742b88c73 100644 --- a/docs/source/deployment/integrations/index.md +++ b/docs/source/deployment/integrations/index.md @@ -7,4 +7,5 @@ kserve kubeai llamastack llmaz +production-stack ::: diff --git a/docs/source/deployment/integrations/production-stack.md b/docs/source/deployment/integrations/production-stack.md new file mode 100644 index 000000000000..e66e8e6a16b2 --- /dev/null +++ b/docs/source/deployment/integrations/production-stack.md @@ -0,0 +1,154 @@ +(deployment-production-stack)= + +# Production stack + +Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine learning models. This guide walks you through deploying vLLM using the [vLLM production stack](https://github.com/vllm-project/production-stack). Born out of a Berkeley-UChicago collaboration, [vLLM production stack](https://github.com/vllm-project/production-stack) is an officially released, production-optimized codebase under the [vLLM project](https://github.com/vllm-project), designed for LLM deployment with: + +* **Upstream vLLM compatibility** – It wraps around upstream vLLM without modifying its code. +* **Ease of use** – Simplified deployment via Helm charts and observability through Grafana dashboards. +* **High performance** – Optimized for LLM workloads with features like multi-model support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others. + +If you are new to Kubernetes, don't worry: in the vLLM production stack [repo](https://github.com/vllm-project/production-stack), we provide a step-by-step [guide](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) and a [short video](https://www.youtube.com/watch?v=EsTJbQtzj0g) to set up everything and get started in **4 minutes**! + +## Pre-requisite + +Ensure that you have a running Kubernetes environment with GPU (you can follow [this tutorial](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) to install a Kubernetes environment on a bare-medal GPU machine). + +## Deployment using vLLM production stack + +The standard vLLM production stack install uses a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/tutorials/install-helm.sh) to install Helm on your GPU server. + +To install the vLLM production stack, run the following commands on your desktop: + +```bash +sudo helm repo add vllm https://vllm-project.github.io/production-stack +sudo helm install vllm vllm/vllm-stack -f tutorials/assets/values-01-minimal-example.yaml +``` + +This will instantiate a vLLM-production-stack-based deployment named `vllm` that runs a small LLM (Facebook opt-125M model). + +### Validate Installation + +Monitor the deployment status using: + +```bash +sudo kubectl get pods +``` + +And you will see that pods for the `vllm` deployment will transit to `Running` state. + +```text +NAME READY STATUS RESTARTS AGE +vllm-deployment-router-859d8fb668-2x2b7 1/1 Running 0 2m38s +vllm-opt125m-deployment-vllm-84dfc9bd7-vb9bs 1/1 Running 0 2m38s +``` + +**NOTE**: It may take some time for the containers to download the Docker images and LLM weights. + +### Send a Query to the Stack + +Forward the `vllm-router-service` port to the host machine: + +```bash +sudo kubectl port-forward svc/vllm-router-service 30080:80 +``` + +And then you can send out a query to the OpenAI-compatible API to check the available models: + +```bash +curl -o- http://localhost:30080/models +``` + +Expected output: + +```json +{ + "object": "list", + "data": [ + { + "id": "facebook/opt-125m", + "object": "model", + "created": 1737428424, + "owned_by": "vllm", + "root": null + } + ] +} +``` + +To send an actual chatting request, you can issue a curl request to the OpenAI `/completion` endpoint: + +```bash +curl -X POST http://localhost:30080/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "facebook/opt-125m", + "prompt": "Once upon a time,", + "max_tokens": 10 + }' +``` + +Expected output: + +```json +{ + "id": "completion-id", + "object": "text_completion", + "created": 1737428424, + "model": "facebook/opt-125m", + "choices": [ + { + "text": " there was a brave knight who...", + "index": 0, + "finish_reason": "length" + } + ] +} +``` + +### Uninstall + +To remove the deployment, run: + +```bash +sudo helm uninstall vllm +``` + +------ + +### (Advanced) Configuring vLLM production stack + +The core vLLM production stack configuration is managed with YAML. Here is the example configuration used in the installation above: + +```yaml +servingEngineSpec: + runtimeClassName: "" + modelSpec: + - name: "opt125m" + repository: "vllm/vllm-openai" + tag: "latest" + modelURL: "facebook/opt-125m" + + replicaCount: 1 + + requestCPU: 6 + requestMemory: "16Gi" + requestGPU: 1 + + pvcStorage: "10Gi" +``` + +In this YAML configuration: +* **`modelSpec`** includes: + * `name`: A nickname that you prefer to call the model. + * `repository`: Docker repository of vLLM. + * `tag`: Docker image tag. + * `modelURL`: The LLM model that you want to use. +* **`replicaCount`**: Number of replicas. +* **`requestCPU` and `requestMemory`**: Specifies the CPU and memory resource requests for the pod. +* **`requestGPU`**: Specifies the number of GPUs required. +* **`pvcStorage`**: Allocates persistent storage for the model. + +**NOTE:** If you intend to set up two pods, please refer to this [YAML file](https://github.com/vllm-project/production-stack/blob/main/tutorials/assets/values-01-2pods-minimal-example.yaml). + +**NOTE:** vLLM production stack offers many more features (*e.g.* CPU offloading and a wide range of routing algorithms). Please check out these [examples and tutorials](https://github.com/vllm-project/production-stack/tree/main/tutorials) and our [repo](https://github.com/vllm-project/production-stack) for more details! diff --git a/docs/source/deployment/k8s.md b/docs/source/deployment/k8s.md index cbc95c20ff4b..64071ba042d0 100644 --- a/docs/source/deployment/k8s.md +++ b/docs/source/deployment/k8s.md @@ -2,17 +2,21 @@ # Using Kubernetes -Using Kubernetes to deploy vLLM is a scalable and efficient way to serve machine learning models. This guide will walk you through the process of deploying vLLM with Kubernetes, including the necessary prerequisites, steps for deployment, and testing. +Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine learning models. This guide walks you through deploying vLLM using native Kubernetes. -## Prerequisites +-------- -Before you begin, ensure that you have the following: +Alternatively, you can also deploy Kubernetes using [helm chart](https://docs.vllm.ai/en/latest/deployment/frameworks/helm.html). There are also open-source projects available to make your deployment even smoother. -- A running Kubernetes cluster -- NVIDIA Kubernetes Device Plugin (`k8s-device-plugin`): This can be found at `https://github.com/NVIDIA/k8s-device-plugin/` -- Available GPU resources in your cluster +* [vLLM production-stack](https://github.com/vllm-project/production-stack): Born out of a Berkeley-UChicago collaboration, vLLM production stack is a project that contains latest research and community effort, while still delivering production-level stability and performance. Checkout the [documentation page](https://docs.vllm.ai/en/latest/deployment/integrations/production-stack.html) for more details and examples. -## Deployment Steps +-------- + +## Pre-requisite + +Ensure that you have a running Kubernetes environment with GPU (you can follow [this tutorial](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) to install a Kubernetes environment on a bare-medal GPU machine). + +## Deployment using native K8s 1. Create a PVC, Secret and Deployment for vLLM From b469f95d0c2f3e4afae3e12ec75db66c9e4f868b Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 1 Mar 2025 14:49:15 +0800 Subject: [PATCH 303/317] [Doc] Consolidate `whisper` and `florence2` examples (#14050) --- examples/offline_inference/audio_language.py | 82 +++++---- .../encoder_decoder_multimodal.py | 158 ++++++++++++++++++ .../offline_inference/florence2_inference.py | 53 ------ examples/offline_inference/whisper.py | 61 ------- vllm/model_executor/models/whisper.py | 4 +- 5 files changed, 210 insertions(+), 148 deletions(-) create mode 100644 examples/offline_inference/encoder_decoder_multimodal.py delete mode 100644 examples/offline_inference/florence2_inference.py delete mode 100644 examples/offline_inference/whisper.py diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 3e3034a02f0f..1ceec026b319 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -24,25 +24,30 @@ # Unless specified, these settings have been tested to work on a single L4. -# Ultravox 0.5-1B -def run_ultravox(question: str, audio_count: int): - model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" +# MiniCPM-O +def run_minicpmo(question: str, audio_count: int): + model_name = "openbmb/MiniCPM-o-2_6" + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + llm = LLM(model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}) - tokenizer = AutoTokenizer.from_pretrained(model_name) + stop_tokens = ['<|im_end|>', '<|endoftext|>'] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + + audio_placeholder = "()" * audio_count + audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501 messages = [{ 'role': 'user', - 'content': "<|audio|>\n" * audio_count + question + 'content': f'{audio_placeholder}\n{question}' }] prompt = tokenizer.apply_chat_template(messages, tokenize=False, - add_generation_prompt=True) - - llm = LLM(model=model_name, - max_model_len=4096, - max_num_seqs=5, - trust_remote_code=True, - limit_mm_per_prompt={"audio": audio_count}) - stop_token_ids = None + add_generation_prompt=True, + chat_template=audio_chat_template) return llm, prompt, stop_token_ids @@ -68,36 +73,49 @@ def run_qwen2_audio(question: str, audio_count: int): return llm, prompt, stop_token_ids -def run_minicpmo(question: str, audio_count: int): - model_name = "openbmb/MiniCPM-o-2_6" - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - llm = LLM(model=model_name, - trust_remote_code=True, - max_model_len=4096, - max_num_seqs=5, - limit_mm_per_prompt={"audio": audio_count}) - - stop_tokens = ['<|im_end|>', '<|endoftext|>'] - stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] +# Ultravox 0.5-1B +def run_ultravox(question: str, audio_count: int): + model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" - audio_placeholder = "()" * audio_count - audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501 + tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [{ 'role': 'user', - 'content': f'{audio_placeholder}\n{question}' + 'content': "<|audio|>\n" * audio_count + question }] prompt = tokenizer.apply_chat_template(messages, tokenize=False, - add_generation_prompt=True, - chat_template=audio_chat_template) + add_generation_prompt=True) + + llm = LLM(model=model_name, + max_model_len=4096, + max_num_seqs=5, + trust_remote_code=True, + limit_mm_per_prompt={"audio": audio_count}) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +# Whisper +def run_whisper(question: str, audio_count: int): + assert audio_count == 1, ( + "Whisper only support single audio input per prompt") + model_name = "openai/whisper-large-v3-turbo" + + prompt = "<|startoftranscript|>" + + llm = LLM(model=model_name, + max_model_len=448, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}) + stop_token_ids = None return llm, prompt, stop_token_ids model_example_map = { - "ultravox": run_ultravox, + "minicpmo": run_minicpmo, "qwen2_audio": run_qwen2_audio, - "minicpmo": run_minicpmo + "ultravox": run_ultravox, + "whisper": run_whisper, } diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py new file mode 100644 index 000000000000..f44bc423658e --- /dev/null +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to use vLLM for running offline inference with +the explicit/implicit prompt format on enc-dec LMMs for text generation. +""" +import time + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.utils import FlexibleArgumentParser + + +def run_florence2(): + # Create a Florence-2 encoder/decoder model instance + llm = LLM( + model="microsoft/Florence-2-large", + tokenizer="facebook/bart-large", + max_num_seqs=8, + trust_remote_code=True, + limit_mm_per_prompt={"image": 1}, + dtype="half", + ) + + prompts = [ + { # implicit prompt with task token + "prompt": "", + "multi_modal_data": { + "image": ImageAsset("stop_sign").pil_image + }, + }, + { # explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "Describe in detail what is shown in the image.", + "multi_modal_data": { + "image": ImageAsset("cherry_blossom").pil_image + }, + }, + "decoder_prompt": "", + }, + ] + return llm, prompts + + +def run_mllama(): + # Create a Mllama encoder/decoder model instance + llm = LLM( + model="meta-llama/Llama-3.2-11B-Vision-Instruct", + max_model_len=4096, + max_num_seqs=2, + limit_mm_per_prompt={"image": 1}, + dtype="half", + ) + + prompts = [ + { # Implicit prompt + "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501 + "multi_modal_data": { + "image": ImageAsset("stop_sign").pil_image, + }, + }, + { # Explicit prompt + "encoder_prompt": { + "prompt": "<|image|>", + "multi_modal_data": { + "image": ImageAsset("stop_sign").pil_image, + }, + }, + "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501 + }, + ] + return llm, prompts + + +def run_whisper(): + # Create a Whisper encoder/decoder model instance + llm = LLM( + model="openai/whisper-large-v3-turbo", + max_model_len=448, + max_num_seqs=16, + limit_mm_per_prompt={"audio": 1}, + dtype="half", + ) + + prompts = [ + { # Test implicit prompt + "prompt": "<|startoftranscript|>", + "multi_modal_data": { + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + }, + }, + { # Test explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": AudioAsset("winning_call").audio_and_sample_rate, + }, + }, + "decoder_prompt": "<|startoftranscript|>", + } + ] + return llm, prompts + + +model_example_map = { + "florence2": run_florence2, + "mllama": run_mllama, + "whisper": run_whisper, +} + + +def main(args): + model = args.model_type + if model not in model_example_map: + raise ValueError(f"Model type {model} is not supported.") + + llm, prompts = model_example_map[model]() + + # Create a sampling params object. + sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + max_tokens=64, + ) + + start = time.time() + + # Generate output tokens from the prompts. The output is a list of + # RequestOutput objects that contain the prompt, generated + # text, and other information. + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Decoder prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") + + duration = time.time() - start + + print("Duration:", duration) + print("RPS:", len(prompts) / duration) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models for text generation') + parser.add_argument('--model-type', + '-m', + type=str, + default="mllama", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference/florence2_inference.py b/examples/offline_inference/florence2_inference.py deleted file mode 100644 index 27aceee43cbf..000000000000 --- a/examples/offline_inference/florence2_inference.py +++ /dev/null @@ -1,53 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Demonstrate prompting of text-to-text -encoder/decoder models, specifically Florence-2 -""" -# TODO(Isotr0py): -# Move to offline_inference/vision_language.py -# after porting vision backbone -from vllm import LLM, SamplingParams -from vllm.assets.image import ImageAsset - -# Create a Florence-2 encoder/decoder model instance -llm = LLM( - model="microsoft/Florence-2-large", - tokenizer="facebook/bart-large", - max_num_seqs=8, - trust_remote_code=True, -) - -prompts = [ - { # implicit prompt with task token - "prompt": "", - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image - }, - }, - { # explicit encoder/decoder prompt - "encoder_prompt": { - "prompt": "Describe in detail what is shown in the image.", - "multi_modal_data": { - "image": ImageAsset("cherry_blossom").pil_image - }, - }, - "decoder_prompt": "", - }, -] -# Create a sampling params object. -sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - min_tokens=0, - max_tokens=128, -) - -# Generate output tokens from the prompts. The output is a list of -# RequestOutput objects that contain the prompt, generated -# text, and other information. -outputs = llm.generate(prompts, sampling_params) - -# Print the outputs. -for output in outputs: - generated_text = output.outputs[0].text - print(f"Generated text: {generated_text!r}") diff --git a/examples/offline_inference/whisper.py b/examples/offline_inference/whisper.py deleted file mode 100644 index 59c119a772da..000000000000 --- a/examples/offline_inference/whisper.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import time - -from vllm import LLM, SamplingParams -from vllm.assets.audio import AudioAsset - -# Create a Whisper encoder/decoder model instance -llm = LLM( - model="openai/whisper-large-v3", - max_model_len=448, - max_num_seqs=400, - limit_mm_per_prompt={"audio": 1}, - kv_cache_dtype="fp8", -) - -prompts = [ - { - "prompt": "<|startoftranscript|>", - "multi_modal_data": { - "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, - }, - }, - { # Test explicit encoder/decoder prompt - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "audio": AudioAsset("winning_call").audio_and_sample_rate, - }, - }, - "decoder_prompt": "<|startoftranscript|>", - } -] * 1024 - -# Create a sampling params object. -sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - max_tokens=200, -) - -start = time.time() - -# Generate output tokens from the prompts. The output is a list of -# RequestOutput objects that contain the prompt, generated -# text, and other information. -outputs = llm.generate(prompts, sampling_params) - -# Print the outputs. -for output in outputs: - prompt = output.prompt - encoder_prompt = output.encoder_prompt - generated_text = output.outputs[0].text - print(f"Encoder prompt: {encoder_prompt!r}, " - f"Decoder prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") - -duration = time.time() - start - -print("Duration:", duration) -print("RPS:", len(prompts) / duration) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 656e5fc6dcf3..c5a55e300c46 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -748,11 +748,11 @@ def _create_fake_bias_for_k_proj( weights: Iterable[Tuple[str, torch.Tensor]] ) -> Iterable[Tuple[str, torch.Tensor]]: """ - Create full zeros bias for k_proj weight in self-attention layers. + Create full zeros bias for k_proj weight in self-attn and x-attn layers. So that the bias for k_proj in qkv_proj can be initialized with zeros. """ for name, weight in weights: - if name.endswith(".self_attn.k_proj.weight"): + if name.endswith(".k_proj.weight"): bias = torch.zeros(weight.size(0)) bias_name = name.replace("weight", "bias") yield from [(name, weight), (bias_name, bias)] From 8aea81ef7ed1eff4eff258900dc3ad55d5bcb402 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Feb 2025 23:09:14 -0800 Subject: [PATCH 304/317] [V1][Minor] Do not print attn backend twice (#13985) Signed-off-by: Woosuk Kwon --- vllm/platforms/cuda.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2a4cac46c066..bffa113cab89 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -178,7 +178,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, block_size) else: if use_v1: - logger.info("Using FlashMLA backend on V1 engine.") + logger.info_once( + "Using FlashMLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." "flashmla.FlashMLABackend") else: @@ -187,14 +188,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, "flashmla.FlashMLABackend") if use_v1: - logger.info("Using Triton MLA backend on V1 engine.") + logger.info_once("Using Triton MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." "triton_mla.TritonMLABackend") else: logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" if use_v1: - logger.info("Using Flash Attention backend on V1 engine.") + logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends.flash_attn." "FlashAttentionBackend") if selected_backend == _Backend.FLASHINFER: From 2658c530e309db92301378d752cb4df98b464a6c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 28 Feb 2025 23:18:32 -0800 Subject: [PATCH 305/317] [ROCm][V1][Bugfix] Add get_builder_cls method to the ROCmAttentionBackend class (#14065) Signed-off-by: Sage Moore --- vllm/v1/attention/backends/rocm_attn.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 0f3fabf05fc2..5c7d759b1812 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -9,7 +9,8 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.logger import init_logger -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata, FlashAttentionMetadataBuilder) logger = init_logger(__name__) @@ -49,6 +50,10 @@ def get_kv_cache_shape( def use_cascade_attention(*args, **kwargs) -> bool: return False + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + class ROCmAttentionImpl(AttentionImpl): From b029cc442a40aad294bd6d43b8a044575a61ca9f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 1 Mar 2025 16:25:54 +0800 Subject: [PATCH 306/317] [v1][Bugfix] Only cache blocks that are not in the prefix cache (#14073) --- vllm/v1/core/block_pool.py | 22 ++++------------------ vllm/v1/core/kv_cache_manager.py | 9 +++++---- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 5ef495c7eed8..1b5c7f96f668 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -107,34 +107,20 @@ def cache_full_blocks( assert prev_block.block_hash is not None prev_block_hash_value = prev_block.block_hash.hash_value - # Find the first uncached block. - # FIXME: num_cached_blocks should be corrected by the caller - # so this should never happen. - offset = 0 - for blk in new_full_blocks: - if blk.block_hash is None: - break - else: - prev_block_hash_value = blk.block_hash.hash_value - offset += 1 - else: - # All blocks are cached. - return - - for i, blk in enumerate(new_full_blocks[offset:]): - blk_idx = num_cached_blocks + offset + i + for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None - if i + offset < len(new_block_hashes): + if i < len(new_block_hashes): # The block hash may already be computed in # "get_computed_blocks" if the tokens are not generated by # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. - block_hash = new_block_hashes[i + offset] + block_hash = new_block_hashes[i] else: # Otherwise compute the block hash and cache it in the request # in case it will be preempted in the future. + blk_idx = num_cached_blocks + i start_token_idx = blk_idx * block_size end_token_idx = (blk_idx + 1) * block_size block_tokens = request.all_token_ids[ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index fc7bfa0eff57..030574de2bde 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -65,7 +65,7 @@ def __init__( # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. - self.num_cached_block: Dict[str, int] = defaultdict(int) + self.num_cached_block: Dict[str, int] = {} self.prefix_cache_stats = PrefixCacheStats() @property @@ -224,9 +224,10 @@ def allocate_slots( if not self.enable_caching: return new_blocks - # FIXME: `num_cached_blocks` is not correct when the prefix cache - # of a new request is hit. - num_cached_blocks = self.num_cached_block[request.request_id] + # Use `new_computed_blocks` for a new request, and `num_cached_block` + # for a running request. + num_cached_blocks = self.num_cached_block.get(request.request_id, + len(new_computed_blocks)) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. From a7ca7a60fecbf7b818e10b9eb3775cd738a5c01d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 2 Mar 2025 04:46:02 +0800 Subject: [PATCH 307/317] [v1] Add `__repr__` to KVCacheBlock to avoid recursive print (#14081) --- vllm/v1/core/kv_cache_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index e3eb6b24c195..546fddf67f41 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -128,6 +128,19 @@ def reset_hash(self): """Reset the block hash when the block is evicted.""" self._block_hash = None + def __repr__(self) -> str: + # Use block_id instead of KVCacheBlock object to avoid calling __repr__ + # on KVCacheBlock object recursively. + prev_block_id = self.prev_free_block.block_id \ + if self.prev_free_block else None + next_block_id = self.next_free_block.block_id \ + if self.next_free_block else None + return (f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}, " + f"_block_hash={self._block_hash}, " + f"prev_free_block={prev_block_id}, " + f"next_free_block={next_block_id})") + class FreeKVCacheBlockQueue: """This class organizes a list of KVCacheBlock objects to a doubly linked From 615e4922b68f2c243cac4401714cbc7e70154eee Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 2 Mar 2025 09:17:34 +0800 Subject: [PATCH 308/317] [Model] Add LoRA support for TransformersModel (#13770) Signed-off-by: Jee Jee Li --- .buildkite/test-pipeline.yaml | 3 +- docs/source/models/supported_models.md | 15 +-- tests/lora/conftest.py | 5 + tests/lora/test_transfomers_model.py | 120 +++++++++++++++++++++ vllm/lora/layers.py | 25 +++-- vllm/lora/utils.py | 25 +++-- vllm/model_executor/models/transformers.py | 43 ++------ 7 files changed, 166 insertions(+), 70 deletions(-) create mode 100644 tests/lora/test_transfomers_model.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 05c4d2616990..d0f5c94ffd8d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -275,7 +275,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py parallelism: 4 - label: PyTorch Fullgraph Smoke Test # 9min @@ -589,6 +589,7 @@ steps: - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_minicpmv_tp.py + - pytest -v -s -x lora/test_transfomers_model.py - label: Weight Loading Multiple GPU Test # 33min diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 4b1f3e180ed5..0e93a15b84fc 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -62,20 +62,7 @@ Transformers fallback has supported most of available quantization in vLLM (exce ##### LoRA -LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team! - -Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly. - -Hints as to how this would look like: - -```python -class TransformersModel(nn.Module, SupportsLoRA): - def __init__(*): - ... - self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"]) -``` - -Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint! +Transformers fallback has supported LoRA. The usage way is identical to how LoRA works with models supported by vLLM. If you encounter any issues, please open an issue. ##### Remote code diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 489ffc7d3257..944f1c011708 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -240,6 +240,11 @@ def baichuan_regex_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex") +@pytest.fixture(scope="session") +def ilama_lora_files(): + return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider") + + @pytest.fixture(scope="session") def minicpmv_lora_files(): return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") diff --git a/tests/lora/test_transfomers_model.py b/tests/lora/test_transfomers_model.py new file mode 100644 index 000000000000..07af1e9f449d --- /dev/null +++ b/tests/lora/test_transfomers_model.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import pytest + +import vllm +from tests.utils import fork_new_process_for_each_test +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "ArthurZ/ilama-3.2-1B" + +PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 + +EXPECTED_LORA_OUTPUT = [ + "SELECT count(*) FROM singer", + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 + "SELECT DISTINCT Country FROM singer WHERE Age > 20", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + prompts = [ + PROMPT_TEMPLATE.format(query="How many singers do we have?"), + PROMPT_TEMPLATE.format( + query= + "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + query= + "What are all distinct countries where singers above age 20 are from?" # noqa: E501 + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.mark.skip_v1 +@fork_new_process_for_each_test +def test_ilama_lora(ilama_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=1, + trust_remote_code=True, + enable_chunked_prefill=True) + + output1 = do_sample(llm, ilama_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, ilama_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] + + +@pytest.mark.skip_v1 +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_ilama_lora_tp4(ilama_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False, + enable_chunked_prefill=True) + + output1 = do_sample(llm, ilama_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, ilama_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] + + +@pytest.mark.skip_v1 +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True) + output1 = do_sample(llm, ilama_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, ilama_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7b718458c70d..e527addc99f9 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -408,6 +408,11 @@ def apply(self, return output + @classmethod + def get_source_layer(cls, source_layer: nn.Module) -> type: + # Check parent_cls in case source_layer is a HFCompatibleLinear. + return getattr(source_layer, "parent_cls", type(source_layer)) + class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): @@ -450,7 +455,8 @@ def can_replace_layer( packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return type(source_layer) is ReplicatedLinear + source_layer = cls.get_source_layer(source_layer) + return source_layer is ReplicatedLinear class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): @@ -546,8 +552,9 @@ def can_replace_layer( packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return type(source_layer) is ColumnParallelLinear or ( - type(source_layer) is MergedColumnParallelLinear + source_layer = cls.get_source_layer(source_layer) + return source_layer is ColumnParallelLinear or ( + source_layer is MergedColumnParallelLinear and len(packed_modules_list) == 1) @@ -689,7 +696,8 @@ def can_replace_layer( packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return (type(source_layer) is MergedColumnParallelLinear + source_layer = cls.get_source_layer(source_layer) + return (source_layer is MergedColumnParallelLinear and len(packed_modules_list) == 2) @@ -757,7 +765,8 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( + source_layer = cls.get_source_layer(source_layer) + return source_layer is QKVParallelLinear and len( packed_modules_list) == 1 @@ -818,7 +827,8 @@ def can_replace_layer( packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return (type(source_layer) is QKVParallelLinear + source_layer = cls.get_source_layer(source_layer) + return (source_layer is QKVParallelLinear and len(packed_modules_list) == 3) @@ -903,7 +913,8 @@ def can_replace_layer( packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return type(source_layer) is RowParallelLinear + source_layer = cls.get_source_layer(source_layer) + return source_layer is RowParallelLinear class LogitsProcessorWithLoRA(BaseLayerWithLoRA): diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 63b465fdf743..9f1b14b49704 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -66,17 +66,20 @@ def from_layer(layer: nn.Module, lora_config=lora_config, packed_modules_list=packed_modules_list, model_config=model_config): - ret = lora_cls(layer) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret - - # The Case for HFCompatibleLinear - if (hasattr(layer, "get_lora_class") - and layer.__class__.__name__ == "HFCompatibleLinear"): - lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras) - ret = lora_cls(layer) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret + instance_layer = lora_cls(layer) + if layer.__class__.__name__ == "HFCompatibleLinear": + # HACK: Make the forward method compatible with the original + # forward method of the instance_layer. + original_forward = instance_layer.forward + + def new_forward(input): + input = input.squeeze(0) + return original_forward(input)[0] # noqa: B023 + + instance_layer.forward = new_forward + instance_layer.create_lora_weights(max_loras, lora_config, + model_config) + return instance_layer return layer diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 1c3c443b2941..61cfc566dd31 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -27,11 +27,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide from vllm.logger import init_logger -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA) from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -43,7 +38,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsQuant +from .interfaces import SupportsLoRA, SupportsQuant from .utils import maybe_prefix logger = init_logger(__name__) @@ -102,44 +97,18 @@ def replace_linear_class( "rowwise": RowParallelLinear, }.get(style, ReplicatedLinear) - lora_linear_cls = { - ColumnParallelLinear: { - True: ColumnParallelLinearWithShardedLoRA, # fully sharded - False: ColumnParallelLinearWithLoRA # not fully sharded - }, - RowParallelLinear: { - True: RowParallelLinearWithShardedLoRA, - False: RowParallelLinearWithLoRA - }, - # ReplicatedLinear doesn't support fully sharded LoRA yet, - # so we use the same class for both cases. - ReplicatedLinear: { - True: ReplicatedLinearWithLoRA, - False: ReplicatedLinearWithLoRA - } - } - class HFCompatibleLinear(vllm_linear_cls): """ Wrapper class that removes `output_bias` from returned output. """ + # NOTE: The LoRA layer needs to use `parent_cls`. + @property + def parent_cls(self) -> type: + return vllm_linear_cls def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input)[0] - @classmethod - def get_lora_class(cls, fully_sharded: bool = False): - """ - Get the LoRA class corresponding to the current transformer - linear class. - - Args: - fully_sharded (bool): If True, select the LoRA class variant - that supports fully sharded LoRA. Defaults to False. - - """ - return lora_linear_cls[vllm_linear_cls][fully_sharded] - return HFCompatibleLinear( input_size=linear.in_features, output_size=linear.out_features, @@ -148,7 +117,7 @@ def get_lora_class(cls, fully_sharded: bool = False): ) -class TransformersModel(nn.Module, SupportsQuant): +class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" ] # TODO transformers will have a util to get it From 6bab1e2ac3c76548dfa8dec48571e6cea82b130f Mon Sep 17 00:00:00 2001 From: Jun Duan Date: Sat, 1 Mar 2025 20:20:30 -0500 Subject: [PATCH 309/317] [Misc] Accurately capture the time of loading weights (#14063) Signed-off-by: Jun Duan --- vllm/model_executor/model_loader/loader.py | 11 +++++++++++ vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/worker/model_runner.py | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 6244241d1891..4f1092f68f50 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -10,6 +10,7 @@ import itertools import math import os +import time import warnings from abc import ABC, abstractmethod from contextlib import contextmanager @@ -216,6 +217,9 @@ class Source: allow_patterns_overrides: Optional[list[str]] = None """If defined, weights will load exclusively using these patterns.""" + counter_before_loading_weights: float = 0.0 + counter_after_loading_weights: float = 0.0 + def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -368,6 +372,8 @@ def _xla_weights_iterator(iterator: Generator): weights_iterator = _xla_weights_iterator(weights_iterator) + if self.counter_before_loading_weights == 0.0: + self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) @@ -412,6 +418,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( self._get_all_weights(model_config, model)) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - + self.counter_before_loading_weights) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if model_config.quantization is None and loaded_weights is not None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0215b2735384..6785d6684269 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1061,7 +1061,7 @@ def load_model(self) -> None: self.device) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB and %.6f seconds", + logger.info("Model loading took %.4f GB and %.6f seconds", self.model_memory_usage / float(2**30), time_after_load - time_before_load) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bb2228165b52..0ea1d5dcbbb7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1114,7 +1114,7 @@ def load_model(self) -> None: time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB and %.6f seconds", + logger.info("Model loading took %.4f GB and %.6f seconds", self.model_memory_usage / float(2**30), time_after_load - time_before_load) From 657beeaf367b6fba7eb15eacb6aa46c4fb056892 Mon Sep 17 00:00:00 2001 From: qux-bbb <1147635419@qq.com> Date: Sun, 2 Mar 2025 18:59:50 +0800 Subject: [PATCH 310/317] [Doc] Source building add clone step (#14086) Signed-off-by: qux-bbb <1147635419@qq.com> --- .../source/getting_started/installation/cpu/build.inc.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/installation/cpu/build.inc.md b/docs/source/getting_started/installation/cpu/build.inc.md index 2a8173803c05..46329e9bd281 100644 --- a/docs/source/getting_started/installation/cpu/build.inc.md +++ b/docs/source/getting_started/installation/cpu/build.inc.md @@ -6,7 +6,14 @@ sudo apt-get install -y gcc-12 g++-12 libnuma-dev sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 ``` -Second, install Python packages for vLLM CPU backend building: +Second, clone vLLM project: + +```console +git clone https://github.com/vllm-project/vllm.git vllm_source +cd vllm_source +``` + +Third, install Python packages for vLLM CPU backend building: ```console pip install --upgrade pip From 0aba218cc37cbc44281c2dee98169838dbfab2ae Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Mon, 3 Mar 2025 03:49:42 +0800 Subject: [PATCH 311/317] [v0][structured output] Support reasoning output (#12955) Signed-off-by: Ce Gao --- docs/source/features/reasoning_outputs.md | 43 +++++-- ...etion_structured_outputs_with_reasoning.py | 64 ++++++++++ .../model_executor/test_guided_processors.py | 116 +++++++++++++++--- vllm/config.py | 2 + vllm/engine/arg_utils.py | 26 +++- vllm/engine/async_llm_engine.py | 11 +- vllm/engine/llm_engine.py | 7 +- vllm/engine/multiprocessing/client.py | 3 +- vllm/entrypoints/openai/cli_args.py | 18 --- .../guided_decoding/__init__.py | 30 +++-- .../guided_decoding/outlines_decoding.py | 30 +++-- .../outlines_logits_processors.py | 37 ++++-- .../guided_decoding/reasoner/__init__.py | 23 ++++ .../reasoner/deepseek_reasoner.py | 28 +++++ .../guided_decoding/reasoner/reasoner.py | 19 +++ .../guided_decoding/xgrammar_decoding.py | 19 ++- 16 files changed, 400 insertions(+), 76 deletions(-) create mode 100644 examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py create mode 100644 vllm/model_executor/guided_decoding/reasoner/__init__.py create mode 100644 vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py create mode 100644 vllm/model_executor/guided_decoding/reasoner/reasoner.py diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index e39bbacf1138..5c0c1762f8aa 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni } ``` -Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. +Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). + +## Limitations + +- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). +- It is not compatible with [`tool_calling`](#tool_calling). +- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning. ## How to support a new reasoning model @@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser): """ ``` -After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint. +Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`. + +```python +@dataclass +class DeepSeekReasoner(Reasoner): + """ + Reasoner for DeepSeek R series models. + """ + start_token_id: int + end_token_id: int + + start_token: str = "" + end_token: str = "" + + @classmethod + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: + return cls(start_token_id=tokenizer.encode( + "", add_special_tokens=False)[0], + end_token_id=tokenizer.encode("", + add_special_tokens=False)[0]) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids +``` + +The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. + +Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags. ```bash vllm serve \ --enable-reasoning --reasoning-parser example ``` - -## Limitations - -- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). -- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features. -- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning. diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py new file mode 100644 index 000000000000..1f72e1164d42 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +An example shows how to generate structured outputs from reasoning models +like DeepSeekR1. The thinking process will not be guided by the JSON +schema provided by the user. Only the final output will be structured. + +To run this example, you need to start the vLLM server with the reasoning +parser: + +```bash +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --enable-reasoning --reasoning-parser deepseek_r1 +``` + +This example demonstrates how to generate chat completions from reasoning models +using the OpenAI Python client library. +""" + +from enum import Enum + +from openai import OpenAI +from pydantic import BaseModel + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's, think in 100 tokens") +completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": json_schema}, +) +print("content", completion.choices[0].message.content) +print("reasoning_content: ", completion.choices[0].message.reasoning_content) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index be544698fa03..531c3a8c13b2 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -16,17 +16,33 @@ MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] +GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"] +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" -def test_guided_logits_processors(sample_regex, sample_json_schema): +# Initialize the tokenizer for the model here to avoid repeated loading +@pytest.fixture(scope="module") +def zephyr_7B_tokenzer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def deepseek_r1_qwen_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, + sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') - regex_LP = RegexLogitsProcessor(sample_regex, tokenizer) + regex_LP = RegexLogitsProcessor(sample_regex, + zephyr_7B_tokenzer, + reasoner=None) json_LP = JSONLogitsProcessor(sample_json_schema, - tokenizer, - whitespace_pattern=None) + zephyr_7B_tokenzer, + whitespace_pattern=None, + reasoner=None) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) tensor = torch.rand(32000) @@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.parametrize("is_local", [True, False]) async def test_guided_logits_processor_black_box(backend: str, is_local: bool, sample_regex, - sample_json_schema): + sample_json_schema, + zephyr_7B_tokenzer): config = ModelConfig( MODEL_NAME, @@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, seed=0, dtype="bfloat16", ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = get_local_guided_decoding_logits_processor( - regex_request, tokenizer, config) if is_local else \ + regex_request, zephyr_7B_tokenzer, config) if is_local else \ await get_guided_decoding_logits_processor( - regex_request, tokenizer, config) + regex_request, zephyr_7B_tokenzer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = tokenizer.encode( + token_ids = zephyr_7B_tokenzer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, tokenizer, config) + json_request, zephyr_7B_tokenzer, config) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", + GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT) +@pytest.mark.parametrize("is_local", [True, False]) +@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"]) +async def test_guided_logits_processor_with_reasoning( + backend: str, is_local: bool, reasoning_backend: str, sample_regex, + sample_json_schema, deepseek_r1_qwen_tokenizer): + + config = ModelConfig( + REASONING_MODEL_NAME, + task="generate", + tokenizer=REASONING_MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + token_ids = deepseek_r1_qwen_tokenizer.encode( + f"Give an example IPv4 address with this regex: {sample_regex}." + "here is the thinking process") + regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) + + regex_lp = get_local_guided_decoding_logits_processor(regex_request, + deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + regex_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) + assert regex_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert torch.allclose(tensor, original_tensor) + + token_ids = deepseek_r1_qwen_tokenizer.encode( + f"Give an employee profile that fits this schema: {sample_json_schema}." + "here is the thinking process") + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) + json_lp = get_local_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert torch.allclose(tensor, original_tensor) + + # Thinking is over, so the tensor should change. + token_ids = deepseek_r1_qwen_tokenizer.encode( + f"Give an employee profile that fits this schema: {sample_json_schema}." + "here is the thinking process Then") + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) + json_lp = get_local_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, + reasoning_backend) if is_local else \ + await get_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) diff --git a/vllm/config.py b/vllm/config.py index c7108473442b..54ed38418dd4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2715,6 +2715,8 @@ class DecodingConfig: # 'outlines' / 'lm-format-enforcer' / 'xgrammar' guided_decoding_backend: str = 'xgrammar' + reasoning_backend: Optional[str] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1a2f794c9151..989eb4dbfd14 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -213,6 +213,8 @@ class EngineArgs: calculate_kv_scales: Optional[bool] = None additional_config: Optional[Dict[str, Any]] = None + enable_reasoning: Optional[bool] = None + reasoning_parser: Optional[str] = None def __post_init__(self): if not self.tokenizer: @@ -1059,6 +1061,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Different platforms may support different configs. Make sure the " "configs are valid for the platform you are using. The input format" " is like '{\"config_key\":\"config_value\"}'") + + parser.add_argument( + "--enable-reasoning", + action="store_true", + default=False, + help="Whether to enable reasoning_content for the model. " + "If enabled, the model will be able to generate reasoning content." + ) + + parser.add_argument( + "--reasoning-parser", + type=str, + choices=["deepseek_r1"], + default=None, + help= + "Select the reasoning parser depending on the model that you're " + "using. This is used to parse the reasoning content into OpenAI " + "API format. Required for ``--enable-reasoning``.") + return parser @classmethod @@ -1332,7 +1353,10 @@ def create_engine_config(self, if self.enable_prompt_adapter else None decoding_config = DecodingConfig( - guided_decoding_backend=self.guided_decoding_backend) + guided_decoding_backend=self.guided_decoding_backend, + reasoning_backend=self.reasoning_parser + if self.enable_reasoning else None, + ) show_hidden_metrics = False if self.show_hidden_metrics_for_version is not None: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 93d9b74d8e1e..90e66b005f39 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -509,6 +509,7 @@ async def add_request_async( tokenizer=await self.get_tokenizer_async(lora_request), default_guided_backend=self.decoding_config. guided_decoding_backend, + reasoning_backend=self.decoding_config.reasoning_backend, model_config=self.model_config) self._add_processed_request( @@ -530,7 +531,7 @@ async def check_health_async(self) -> None: async def build_guided_decoding_logits_processor_async( sampling_params: SamplingParams, tokenizer: AnyTokenizer, - default_guided_backend: str, + default_guided_backend: str, reasoning_backend: Optional[str], model_config: ModelConfig) -> SamplingParams: """Constructs logits processors based on the guided_decoding, logits_bias, and allowed_token_ids fields in sampling_params. Deletes @@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async( sampling_params = copy.copy(sampling_params) guided_decoding = sampling_params.guided_decoding - logger.debug("Building guided decoding logits processor. " - "Params: %s", guided_decoding) + logger.info( + "Building guided decoding logits processor. " + "guided_decoding: %s%s", guided_decoding, + f", reasoning_backend: {reasoning_backend}" + if reasoning_backend is not None else "") guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( guided_params=guided_decoding, tokenizer=tokenizer, + reasoning_backend=reasoning_backend, model_config=model_config) if processor: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9c83ea75ead7..f055438d1feb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2048,10 +2048,15 @@ def _build_logits_processors( guided_decoding.backend = guided_decoding.backend or \ self.decoding_config.guided_decoding_backend + logger.debug("Reasoning backend: %s", + self.decoding_config.reasoning_backend) + processor = get_local_guided_decoding_logits_processor( guided_params=guided_decoding, tokenizer=tokenizer, - model_config=self.model_config) + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) if processor: logits_processors.append(processor) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index c12fe242082b..005ba81cd226 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -611,7 +611,8 @@ async def _process_request( default_guided_backend=(self.decoding_config.guided_decoding_backend if self.decoding_config else DecodingConfig.guided_decoding_backend), - model_config=self.model_config + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, ) # 1) Create output queue for this requests. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index ba953c219708..8d877046f75f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -13,7 +13,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) -from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager from vllm.entrypoints.openai.serving_models import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -215,23 +214,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Enable auto tool choice for supported models. Use " "``--tool-call-parser`` to specify which parser to use.") - parser.add_argument( - "--enable-reasoning", - action="store_true", - default=False, - help="Whether to enable reasoning_content for the model. " - "If enabled, the model will be able to generate reasoning content.") - - valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys() - parser.add_argument( - "--reasoning-parser", - type=str, - metavar="{" + ",".join(valid_reasoning_parsers) + "}", - default=None, - help= - "Select the reasoning parser depending on the model that you're using." - " This is used to parse the reasoning content into OpenAI API " - "format. Required for ``--enable-reasoning``.") valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 1522e3404182..86f6f0e5f907 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.reasoner import get_reasoner from vllm.model_executor.guided_decoding.utils import ( convert_lark_to_gbnf, grammar_is_likely_lark, has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) @@ -103,8 +104,13 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str, async def get_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None) -> LogitsProcessor | None: + + reasoner = get_reasoner(tokenizer, reasoning_backend) + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': @@ -112,8 +118,8 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( - guided_params, tokenizer) - if guided_params.backend_name == 'lm-format-enforcer': + guided_params, tokenizer, reasoner) + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( @@ -122,7 +128,7 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config) + guided_params, tokenizer, model_config, reasoner) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -130,16 +136,22 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, - model_config: ModelConfig) -> LogitsProcessor | None: + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) + + # Get the reasoner if needed, it will be None if reasoning_ + reasoner = get_reasoner(tokenizer, reasoning_backend) + # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( - guided_params, tokenizer) + guided_params, tokenizer, reasoner) if guided_params.backend_name == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) @@ -149,7 +161,7 @@ def get_local_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config) + guided_params, tokenizer, model_config, reasoner) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index ba9c98290368..97f63ae11f45 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -6,12 +6,13 @@ from enum import Enum from json import dumps as json_dumps from re import escape as regex_escape -from typing import Tuple, Union +from typing import Optional, Tuple, Union from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.sampling_params import GuidedDecodingParams @@ -58,7 +59,9 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -82,11 +85,14 @@ async def get_outlines_guided_decoding_logits_processor( return await loop.run_in_executor(global_thread_pool, _get_logits_processor, guide, tokenizer, - mode, guided_params.whitespace_pattern) + mode, guided_params.whitespace_pattern, + reasoner) def get_local_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -100,7 +106,7 @@ def get_local_outlines_guided_decoding_logits_processor( return None return _get_logits_processor(guide, tokenizer, mode, - guided_params.whitespace_pattern) + guided_params.whitespace_pattern, reasoner) def _get_guide_and_mode( @@ -131,14 +137,18 @@ def _get_guide_and_mode( def _get_logits_processor( - guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode, - whitespace_pattern: Union[str, None] + guide: str, + tokenizer: PreTrainedTokenizerBase, + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None], + reasoner: Optional[Reasoner], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: if mode == GuidedDecodingMode.JSON: - return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, + reasoner) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: - return RegexLogitsProcessor(guide, tokenizer) + return RegexLogitsProcessor(guide, tokenizer, reasoner) elif mode == GuidedDecodingMode.GRAMMAR: - return CFGLogitsProcessor(guide, tokenizer) + return CFGLogitsProcessor(guide, tokenizer, reasoner) else: raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index a05267d921d1..db5d738f42e4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -19,7 +19,7 @@ import json from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Union +from typing import Callable, DefaultDict, Dict, List, Optional, Union import numpy as np import torch @@ -32,13 +32,18 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase +from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.platforms import current_platform +logger = init_logger(__name__) + class BaseLogitsProcessor: - def __init__(self, guide: Guide): + def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): self._guide: Guide = guide + self._reasoner = reasoner # CFGState is used for the FSM state for CFGGuide self._fsm_state: DefaultDict[int, Union[int, CFGState]] = defaultdict(int) @@ -46,6 +51,14 @@ def __init__(self, guide: Guide): def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" + + # Skip the structured logits processing if reasoning is not finished. + # reasoner is not None only when `--enable-reasoning` is set. + if self._reasoner is not None and \ + not self._reasoner.is_reasoning_end( + input_ids): + return scores + seq_id = hash(tuple(input_ids)) if len(input_ids) > 0: @@ -113,7 +126,12 @@ def _get_guide(cls, regex_string: str, tokenizer = _adapt_tokenizer(tokenizer) return RegexGuide.from_regex(regex_string, tokenizer) - def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): + def __init__( + self, + regex_string: str, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner], + ): """Compile the FSM that drives the regex-structured generation. Parameters @@ -125,14 +143,15 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): """ super().__init__( - RegexLogitsProcessor._get_guide(regex_string, tokenizer)) + RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner) class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Union[str, None]): + whitespace_pattern: Union[str, None], + reasoner: Optional[Reasoner]): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -160,7 +179,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel], f"a Pydantic object, a dictionary or a string that contains " f"the JSON Schema specification") regex_string = build_regex_from_schema(schema_str, whitespace_pattern) - super().__init__(regex_string, tokenizer) + super().__init__(regex_string, tokenizer, reasoner) class CFGLogitsProcessor(BaseLogitsProcessor): @@ -171,7 +190,8 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: tokenizer = _adapt_tokenizer(tokenizer) return CFGGuide(cfg, tokenizer) - def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[Reasoner]): """Compile the FSM that drives the context free grammar generation. Parameters @@ -182,7 +202,8 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer)) + super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer), + reasoner) self._guide = self._guide.copy() diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py new file mode 100644 index 000000000000..5a91f791d45b --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from transformers import PreTrainedTokenizer + +from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501 + DeepSeekReasoner) +from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner + + +def get_reasoner(tokenizer: PreTrainedTokenizer, + reasoning_backend: str | None) -> Reasoner | None: + if reasoning_backend is None: + # No reasoning backend specified + return None + elif reasoning_backend == "deepseek_r1": + return DeepSeekReasoner.from_tokenizer(tokenizer) + else: + raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'") + + +__all__ = ["Reasoner", "get_reasoner"] diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py new file mode 100644 index 000000000000..e762fb0659de --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from transformers import PreTrainedTokenizer + +from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner + + +@dataclass +class DeepSeekReasoner(Reasoner): + """ + Reasoner for DeepSeek R series models. + """ + start_token_id: int + end_token_id: int + + start_token: str = "" + end_token: str = "" + + @classmethod + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: + return cls(start_token_id=tokenizer.encode( + "", add_special_tokens=False)[0], + end_token_id=tokenizer.encode("", + add_special_tokens=False)[0]) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py new file mode 100644 index 000000000000..5db0c9bc7850 --- /dev/null +++ b/vllm/model_executor/guided_decoding/reasoner/reasoner.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from transformers import PreTrainedTokenizer + + +@dataclass +class Reasoner(ABC): + + @abstractmethod + def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: + pass + + @abstractmethod + def is_reasoning_end(self, input_ids: list[int]) -> bool: + pass diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index eb9d83acb286..ce278c15ab3b 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -11,6 +11,8 @@ import torch from transformers import PreTrainedTokenizerFast +from vllm.logger import init_logger + try: import xgrammar as xgr from xgrammar.base import _core as xgr_core @@ -19,7 +21,6 @@ xgr_installed = False pass -from vllm.logger import init_logger from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, grammar_is_likely_lark) from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer @@ -28,6 +29,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig + from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.sampling_params import GuidedDecodingParams logger = init_logger(__name__) @@ -38,12 +40,13 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig, + reasoner: Reasoner | None, max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, model_config=model_config, tokenizer=tokenizer, max_threads=max_threads) - return XGrammarLogitsProcessor(config) + return XGrammarLogitsProcessor(config, reasoner) @dataclass(frozen=True) @@ -293,6 +296,7 @@ def choice_as_grammar(choice: List[str] | None) -> str: class XGrammarLogitsProcessor: """Wrapper class to support pickle protocol""" config: GrammarConfig + reasoner: Reasoner | None = None ctx: xgr.CompiledGrammar | None = None token_bitmask: torch.Tensor = None # type: ignore[assignment] @@ -301,10 +305,11 @@ class XGrammarLogitsProcessor: prefilled: bool = field(default=False) def __getstate__(self) -> dict[str, Any]: - return {'config': self.config} + return {'config': self.config, 'reasoner': self.reasoner} def __setstate__(self, state: dict[str, Any]): self.config = state['config'] + self.reasoner = state['reasoner'] self.ctx = None self.matchers = [] @@ -331,6 +336,14 @@ def _ensure_ctx(self): def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: + + # Skip the structured logits processing if reasoning is not finished. + # reasoner is not None only when `--enable-reasoning` is set. + if self.reasoner is not None and \ + not self.reasoner.is_reasoning_end( + input_ids): + return scores + if self.ctx is None: self._ensure_ctx() From 3610d541f66c9e629a45626b8090ac8f5e3d8a49 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 3 Mar 2025 01:34:51 +0000 Subject: [PATCH 312/317] Update deprecated Python 3.8 typing (#13971) --- benchmarks/backend_request_func.py | 6 +- benchmarks/benchmark_guided.py | 17 +- benchmarks/benchmark_latency.py | 6 +- benchmarks/benchmark_prefix_caching.py | 16 +- benchmarks/benchmark_prioritization.py | 8 +- benchmarks/benchmark_serving.py | 77 +++---- benchmarks/benchmark_serving_guided.py | 57 ++--- benchmarks/benchmark_throughput.py | 38 ++-- benchmarks/benchmark_utils.py | 8 +- .../cutlass_benchmarks/sparse_benchmarks.py | 9 +- benchmarks/cutlass_benchmarks/utils.py | 8 +- .../cutlass_benchmarks/w8a8_benchmarks.py | 17 +- .../fused_kernels/layernorm_rms_benchmarks.py | 5 +- benchmarks/kernels/benchmark_lora.py | 60 ++--- benchmarks/kernels/benchmark_machete.py | 25 ++- benchmarks/kernels/benchmark_marlin.py | 6 +- benchmarks/kernels/benchmark_moe.py | 18 +- .../kernels/benchmark_paged_attention.py | 4 +- benchmarks/kernels/benchmark_rmsnorm.py | 4 +- benchmarks/kernels/benchmark_rope.py | 4 +- benchmarks/kernels/graph_machete_bench.py | 3 +- benchmarks/kernels/utils.py | 3 +- .../vllm_cutlass_library_extension.py | 14 +- csrc/quantization/machete/generate.py | 20 +- docs/source/conf.py | 3 +- docs/source/features/reasoning_outputs.md | 4 +- docs/source/features/structured_outputs.md | 2 +- docs/source/generate_examples.py | 2 +- examples/offline_inference/distributed.py | 10 +- .../offline_inference/llm_engine_example.py | 7 +- .../lora_with_quantization_inference.py | 8 +- examples/offline_inference/mlpspeculator.py | 3 +- .../offline_inference/multilora_inference.py | 8 +- .../prithvi_geospatial_mae.py | 8 +- examples/offline_inference/profiling.py | 15 +- .../profiling_tpu/profiling.py | 3 +- .../vision_language_multi_image.py | 34 +-- examples/online_serving/api_client.py | 6 +- .../online_serving/openai_embedding_client.py | 2 +- pyproject.toml | 28 ++- setup.py | 7 +- tests/async_engine/api_server_async_engine.py | 5 +- tests/async_engine/test_async_llm_engine.py | 4 +- tests/compile/piecewise/test_toy_llama.py | 6 +- tests/compile/test_basic_correctness.py | 8 +- tests/conftest.py | 159 +++++++------ tests/core/block/e2e/conftest.py | 3 +- .../e2e/test_correctness_sliding_window.py | 11 +- tests/core/block/test_block_table.py | 8 +- tests/core/block/test_naive_block.py | 4 +- tests/core/block/test_prefix_caching_block.py | 16 +- tests/core/test_chunked_prefill_scheduler.py | 25 +-- tests/core/test_scheduler.py | 19 +- tests/core/test_scheduler_encoder_decoder.py | 4 +- tests/core/utils.py | 21 +- tests/distributed/test_expert_parallel.py | 6 +- tests/distributed/test_pipeline_parallel.py | 8 +- tests/distributed/test_pynccl.py | 5 +- tests/distributed/test_shm_broadcast.py | 3 +- tests/encoder_decoder/test_e2e_correctness.py | 4 +- tests/engine/test_executor.py | 6 +- tests/engine/test_multiproc_workers.py | 6 +- tests/engine/test_stop_strings.py | 6 +- tests/entrypoints/llm/test_chat.py | 4 +- tests/entrypoints/llm/test_encode.py | 5 +- tests/entrypoints/llm/test_generate.py | 3 +- .../test_transcription_api_correctness.py | 3 +- .../test_deepseekr1_reasoning_parser.py | 4 +- .../openai/reasoning_parsers/utils.py | 14 +- tests/entrypoints/openai/test_audio.py | 16 +- tests/entrypoints/openai/test_basic.py | 3 +- tests/entrypoints/openai/test_chat.py | 8 +- tests/entrypoints/openai/test_completion.py | 8 +- tests/entrypoints/openai/test_embedding.py | 4 +- tests/entrypoints/openai/test_pooling.py | 4 +- tests/entrypoints/openai/test_root_path.py | 4 +- tests/entrypoints/openai/test_video.py | 12 +- tests/entrypoints/openai/test_vision.py | 12 +- .../openai/test_vision_embedding.py | 4 +- .../tool_parsers/test_pythonic_tool_parser.py | 3 +- .../entrypoints/openai/tool_parsers/utils.py | 9 +- tests/kernels/quant_utils.py | 6 +- tests/kernels/test_activation.py | 3 +- tests/kernels/test_attention.py | 16 +- tests/kernels/test_blocksparse_attention.py | 12 +- tests/kernels/test_cache.py | 5 +- tests/kernels/test_cascade_flash_attn.py | 8 +- tests/kernels/test_cutlass.py | 11 +- tests/kernels/test_cutlass_2of4_sparse.py | 5 +- tests/kernels/test_encoder_decoder_attn.py | 4 +- tests/kernels/test_flash_attn.py | 16 +- tests/kernels/test_flashinfer.py | 22 +- tests/kernels/test_fused_quant_layernorm.py | 12 +- tests/kernels/test_gguf.py | 3 +- tests/kernels/test_machete_mm.py | 14 +- tests/kernels/test_mamba_mixer2.py | 3 +- tests/kernels/test_mamba_ssm_ssd.py | 8 +- tests/kernels/test_pos_encoding.py | 6 +- tests/kernels/test_triton_scaled_mm.py | 4 +- tests/kernels/utils.py | 68 +++--- tests/kv_transfer/test_send_recv.py | 3 +- tests/lora/conftest.py | 6 +- tests/lora/data/long_context_test_data.py | 4 +- tests/lora/test_add_lora.py | 9 +- tests/lora/test_baichuan.py | 6 +- tests/lora/test_chatglm3_tp.py | 6 +- tests/lora/test_gemma.py | 6 +- tests/lora/test_jamba.py | 6 +- tests/lora/test_layers.py | 48 ++-- tests/lora/test_llama_tp.py | 6 +- tests/lora/test_long_context.py | 16 +- tests/lora/test_lora_bias_e2e.py | 6 +- tests/lora/test_lora_checkpoints.py | 6 +- tests/lora/test_lora_functions.py | 5 +- tests/lora/test_lora_huggingface.py | 4 +- tests/lora/test_lora_manager.py | 7 +- tests/lora/test_minicpmv_tp.py | 6 +- tests/lora/test_mixtral.py | 6 +- tests/lora/test_phi.py | 6 +- tests/lora/test_punica_ops.py | 5 +- tests/lora/test_quant_model.py | 7 +- tests/lora/test_qwen2vl.py | 10 +- tests/lora/test_transfomers_model.py | 6 +- tests/lora/test_ultravox.py | 7 +- tests/lora/utils.py | 14 +- tests/metrics/test_metrics.py | 3 +- tests/mistral_tool_use/utils.py | 8 +- .../model_executor/test_enabled_custom_ops.py | 4 +- .../audio_language/test_ultravox.py | 16 +- .../models/decoder_only/language/test_gguf.py | 6 +- .../decoder_only/language/test_modelopt.py | 3 +- .../decoder_only/vision_language/test_awq.py | 6 +- .../vision_language/test_models.py | 39 ++-- .../vision_language/test_phi3v.py | 10 +- .../vision_language/test_pixtral.py | 12 +- .../vision_language/test_qwen2_vl.py | 46 ++-- .../vision_language/vlm_utils/builders.py | 7 +- .../vlm_utils/case_filtering.py | 10 +- .../vision_language/vlm_utils/core.py | 22 +- .../vision_language/vlm_utils/model_utils.py | 16 +- .../vision_language/vlm_utils/runners.py | 21 +- .../vision_language/vlm_utils/types.py | 36 +-- .../models/embedding/language/test_gritlm.py | 11 +- tests/models/embedding/utils.py | 6 +- .../vision_language/test_dse_qwen2_vl.py | 12 +- .../vision_language/test_llava_next.py | 8 +- .../embedding/vision_language/test_phi3v.py | 8 +- .../encoder_decoder/language/test_bart.py | 10 +- .../vision_language/test_florence2.py | 8 +- .../vision_language/test_mllama.py | 36 +-- .../multimodal/processing/test_h2ovl.py | 3 +- .../multimodal/processing/test_internvl.py | 3 +- tests/models/registry.py | 5 +- tests/models/test_transformers.py | 15 +- tests/models/utils.py | 21 +- tests/mq_llm_engine/utils.py | 4 +- .../multi_step/test_correctness_async_llm.py | 4 +- tests/multimodal/test_utils.py | 8 +- tests/neuron/test_logits_processor.py | 3 +- .../my_gemma_embedding.py | 5 +- tests/quantization/test_configs.py | 3 +- .../test_register_quantization_config.py | 8 +- tests/samplers/test_logprobs.py | 4 +- tests/samplers/test_no_bad_words.py | 16 +- tests/samplers/test_rejection_sampler.py | 11 +- tests/samplers/test_sampler.py | 44 ++-- tests/spec_decode/e2e/conftest.py | 9 +- tests/spec_decode/test_batch_expansion.py | 4 +- tests/spec_decode/test_multi_step_worker.py | 15 +- tests/spec_decode/test_scorer.py | 3 +- tests/spec_decode/test_spec_decode_worker.py | 11 +- tests/spec_decode/utils.py | 33 ++- tests/test_cache_block_hashing.py | 6 +- tests/test_inputs.py | 4 +- tests/test_logger.py | 2 +- tests/test_logits_processor.py | 3 +- tests/test_utils.py | 4 +- tests/tokenization/test_detokenize.py | 23 +- tests/tokenization/test_tokenizer_group.py | 4 +- tests/tokenization/test_tokenizer_registry.py | 32 +-- tests/tool_use/test_chat_completions.py | 6 +- tests/tool_use/test_jamba_tool_parser.py | 13 +- tests/tool_use/test_parallel_tool_calls.py | 10 +- tests/tool_use/test_tool_calls.py | 10 +- tests/tool_use/utils.py | 26 +-- tests/tracing/test_tracing.py | 5 +- tests/utils.py | 42 ++-- tests/v1/core/test_prefix_caching.py | 3 +- tests/v1/core/test_scheduler.py | 6 +- tests/v1/engine/conftest.py | 6 +- tests/v1/engine/test_async_llm.py | 10 +- tests/v1/engine/test_engine_core.py | 3 +- tests/v1/engine/test_engine_core_client.py | 10 +- tests/v1/engine/test_llm_engine.py | 8 +- tests/v1/engine/test_output_processor.py | 12 +- tests/v1/engine/utils.py | 50 ++--- .../v1/entrypoints/openai/test_completion.py | 12 +- tests/v1/sample/test_logprobs.py | 9 +- tests/v1/sample/test_rejection_sampler.py | 7 +- tests/v1/sample/test_sampler.py | 26 +-- tests/v1/sample/utils.py | 5 +- tests/v1/test_utils.py | 6 +- tests/v1/worker/test_gpu_input_batch.py | 22 +- .../vllm_test_utils/vllm_test_utils/blame.py | 3 +- .../vllm_test_utils/monitor.py | 3 +- .../test_encoder_decoder_model_runner.py | 21 +- tests/worker/test_model_input.py | 11 +- tests/worker/test_model_runner.py | 20 +- tools/profiler/print_layerwise_table.py | 3 +- tools/profiler/visualize_layerwise_profile.py | 14 +- vllm/_custom_ops.py | 56 ++--- vllm/_ipex_ops.py | 8 +- vllm/beam_search.py | 18 +- vllm/config.py | 141 ++++++------ vllm/connections.py | 3 +- vllm/entrypoints/api_server.py | 3 +- vllm/entrypoints/chat_utils.py | 49 ++-- vllm/entrypoints/cli/openai.py | 10 +- vllm/entrypoints/cli/serve.py | 3 +- vllm/entrypoints/llm.py | 210 +++++++++--------- vllm/entrypoints/logger.py | 4 +- vllm/entrypoints/openai/api_server.py | 9 +- vllm/entrypoints/openai/cli_args.py | 7 +- vllm/entrypoints/openai/logits_processors.py | 23 +- vllm/entrypoints/openai/protocol.py | 128 +++++------ .../abs_reasoning_parsers.py | 21 +- .../deepseek_r1_reasoning_parser.py | 5 +- vllm/entrypoints/openai/run_batch.py | 9 +- vllm/entrypoints/openai/serving_chat.py | 31 ++- vllm/entrypoints/openai/serving_completion.py | 32 +-- vllm/entrypoints/openai/serving_embedding.py | 15 +- vllm/entrypoints/openai/serving_engine.py | 43 ++-- vllm/entrypoints/openai/serving_models.py | 10 +- vllm/entrypoints/openai/serving_pooling.py | 15 +- vllm/entrypoints/openai/serving_score.py | 49 ++-- .../openai/serving_tokenization.py | 4 +- .../openai/serving_transcription.py | 3 +- .../tool_parsers/abstract_tool_parser.py | 21 +- .../granite_20b_fc_tool_parser.py | 5 +- .../tool_parsers/granite_tool_parser.py | 5 +- .../openai/tool_parsers/hermes_tool_parser.py | 7 +- .../tool_parsers/internlm2_tool_parser.py | 5 +- .../openai/tool_parsers/jamba_tool_parser.py | 11 +- .../openai/tool_parsers/llama_tool_parser.py | 11 +- .../tool_parsers/mistral_tool_parser.py | 13 +- .../tool_parsers/pythonic_tool_parser.py | 5 +- vllm/entrypoints/openai/tool_parsers/utils.py | 6 +- vllm/entrypoints/score_utils.py | 14 +- vllm/envs.py | 8 +- vllm/forward_context.py | 6 +- vllm/logger.py | 2 +- vllm/logits_process.py | 16 +- vllm/outputs.py | 24 +- vllm/sampling_params.py | 53 +++-- vllm/sequence.py | 132 +++++------ vllm/tracing.py | 3 +- vllm/utils.py | 76 +++---- vllm/v1/attention/backends/flash_attn.py | 18 +- vllm/v1/attention/backends/mla/common.py | 21 +- vllm/v1/attention/backends/mla/flashmla.py | 14 +- vllm/v1/attention/backends/mla/triton_mla.py | 8 +- vllm/v1/attention/backends/pallas.py | 16 +- vllm/v1/attention/backends/rocm_attn.py | 16 +- vllm/v1/core/block_pool.py | 17 +- vllm/v1/core/encoder_cache_manager.py | 16 +- vllm/v1/core/kv_cache_manager.py | 19 +- vllm/v1/core/kv_cache_utils.py | 32 +-- vllm/v1/core/scheduler.py | 45 ++-- vllm/v1/core/scheduler_output.py | 42 ++-- vllm/v1/engine/__init__.py | 16 +- vllm/v1/engine/async_llm.py | 9 +- vllm/v1/engine/core.py | 16 +- vllm/v1/engine/core_client.py | 34 +-- vllm/v1/engine/detokenizer.py | 12 +- vllm/v1/engine/llm_engine.py | 17 +- vllm/v1/engine/logprobs.py | 12 +- vllm/v1/engine/mm_input_cache.py | 18 +- vllm/v1/engine/output_processor.py | 20 +- vllm/v1/engine/parallel_sampling.py | 16 +- vllm/v1/engine/processor.py | 3 +- vllm/v1/executor/abstract.py | 10 +- vllm/v1/executor/multiproc_executor.py | 10 +- vllm/v1/kv_cache_interface.py | 7 +- vllm/v1/metrics/loggers.py | 18 +- vllm/v1/metrics/stats.py | 20 +- vllm/v1/outputs.py | 20 +- vllm/v1/request.py | 24 +- vllm/v1/sample/metadata.py | 10 +- vllm/v1/sample/ops/penalties.py | 12 +- vllm/v1/sample/ops/topk_topp_sampler.py | 10 +- vllm/v1/sample/rejection_sampler.py | 7 +- vllm/v1/stats/common.py | 18 +- vllm/v1/utils.py | 20 +- vllm/v1/worker/block_table.py | 6 +- vllm/v1/worker/gpu_input_batch.py | 62 +++--- vllm/v1/worker/gpu_model_runner.py | 34 +-- vllm/v1/worker/gpu_worker.py | 4 +- vllm/v1/worker/lora_model_runner_mixin.py | 17 +- vllm/v1/worker/tpu_model_runner.py | 24 +- vllm/v1/worker/tpu_worker.py | 6 +- 300 files changed, 2294 insertions(+), 2347 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index e43549c13c8e..158705769b5e 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -6,7 +6,7 @@ import time import traceback from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import Optional, Union import aiohttp import huggingface_hub.constants @@ -41,8 +41,8 @@ class RequestFuncOutput: latency: float = 0.0 output_tokens: int = 0 ttft: float = 0.0 # Time to first token - itl: List[float] = field( - default_factory=list) # List of inter-token latencies + itl: list[float] = field( + default_factory=list) # list of inter-token latencies tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" diff --git a/benchmarks/benchmark_guided.py b/benchmarks/benchmark_guided.py index dc2bf0e79cbc..2e0f6c6b5d20 100644 --- a/benchmarks/benchmark_guided.py +++ b/benchmarks/benchmark_guided.py @@ -6,7 +6,6 @@ import os import random import time -from typing import List import datasets import pandas as pd @@ -39,7 +38,7 @@ class SampleRequest: completion: str = None -def run_vllm(requests: List[SampleRequest], +def run_vllm(requests: list[SampleRequest], engine_args: EngineArgs, n: int, guided_decoding_rate: float = 1.0, @@ -54,8 +53,8 @@ def run_vllm(requests: List[SampleRequest], " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. - prompts: List[str] = [] - sampling_params: List[SamplingParams] = [] + prompts: list[str] = [] + sampling_params: list[SamplingParams] = [] # create a list containing random selected true or false guided_decoding_req_idx = random.sample( range(len(requests)), int(len(requests) * guided_decoding_rate)) @@ -110,7 +109,7 @@ def run_vllm(requests: List[SampleRequest], async def run_vllm_async( - requests: List[SampleRequest], + requests: list[SampleRequest], engine_args: AsyncEngineArgs, n: int, guided_decoding_rate: float = 1.0, @@ -129,8 +128,8 @@ async def run_vllm_async( " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. - prompts: List[str] = [] - sampling_params: List[SamplingParams] = [] + prompts: list[str] = [] + sampling_params: list[SamplingParams] = [] guided_decoding_req_idx = random.sample( range(len(requests)), int(len(requests) * guided_decoding_rate)) @@ -203,7 +202,7 @@ async def run_vllm_async( def sample_requests(tokenizer: PreTrainedTokenizerBase, - args: argparse.Namespace) -> List[SampleRequest]: + args: argparse.Namespace) -> list[SampleRequest]: if args.dataset == 'json': if args.json_schema_path is None: dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -287,7 +286,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, elif args.dataset == "xgrammar_bench": args.warmup = False - requests: List[SampleRequest] = [] + requests: list[SampleRequest] = [] dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train") print(f"dataset has {len(dataset)} entries") diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index c82358d14512..d7f39f50f6ca 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -7,7 +7,7 @@ import os import time from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Optional import numpy as np import torch @@ -22,7 +22,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: Dict[str, Any]) -> None: + results: dict[str, Any]) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={"latency": results["latencies"]}, @@ -57,7 +57,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompts: List[PromptType] = [{ + dummy_prompts: list[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 23822856b882..fba32520442f 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -31,7 +31,7 @@ import json import random import time -from typing import List, Optional, Tuple +from typing import Optional from transformers import PreTrainedTokenizerBase @@ -77,9 +77,9 @@ def sample_requests_from_dataset( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, - input_length_range: Tuple[int, int], + input_length_range: tuple[int, int], fixed_output_len: Optional[int], -) -> List[Request]: +) -> list[Request]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -99,7 +99,7 @@ def sample_requests_from_dataset( assert min_len >= 0 and max_len >= min_len, "input_length_range too small" # Filter out sequences that are too long or too short - filtered_requests: List[Request] = [] + filtered_requests: list[Request] = [] for i in range(len(dataset)): if len(filtered_requests) == num_requests: @@ -122,10 +122,10 @@ def sample_requests_from_dataset( def sample_requests_from_random( num_requests: int, tokenizer: PreTrainedTokenizerBase, - input_length_range: Tuple[int, int], + input_length_range: tuple[int, int], fixed_output_len: Optional[int], prefix_len: int, -) -> List[Request]: +) -> list[Request]: requests = [] prefix_token_ids = sample_tokens(tokenizer, prefix_len) @@ -144,9 +144,9 @@ def sample_requests_from_random( return requests -def repeat_and_sort_requests(requests: List[Request], +def repeat_and_sort_requests(requests: list[Request], repeat_count: int, - sort: bool = False) -> List[str]: + sort: bool = False) -> list[str]: repeated_requests = requests * repeat_count if sort: repeated_requests.sort(key=lambda x: x[1]) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index 24014e5b6c37..43b2c1b03323 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -5,7 +5,7 @@ import json import random import time -from typing import List, Optional, Tuple +from typing import Optional from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -23,7 +23,7 @@ def sample_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int], -) -> List[Tuple[str, int, int]]: +) -> list[tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -40,7 +40,7 @@ def sample_requests( random.shuffle(dataset) # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] + filtered_dataset: list[tuple[str, int, int]] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break @@ -68,7 +68,7 @@ def sample_requests( def run_vllm( - requests: List[Tuple[str, int, int]], + requests: list[tuple[str, int, int]], n: int, engine_args: EngineArgs, ) -> float: diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 1bb83b082beb..16ec0a4817a2 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -33,9 +33,10 @@ import random import time import warnings +from collections.abc import AsyncGenerator, Collection from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple +from typing import Any, Optional import numpy as np import pandas as pd @@ -73,22 +74,22 @@ class BenchmarkMetrics: mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float - percentiles_ttft_ms: List[Tuple[float, float]] + percentiles_ttft_ms: list[tuple[float, float]] mean_tpot_ms: float median_tpot_ms: float std_tpot_ms: float - percentiles_tpot_ms: List[Tuple[float, float]] + percentiles_tpot_ms: list[tuple[float, float]] mean_itl_ms: float median_itl_ms: float std_itl_ms: float - percentiles_itl_ms: List[Tuple[float, float]] + percentiles_itl_ms: list[tuple[float, float]] # E2EL stands for end-to-end latency per request. # It is the time taken on the client side from sending # a request to receiving a complete response. mean_e2el_ms: float median_e2el_ms: float std_e2el_ms: float - percentiles_e2el_ms: List[Tuple[float, float]] + percentiles_e2el_ms: list[tuple[float, float]] def sample_sharegpt_requests( @@ -96,7 +97,7 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, int, int, None]]: +) -> list[tuple[str, int, int, None]]: # Load the dataset. with open(dataset_path, encoding='utf-8') as f: dataset = json.load(f) @@ -110,7 +111,7 @@ def sample_sharegpt_requests( random.shuffle(dataset) # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] + filtered_dataset: list[tuple[str, int, int]] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break @@ -139,7 +140,7 @@ def sample_burstgpt_requests( num_requests: int, random_seed: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, int, int, None]]: +) -> list[tuple[str, int, int, None]]: df = pd.read_csv(dataset_path) gpt4_df = df[df["Model"] == "GPT-4"] # Remove the failed requests (i.e., response length is 0) @@ -170,7 +171,7 @@ def sample_sonnet_requests( output_len: int, prefix_len: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, str, int, int, None]]: +) -> list[tuple[str, str, int, int, None]]: assert ( input_len > prefix_len ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." @@ -211,7 +212,7 @@ def sample_sonnet_requests( prefix_lines = poem_lines[:num_prefix_lines] # Sample the rest of lines per request. - sampled_requests: List[Tuple[str, int, int]] = [] + sampled_requests: list[tuple[str, int, int]] = [] for _ in range(num_requests): num_lines_needed = num_input_lines - num_prefix_lines sampled_lines = "".join(prefix_lines + @@ -238,8 +239,8 @@ def sample_vision_arena_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: - sampled_requests: List[Tuple[str, int, int, Dict[str, +) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]: + sampled_requests: list[tuple[str, int, int, dict[str, Collection[str]]]] = [] for data in dataset: if len(sampled_requests) == num_requests: @@ -285,7 +286,7 @@ def sample_hf_requests( tokenizer: PreTrainedTokenizerBase, random_seed: int, fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: +) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]: # Special case for vision_arena dataset if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \ @@ -307,7 +308,7 @@ def sample_hf_requests( "HF Dataset must have 'conversations' column.") filter_func = lambda x: len(x["conversations"]) >= 2 filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func) - sampled_requests: List[Tuple[str, int, int, Dict[str, + sampled_requests: list[tuple[str, int, int, dict[str, Collection[str]]]] = [] for data in filtered_dataset: if len(sampled_requests) == num_requests: @@ -370,7 +371,7 @@ def sample_random_requests( num_prompts: int, range_ratio: float, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, int, int]]: +) -> list[tuple[str, int, int]]: prefix_token_ids = np.random.randint(0, tokenizer.vocab_size, size=prefix_len).tolist() @@ -399,10 +400,10 @@ def sample_random_requests( async def get_request( - input_requests: List[Tuple[str, int, int]], + input_requests: list[tuple[str, int, int]], request_rate: float, burstiness: float = 1.0, -) -> AsyncGenerator[Tuple[str, int, int], None]: +) -> AsyncGenerator[tuple[str, int, int], None]: """ Asynchronously generates requests at a specified rate with OPTIONAL burstiness. @@ -443,23 +444,23 @@ async def get_request( def calculate_metrics( - input_requests: List[Tuple[str, int, int]], - outputs: List[RequestFuncOutput], + input_requests: list[tuple[str, int, int]], + outputs: list[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, - selected_percentile_metrics: List[str], - selected_percentiles: List[float], - goodput_config_dict: Dict[str, float], -) -> Tuple[BenchmarkMetrics, List[int]]: - actual_output_lens: List[int] = [] + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], +) -> tuple[BenchmarkMetrics, list[int]]: + actual_output_lens: list[int] = [] total_input = 0 completed = 0 good_completed = 0 - itls: List[float] = [] - tpots: List[float] = [] - all_tpots: List[float] = [] - ttfts: List[float] = [] - e2els: List[float] = [] + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] for i in range(len(outputs)): if outputs[i].success: output_len = outputs[i].output_tokens @@ -557,19 +558,19 @@ async def benchmark( model_id: str, model_name: str, tokenizer: PreTrainedTokenizerBase, - input_requests: List[Tuple[str, int, int]], + input_requests: list[tuple[str, int, int]], logprobs: Optional[int], best_of: int, request_rate: float, burstiness: float, disable_tqdm: bool, profile: bool, - selected_percentile_metrics: List[str], - selected_percentiles: List[str], + selected_percentile_metrics: list[str], + selected_percentiles: list[str], ignore_eos: bool, - goodput_config_dict: Dict[str, float], + goodput_config_dict: dict[str, float], max_concurrency: Optional[int], - lora_modules: Optional[List[str]], + lora_modules: Optional[list[str]], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -652,7 +653,7 @@ async def limited_request_func(request_func_input, pbar): pbar=pbar) benchmark_start_time = time.perf_counter() - tasks: List[asyncio.Task] = [] + tasks: list[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request req_model_id, req_model_name = model_id, model_name @@ -674,7 +675,7 @@ async def limited_request_func(request_func_input, pbar): asyncio.create_task( limited_request_func(request_func_input=request_func_input, pbar=pbar))) - outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: print("Stopping profiler...") @@ -820,7 +821,7 @@ def parse_goodput(slo_pairs): def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: Dict[str, Any], + results: dict[str, Any], file_name: str) -> None: metrics = [ "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", @@ -974,7 +975,7 @@ def main(args: argparse.Namespace): # Save config and results to json if args.save_result: - result_json: Dict[str, Any] = {} + result_json: dict[str, Any] = {} # Setup current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/benchmarks/benchmark_serving_guided.py b/benchmarks/benchmark_serving_guided.py index 05eadff79787..6c132d05f1b6 100644 --- a/benchmarks/benchmark_serving_guided.py +++ b/benchmarks/benchmark_serving_guided.py @@ -30,8 +30,9 @@ import random import time import warnings +from collections.abc import AsyncGenerator from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Optional, Tuple +from typing import Optional import datasets import numpy as np @@ -66,22 +67,22 @@ class BenchmarkMetrics: mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float - percentiles_ttft_ms: List[Tuple[float, float]] + percentiles_ttft_ms: list[tuple[float, float]] mean_tpot_ms: float median_tpot_ms: float std_tpot_ms: float - percentiles_tpot_ms: List[Tuple[float, float]] + percentiles_tpot_ms: list[tuple[float, float]] mean_itl_ms: float median_itl_ms: float std_itl_ms: float - percentiles_itl_ms: List[Tuple[float, float]] + percentiles_itl_ms: list[tuple[float, float]] # E2EL stands for end-to-end latency per request. # It is the time taken on the client side from sending # a request to receiving a complete response. mean_e2el_ms: float median_e2el_ms: float std_e2el_ms: float - percentiles_e2el_ms: List[Tuple[float, float]] + percentiles_e2el_ms: list[tuple[float, float]] @dataclasses.dataclass @@ -104,7 +105,7 @@ class SampleRequest: def sample_requests(tokenizer: PreTrainedTokenizerBase, - args: argparse.Namespace) -> List[SampleRequest]: + args: argparse.Namespace) -> list[SampleRequest]: if args.dataset == 'json': if args.json_schema_path is None: dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -187,7 +188,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ] elif args.dataset == "xgrammar_bench": - requests: List[SampleRequest] = [] + requests: list[SampleRequest] = [] dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train") print(f"dataset has {len(dataset)} entries") @@ -214,10 +215,10 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, async def get_request( - input_requests: List[SampleRequest], + input_requests: list[SampleRequest], request_rate: float, burstiness: float = 1.0, -) -> AsyncGenerator[Tuple[int, SampleRequest], None]: +) -> AsyncGenerator[tuple[int, SampleRequest], None]: """ Asynchronously generates requests at a specified rate with OPTIONAL burstiness. @@ -258,23 +259,23 @@ async def get_request( def calculate_metrics( - input_requests: List[Tuple[str, int, int]], - outputs: List[RequestFuncOutput], + input_requests: list[tuple[str, int, int]], + outputs: list[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, - selected_percentile_metrics: List[str], - selected_percentiles: List[float], - goodput_config_dict: Optional[Dict[str, float]] = None, -) -> Tuple[BenchmarkMetrics, List[int]]: - actual_output_lens: List[int] = [] + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + goodput_config_dict: Optional[dict[str, float]] = None, +) -> tuple[BenchmarkMetrics, list[int]]: + actual_output_lens: list[int] = [] total_input = 0 completed = 0 good_completed = 0 - itls: List[float] = [] - tpots: List[float] = [] - all_tpots: List[float] = [] - ttfts: List[float] = [] - e2els: List[float] = [] + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] for i in range(len(outputs)): if outputs[i].success: # We use the tokenizer to count the number of output tokens for all @@ -368,18 +369,18 @@ async def benchmark( base_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, - input_requests: List[SampleRequest], + input_requests: list[SampleRequest], request_rate: float, burstiness: float, disable_tqdm: bool, profile: bool, - selected_percentile_metrics: List[str], - selected_percentiles: List[str], + selected_percentile_metrics: list[str], + selected_percentiles: list[str], ignore_eos: bool, max_concurrency: Optional[int], guided_decoding_ratio: float, guided_decoding_backend: str, - goodput_config_dict: Optional[Dict[str, float]] = None, + goodput_config_dict: Optional[dict[str, float]] = None, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -459,8 +460,8 @@ async def limited_request_func(request_func_input, pbar): pbar=pbar) benchmark_start_time = time.perf_counter() - tasks: List[asyncio.Task] = [] - expected: List[str] = [] + tasks: list[asyncio.Task] = [] + expected: list[str] = [] async for i, request in get_request(input_requests, request_rate, burstiness): extra_body = prepare_extra_body( @@ -479,7 +480,7 @@ async def limited_request_func(request_func_input, pbar): asyncio.create_task( limited_request_func(request_func_input=request_func_input, pbar=pbar))) - outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: print("Stopping profiler...") diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 04de08fa97c9..aabce64ff776 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -7,7 +7,7 @@ import random import time from functools import cache -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import torch import uvloop @@ -74,12 +74,12 @@ def lora_path_on_disk(lora_path: str) -> str: return get_adapter_absolute_path(lora_path) -lora_tokenizer_cache: Dict[int, AnyTokenizer] = {} +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} def get_random_lora_request( args: argparse.Namespace -) -> Tuple[LoRARequest, Optional[AnyTokenizer]]: +) -> tuple[LoRARequest, Optional[AnyTokenizer]]: global lora_tokenizer_cache lora_id = random.randint(1, args.max_loras) lora_request = LoRARequest(lora_name=str(lora_id), @@ -91,7 +91,7 @@ def get_random_lora_request( def sample_requests(tokenizer: PreTrainedTokenizerBase, - args: argparse.Namespace) -> List[SampleRequest]: + args: argparse.Namespace) -> list[SampleRequest]: dataset_path: str = args.dataset num_requests: int = args.num_prompts @@ -109,7 +109,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, random.shuffle(dataset) # Filter out sequences that are too long or too short - filtered_dataset: List[SampleRequest] = [] + filtered_dataset: list[SampleRequest] = [] for data in tqdm(dataset, total=len(filtered_dataset), desc="sampling requests"): @@ -165,7 +165,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, def run_vllm( - requests: List[SampleRequest], + requests: list[SampleRequest], n: int, engine_args: EngineArgs, ) -> float: @@ -178,8 +178,8 @@ def run_vllm( "Please ensure that max_model_len is greater than the sum of" " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. - prompts: List[TextPrompt] = [] - sampling_params: List[SamplingParams] = [] + prompts: list[TextPrompt] = [] + sampling_params: list[SamplingParams] = [] for request in requests: prompts.append( TextPrompt(prompt=request.prompt, @@ -192,7 +192,7 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, )) - lora_requests: Optional[List[LoRARequest]] = None + lora_requests: Optional[list[LoRARequest]] = None if engine_args.enable_lora: lora_requests = [request.lora_request for request in requests] @@ -225,7 +225,7 @@ def run_vllm( async def run_vllm_async( - requests: List[SampleRequest], + requests: list[SampleRequest], n: int, engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, @@ -242,9 +242,9 @@ async def run_vllm_async( " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. - prompts: List[TextPrompt] = [] - sampling_params: List[SamplingParams] = [] - lora_requests: List[Optional[LoRARequest]] = [] + prompts: list[TextPrompt] = [] + sampling_params: list[SamplingParams] = [] + lora_requests: list[Optional[LoRARequest]] = [] for request in requests: prompts.append( TextPrompt(prompt=request.prompt, @@ -276,7 +276,7 @@ async def run_vllm_async( def run_hf( - requests: List[SampleRequest], + requests: list[SampleRequest], model: str, tokenizer: PreTrainedTokenizerBase, n: int, @@ -292,7 +292,7 @@ def run_hf( pbar = tqdm(total=len(requests)) start = time.perf_counter() - batch: List[str] = [] + batch: list[str] = [] max_prompt_len = 0 max_output_len = 0 for i in range(len(requests)): @@ -334,7 +334,7 @@ def run_hf( def run_mii( - requests: List[SampleRequest], + requests: list[SampleRequest], model: str, tensor_parallel_size: int, output_len: int, @@ -352,7 +352,7 @@ def run_mii( def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: Dict[str, Any]) -> None: + results: dict[str, Any]) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={ @@ -479,8 +479,8 @@ def main(args: argparse.Namespace): type=str, default=None, help="Path to the dataset. The dataset is expected to " - "be a json in form of List[Dict[..., conversations: " - "List[Dict[..., value: ]]]]") + "be a json in form of list[dict[..., conversations: " + "list[dict[..., value: ]]]]") parser.add_argument("--input-len", type=int, default=None, diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index ac0688ca013f..45a0ddbd5d08 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -4,12 +4,12 @@ import json import math import os -from typing import Any, Dict, List +from typing import Any def convert_to_pytorch_benchmark_format(args: argparse.Namespace, - metrics: Dict[str, List], - extra_info: Dict[str, Any]) -> List: + metrics: dict[str, list], + extra_info: dict[str, Any]) -> list: """ Save the benchmark results in the format used by PyTorch OSS benchmark with on metric per record @@ -64,6 +64,6 @@ def iterencode(self, o: Any, *args, **kwargs) -> Any: return super().iterencode(self.clear_inf(o), *args, **kwargs) -def write_to_json(filename: str, records: List) -> None: +def write_to_json(filename: str, records: list) -> None: with open(filename, "w") as f: json.dump(records, f, cls=InfEncoder) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index 468a1b2868f0..9e36b0a9d3bb 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -5,7 +5,8 @@ import itertools import pickle as pkl import time -from typing import Callable, Iterable, List, Tuple +from collections.abc import Iterable +from typing import Callable import torch import torch.utils.benchmark as TBenchmark @@ -228,7 +229,7 @@ def print_timers(timers: Iterable[TMeasurement]): def run(dtype: torch.dtype, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", @@ -241,7 +242,7 @@ def run(dtype: torch.dtype, # output makers def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[Tuple[int, int, int]], + MKNs: Iterable[tuple[int, int, int]], base_description: str, timestamp=None): print(f"== All Results {base_description} ====") @@ -282,7 +283,7 @@ def run_model_bench(args): for i, model in enumerate(args.models): print(f"[{i}] {model}") - def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: KNs = [] for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): KN[tp_split_dim] = KN[tp_split_dim] // tp_size diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index bab377800729..fe4d8fdfc066 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Cutlass bench utils -from typing import Iterable, Tuple +from collections.abc import Iterable import torch @@ -27,7 +27,7 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor: def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: + k: int) -> tuple[torch.Tensor, torch.Tensor]: a = torch.randn((m, k), device='cuda') * 5 b = torch.randn((n, k), device='cuda').t() * 5 @@ -63,7 +63,7 @@ def prune_to_2_4(tensor): def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: + k: int) -> tuple[torch.Tensor, torch.Tensor]: a = torch.randn((m, k), device='cuda') * 5 b = torch.randn((n, k), device='cuda').t() * 5 @@ -88,7 +88,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int) -> \ - Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: ABs = [] for _ in range(num_tensors): b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 6552b62dae88..e7b742d8bec9 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -5,7 +5,8 @@ import itertools import pickle as pkl import time -from typing import Callable, Iterable, List, Optional, Tuple +from collections.abc import Iterable +from typing import Callable, Optional import torch import torch.utils.benchmark as TBenchmark @@ -49,7 +50,7 @@ def bench_int8( n: int, label: str, sub_label: str, - bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: + bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: """Benchmark INT8-based kernels.""" assert dtype == torch.int8 a, b = make_rand_tensors(torch.int8, m, n, k) @@ -101,7 +102,7 @@ def bench_fp8( n: int, label: str, sub_label: str, - bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: + bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: """Benchmark FP8-based kernels.""" assert dtype == torch.float8_e4m3fn a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) @@ -180,7 +181,7 @@ def bench(dtype: torch.dtype, n: int, label: str, sub_label: str, - bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: + bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: if dtype == torch.int8: return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) if dtype == torch.float8_e4m3fn: @@ -195,8 +196,8 @@ def print_timers(timers: Iterable[TMeasurement]): def run(dtype: torch.dtype, - MKNs: Iterable[Tuple[int, int, int]], - bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: + MKNs: Iterable[tuple[int, int, int]], + bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: timers = bench(dtype, @@ -212,7 +213,7 @@ def run(dtype: torch.dtype, def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[Tuple[int, int, int]], + MKNs: Iterable[tuple[int, int, int]], base_description: str, timestamp=None): print(f"== All Results {base_description} ====") @@ -248,7 +249,7 @@ def run_model_bench(args): for i, model in enumerate(args.models): print(f"[{i}] {model}") - def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: KNs = [] for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): KN[tp_split_dim] = KN[tp_split_dim] // tp_size diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index c56cc743845e..3da583a33448 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -2,9 +2,10 @@ import pickle as pkl import time +from collections.abc import Iterable from dataclasses import dataclass from itertools import product -from typing import Callable, Iterable, List, Optional +from typing import Callable, Optional import torch import torch.utils.benchmark as TBenchmark @@ -29,7 +30,7 @@ def description(self): f'x DT {self.dtype}') -def get_bench_params() -> List[bench_params_t]: +def get_bench_params() -> list[bench_params_t]: ## Test Fixtures NUM_TOKENS = [2**x for x in range(11)] HIDDEN_SIZES = list(range(1024, 8129, 1024)) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 1deb0026a6e5..5eaeec017053 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -9,7 +9,7 @@ from enum import Enum, auto from itertools import product from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import torch import torch.utils.benchmark as TBenchmark @@ -61,15 +61,15 @@ def make_rand_lora_weight_tensor(k: int, def make_rand_tensors( - a_shape: Tuple[int], - b_shape: Tuple[int], - c_shape: Tuple[int], + a_shape: tuple[int], + b_shape: tuple[int], + c_shape: tuple[int], a_dtype: torch.dtype, b_dtype: torch.dtype, c_dtype: torch.dtype, num_slices: int, device: str = "cuda", -) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: +) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: """ Make LoRA input/output matrices. """ @@ -135,7 +135,7 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int, def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, - lora_weights: List[torch.Tensor], + lora_weights: list[torch.Tensor], seq_lens_cpu: torch.Tensor, prompt_lora_mapping_cpu: torch.Tensor, scaling: float, add_inputs: Optional[bool]): @@ -204,7 +204,7 @@ def is_decode_op(self) -> bool: def is_expand_slice_fn(self) -> bool: return self in [OpType.BGMV_EXPAND_SLICE] - def num_slices(self) -> List[int]: + def num_slices(self) -> list[int]: if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]: # SGMV kernels supports slices return [1, 2, 3] @@ -215,7 +215,7 @@ def num_slices(self) -> List[int]: raise ValueError(f"Unrecognized OpType {self}") def mkn(self, batch_size: int, seq_length: int, hidden_size: int, - lora_rank: int) -> Tuple[int, int, int]: + lora_rank: int) -> tuple[int, int, int]: num_tokens = batch_size * seq_length if self.is_shrink_fn(): m = num_tokens @@ -230,7 +230,7 @@ def mkn(self, batch_size: int, seq_length: int, hidden_size: int, def matmul_dtypes( self, op_dtype: torch.dtype - ) -> Tuple[torch.dtype, torch.dtype, torch.dtype]: + ) -> tuple[torch.dtype, torch.dtype, torch.dtype]: """ return a type, b type and c type for A x B = C """ @@ -243,7 +243,7 @@ def matmul_dtypes( def matmul_shapes( self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int, num_loras: int, - num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: + num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]: """ Given num_slices, return the shapes of the A, B, and C matrices in A x B = C, for the op_type @@ -268,7 +268,7 @@ def matmul_shapes( def bench_fn(self) -> Callable: - def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): + def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]): for x in kwargs_list: bgmv_expand_slice(**x) @@ -285,7 +285,7 @@ def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): raise ValueError(f"Unrecognized optype {self}") def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, - lora_weights: List[torch.Tensor], + lora_weights: list[torch.Tensor], **kwargs) -> Callable: """Each benchmark operation expected the input, lora_weights and outputs in a slightly different format. Refer to self.matmul_shapes(). @@ -384,7 +384,7 @@ class BenchmarkTensors: """ # matmul tensors input: torch.Tensor - lora_weights_lst: List[torch.Tensor] + lora_weights_lst: list[torch.Tensor] output: torch.Tensor # metadata tensors seq_lens: torch.Tensor @@ -469,7 +469,7 @@ def to_device(tensor: torch.Tensor): for i in range(len(self.lora_weights_lst)): self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) - def metadata(self) -> Tuple[int, int, int]: + def metadata(self) -> tuple[int, int, int]: """ Return num_seqs, num_tokens and max_seq_len """ @@ -505,7 +505,7 @@ def convert_to_sgmv_benchmark_tensors(self): self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype) self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype) - def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: + def as_sgmv_shrink_kwargs(self) -> dict[str, Any]: self.convert_to_sgmv_benchmark_tensors() self.sanity_check() self.to_device(self.input.device) @@ -540,7 +540,7 @@ def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: 'scaling': 1.0, } - def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]: + def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: self.convert_to_sgmv_benchmark_tensors() self.sanity_check() @@ -578,7 +578,7 @@ def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]: 'add_inputs': add_inputs, } - def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: + def as_bgmv_shrink_kwargs(self) -> dict[str, Any]: assert len(self.lora_weights_lst) == 1 self.to_device(self.input.device) @@ -634,7 +634,7 @@ def as_bgmv_expand_kwargs(self, add_inputs: bool): 'add_inputs': add_inputs } - def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: + def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]: _, num_tokens, _, num_slices = self.metadata() # Sanity check shapes @@ -670,7 +670,7 @@ def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: def bench_fn_kwargs(self, op_type: OpType, - add_inputs: Optional[bool] = None) -> Dict[str, Any]: + add_inputs: Optional[bool] = None) -> dict[str, Any]: if op_type.is_shrink_fn(): assert add_inputs is None else: @@ -734,7 +734,7 @@ def bench_optype(ctx: BenchmarkContext, assert expand_fn_add_inputs is not None # BenchmarkContext -> BenchmarkTensors - bench_tensors : List[BenchmarkTensors] = \ + bench_tensors : list[BenchmarkTensors] = \ [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] for bt in bench_tensors: bt.sanity_check() @@ -746,7 +746,7 @@ def bench_optype(ctx: BenchmarkContext, for bt in bench_tensors ]) - # BenchmarkTensors -> Dict (kwargs) + # BenchmarkTensors -> dict (kwargs) kwargs_list = [ bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors @@ -841,7 +841,7 @@ def use_cuda_graph_recommendation() -> str: """ -def print_timers(timers: List[TMeasurement], +def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None): compare = TBenchmark.Compare(timers) compare.print() @@ -861,7 +861,7 @@ def print_timers(timers: List[TMeasurement], "small num_loras the goal should be to match the torch.mm numbers.") -def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): +def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): if args.cuda_graph_nops is not None: assert args.cuda_graph_nops > 0 @@ -873,7 +873,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): timers = [] for bench_ctx in bench_ctxs: for seq_len in args.seq_lengths: - bench_ops: List[OpType] = [] + bench_ops: list[OpType] = [] if seq_len == 1: # bench all decode ops bench_ops = [op for op in args.op_types if op.is_decode_op()] @@ -921,10 +921,10 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): pickle.dump(timers, f) -def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int], - args: argparse.Namespace) -> List[BenchmarkContext]: +def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], + args: argparse.Namespace) -> list[BenchmarkContext]: - ctxs: List[BenchmarkContext] = [] + ctxs: list[BenchmarkContext] = [] for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, args.sort_by_lora_id): @@ -954,7 +954,7 @@ def run_list_bench(args: argparse.Namespace): f" LoRA Ranks {args.lora_ranks}") # Get all benchmarking contexts - bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) run(args, bench_contexts) @@ -975,7 +975,7 @@ def run_range_bench(args: argparse.Namespace): f" LoRA Ranks {lora_ranks}") # Get all benchmarking contexts - bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) run(args, bench_contexts) @@ -1002,7 +1002,7 @@ def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]: f" LoRA Ranks {args.lora_ranks}") # Get all benchmarking contexts - bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) run(args, bench_contexts) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 0301fee1a886..3fa57bd7b233 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -7,9 +7,10 @@ import os import pickle as pkl import time +from collections.abc import Iterable from dataclasses import dataclass from itertools import product -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Optional import pandas as pd import torch @@ -102,8 +103,8 @@ def quantize_and_pack(atype: torch.dtype, return w_ref, w_q, w_s, w_zp -def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, - group_size: Optional[int]) -> List[BenchmarkTensors]: +def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> list[BenchmarkTensors]: m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so @@ -114,7 +115,7 @@ def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, a = rand_data((m, k), types.act_type, scale=5) - benchmark_tensors: List[BenchmarkTensors] = [] + benchmark_tensors: list[BenchmarkTensors] = [] for _ in range(num_weights): w = rand_data((k, n), types.act_type, scale=5) @@ -276,7 +277,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors, def bench_fns(label: str, sub_label: str, description: str, - fns: List[Callable]): + fns: list[Callable]): min_run_time = 1 if not NVTX_PROFILE else 0.1 res = TBenchmark.Timer( @@ -311,7 +312,7 @@ def bench(types: TypeConfig, n: int, label: str, sub_label: str, - sweep_schedules: bool = True) -> List[TMeasurement]: + sweep_schedules: bool = True) -> list[TMeasurement]: benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) sub_label += f", L={len(benchmark_tensors)}" @@ -414,12 +415,12 @@ def bench(types: TypeConfig, # runner -def print_timers(timers: List[TMeasurement]): +def print_timers(timers: list[TMeasurement]): compare = TBenchmark.Compare(timers) compare.print() -def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: +def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: types = TypeConfig( act_type=args.act_type, weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ @@ -431,7 +432,7 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: token_scale_type=args.token_scale_type, ) - results: List[TMeasurement] = [] + results: list[TMeasurement] = [] for m, k, n in MKNs: timers = bench(types, args.group_size, @@ -449,8 +450,8 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: # output makers def make_output( - data: List[TMeasurement], - MKNs: Iterable[Tuple[int, int, int]], + data: list[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], base_description: str, timestamp=None, ): @@ -497,7 +498,7 @@ def run_model_bench(args): for i, model in enumerate(args.models): print(f"[{i}] {model}") - def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: KNs = [] for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): KN[tp_split_dim] = KN[tp_split_dim] // tp_size diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 21ef491294e3..1e785ac8fc73 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import torch import torch.utils.benchmark as benchmark from benchmark_shapes import WEIGHT_SHAPES @@ -31,7 +29,7 @@ K_FULL_OPTS = [False, True] -def bench_run(results: List[benchmark.Measurement], model: str, +def bench_run(results: list[benchmark.Measurement], model: str, act_order: bool, is_k_full: bool, quant_type: ScalarType, group_size: int, size_m: int, size_k: int, size_n: int): label = "Quant Matmul" @@ -221,7 +219,7 @@ def main(args): for i, model in enumerate(args.models): print(f"[{i}] {model}") - results: List[benchmark.Measurement] = [] + results: list[benchmark.Measurement] = [] for model in args.models: for layer in WEIGHT_SHAPES[model]: diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 410750686ee1..c862dec81fcc 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -4,7 +4,7 @@ import time from datetime import datetime from itertools import product -from typing import Any, Dict, List, Tuple, TypedDict +from typing import Any, TypedDict import ray import torch @@ -132,7 +132,7 @@ def run(): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - latencies: List[float] = [] + latencies: list[float] = [] for i in range(num_iters): prepare(i) torch.cuda.synchronize() @@ -175,8 +175,8 @@ def get_rocm_tuning_space(use_fp16): return param_ranges -def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]: - configs: List[BenchmarkConfig] = [] +def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]: + configs: list[BenchmarkConfig] = [] if current_platform.is_rocm(): param_ranges = get_rocm_tuning_space(use_fp16) @@ -335,7 +335,7 @@ def benchmark( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - ) -> Tuple[Dict[str, int], float]: + ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, @@ -371,8 +371,8 @@ def tune( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - search_space: List[Dict[str, int]], - ) -> Dict[str, int]: + search_space: list[dict[str, int]], + ) -> dict[str, int]: best_config = None best_time = float("inf") if current_platform.is_rocm(): @@ -434,7 +434,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: } -def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, +def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, shard_intermediate_size: int, hidden_size: int, topk: int, dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: @@ -498,7 +498,7 @@ def main(args: argparse.Namespace): num_gpus = int(ray.available_resources()["GPU"]) workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] - def _distribute(method: str, inputs: List[Any]) -> List[Any]: + def _distribute(method: str, inputs: list[Any]) -> list[Any]: outputs = [] worker_idx = 0 for input_args in inputs: diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index daedaadb1a77..d00e84824361 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -2,7 +2,7 @@ import random import time -from typing import List, Optional +from typing import Optional import torch @@ -54,7 +54,7 @@ def main( # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables_lst: List[List[int]] = [] + block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ random.randint(0, NUM_BLOCKS - 1) diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py index dba153742da4..010a38b75271 100644 --- a/benchmarks/kernels/benchmark_rmsnorm.py +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch import triton @@ -22,7 +22,7 @@ def forward( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 8ee0212a0c11..05d24fc4b16d 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from itertools import accumulate -from typing import List, Optional +from typing import Optional import nvtx import torch @@ -39,7 +39,7 @@ def benchmark_rope_kernels_multi_lora( }) # non-batched RoPE takes only one scaling factor, we create multiple # instances to simulate the same behavior - non_batched_ropes: List[RotaryEmbedding] = [] + non_batched_ropes: list[RotaryEmbedding] = [] for scaling_factor in scaling_factors: non_batched_ropes.append( get_rope(head_size, rotary_dim, max_position, base, is_neox_style, diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index 01d97d63d7cf..bd62173a7b3a 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -4,7 +4,6 @@ import pickle import re from collections import defaultdict -from typing import List import matplotlib.pyplot as plt import pandas as pd @@ -23,7 +22,7 @@ with open(args.filename, 'rb') as f: data = pickle.load(f) - raw_results: List[TMeasurement] = data["results"] + raw_results: list[TMeasurement] = data["results"] results = defaultdict(lambda: list()) for v in raw_results: diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index 728170748492..ac64f786f184 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses -from typing import Any, Callable, Iterable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional import torch import torch.utils.benchmark as TBenchmark diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index d5a5e2ef83dd..d64f0d0a5c2a 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from typing import Dict, Union +from typing import Union from cutlass_library import * @@ -21,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum): TmaWarpSpecializedCooperative = enum_auto() -VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { +VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = { **DataTypeNames, # type: ignore **{ VLLMDataType.u4b8: "u4b8", @@ -29,7 +29,7 @@ class MixedInputKernelScheduleType(enum.Enum): } } -VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { +VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { **DataTypeTag, # type: ignore **{ VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", @@ -37,7 +37,7 @@ class MixedInputKernelScheduleType(enum.Enum): } } -VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = { +VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { **DataTypeSize, # type: ignore **{ VLLMDataType.u4b8: 4, @@ -45,7 +45,7 @@ class MixedInputKernelScheduleType(enum.Enum): } } -VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = { +VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { VLLMDataType.u4b8: "vllm::kU4B8", VLLMDataType.u8b128: "vllm::kU8B128", DataType.u4: "vllm::kU4", @@ -56,7 +56,7 @@ class MixedInputKernelScheduleType(enum.Enum): DataType.bf16: "vllm::kBfloat16", } -VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { +VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { DataType.u8: "at::ScalarType::Byte", DataType.s8: "at::ScalarType::Char", DataType.e4m3: "at::ScalarType::Float8_e4m3fn", @@ -66,7 +66,7 @@ class MixedInputKernelScheduleType(enum.Enum): DataType.f32: "at::ScalarType::Float", } -VLLMKernelScheduleTag: Dict[Union[ +VLLMKernelScheduleTag: dict[Union[ MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore **{ diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 02e59fe28b9a..3114e14baa0c 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -8,7 +8,7 @@ from copy import deepcopy from dataclasses import dataclass, fields from functools import reduce -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import jinja2 # yapf conflicts with isort for this block @@ -247,8 +247,8 @@ @dataclass(frozen=True) class ScheduleConfig: - tile_shape_mn: Tuple[int, int] - cluster_shape_mnk: Tuple[int, int, int] + tile_shape_mn: tuple[int, int] + cluster_shape_mnk: tuple[int, int, int] kernel_schedule: MixedInputKernelScheduleType epilogue_schedule: EpilogueScheduleType tile_scheduler: TileSchedulerType @@ -277,8 +277,8 @@ class PrepackTypeConfig: @dataclass class ImplConfig: types: TypeConfig - schedules: List[ScheduleConfig] - heuristic: List[Tuple[Optional[str], ScheduleConfig]] + schedules: list[ScheduleConfig] + heuristic: list[tuple[Optional[str], ScheduleConfig]] def generate_sch_sig(schedule_config: ScheduleConfig) -> str: @@ -333,7 +333,7 @@ def is_power_of_two(n): return (n != 0) and (n & (n - 1) == 0) -def to_cute_constant(value: List[int]): +def to_cute_constant(value: list[int]): def _to_cute_constant(value: int): if is_power_of_two(value): @@ -347,7 +347,7 @@ def _to_cute_constant(value: int): return _to_cute_constant(value) -def unique_schedules(impl_configs: List[ImplConfig]): +def unique_schedules(impl_configs: list[ImplConfig]): return list( set(sch for impl_config in impl_configs for sch in impl_config.schedules)) @@ -391,7 +391,7 @@ def create_template(template_str): prepack_dispatch_template = create_template(PREPACK_TEMPLATE) -def create_sources(impl_configs: List[ImplConfig], num_impl_files=8): +def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): sources = [] sources.append(( @@ -435,7 +435,7 @@ def prepacked_type_key(prepack_type: PrepackTypeConfig): num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) num_impls_per_file = math.ceil(num_impls / num_impl_files) - files_impls: List[List[ImplConfig]] = [[]] + files_impls: list[list[ImplConfig]] = [[]] curr_num_impls_assigned = 0 curr_impl_in_file = 0 @@ -515,7 +515,7 @@ def generate(): for cond, tile_config in default_tile_heuristic_config.items() ] - def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): + def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): # Do not use schedules = list(set(...)) because we need to make sure # the output list is deterministic; otherwise the generated kernel file # will be non-deterministic and causes ccache miss. diff --git a/docs/source/conf.py b/docs/source/conf.py index 97bec81b1eee..b72faef9af10 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,7 +17,6 @@ import logging import os import sys -from typing import List import requests from sphinx.ext import autodoc @@ -58,7 +57,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns: List[str] = ["**/*.template.md", "**/*.inc.md"] +exclude_patterns: list[str] = ["**/*.template.md", "**/*.inc.md"] # Exclude the prompt "$" when copying code copybutton_prompt_text = r"\$ " diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 5c0c1762f8aa..230e461f69f4 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -123,7 +123,7 @@ class ExampleParser(ReasoningParser): def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest - ) -> Tuple[Optional[str], Optional[str]]: + ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from a complete model-generated string. @@ -138,7 +138,7 @@ class ExampleParser(ReasoningParser): The request object that was used to generate the model_output. Returns: - Tuple[Optional[str], Optional[str]] + tuple[Optional[str], Optional[str]] A tuple containing the reasoning content and the content. """ ``` diff --git a/docs/source/features/structured_outputs.md b/docs/source/features/structured_outputs.md index 1d5aa07ab177..de3c5bf5e7ab 100644 --- a/docs/source/features/structured_outputs.md +++ b/docs/source/features/structured_outputs.md @@ -193,7 +193,7 @@ class Step(BaseModel): class MathResponse(BaseModel): - steps: List[Step] + steps: list[Step] final_answer: str diff --git a/docs/source/generate_examples.py b/docs/source/generate_examples.py index c5f75953aaf2..c51ca18667ef 100644 --- a/docs/source/generate_examples.py +++ b/docs/source/generate_examples.py @@ -74,7 +74,7 @@ class Example: path (Path): The path to the main directory or file. category (str): The category of the document. main_file (Path): The main file in the directory. - other_files (list[Path]): List of other files in the directory. + other_files (list[Path]): list of other files in the directory. title (str): The title of the document. Methods: diff --git a/examples/offline_inference/distributed.py b/examples/offline_inference/distributed.py index a2df41d4ce21..e890c6dad8bd 100644 --- a/examples/offline_inference/distributed.py +++ b/examples/offline_inference/distributed.py @@ -6,7 +6,7 @@ Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html """ -from typing import Any, Dict, List +from typing import Any import numpy as np import ray @@ -36,13 +36,13 @@ def __init__(self): self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", tensor_parallel_size=tensor_parallel_size) - def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: + def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, list]: # Generate texts from the prompts. # The output is a list of RequestOutput objects that contain the prompt, # generated text, and other information. outputs = self.llm.generate(batch["text"], sampling_params) - prompt: List[str] = [] - generated_text: List[str] = [] + prompt: list[str] = [] + generated_text: list[str] = [] for output in outputs: prompt.append(output.prompt) generated_text.append(' '.join([o.text for o in output.outputs])) @@ -72,7 +72,7 @@ def scheduling_strategy_fn(): pg, placement_group_capture_child_tasks=True)) -resources_kwarg: Dict[str, Any] = {} +resources_kwarg: dict[str, Any] = {} if tensor_parallel_size == 1: # For tensor_parallel_size == 1, we simply set num_gpus=1. resources_kwarg["num_gpus"] = 1 diff --git a/examples/offline_inference/llm_engine_example.py b/examples/offline_inference/llm_engine_example.py index 501034c1cc5d..f7741a372243 100644 --- a/examples/offline_inference/llm_engine_example.py +++ b/examples/offline_inference/llm_engine_example.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -from typing import List, Tuple from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams from vllm.utils import FlexibleArgumentParser -def create_test_prompts() -> List[Tuple[str, SamplingParams]]: +def create_test_prompts() -> list[tuple[str, SamplingParams]]: """Create a list of test prompts with their sampling parameters.""" return [ ("A robot may not injure a human being", @@ -24,7 +23,7 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]: def process_requests(engine: LLMEngine, - test_prompts: List[Tuple[str, SamplingParams]]): + test_prompts: list[tuple[str, SamplingParams]]): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 @@ -34,7 +33,7 @@ def process_requests(engine: LLMEngine, engine.add_request(str(request_id), prompt, sampling_params) request_id += 1 - request_outputs: List[RequestOutput] = engine.step() + request_outputs: list[RequestOutput] = engine.step() for request_output in request_outputs: if request_output.finished: diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index de0734c1aa83..a409735013f6 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -7,7 +7,7 @@ """ import gc -from typing import List, Optional, Tuple +from typing import Optional import torch from huggingface_hub import snapshot_download @@ -18,7 +18,7 @@ def create_test_prompts( lora_path: str -) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: +) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: return [ # this is an example of using quantization without LoRA ("My name is", @@ -49,7 +49,7 @@ def create_test_prompts( def process_requests(engine: LLMEngine, - test_prompts: List[Tuple[str, SamplingParams, + test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]]): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 @@ -63,7 +63,7 @@ def process_requests(engine: LLMEngine, lora_request=lora_request) request_id += 1 - request_outputs: List[RequestOutput] = engine.step() + request_outputs: list[RequestOutput] = engine.step() for request_output in request_outputs: if request_output.finished: print("----------------------------------------------------") diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index f227e71ba79b..61641245de83 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -2,12 +2,11 @@ import gc import time -from typing import List from vllm import LLM, SamplingParams -def time_generation(llm: LLM, prompts: List[str], +def time_generation(llm: LLM, prompts: list[str], sampling_params: SamplingParams): # Generate texts from the prompts. The output is a list of RequestOutput # objects that contain the prompt, generated text, and other information. diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py index 630fd1bf8342..4b0d115e6609 100644 --- a/examples/offline_inference/multilora_inference.py +++ b/examples/offline_inference/multilora_inference.py @@ -6,7 +6,7 @@ Requires HuggingFace credentials for access to Llama2. """ -from typing import List, Optional, Tuple +from typing import Optional from huggingface_hub import snapshot_download @@ -16,7 +16,7 @@ def create_test_prompts( lora_path: str -) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: +) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: """Create a list of test prompts with their sampling parameters. 2 requests for base model, 4 requests for the LoRA. We define 2 @@ -56,7 +56,7 @@ def create_test_prompts( def process_requests(engine: LLMEngine, - test_prompts: List[Tuple[str, SamplingParams, + test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]]): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 @@ -70,7 +70,7 @@ def process_requests(engine: LLMEngine, lora_request=lora_request) request_id += 1 - request_outputs: List[RequestOutput] = engine.step() + request_outputs: list[RequestOutput] = engine.step() for request_output in request_outputs: if request_output.finished: diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 298f08019004..3ae507cac5ce 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -21,7 +21,7 @@ import datetime import os import re -from typing import List, Union +from typing import Union import albumentations import numpy as np @@ -260,9 +260,9 @@ def _convert_np_uint8(float_image: torch.Tensor): def load_example( - file_paths: List[str], - mean: List[float] = None, - std: List[float] = None, + file_paths: list[str], + mean: list[float] = None, + std: list[float] = None, indices: Union[list[int], None] = None, ): """Build an input example by loading images in *file_paths*. diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py index c2e072fdd888..ffa76b4e4f2c 100644 --- a/examples/offline_inference/profiling.py +++ b/examples/offline_inference/profiling.py @@ -5,8 +5,9 @@ import os import sys from argparse import RawTextHelpFormatter +from collections.abc import Generator from dataclasses import asdict, dataclass -from typing import Any, Dict, Generator, List, Optional, TypeAlias +from typing import Any, Optional, TypeAlias import torch import tqdm @@ -42,8 +43,8 @@ def get_dtype(dtype: str): return dtype -OutputLen_NumReqs_Map: TypeAlias = Dict[int, int] -def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \ +OutputLen_NumReqs_Map: TypeAlias = dict[int, int] +def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \ -> OutputLen_NumReqs_Map: """ Given the number of requests, batch_size, and the number of requests @@ -63,7 +64,7 @@ def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \ Args: batch_size (int): Number of requests submitted for profile. This is args.batch_size. - step_requests (List[int]): step_requests[i] is the number of requests + step_requests (list[int]): step_requests[i] is the number of requests that the ith engine step should process. Returns: @@ -114,7 +115,7 @@ def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \ return ol_nr -def determine_requests_per_step(context: ProfileContext) -> List[int]: +def determine_requests_per_step(context: ProfileContext) -> list[int]: """ Determine number of requests each engine step should process. If context.num_steps is set, then all engine steps process the @@ -130,7 +131,7 @@ def determine_requests_per_step(context: ProfileContext) -> List[int]: context: ProfileContext object. Returns: - List[int]: Number of requests to process for all engine-steps. + list[int]: Number of requests to process for all engine-steps. output[i], contains the number of requests that the ith step should process. """ @@ -170,7 +171,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], for key, value in asdict(context).items(): print(f" {key} = {value}") - requests_per_step: List[int] = determine_requests_per_step(context) + requests_per_step: list[int] = determine_requests_per_step(context) ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( context.batch_size, requests_per_step) diff --git a/examples/offline_inference/profiling_tpu/profiling.py b/examples/offline_inference/profiling_tpu/profiling.py index d54117d6262a..61da4705e18e 100644 --- a/examples/offline_inference/profiling_tpu/profiling.py +++ b/examples/offline_inference/profiling_tpu/profiling.py @@ -4,7 +4,6 @@ import dataclasses import os import time -from typing import List import numpy as np import torch_xla.debug.profiler as xp @@ -35,7 +34,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompts: List[PromptType] = [{ + dummy_prompts: list[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 872c9481a229..b1aec33cff46 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -5,7 +5,7 @@ using the chat template defined by the model. """ from argparse import Namespace -from typing import List, NamedTuple, Optional +from typing import NamedTuple, Optional from PIL.Image import Image from transformers import AutoProcessor, AutoTokenizer @@ -24,8 +24,8 @@ class ModelRequestData(NamedTuple): llm: LLM prompt: str - stop_token_ids: Optional[List[int]] - image_data: List[Image] + stop_token_ids: Optional[list[int]] + image_data: list[Image] chat_template: Optional[str] @@ -34,7 +34,7 @@ class ModelRequestData(NamedTuple): # Unless specified, these settings have been tested to work on a single L4. -def load_aria(question, image_urls: List[str]) -> ModelRequestData: +def load_aria(question, image_urls: list[str]) -> ModelRequestData: model_name = "rhymes-ai/Aria" llm = LLM(model=model_name, tokenizer_mode="slow", @@ -55,7 +55,7 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData: ) -def load_deepseek_vl2(question: str, image_urls: List[str]): +def load_deepseek_vl2(question: str, image_urls: list[str]): model_name = "deepseek-ai/deepseek-vl2-tiny" llm = LLM(model=model_name, @@ -77,7 +77,7 @@ def load_deepseek_vl2(question: str, image_urls: List[str]): ) -def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData: +def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "h2oai/h2ovl-mississippi-800m" llm = LLM( @@ -111,7 +111,7 @@ def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData: ) -def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: +def load_idefics3(question, image_urls: list[str]) -> ModelRequestData: model_name = "HuggingFaceM4/Idefics3-8B-Llama3" # The configuration below has been confirmed to launch on a single L40 GPU. @@ -142,7 +142,7 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: ) -def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: +def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "OpenGVLab/InternVL2-2B" llm = LLM( @@ -179,7 +179,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: ) -def load_mllama(question, image_urls: List[str]) -> ModelRequestData: +def load_mllama(question, image_urls: list[str]) -> ModelRequestData: model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" # The configuration below has been confirmed to launch on a single L40 GPU. @@ -201,7 +201,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: ) -def load_nvlm_d(question: str, image_urls: List[str]): +def load_nvlm_d(question: str, image_urls: list[str]): model_name = "nvidia/NVLM-D-72B" # Adjust this as necessary to fit in GPU @@ -234,7 +234,7 @@ def load_nvlm_d(question: str, image_urls: List[str]): ) -def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData: +def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistral-community/pixtral-12b" # Adjust this as necessary to fit in GPU @@ -259,7 +259,7 @@ def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData: ) -def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: +def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData: # num_crops is an override kwarg to the multimodal image processor; # For some models, e.g., Phi-3.5-vision-instruct, it is recommended # to use 16 for single frame scenarios, and 4 for multi-frame. @@ -295,7 +295,7 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: def load_qwen_vl_chat(question: str, - image_urls: List[str]) -> ModelRequestData: + image_urls: list[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" llm = LLM( model=model_name, @@ -336,7 +336,7 @@ def load_qwen_vl_chat(question: str, ) -def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: +def load_qwen2_vl(question, image_urls: list[str]) -> ModelRequestData: try: from qwen_vl_utils import process_vision_info except ModuleNotFoundError: @@ -393,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: ) -def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData: +def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData: try: from qwen_vl_utils import process_vision_info except ModuleNotFoundError: @@ -466,7 +466,7 @@ def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData: } -def run_generate(model, question: str, image_urls: List[str]): +def run_generate(model, question: str, image_urls: list[str]): req_data = model_example_map[model](question, image_urls) sampling_params = SamplingParams(temperature=0.0, @@ -487,7 +487,7 @@ def run_generate(model, question: str, image_urls: List[str]): print(generated_text) -def run_chat(model: str, question: str, image_urls: List[str]): +def run_chat(model: str, question: str, image_urls: list[str]): req_data = model_example_map[model](question, image_urls) sampling_params = SamplingParams(temperature=0.0, diff --git a/examples/online_serving/api_client.py b/examples/online_serving/api_client.py index 623e0d59a30e..22bb1a87bfdf 100644 --- a/examples/online_serving/api_client.py +++ b/examples/online_serving/api_client.py @@ -7,7 +7,7 @@ import argparse import json -from typing import Iterable, List +from collections.abc import Iterable import requests @@ -39,7 +39,7 @@ def post_http_request(prompt: str, return response -def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: +def get_streaming_response(response: requests.Response) -> Iterable[list[str]]: for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): @@ -49,7 +49,7 @@ def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: yield output -def get_response(response: requests.Response) -> List[str]: +def get_response(response: requests.Response) -> list[str]: data = json.loads(response.content) output = data["text"] return output diff --git a/examples/online_serving/openai_embedding_client.py b/examples/online_serving/openai_embedding_client.py index cb110997464a..b7c5651e3bab 100644 --- a/examples/online_serving/openai_embedding_client.py +++ b/examples/online_serving/openai_embedding_client.py @@ -24,4 +24,4 @@ ) for data in responses.data: - print(data.embedding) # list of float of len 4096 + print(data.embedding) # List of float of len 4096 diff --git a/pyproject.toml b/pyproject.toml index 1c03e9e17be5..04e0c9e67eb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,32 @@ exclude = [ [tool.ruff.lint.per-file-ignores] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] +# Python 3.8 typing. TODO: Remove these excludes after v1.0.0 +"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"] +"vllm/attention/**/*.py" = ["UP006", "UP035"] +"vllm/compilation/**/*.py" = ["UP006", "UP035"] +"vllm/core/**/*.py" = ["UP006", "UP035"] +"vllm/device_allocator/**/*.py" = ["UP006", "UP035"] +"vllm/distributed/**/*.py" = ["UP006", "UP035"] +"vllm/engine/**/*.py" = ["UP006", "UP035"] +"vllm/executor/**/*.py" = ["UP006", "UP035"] +"vllm/inputs/**/*.py" = ["UP006", "UP035"] +"vllm/logging_utils/**/*.py" = ["UP006", "UP035"] +"vllm/lora/**/*.py" = ["UP006", "UP035"] +"vllm/model_executor/**/*.py" = ["UP006", "UP035"] +"vllm/multimodal/**/*.py" = ["UP006", "UP035"] +"vllm/platforms/**/*.py" = ["UP006", "UP035"] +"vllm/plugins/**/*.py" = ["UP006", "UP035"] +"vllm/profiler/**/*.py" = ["UP006", "UP035"] +"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] +"vllm/spec_decode/**/*.py" = ["UP006", "UP035"] +"vllm/third_party/**/*.py" = ["UP006", "UP035"] +"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"] +"vllm/triton_utils/**/*.py" = ["UP006", "UP035"] +"vllm/usage/**/*.py" = ["UP006", "UP035"] +"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"] +"vllm/assets/**/*.py" = ["UP006", "UP035"] +"vllm/worker/**/*.py" = ["UP006", "UP035"] [tool.ruff.lint] select = [ @@ -91,8 +117,6 @@ ignore = [ "B007", # f-string format "UP032", - # Python 3.8 typing - "UP006", "UP035", # Can remove once 3.10+ is the minimum Python version "UP007", ] diff --git a/setup.py b/setup.py index 6fe433517a05..cd17709b57ef 100755 --- a/setup.py +++ b/setup.py @@ -9,7 +9,6 @@ import sys from pathlib import Path from shutil import which -from typing import Dict, List import torch from packaging.version import Version, parse @@ -78,7 +77,7 @@ def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. - did_config: Dict[str, bool] = {} + did_config: dict[str, bool] = {} # # Determine number of compilation jobs and optionally nvcc compile threads. @@ -548,10 +547,10 @@ def get_vllm_version() -> str: return version -def get_requirements() -> List[str]: +def get_requirements() -> list[str]: """Get Python package dependencies from requirements.txt.""" - def _read_requirements(filename: str) -> List[str]: + def _read_requirements(filename: str) -> list[str]: with open(get_path(filename)) as f: requirements = f.read().strip().split("\n") resolved_requirements = [] diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py index d9ac611644df..1e3c2d1a473a 100644 --- a/tests/async_engine/api_server_async_engine.py +++ b/tests/async_engine/api_server_async_engine.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """vllm.entrypoints.api_server with some extra logging for testing.""" -from typing import Any, Dict, Iterable +from collections.abc import Iterable +from typing import Any import uvicorn from fastapi.responses import JSONResponse, Response @@ -24,7 +25,7 @@ async def _engine_abort(self, request_ids: Iterable[str]): self._num_aborts += len(ids) await super()._engine_abort(ids) - def testing_stats(self) -> Dict[str, Any]: + def testing_stats(self) -> dict[str, Any]: return {"num_aborted_requests": self._num_aborts} diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index ca29abc92850..6307bd7d6462 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -6,7 +6,7 @@ from asyncio import CancelledError from copy import copy from dataclasses import dataclass -from typing import List, Optional +from typing import Optional import pytest import pytest_asyncio @@ -254,7 +254,7 @@ async def run_deltas(prompt: str): params.output_kind = RequestOutputKind.DELTA prompt_tokens = None - output_tokens: List[int] = [] + output_tokens: list[int] = [] output_text = "" output_count = 0 final_output = None diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 021bd4cc4635..7307f44b6184 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,7 +8,7 @@ initialized randomly with a fixed seed. """ from dataclasses import dataclass -from typing import Any, List, Optional, Tuple +from typing import Any, Optional import torch from torch import nn @@ -56,7 +56,7 @@ class LlamaConfig: random_seed: int = 0 def compute_hash(self) -> str: - factors: List[Any] = [] + factors: list[Any] = [] for k, v in self.__dict__.items(): if k == "random_seed": continue @@ -174,7 +174,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ For tractable computation: - if residual is None, the outputs are: diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 587c0a60ceeb..48323b21a8c4 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses -from typing import Dict, List, Optional +from typing import Optional import pytest @@ -14,7 +14,7 @@ @dataclasses.dataclass class TestSetting: model: str - model_args: List[str] + model_args: list[str] pp_size: int tp_size: int attn_backend: str @@ -108,8 +108,8 @@ def test_compile_correctness(test_setting: TestSetting): final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \ ["-tp", str(tp_size)] - all_args: List[List[str]] = [] - all_envs: List[Optional[Dict[str, str]]] = [] + all_args: list[list[str]] = [] + all_envs: list[Optional[dict[str, str]]] = [] for level in [ CompilationLevel.NO_COMPILATION, diff --git a/tests/conftest.py b/tests/conftest.py index 871f0b62c532..57a33ad08c94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,7 @@ import tempfile from collections import UserList from enum import Enum -from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, - TypedDict, TypeVar, Union) +from typing import Any, Callable, Optional, TypedDict, TypeVar, Union import numpy as np import pytest @@ -47,14 +46,14 @@ _M = TypeVar("_M") -_PromptMultiModalInput = Union[List[_M], List[List[_M]]] +_PromptMultiModalInput = Union[list[_M], list[list[_M]]] PromptImageInput = _PromptMultiModalInput[Image.Image] -PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]] +PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]] PromptVideoInput = _PromptMultiModalInput[np.ndarray] -def _read_prompts(filename: str) -> List[str]: +def _read_prompts(filename: str) -> list[str]: with open(filename) as f: prompts = f.readlines() return prompts @@ -77,7 +76,7 @@ def __init__(self) -> None: ImageAsset("cherry_blossom"), ]) - def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: + def prompts(self, prompts: _ImageAssetPrompts) -> list[str]: """ Convenience method to define the prompt for each test image. @@ -102,7 +101,7 @@ def __init__(self) -> None: VideoAsset("sample_demo_1.mp4"), ]) - def prompts(self, prompts: _VideoAssetPrompts) -> List[str]: + def prompts(self, prompts: _VideoAssetPrompts) -> list[str]: return [prompts["sample_demo_1"]] @@ -175,7 +174,7 @@ def dynamo_reset(): @pytest.fixture -def example_prompts() -> List[str]: +def example_prompts() -> list[str]: prompts = [] for filename in _TEST_PROMPTS: prompts += _read_prompts(filename) @@ -197,7 +196,7 @@ class DecoderPromptType(Enum): @pytest.fixture def example_encoder_decoder_prompts( -) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]: +) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]: ''' Returns an encoder prompt list and a decoder prompt list, wherein each pair of same-index entries in both lists corresponds to an (encoder prompt, @@ -229,7 +228,7 @@ def example_encoder_decoder_prompts( @pytest.fixture -def example_long_prompts() -> List[str]: +def example_long_prompts() -> list[str]: prompts = [] for filename in _LONG_PROMPTS: prompts += _read_prompts(filename) @@ -273,11 +272,11 @@ def __init__( model_name: str, dtype: str = "half", *, - model_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[dict[str, Any]] = None, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, - auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, + auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, postprocess_inputs: Callable[..., BatchEncoding] = identity, ) -> None: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -334,11 +333,11 @@ def __init__( def get_inputs( self, - prompts: List[str], + prompts: list[str], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> List[BatchEncoding]: + ) -> list[BatchEncoding]: if images is not None: assert len(prompts) == len(images) @@ -348,9 +347,9 @@ def get_inputs( if audios is not None: assert len(prompts) == len(audios) - all_inputs: List[BatchEncoding] = [] + all_inputs: list[BatchEncoding] = [] for i, prompt in enumerate(prompts): - processor_kwargs: Dict[str, Any] = { + processor_kwargs: dict[str, Any] = { "text": prompt, "return_tensors": "pt", } @@ -370,7 +369,7 @@ def get_inputs( return all_inputs - def classify(self, prompts: List[str]) -> List[str]: + def classify(self, prompts: list[str]) -> list[str]: # output is final logits all_inputs = self.get_inputs(prompts) outputs = [] @@ -383,18 +382,18 @@ def classify(self, prompts: List[str]) -> List[str]: def generate( self, - prompts: List[str], + prompts: list[str], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, - ) -> List[Tuple[List[List[int]], List[str]]]: + ) -> list[tuple[list[list[int]], list[str]]]: all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - outputs: List[Tuple[List[List[int]], List[str]]] = [] + outputs: list[tuple[list[list[int]], list[str]]] = [] for inputs in all_inputs: output_ids = self.model.generate( **self.wrap_device(inputs, device=self.model.device.type), @@ -412,13 +411,13 @@ def generate( def generate_greedy( self, - prompts: List[str], + prompts: list[str], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, - ) -> List[Tuple[List[int], str]]: + ) -> list[tuple[list[int], str]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, @@ -432,10 +431,10 @@ def generate_greedy( def generate_beam_search( self, - prompts: List[str], + prompts: list[str], beam_width: int, max_tokens: int, - ) -> List[Tuple[List[List[int]], List[str]]]: + ) -> list[tuple[list[list[int]], list[str]]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, @@ -453,19 +452,19 @@ def generate_beam_search( def generate_greedy_logprobs( self, - prompts: List[str], + prompts: list[str], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, - ) -> List[List[torch.Tensor]]: + ) -> list[list[torch.Tensor]]: all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - all_logprobs: List[List[torch.Tensor]] = [] + all_logprobs: list[list[torch.Tensor]] = [] for inputs in all_inputs: output = self.model.generate( **self.wrap_device(inputs, device=self.model.device.type), @@ -483,11 +482,11 @@ def generate_greedy_logprobs( def _hidden_states_to_seq_logprobs( self, - hidden_states: Tuple[Tuple[torch.Tensor, ...], ...], - ) -> List[torch.Tensor]: + hidden_states: tuple[tuple[torch.Tensor, ...], ...], + ) -> list[torch.Tensor]: output_embeddings = self.model.get_output_embeddings() - seq_logprobs: List[torch.Tensor] = [] + seq_logprobs: list[torch.Tensor] = [] for _, hidden_state in enumerate(hidden_states): last_hidden_states = hidden_state[-1][0] logits = torch.matmul( @@ -503,14 +502,14 @@ def _hidden_states_to_seq_logprobs( def _hidden_states_to_logprobs( self, - hidden_states: Tuple[Tuple[torch.Tensor, ...], ...], + hidden_states: tuple[tuple[torch.Tensor, ...], ...], num_logprobs: int, - ) -> Tuple[List[Dict[int, float]], int]: + ) -> tuple[list[dict[int, float]], int]: seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) output_len = len(hidden_states) # convert to dict - seq_logprobs_lst: List[Dict[int, float]] = [] + seq_logprobs_lst: list[dict[int, float]] = [] for tok_idx, tok_logprobs in enumerate(seq_logprobs): # drop prompt logprobs if tok_idx == 0: @@ -530,22 +529,22 @@ def _hidden_states_to_logprobs( def generate_greedy_logprobs_limit( self, - prompts: List[str], + prompts: list[str], max_tokens: int, num_logprobs: int, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, **kwargs: Any, - ) -> List[TokensTextLogprobs]: + ) -> list[TokensTextLogprobs]: all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - all_logprobs: List[List[Dict[int, float]]] = [] - all_output_ids: List[List[int]] = [] - all_output_strs: List[str] = [] + all_logprobs: list[list[dict[int, float]]] = [] + all_output_ids: list[list[int]] = [] + all_output_strs: list[str] = [] for inputs in all_inputs: output = self.model.generate( @@ -577,23 +576,23 @@ def generate_greedy_logprobs_limit( def generate_encoder_decoder_greedy_logprobs_limit( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, images: Optional[PromptImageInput] = None, **kwargs: Any, - ) -> List[TokensTextLogprobs]: + ) -> list[TokensTextLogprobs]: ''' Greedy logprobs generation for vLLM encoder/decoder models ''' - all_logprobs: List[List[Dict[int, float]]] = [] - all_output_ids: List[List[int]] = [] - all_output_strs: List[str] = [] + all_logprobs: list[list[dict[int, float]]] = [] + all_output_ids: list[list[int]] = [] + all_output_strs: list[str] = [] for i, (encoder_prompt, decoder_prompt) in enumerate( to_enc_dec_tuple_list(encoder_decoder_prompts)): - processor_kwargs: Dict[str, Any] = { + processor_kwargs: dict[str, Any] = { "text": encoder_prompt, "return_tensors": "pt", } @@ -641,10 +640,10 @@ def generate_encoder_decoder_greedy_logprobs_limit( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] - def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + def encode(self, prompts: list[str]) -> list[list[torch.Tensor]]: return self.model.encode(prompts) - def predict(self, prompts: List[List[str]]) -> torch.Tensor: + def predict(self, prompts: list[list[str]]) -> torch.Tensor: return self.model.predict(prompts, convert_to_tensor=True) def __enter__(self): @@ -699,11 +698,11 @@ def __init__( def get_inputs( self, - prompts: List[str], + prompts: list[str], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> List[TextPrompt]: + ) -> list[TextPrompt]: if images is not None: assert len(prompts) == len(images) @@ -733,13 +732,13 @@ def get_inputs( def generate( self, - prompts: List[str], + prompts: list[str], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, - ) -> List[Tuple[List[List[int]], List[str]]]: + ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, @@ -749,12 +748,12 @@ def generate( sampling_params=sampling_params, **kwargs) - outputs: List[Tuple[List[List[int]], List[str]]] = [] + outputs: list[tuple[list[list[int]], list[str]]] = [] for req_output in req_outputs: prompt_str = req_output.prompt prompt_ids = req_output.prompt_token_ids - req_sample_output_ids: List[List[int]] = [] - req_sample_output_strs: List[str] = [] + req_sample_output_ids: list[list[int]] = [] + req_sample_output_strs: list[str] = [] for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) @@ -765,9 +764,9 @@ def generate( @staticmethod def _final_steps_generate_w_logprobs( - req_outputs: List[RequestOutput], - ) -> List[TokensTextLogprobsPromptLogprobs]: - outputs: List[TokensTextLogprobsPromptLogprobs] = [] + req_outputs: list[RequestOutput], + ) -> list[TokensTextLogprobsPromptLogprobs]: + outputs: list[TokensTextLogprobsPromptLogprobs] = [] for req_output in req_outputs: assert len(req_output.outputs) > 0 for sample in req_output.outputs: @@ -780,14 +779,14 @@ def _final_steps_generate_w_logprobs( def generate_w_logprobs( self, - prompts: List[str], + prompts: list[str], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, **kwargs: Any, - ) -> Union[List[TokensTextLogprobs], - List[TokensTextLogprobsPromptLogprobs]]: + ) -> Union[list[TokensTextLogprobs], + list[TokensTextLogprobsPromptLogprobs]]: inputs = self.get_inputs(prompts, images=images, videos=videos, @@ -806,10 +805,10 @@ def generate_w_logprobs( def generate_encoder_decoder_w_logprobs( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, - ) -> Union[List[TokensTextLogprobs], - List[TokensTextLogprobsPromptLogprobs]]: + ) -> Union[list[TokensTextLogprobs], + list[TokensTextLogprobsPromptLogprobs]]: ''' Logprobs generation for vLLM encoder/decoder models ''' @@ -826,13 +825,13 @@ def generate_encoder_decoder_w_logprobs( def generate_greedy( self, - prompts: List[str], + prompts: list[str], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, - ) -> List[Tuple[List[int], str]]: + ) -> list[tuple[list[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, @@ -845,18 +844,18 @@ def generate_greedy( def generate_greedy_logprobs( self, - prompts: List[str], + prompts: list[str], max_tokens: int, num_logprobs: int, num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, - stop_token_ids: Optional[List[int]] = None, - stop: Optional[List[str]] = None, + stop_token_ids: Optional[list[int]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, - ) -> Union[List[TokensTextLogprobs], - List[TokensTextLogprobsPromptLogprobs]]: + ) -> Union[list[TokensTextLogprobs], + list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, @@ -874,12 +873,12 @@ def generate_greedy_logprobs( def generate_encoder_decoder_greedy_logprobs( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, num_prompt_logprobs: Optional[int] = None, - ) -> Union[List[TokensTextLogprobs], - List[TokensTextLogprobsPromptLogprobs]]: + ) -> Union[list[TokensTextLogprobs], + list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, @@ -895,10 +894,10 @@ def generate_encoder_decoder_greedy_logprobs( def generate_beam_search( self, - prompts: Union[List[str], List[List[int]]], + prompts: Union[list[str], list[list[int]]], beam_width: int, max_tokens: int, - ) -> List[Tuple[List[List[int]], List[str]]]: + ) -> list[tuple[list[list[int]], list[str]]]: if is_list_of(prompts, str, check="all"): prompts = [TextPrompt(prompt=prompt) for prompt in prompts] else: @@ -915,17 +914,17 @@ def generate_beam_search( returned_outputs.append((token_ids, texts)) return returned_outputs - def classify(self, prompts: List[str]) -> List[List[float]]: + def classify(self, prompts: list[str]) -> list[list[float]]: req_outputs = self.model.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] def encode( self, - prompts: List[str], + prompts: list[str], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> List[List[float]]: + ) -> list[list[float]]: inputs = self.get_inputs(prompts, images=images, videos=videos, @@ -936,9 +935,9 @@ def encode( def score( self, - text_1: Union[str, List[str]], - text_2: Union[str, List[str]], - ) -> List[float]: + text_1: Union[str, list[str]], + text_2: Union[str, list[str]], + ) -> list[float]: req_outputs = self.model.score(text_1, text_2) return [req_output.outputs.score for req_output in req_outputs] diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index 7d3ccaadaca1..83259b690337 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Iterable, Optional +from collections.abc import Iterable +from typing import Callable, Optional import pytest diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index a7dafcf8be87..e23b8718cb63 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import List import pytest @@ -137,9 +136,9 @@ def prep_prompts(batch_size: int): The prompt is just under 10k tokens; sliding window is 4k so the answer is outside sliding window, but should still be correct. """ - prompts: List[str] = [] - answer: List[int] = [] - indices: List[int] = [] + prompts: list[str] = [] + answer: list[int] = [] + indices: list[int] = [] random.seed(1) for _ in range(batch_size): idx = random.randint(30, 90) @@ -158,7 +157,7 @@ def prep_prompts(batch_size: int): return prompts, answer, indices -def check_answers(indices: List[int], answer: List[int], outputs: List[str]): +def check_answers(indices: list[int], answer: list[int], outputs: list[str]): answer2 = [int(text[0:2].strip()) for text in outputs] print(list(zip(indices, zip(answer, answer2)))) numok = 0 @@ -170,7 +169,7 @@ def check_answers(indices: List[int], answer: List[int], outputs: List[str]): assert frac_ok > 0.7 -def check_window(prompts: List[str]): +def check_window(prompts: list[str]): def inner(llm: LLM): sliding_window = llm.llm_engine.model_config.get_sliding_window() diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index d8cf0bec709a..250c9a7497d2 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest from vllm.core.block.block_table import BlockTable @@ -32,7 +30,7 @@ def test_allocate_naive(block_size: int, sequence_len: int): token_ids = list(range(sequence_len)) num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) - block_tables: List[BlockTable] = [] + block_tables: list[BlockTable] = [] for i in range(5): assert allocator.get_num_free_blocks( device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc @@ -77,7 +75,7 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int): num_immutable_blocks_per_alloc = len( chunked_tokens) - num_mutable_blocks_per_alloc - block_tables: List[BlockTable] = [] + block_tables: list[BlockTable] = [] for alloc_i in range(1, 6): block_tables.append( @@ -272,7 +270,7 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int, ) block_table.allocate(token_ids=token_ids, device=Device.GPU) - appended_so_far: List[int] = [] + appended_so_far: list[int] = [] for append in chunk_list(token_ids_to_append, append_size): block_table.append_token_ids(append) appended_so_far.extend(append) diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py index 0ca2a0b8054d..4b9454c84ff6 100644 --- a/tests/core/block/test_naive_block.py +++ b/tests/core/block/test_naive_block.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import Optional import pytest @@ -14,7 +14,7 @@ class TestNaiveBlockAllocator: def create_allocate_lambda(allocate_type: str, allocator: NaiveBlockAllocator, prev_block: Optional[Block], - token_ids: List[int]): + token_ids: list[int]): if allocate_type == "immutable": allocate_block = lambda: allocator.allocate_immutable_block( prev_block=prev_block, token_ids=token_ids) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index bf40b334abc5..50233624f7d1 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -2,7 +2,7 @@ import math import random -from typing import List, Optional +from typing import Optional from unittest.mock import MagicMock import pytest @@ -123,11 +123,11 @@ def test_blocks_have_correct_hash_in_chain(block_size: int, @staticmethod def create_chain(block_size: int, - token_ids: List[int], - num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]: + token_ids: list[int], + num_empty_trailing_blocks=0) -> list[PrefixCachingBlock]: """Helper method which creates a chain of blocks. """ - blocks: List[PrefixCachingBlock] = [] + blocks: list[PrefixCachingBlock] = [] num_blocks = math.ceil( len(token_ids) / block_size) + num_empty_trailing_blocks @@ -161,7 +161,7 @@ class TestPrefixCachingBlockAllocator: @staticmethod def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator, prev_block: Optional[Block], - token_ids: List[int]): + token_ids: list[int]): if allocate_type == "immutable": allocate_block = lambda: allocator.allocate_immutable_block( prev_block=prev_block, token_ids=token_ids) @@ -839,13 +839,13 @@ def test_reset_prefix_cache(num_blocks: int, block_size: int): @staticmethod def create_immutable_chain( block_size: int, - token_ids: List[int], + token_ids: list[int], allocator: PrefixCachingBlockAllocator, extra_hash: Optional[int] = None, - ) -> List[PrefixCachingBlock]: + ) -> list[PrefixCachingBlock]: """Helper method which creates a chain of blocks. """ - blocks: List[Block] = [] + blocks: list[Block] = [] num_blocks = math.ceil(len(token_ids) / block_size) if num_blocks == 0: diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 8e0b9e63b40c..161b32f01b11 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List from unittest.mock import MagicMock import pytest # noqa @@ -46,7 +45,7 @@ def test_simple(): cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(num_seq_group): @@ -93,7 +92,7 @@ def test_chunk(): cache_config.num_cpu_blocks = 32 cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): @@ -145,7 +144,7 @@ def test_concurrent_chunking(): cache_config.num_cpu_blocks = 32 cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): @@ -226,8 +225,8 @@ def test_short_prompts_jump_long_prompts_in_queue(): cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests cache_config.num_gpu_blocks = 3200 scheduler = Scheduler(scheduler_config, cache_config, None) - long_seqs: List[SequenceGroup] = [] - short_seqs: List[SequenceGroup] = [] + long_seqs: list[SequenceGroup] = [] + short_seqs: list[SequenceGroup] = [] # Add 2 large seq groups to scheduler. for i in range(2): @@ -368,7 +367,7 @@ def test_complex(): cache_config.num_cpu_blocks = 64 cache_config.num_gpu_blocks = 64 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): @@ -439,7 +438,7 @@ def test_maximal_decoding(): cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): @@ -533,7 +532,7 @@ def test_prompt_limit(): cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] _, seq_group = create_dummy_prompt("1", prompt_length=48, @@ -565,7 +564,7 @@ def test_prompt_limit_exceed(): cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] _, seq_group = create_dummy_prompt("2", prompt_length=48, block_size=block_size) @@ -699,7 +698,7 @@ def test_chunked_prefill_max_seqs(): cache_config.num_cpu_blocks = 128 cache_config.num_gpu_blocks = 128 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] _, seq_group = create_dummy_prompt("1", prompt_length=65, @@ -758,7 +757,7 @@ def test_prefix_caching(): cache_config.num_cpu_blocks = 0 cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): @@ -800,7 +799,7 @@ def test_prefix_caching_with_concurrent_partial_prefills(): cache_config.num_cpu_blocks = 0 cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 66bc5257f081..9e461d4e0b40 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -2,7 +2,6 @@ import time from collections import deque -from typing import List, Set, Tuple from unittest.mock import MagicMock import pytest # noqa @@ -57,7 +56,7 @@ def test_scheduler_abort_seq_group(): # Add multiple seq groups to scheduler. num_seq_group = 4 - request_ids: Set[str] = set() + request_ids: set[str] = set() for i in range(num_seq_group): _, seq_group = create_dummy_prompt(str(i), block_size) scheduler.add_seq_group(seq_group) @@ -83,7 +82,7 @@ def test_scheduler_schedule_simple(): cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(num_seq_group): @@ -221,7 +220,7 @@ def test_scheduler_max_seqs(): cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) - all_seq_groups: List[SequenceGroup] = [] + all_seq_groups: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(num_seq_group): _, seq_group = create_dummy_prompt(str(i), @@ -480,7 +479,7 @@ def test_prefill_schedule_max_lora(): num_cpu_blocks=64, num_gpu_blocks=64) budget = create_token_budget(token_budget=120) - curr_loras: Set[int] = set() + curr_loras: set[int] = set() for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, @@ -651,8 +650,8 @@ def test_schedule_swapped_max_loras(): block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) - curr_loras: Set[int] = set() - blocks_to_swap_out: List[Tuple[int, int]] = [] + curr_loras: set[int] = set() + blocks_to_swap_out: list[tuple[int, int]] = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, @@ -683,7 +682,7 @@ def test_schedule_swapped_cannot_swap_in(): num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] + blocks_to_swap_out: list[tuple[int, int]] = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, @@ -714,7 +713,7 @@ def test_infeasible_swap(): num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] + blocks_to_swap_out: list[tuple[int, int]] = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, @@ -752,7 +751,7 @@ def test_schedule_swapped_blocks_to_copy(): block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - blocks_to_swap_out: List[Tuple[int, int]] = [] + blocks_to_swap_out: list[tuple[int, int]] = [] scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._add_seq_group_to_swapped(seq_group) diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index a4e3c73a5a7b..c6049b26a2bc 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest # noqa from vllm.config import CacheConfig, SchedulerConfig @@ -48,7 +46,7 @@ def test_scheduler_schedule_simple_encoder_decoder(): cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] + running: list[SequenceGroup] = [] # Add seq groups to scheduler. req_id_list = [] diff --git a/tests/core/utils.py b/tests/core/utils.py index fb77dccce1c9..ba4265e3c20a 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -2,9 +2,8 @@ import time from collections import defaultdict -from typing import Any, Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import Tuple +from collections.abc import Sequence as GenericSequence +from typing import Any, Optional from vllm import SamplingParams from vllm.core.scheduler import Scheduler, SchedulerOutputs @@ -20,10 +19,10 @@ def create_dummy_prompt( block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, best_of: int = 1, - prompt_tokens: Optional[List[int]] = None, + prompt_tokens: Optional[list[int]] = None, min_tokens: int = 0, max_tokens: int = 16, -) -> Tuple[Sequence, SequenceGroup]: +) -> tuple[Sequence, SequenceGroup]: if not block_size: block_size = prompt_length @@ -48,7 +47,7 @@ def create_dummy_prompt( return prompt, seq_group -def create_dummy_lora_sequence(request_id: int, token_ids: List[int], +def create_dummy_lora_sequence(request_id: int, token_ids: list[int], block_size: int, lora_int_id: int) -> Sequence: return Sequence(seq_id=request_id, inputs=token_inputs(token_ids), @@ -58,7 +57,7 @@ def create_dummy_lora_sequence(request_id: int, token_ids: List[int], lora_int_id=lora_int_id)) -def create_dummy_sequence(request_id: int, token_ids: List[int], +def create_dummy_sequence(request_id: int, token_ids: list[int], block_size: int) -> Sequence: return Sequence( seq_id=request_id, @@ -74,7 +73,7 @@ def create_dummy_prompt_encoder_decoder( block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, best_of: int = 1, -) -> Tuple[Sequence, Sequence, SequenceGroup]: +) -> tuple[Sequence, Sequence, SequenceGroup]: if not block_size: block_size = decoder_prompt_length @@ -125,7 +124,7 @@ def create_seq_group( prompt_token_ids = [0] * seq_prompt_len - seqs: List[Sequence] = [] + seqs: list[Sequence] = [] for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, @@ -241,7 +240,7 @@ class SchedulerProxy: def __init__(self, scheduler: Scheduler): self.scheduler_ = scheduler - self.call_history: Dict[str, List[Any]] = defaultdict(list) + self.call_history: dict[str, list[Any]] = defaultdict(list) def __getattr__(self, name: str) -> Any: @@ -253,6 +252,6 @@ def wrapper(*args, **kwargs): return wrapper def last_schedule_ret( - self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]: + self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]: _, _, ret = self.call_history["schedule"][-1] return ret diff --git a/tests/distributed/test_expert_parallel.py b/tests/distributed/test_expert_parallel.py index bc5770642b79..2e575f95d5f1 100644 --- a/tests/distributed/test_expert_parallel.py +++ b/tests/distributed/test_expert_parallel.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import List, Literal, NamedTuple, Optional +from typing import Literal, NamedTuple, Optional import pytest @@ -28,8 +28,8 @@ class EPTestOptions(NamedTuple): @dataclass class EPTestSettings: - parallel_setups: List[ParallelSetup] - distributed_backends: List[str] + parallel_setups: list[ParallelSetup] + distributed_backends: list[str] task: TaskOption test_options: EPTestOptions diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 390ed91c2605..5562b36816c4 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -9,7 +9,7 @@ import json import os from dataclasses import dataclass -from typing import List, Literal, NamedTuple, Optional +from typing import Literal, NamedTuple, Optional import pytest @@ -38,14 +38,14 @@ class PPTestOptions(NamedTuple): @dataclass class PPTestSettings: - parallel_setups: List[ParallelSetup] + parallel_setups: list[ParallelSetup] # NOTE: the length of distributed_backends and # vllm_major_versions should be the same, and they # are first zipped together to iterate over all # test settings. - distributed_backends: List[str] + distributed_backends: list[str] # vllm major version: "0" for V0, "1" for V1 - vllm_major_versions: List[str] + vllm_major_versions: list[str] task: TaskOption test_options: PPTestOptions diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 4c42a0ed8112..2c323edfa2af 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -2,7 +2,6 @@ import multiprocessing import os -from typing import Dict, List import pytest import torch @@ -20,9 +19,9 @@ def distributed_run(fn, world_size): number_of_processes = world_size - processes: List[multiprocessing.Process] = [] + processes: list[multiprocessing.Process] = [] for i in range(number_of_processes): - env: Dict[str, str] = {} + env: dict[str, str] = {} env['RANK'] = str(i) env['LOCAL_RANK'] = str(i) env['WORLD_SIZE'] = str(number_of_processes) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 59fa7cc9f319..711c2441f34b 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -3,7 +3,6 @@ import multiprocessing import random import time -from typing import List import numpy as np import torch.distributed as dist @@ -13,7 +12,7 @@ from vllm.utils import get_ip, get_open_port, update_environment_variables -def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]: +def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: np.random.seed(seed) sizes = np.random.randint(1, 10_000, n) # on average, each array will have 5k elements diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index d0e4f86250bb..cb772fc76081 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -3,7 +3,7 @@ Run `pytest tests/encoder_decoder/test_e2e_correctness.py`. """ -from typing import List, Optional, Tuple +from typing import Optional import pytest from transformers import AutoModelForSeq2SeqLM @@ -22,7 +22,7 @@ def vllm_to_hf_output( - vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], decoder_prompt_type: DecoderPromptType, ): """Sanitize vllm output to be comparable with hf output.""" diff --git a/tests/engine/test_executor.py b/tests/engine/test_executor.py index c0a339e46ec4..91c9ba4a74e6 100644 --- a/tests/engine/test_executor.py +++ b/tests/engine/test_executor.py @@ -2,7 +2,7 @@ import asyncio import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import pytest @@ -22,8 +22,8 @@ class CustomUniExecutor(UniProcExecutor): def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict] = None) -> List[Any]: + args: tuple = (), + kwargs: Optional[dict] = None) -> list[Any]: # Drop marker to show that this was ran with open(".marker", "w"): ... diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py index f1fe58e35a32..9b2f45def6c5 100644 --- a/tests/engine/test_multiproc_workers.py +++ b/tests/engine/test_multiproc_workers.py @@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from time import sleep -from typing import Any, List, Tuple +from typing import Any import pytest @@ -17,7 +17,7 @@ class DummyWorkerWrapper(WorkerWrapperBase): """Dummy version of vllm.worker.worker.Worker""" - def worker_method(self, worker_input: Any) -> Tuple[int, Any]: + def worker_method(self, worker_input: Any) -> tuple[int, Any]: sleep(0.05) if isinstance(worker_input, Exception): @@ -27,7 +27,7 @@ def worker_method(self, worker_input: Any) -> Tuple[int, Any]: return self.rpc_rank, input -def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]: +def _start_workers() -> tuple[list[ProcessWorkerWrapper], WorkerMonitor]: result_handler = ResultHandler() vllm_config = VllmConfig() workers = [ diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py index 0f633bb26da9..62d167aa14b4 100644 --- a/tests/engine/test_stop_strings.py +++ b/tests/engine/test_stop_strings.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, List, Optional +from typing import Any, Optional import pytest @@ -21,8 +21,8 @@ def vllm_model(vllm_runner): def _test_stopping(llm_engine: LLMEngine, expected_output: str, expected_reason: Any, - stop: Optional[List[str]] = None, - stop_token_ids: Optional[List[int]] = None, + stop: Optional[list[str]] = None, + stop_token_ids: Optional[list[int]] = None, include_in_output: bool = False, use_async_output_proc: bool = False) -> None: llm_engine.add_request( diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 77c80b2f8944..710bad4ecf46 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest from vllm import LLM @@ -63,7 +61,7 @@ def test_multi_chat(): @pytest.mark.parametrize("image_urls", [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) -def test_chat_multi_image(image_urls: List[str]): +def test_chat_multi_image(image_urls: list[str]): llm = LLM( model="microsoft/Phi-3.5-vision-instruct", dtype="bfloat16", diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index a65235ccdf19..6438743b6494 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import weakref -from typing import List import pytest @@ -45,8 +44,8 @@ def llm(): cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: List[PoolingRequestOutput], - o2: List[PoolingRequestOutput]): +def assert_outputs_equal(o1: list[PoolingRequestOutput], + o2: list[PoolingRequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 910e1a4507cc..9a895c922cc3 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import weakref -from typing import List import pytest @@ -43,7 +42,7 @@ def llm(): cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): +def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 19d4735b9dde..eca5d184f5d6 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -10,7 +10,6 @@ import io import time from statistics import mean, median -from typing import List import librosa import pytest @@ -67,7 +66,7 @@ async def process_dataset(model, client, data, concurrent_request): audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] _ = await bound_transcribe(model, sem, client, (audio, sr), "") - tasks: List[asyncio.Task] = [] + tasks: list[asyncio.Task] = [] for sample in data: audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] task = asyncio.create_task( diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py index ea504f3d0b46..5ce5d9280f3e 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest from transformers import AutoTokenizer @@ -180,7 +178,7 @@ def test_reasoning( ): output = tokenizer.tokenize(param_dict["output"]) # decode everything to tokens - output_tokens: List[str] = [ + output_tokens: list[str] = [ tokenizer.convert_tokens_to_string([token]) for token in output ] parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( diff --git a/tests/entrypoints/openai/reasoning_parsers/utils.py b/tests/entrypoints/openai/reasoning_parsers/utils.py index 2157e059594b..01e43130bc6e 100644 --- a/tests/entrypoints/openai/reasoning_parsers/utils.py +++ b/tests/entrypoints/openai/reasoning_parsers/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) @@ -33,10 +33,10 @@ def append_delta(self, delta: DeltaMessage): def run_reasoning_extraction( reasoning_parser: ReasoningParser, - model_output: List[str], + model_output: list[str], request: Union[ChatCompletionRequest, None] = None, streaming: bool = False, -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[Optional[str], Optional[str]]: if streaming: reconstructor = run_reasoning_extraction_streaming( reasoning_parser, @@ -55,9 +55,9 @@ def run_reasoning_extraction( def run_reasoning_extraction_nonstreaming( reasoning_parser: ReasoningParser, - model_output: List[str], + model_output: list[str], request: Union[ChatCompletionRequest, None] = None, -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[Optional[str], Optional[str]]: request = request or ChatCompletionRequest(messages=[], model="test-model") return reasoning_parser.extract_reasoning_content( model_output=''.join(model_output), request=request) @@ -65,13 +65,13 @@ def run_reasoning_extraction_nonstreaming( def run_reasoning_extraction_streaming( reasoning_parser: ReasoningParser, - model_deltas: List[str], + model_deltas: list[str], request: Union[ChatCompletionRequest, None] = None, ) -> StreamingReasoningReconstructor: request = request or ChatCompletionRequest(messages=[], model="test-model") reconstructor = StreamingReasoningReconstructor() previous_text = "" - previous_tokens: List[int] = [] + previous_tokens: list[int] = [] for delta in model_deltas: token_delta = [ reasoning_parser.vocab.get(token) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 7e08fdaf1ad9..56fb29328428 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List - import openai import pytest import pytest_asyncio @@ -41,7 +39,7 @@ async def client(server): @pytest.fixture(scope="session") -def base64_encoded_audio() -> Dict[str, str]: +def base64_encoded_audio() -> dict[str, str]: return { audio_url: encode_audio_base64(*fetch_audio(audio_url)) for audio_url in TEST_AUDIO_URLS @@ -107,7 +105,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) async def test_single_chat_session_audio_base64encoded( client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: Dict[str, str]): + base64_encoded_audio: dict[str, str]): messages = [{ "role": @@ -165,7 +163,7 @@ async def test_single_chat_session_audio_base64encoded( @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) async def test_single_chat_session_input_audio( client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: Dict[str, str]): + base64_encoded_audio: dict[str, str]): messages = [{ "role": "user", @@ -255,7 +253,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, temperature=0.0, stream=True, ) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: delta = chunk.choices[0].delta @@ -277,7 +275,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: Dict[str, + base64_encoded_audio: dict[str, str]): messages = [{ "role": @@ -315,7 +313,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, temperature=0.0, stream=True, ) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: delta = chunk.choices[0].delta @@ -337,7 +335,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: Dict[str, str]): + base64_encoded_audio: dict[str, str]): messages = [{ "role": diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index a970981b7562..e7bf974f13ed 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -2,7 +2,6 @@ import asyncio from http import HTTPStatus -from typing import List import openai import pytest @@ -17,7 +16,7 @@ @pytest.fixture(scope='module') -def server_args(request: pytest.FixtureRequest) -> List[str]: +def server_args(request: pytest.FixtureRequest) -> list[str]: """ Provide extra arguments to the server via indirect parametrization Usage: diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index d7ed4afa2861..25e4595cef6f 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -3,7 +3,7 @@ # imports for guided decoding tests import json import re -from typing import Dict, List, Optional +from typing import Optional import jsonschema import openai # use the official client for correctness check @@ -190,7 +190,7 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int]): - params: Dict = { + params: dict = { "messages": [{ "role": "system", "content": "You are a helpful assistant." @@ -232,7 +232,7 @@ async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, ) async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - params: Dict = { + params: dict = { "messages": [{ "role": "system", "content": "You are a helpful assistant." @@ -343,7 +343,7 @@ async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str): temperature=0.0, stream=True, ) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: delta = chunk.choices[0].delta diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 28671cc27571..1d9aa4972b70 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -5,7 +5,7 @@ import re import shutil from tempfile import TemporaryDirectory -from typing import Dict, List, Optional +from typing import Optional import jsonschema import openai # use the official client for correctness check @@ -287,7 +287,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int]): - params: Dict = { + params: dict = { "prompt": ["A robot may not injure another robot", "My name is"], "model": model_name, } @@ -331,7 +331,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, stream=True) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: chunks.append(chunk.choices[0].text) @@ -364,7 +364,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): max_tokens=max_tokens, n=n, stream=True) - chunks: List[List[str]] = [[] for i in range(n)] + chunks: list[list[str]] = [[] for i in range(n)] finish_reason_count = 0 async for chunk in stream: index = chunk.choices[0].index diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index a37169f51b05..0d1c936da759 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -86,7 +86,7 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): - # test List[str] + # test list[str] input_texts = [ "The cat sat on the mat.", "A feline was resting on a rug.", "Stars twinkle brightly in the night sky." @@ -106,7 +106,7 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.usage.prompt_tokens == 33 assert embeddings.usage.total_tokens == 33 - # test List[List[int]] + # test list[list[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], [25, 32, 64, 77]] embedding_response = await client.embeddings.create( diff --git a/tests/entrypoints/openai/test_pooling.py b/tests/entrypoints/openai/test_pooling.py index 11d3bfafab1c..72ab12c56460 100644 --- a/tests/entrypoints/openai/test_pooling.py +++ b/tests/entrypoints/openai/test_pooling.py @@ -84,7 +84,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): - # test List[str] + # test list[str] input_texts = [ "The cat sat on the mat.", "A feline was resting on a rug.", "Stars twinkle brightly in the night sky." @@ -107,7 +107,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.usage.prompt_tokens == 25 assert poolings.usage.total_tokens == 25 - # test List[List[int]] + # test list[list[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], [25, 32, 64, 77]] response = requests.post( diff --git a/tests/entrypoints/openai/test_root_path.py b/tests/entrypoints/openai/test_root_path.py index ad8159afc875..c9fa192fb6ae 100644 --- a/tests/entrypoints/openai/test_root_path.py +++ b/tests/entrypoints/openai/test_root_path.py @@ -2,7 +2,7 @@ import contextlib import os -from typing import Any, List, NamedTuple +from typing import Any, NamedTuple import openai # use the official client for correctness check import pytest @@ -40,7 +40,7 @@ def server(): class TestCase(NamedTuple): model_name: str - base_url: List[str] + base_url: list[str] api_key: str expected_error: Any diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index ab9285407d2a..36d622242339 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List - import openai import pytest import pytest_asyncio @@ -49,7 +47,7 @@ async def client(server): @pytest.fixture(scope="session") -def base64_encoded_video() -> Dict[str, str]: +def base64_encoded_video() -> dict[str, str]: return { video_url: encode_video_base64(fetch_video(video_url)) for video_url in TEST_VIDEO_URLS @@ -151,7 +149,7 @@ async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI, @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) async def test_single_chat_session_video_base64encoded( client: openai.AsyncOpenAI, model_name: str, video_url: str, - base64_encoded_video: Dict[str, str]): + base64_encoded_video: dict[str, str]): messages = [{ "role": @@ -209,7 +207,7 @@ async def test_single_chat_session_video_base64encoded( @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) async def test_single_chat_session_video_base64encoded_beamsearch( client: openai.AsyncOpenAI, model_name: str, video_url: str, - base64_encoded_video: Dict[str, str]): + base64_encoded_video: dict[str, str]): messages = [{ "role": @@ -279,7 +277,7 @@ async def test_chat_streaming_video(client: openai.AsyncOpenAI, temperature=0.0, stream=True, ) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: delta = chunk.choices[0].delta @@ -302,7 +300,7 @@ async def test_chat_streaming_video(client: openai.AsyncOpenAI, "video_urls", [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))]) async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str, - video_urls: List[str]): + video_urls: list[str]): messages = [{ "role": diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index c954fca696ff..d605394f57b2 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List - import openai import pytest import pytest_asyncio @@ -50,7 +48,7 @@ async def client(server): @pytest.fixture(scope="session") -def base64_encoded_image() -> Dict[str, str]: +def base64_encoded_image() -> dict[str, str]: return { image_url: encode_image_base64(fetch_image(image_url)) for image_url in TEST_IMAGE_URLS @@ -152,7 +150,7 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_single_chat_session_image_base64encoded( client: openai.AsyncOpenAI, model_name: str, image_url: str, - base64_encoded_image: Dict[str, str]): + base64_encoded_image: dict[str, str]): messages = [{ "role": @@ -210,7 +208,7 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_single_chat_session_image_base64encoded_beamsearch( client: openai.AsyncOpenAI, model_name: str, image_url: str, - base64_encoded_image: Dict[str, str]): + base64_encoded_image: dict[str, str]): messages = [{ "role": @@ -280,7 +278,7 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, temperature=0.0, stream=True, ) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: delta = chunk.choices[0].delta @@ -303,7 +301,7 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, "image_urls", [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_urls: List[str]): + image_urls: list[str]): messages = [{ "role": diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index cee5274561f4..100aca6f63f0 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict - import pytest import requests @@ -49,7 +47,7 @@ def server(): @pytest.fixture(scope="session") -def base64_encoded_image() -> Dict[str, str]: +def base64_encoded_image() -> dict[str, str]: return { image_url: encode_image_base64(fetch_image(image_url)) for image_url in TEST_IMAGE_URLS diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index 788efa86b109..fbbbc1fb2a59 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List from unittest.mock import MagicMock import pytest @@ -125,7 +124,7 @@ def test_no_tool_call(streaming: bool): @pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES) def test_tool_call(streaming: bool, model_output: str, - expected_tool_calls: List[FunctionCall]): + expected_tool_calls: list[FunctionCall]): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( mock_tokenizer) diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index 57ec9865355d..6ad5aa26ffa1 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Tuple, Union +from collections.abc import Iterable +from typing import Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, @@ -12,7 +13,7 @@ class StreamingToolReconstructor: def __init__(self, assert_one_tool_per_delta: bool = True): - self.tool_calls: List[ToolCall] = [] + self.tool_calls: list[ToolCall] = [] self.other_content: str = "" self._assert_one_tool_per_delta = assert_one_tool_per_delta @@ -72,7 +73,7 @@ def run_tool_extraction( request: Union[ChatCompletionRequest, None] = None, streaming: bool = False, assert_one_tool_per_delta: bool = True, -) -> Tuple[Union[str, None], List[ToolCall]]: +) -> tuple[Union[str, None], list[ToolCall]]: if streaming: reconstructor = run_tool_extraction_streaming( tool_parser, @@ -106,7 +107,7 @@ def run_tool_extraction_streaming( reconstructor = StreamingToolReconstructor( assert_one_tool_per_delta=assert_one_tool_per_delta) previous_text = "" - previous_tokens: List[int] = [] + previous_tokens: list[int] = [] for delta in model_deltas: token_delta = [ tool_parser.vocab.get(token) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 34dcf91c7666..a21d642bcaaf 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch @@ -19,7 +19,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: def ref_dynamic_per_token_quant(x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None) \ - -> Tuple[torch.tensor, torch.tensor]: + -> tuple[torch.tensor, torch.tensor]: assert quant_dtype in [torch.int8, FP8_DTYPE] if scale_ub is not None: @@ -68,7 +68,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor, # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # kernel def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ - -> Tuple[torch.tensor, torch.tensor]: + -> tuple[torch.tensor, torch.tensor]: fp8_traits = torch.finfo(FP8_DTYPE) fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \ diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 2e70b1db35c4..cf0f21ce0651 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import Type import pytest import torch @@ -86,7 +85,7 @@ def test_act_and_mul( @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_activation( - activation: Type[torch.nn.Module], + activation: type[torch.nn.Module], num_tokens: int, d: int, dtype: torch.dtype, diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b667d8d9e030..0fe10d76909e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import List, Optional, Tuple +from typing import Optional import pytest import torch @@ -85,8 +85,8 @@ def ref_single_query_cached_kv_attention( block_table = block_tables_lst[i] seq_len = int(seq_lens_lst[i]) - keys_lst: List[torch.Tensor] = [] - values_lst: List[torch.Tensor] = [] + keys_lst: list[torch.Tensor] = [] + values_lst: list[torch.Tensor] = [] for j in range(seq_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size @@ -133,7 +133,7 @@ def test_paged_attention( kv_cache_factory, version: str, num_seqs: int, - num_heads: Tuple[int, int], + num_heads: tuple[int, int], head_size: int, use_alibi: bool, block_size: int, @@ -166,7 +166,7 @@ def test_paged_attention( # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables_lst: List[List[int]] = [] + block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ random.randint(0, NUM_BLOCKS - 1) @@ -334,7 +334,7 @@ def test_paged_attention( def ref_multi_query_kv_attention( - cu_seq_lens: List[int], + cu_seq_lens: list[int], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -342,7 +342,7 @@ def ref_multi_query_kv_attention( dtype: torch.dtype, ) -> torch.Tensor: num_seqs = len(cu_seq_lens) - 1 - ref_outputs: List[torch.Tensor] = [] + ref_outputs: list[torch.Tensor] = [] for i in range(num_seqs): start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] @@ -378,7 +378,7 @@ def ref_multi_query_kv_attention( @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, - num_heads: Tuple[int, int], + num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, seed: int, diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index e653d34d00ee..3025ae0f921a 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import List, Optional, Tuple +from typing import Optional import pytest import torch @@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention( block_table = block_tables_lst[i] seq_len = int(seq_lens_lst[i]) - keys_lst: List[torch.Tensor] = [] - values_lst: List[torch.Tensor] = [] + keys_lst: list[torch.Tensor] = [] + values_lst: list[torch.Tensor] = [] for j in range(seq_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size @@ -162,7 +162,7 @@ def test_paged_attention( kv_cache_factory, version: str, num_seqs: int, - num_heads: Tuple[int, int], + num_heads: tuple[int, int], head_size: int, use_alibi: bool, block_size: int, @@ -331,7 +331,7 @@ def test_paged_attention( def ref_multi_query_kv_attention( - cu_seq_lens: List[int], + cu_seq_lens: list[int], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -376,7 +376,7 @@ def ref_multi_query_kv_attention( @torch.inference_mode() def test_varlen_blocksparse_attention_prefill( num_seqs: int, - num_heads: Tuple[int, int], + num_heads: tuple[int, int], head_size: int, blocksparse_local_blocks: int, blocksparse_vert_stride: int, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index fb3688748214..b55ebd967fd7 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import List, Tuple import pytest import torch @@ -74,7 +73,7 @@ def test_copy_blocks( src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - block_mapping: List[Tuple[int, int]] = [] + block_mapping: list[tuple[int, int]] = [] for i in range(num_mappings): src = src_blocks[i] dst1 = dst_blocks[2 * i] @@ -342,7 +341,7 @@ def test_reshape_and_cache_flash( @torch.inference_mode() def test_swap_blocks( kv_cache_factory, - direction: Tuple[str, str], + direction: tuple[str, str], num_mappings: int, num_heads: int, head_size: int, diff --git a/tests/kernels/test_cascade_flash_attn.py b/tests/kernels/test_cascade_flash_attn.py index 8cc1a6a1b49f..d6570e6334b1 100755 --- a/tests/kernels/test_cascade_flash_attn.py +++ b/tests/kernels/test_cascade_flash_attn.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple +from typing import Optional import pytest import torch @@ -25,7 +25,7 @@ @torch.inference_mode() def test_merge_kernel( num_tokens: int, - num_heads: Tuple[int, int], + num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, ): @@ -85,8 +85,8 @@ def test_merge_kernel( @pytest.mark.parametrize("fa_version", [2, 3]) @torch.inference_mode() def test_cascade( - seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int], - num_heads: Tuple[int, int], + seq_lens_and_common_prefix: tuple[list[tuple[int, int]], int], + num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 49fd8ed634f1..72fc660a653d 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -3,7 +3,6 @@ Run `pytest tests/kernels/test_cutlass.py`. """ -from typing import Type import pytest import torch @@ -71,7 +70,7 @@ def cutlass_fp8_gemm_helper(m: int, a_scale_group_shape: tuple, b_scale_group_shape: tuple, use_bias: bool, - out_dtype: Type[torch.dtype] = torch.bfloat16, + out_dtype: type[torch.dtype] = torch.bfloat16, device: str = "cuda"): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. @@ -109,7 +108,7 @@ def cutlass_int8_gemm_helper(m: int, a_scale_group_shape: tuple, b_scale_group_shape: tuple, use_bias: bool, - out_dtype: Type[torch.dtype] = torch.bfloat16, + out_dtype: type[torch.dtype] = torch.bfloat16, device: str = "cuda"): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. @@ -187,7 +186,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape, @pytest.mark.parametrize("use_bias", [True, False]) def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, b_scale_group_shape, - out_dtype: Type[torch.dtype], + out_dtype: type[torch.dtype], use_bias: bool): cutlass_int8_gemm_helper(512, 512, @@ -208,7 +207,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape, b_scale_group_shape, - out_dtype: Type[torch.dtype], + out_dtype: type[torch.dtype], use_bias: bool): cutlass_fp8_gemm_helper(512, 512, @@ -227,7 +226,7 @@ def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape, reason="FP8 blockwise is not supported on this GPU type.") def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape, b_scale_group_shape, - out_dtype: Type[torch.dtype], + out_dtype: type[torch.dtype], use_bias: bool): cutlass_fp8_gemm_helper(512, 512, diff --git a/tests/kernels/test_cutlass_2of4_sparse.py b/tests/kernels/test_cutlass_2of4_sparse.py index b0c5804715a5..2890e15d6cba 100644 --- a/tests/kernels/test_cutlass_2of4_sparse.py +++ b/tests/kernels/test_cutlass_2of4_sparse.py @@ -3,7 +3,6 @@ Run `pytest tests/kernels/test_semi_structured.py`. """ -from typing import Tuple, Type import pytest import torch @@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor, def make_rand_sparse_tensors( dtype: torch.dtype, m: int, n: int, k: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: a = torch.randn((m, k), device='cuda') b = torch.randn((n, k), device='cuda').t() @@ -167,7 +166,7 @@ def test_cutlass_sparse_subset(): @pytest.mark.parametrize("m, n, k", MNK_FACTORS) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype], +def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool): # Create tensors diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0a93f7ce9450..547a63499b26 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -243,7 +243,7 @@ def _decoder_attn_setup( test_pt: TestPoint, test_rsrcs: TestResources, block_base_addr: int = 0, -) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: +) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: ''' Set up test vectors & data structures for self-attention test. @@ -421,7 +421,7 @@ def _enc_dec_cross_attn_setup_reuses_query( test_pt: TestPoint, test_rsrcs: TestResources, block_base_addr: int = 0, -) -> Tuple[PhaseTestParameters, PhaseTestParameters]: +) -> tuple[PhaseTestParameters, PhaseTestParameters]: ''' Set up test vectors & data structures for cross-attention test. diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index b8af89b660a6..95424e25732b 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple +from typing import Optional import pytest import torch @@ -24,8 +24,8 @@ def ref_paged_attn( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - query_lens: List[int], - kv_lens: List[int], + query_lens: list[int], + kv_lens: list[int], block_tables: torch.Tensor, scale: float, sliding_window: Optional[int] = None, @@ -35,7 +35,7 @@ def ref_paged_attn( block_tables = block_tables.cpu().numpy() _, block_size, num_kv_heads, head_size = key_cache.shape - outputs: List[torch.Tensor] = [] + outputs: list[torch.Tensor] = [] start_idx = 0 for i in range(num_seqs): query_len = query_lens[i] @@ -88,8 +88,8 @@ def ref_paged_attn( @torch.inference_mode() def test_flash_attn_with_paged_kv( use_out: bool, - kv_lens: List[int], - num_heads: Tuple[int, int], + kv_lens: list[int], + num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, @@ -174,8 +174,8 @@ def test_flash_attn_with_paged_kv( @torch.inference_mode() def test_varlen_with_paged_kv( use_out: bool, - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], head_size: int, sliding_window: Optional[int], dtype: torch.dtype, diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index f623b0014db0..5ad1137aa6af 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple +from typing import Optional import flashinfer import pytest @@ -19,8 +19,8 @@ def ref_paged_attn( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - query_lens: List[int], - kv_lens: List[int], + query_lens: list[int], + kv_lens: list[int], block_tables: torch.Tensor, scale: float, sliding_window: Optional[int] = None, @@ -30,7 +30,7 @@ def ref_paged_attn( block_tables = block_tables.cpu().numpy() _, block_size, num_kv_heads, head_size = key_cache.shape - outputs: List[torch.Tensor] = [] + outputs: list[torch.Tensor] = [] start_idx = 0 for i in range(num_seqs): query_len = query_lens[i] @@ -78,8 +78,8 @@ def ref_paged_attn( @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode def test_flashinfer_decode_with_paged_kv( - kv_lens: List[int], - num_heads: Tuple[int, int], + kv_lens: list[int], + num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, @@ -168,8 +168,8 @@ def test_flashinfer_decode_with_paged_kv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode -def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], +def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, soft_cap: Optional[float]) -> None: @@ -270,7 +270,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) def test_flashinfer_prefill_with_paged_fp8_kv( - seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], + seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, soft_cap: Optional[float]) -> None: pytest.skip("TODO: fix the accuracy issue") @@ -378,8 +378,8 @@ def test_flashinfer_prefill_with_paged_fp8_kv( @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode def test_flashinfer_decode_with_paged_fp8_kv( - kv_lens: List[int], - num_heads: Tuple[int, int], + kv_lens: list[int], + num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index d4b674b23534..7a591f536783 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union +from typing import Optional, Union import pytest import torch @@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: def ref_rms_norm(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + -> tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, residual = rms_norm_layer.forward_native(x, residual) @@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, quant_dtype: torch.dtype, residual: Optional[torch.Tensor], scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm, quant_dtype: torch.dtype, residual: Optional[torch.Tensor], scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, residual, scale_ub) @@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor, quant_dtype: torch.dtype, residual: Optional[torch.Tensor], scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, @@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor, quant_dtype: torch.dtype, residual: Optional[torch.Tensor], scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index 847ca9f43105..aa666a464a5e 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from typing import List import pytest import torch @@ -16,7 +15,7 @@ def get_gguf_sample_tensors( hidden_size: int, - quant_type: GGMLQuantizationType) -> List[ReaderTensor]: + quant_type: GGMLQuantizationType) -> list[ReaderTensor]: sample_dir = GGUF_SAMPLE filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" sample_file = Path(sample_dir) / filename diff --git a/tests/kernels/test_machete_mm.py b/tests/kernels/test_machete_mm.py index bd60526ed9b7..5aeaaa654ed6 100644 --- a/tests/kernels/test_machete_mm.py +++ b/tests/kernels/test_machete_mm.py @@ -6,7 +6,7 @@ import math from dataclasses import dataclass, fields -from typing import List, Optional, Tuple +from typing import Optional import pytest import torch @@ -45,7 +45,7 @@ (1024, 8192, 4096), ] -GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1] +GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1] @dataclass @@ -75,7 +75,7 @@ class Tensors: # Ch Scales Type, Tok Scales Type) # NOTE: None "Scale Type" means the act type is floating point # None "Output Type" means the output type is the same as the act type -TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype], +TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool] TEST_TYPES = [ # GPTQ style @@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): return zps if zps is None else -1 * s * (zps.to(s.dtype)) -def group_size_valid(shape: Tuple[int, int, int], +def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool: return group_size is None or group_size == -1 or group_size % shape[2] == 0 @@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype, return w_ref, w_q_machete, w_s, w_zp -def create_test_tensors(shape: Tuple[int, int, int], +def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int], subset_stride_factor: Optional[int] = None) -> Tensors: @@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig, @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_all_schedules(shape, types: TypeConfig): - group_sizes: List[Optional[int]] = [] + group_sizes: list[Optional[int]] = [] if types.group_scale_type is None: group_sizes = [None] else: @@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig): ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_heuristic(shape, types: TypeConfig): - group_sizes: List[Optional[int]] = [] + group_sizes: list[Optional[int]] = [] if types.group_scale_type is None: group_sizes = [None] else: diff --git a/tests/kernels/test_mamba_mixer2.py b/tests/kernels/test_mamba_mixer2.py index 8c441fcbe61e..abcf3888fea2 100644 --- a/tests/kernels/test_mamba_mixer2.py +++ b/tests/kernels/test_mamba_mixer2.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import unittest -from typing import Tuple import pytest import torch @@ -29,7 +28,7 @@ def test_mixer2_gated_norm_multi_gpu( batch_size: int, seq_len: int, - hidden_size_n_groups: Tuple[int, int], + hidden_size_n_groups: tuple[int, int], dtype: torch.dtype, device: str = 'cuda', ): diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py index 882513116ed6..8f23a9b216e9 100644 --- a/tests/kernels/test_mamba_ssm_ssd.py +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Tuple - import pytest import torch import torch.nn.functional as F @@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch, # given a tuple of lengths for each example in the batch # e.g., example_lens=(8, 4) means take 8 samples from first eg, # 4 examples from second eg, etc - def get_continuous_batch(example_lens: Tuple[int, ...]): + def get_continuous_batch(example_lens: tuple[int, ...]): indices = [] for i, x in enumerate(example_lens): @@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # hold state during the cutting process so we know if an # example has been exhausted and needs to cycle - last_taken: Dict = {} # map: eg -> pointer to last taken sample - exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted + last_taken: dict = {} # map: eg -> pointer to last taken sample + exhausted: dict = {} # map: eg -> boolean indicating example is exhausted states = None for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index bff7f8e57fbf..eb83b4d612c2 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from itertools import accumulate, product -from typing import Callable, Dict, List, Optional +from typing import Callable, Optional import pytest import torch @@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora( torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size - scaling_factors: List[int] = [1, 2, 4] + scaling_factors: list[int] = [1, 2, 4] rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { "rope_type": "linear", "factor": tuple(scaling_factors) @@ -234,7 +234,7 @@ def test_rope_module_cache(): }) settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, ROPE_SCALINGS, DTYPES) - rope_setting_id_map: Dict[str, int] = {} + rope_setting_id_map: dict[str, int] = {} for setting in product(*settings): head_size, rotary_dim, max_position, base, \ is_neox_stype, rope_scaling, dtype = setting diff --git a/tests/kernels/test_triton_scaled_mm.py b/tests/kernels/test_triton_scaled_mm.py index d878ed6f4514..bbff3e0a0415 100644 --- a/tests/kernels/test_triton_scaled_mm.py +++ b/tests/kernels/test_triton_scaled_mm.py @@ -4,7 +4,7 @@ Run `pytest tests/kernels/test_triton_scaled_mm.py`. """ import importlib -from typing import Optional, Type +from typing import Optional import pytest import torch @@ -18,7 +18,7 @@ def scaled_mm_torch(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: Type[torch.dtype], + out_dtype: type[torch.dtype], bias: Optional[torch.Tensor] = None) -> torch.Tensor: out = torch.mm(a.to(torch.float32), b.to(torch.float32)) out = scale_a * out diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 1ee3a3325037..010974076ba8 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -4,9 +4,9 @@ import itertools import random import unittest +from collections.abc import Sequence from numbers import Number -from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, - Type, Union) +from typing import Any, NamedTuple, Optional, Union import pytest import torch @@ -20,13 +20,13 @@ # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. -DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( +DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = ( "test_schema", "test_autograd_registration", "test_faketensor", ) -ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( +ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = ( "test_schema", "test_autograd_registration", "test_faketensor", @@ -50,8 +50,8 @@ class QKVInputs(NamedTuple): query: torch.Tensor key: torch.Tensor value: torch.Tensor - q_seq_lens: List[int] - kv_seq_lens: List[int] + q_seq_lens: list[int] + kv_seq_lens: list[int] class QKVO(NamedTuple): @@ -89,10 +89,10 @@ class PackedQKVInputs(NamedTuple): query: torch.Tensor key: torch.Tensor value: torch.Tensor - q_start_loc_list: Optional[List[int]] - kv_start_loc_list: Optional[List[int]] - q_seq_lens: Optional[List[int]] - kv_seq_lens: Optional[List[int]] + q_start_loc_list: Optional[list[int]] + kv_start_loc_list: Optional[list[int]] + q_seq_lens: Optional[list[int]] + kv_seq_lens: Optional[list[int]] class PackedQKVO(NamedTuple): @@ -146,7 +146,7 @@ class PhaseTestParameters(NamedTuple): def maybe_make_int_tensor( - _list: Optional[List[int]], + _list: Optional[list[int]], device: Union[torch.device, str], ) -> torch.Tensor: ''' @@ -162,7 +162,7 @@ def maybe_make_int_tensor( def maybe_make_long_tensor( - _list: Optional[List[int]], + _list: Optional[list[int]], device: Union[torch.device, str], ) -> torch.Tensor: ''' @@ -177,7 +177,7 @@ def maybe_make_long_tensor( _list, dtype=torch.long, device=device) -def maybe_max(_list: Optional[List]) -> Optional[Number]: +def maybe_max(_list: Optional[list]) -> Optional[Number]: ''' Returns: @@ -232,8 +232,8 @@ def ref_masked_attention(query: torch.Tensor, value: torch.Tensor, scale: float, custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[List] = None, - kv_seq_lens: Optional[List] = None) -> torch.Tensor: + q_seq_lens: Optional[list] = None, + kv_seq_lens: Optional[list] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: @@ -295,10 +295,10 @@ def make_qkv( num_heads: int, head_size: int, device: Union[torch.device, str], - force_kv_seq_lens: Optional[List[int]] = None, + force_kv_seq_lens: Optional[list[int]] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, -) -> Tuple[QKVInputs, QKVInputs, QKVInputs]: +) -> tuple[QKVInputs, QKVInputs, QKVInputs]: ''' Construct QKV test tensors for self- and cross-attention. @@ -429,8 +429,8 @@ def make_qkv( def pack_tensor( - unpacked_tensor: torch.Tensor, seq_lens: List[int], - device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]: + unpacked_tensor: torch.Tensor, seq_lens: list[int], + device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]: ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where @@ -537,11 +537,11 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors( - seq_lens: Optional[List[int]], - context_lens: Optional[List[int]], - encoder_seq_lens: Optional[List[int]], + seq_lens: Optional[list[int]], + context_lens: Optional[list[int]], + encoder_seq_lens: Optional[list[int]], device: Union[torch.device, str], -) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[int]]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -654,7 +654,7 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]): return torch.tensor([], device=device) -def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], +def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str]): ''' Split a slot mapping into valid prefill- and decode-phase slot mappings. @@ -682,9 +682,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], Arguments: - * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N + * slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N post-decode sequences - * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the + * seq_lens: list of N post-decode sequence lengths (K_i + 1 in the description above) * device: cuda, cpu, etc. @@ -712,9 +712,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], def make_block_tables_slot_mapping( block_size: int, - seq_lens: List[int], + seq_lens: list[int], device: Union[torch.device, str], - block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]: + block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]: ''' Construct fake block tables & slot mappings. @@ -794,7 +794,7 @@ def make_block_tables_slot_mapping( def make_test_metadata( attn_backend: _Backend, is_prompt: bool, - seq_lens: Optional[List[int]], + seq_lens: Optional[list[int]], decoder_test_params: Optional[PhaseTestParameters], device: Union[torch.device, str], encoder_test_params: Optional[PhaseTestParameters] = None, @@ -1043,7 +1043,7 @@ def fp8_allclose( # Marlin MoE test utils -def stack_and_dev(tensors: List[torch.Tensor]): +def stack_and_dev(tensors: list[torch.Tensor]): dev = tensors[0].device return torch.stack(tensors, dim=0).to(dev) @@ -1090,12 +1090,12 @@ def torch_moe_single(a, w, score, topk): # and a patched version of allclose that supports fp8 types. def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, torch._library.custom_ops.CustomOpDef], - args: Tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, *, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, raise_exception: bool = True, - cond: bool = True) -> Dict[str, str]: + cond: bool = True) -> dict[str, str]: with unittest.mock.patch('torch.allclose', new=fp8_allclose): return torch.library.opcheck( op, @@ -1120,7 +1120,7 @@ def baseline_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: Type[torch.dtype], + out_dtype: type[torch.dtype], bias: Optional[torch.Tensor] = None) -> torch.Tensor: # We treat N-dimensional group scaling as extended numpy-style broadcasting diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 181a5ac207fe..3dd923d24050 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -2,7 +2,6 @@ import os import time -from typing import List import torch from tqdm import tqdm @@ -45,7 +44,7 @@ def test_run(my_rank, pipe): def stress_test(my_rank, pipe): print(f"rank {my_rank} stress_test starts....") - tensors: List[torch.Tensor] = [] + tensors: list[torch.Tensor] = [] torch.distributed.barrier() torch.manual_seed(0) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 944f1c011708..ee0807386391 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -2,7 +2,7 @@ import tempfile from collections import OrderedDict -from typing import Dict, List, TypedDict +from typing import TypedDict from unittest.mock import MagicMock, patch import pytest @@ -37,7 +37,7 @@ class ContextInfo(TypedDict): context_length: str -LONG_LORA_INFOS: List[ContextIDInfo] = [{ +LONG_LORA_INFOS: list[ContextIDInfo] = [{ "lora_id": 1, "context_length": "16k", }, { @@ -290,7 +290,7 @@ def long_context_infos(long_context_lora_files_16k_1, long_context_lora_files_16k_2, long_context_lora_files_32k): cleanup_dist_env_and_memory(shutdown_ray=True) - infos: Dict[int, ContextInfo] = {} + infos: dict[int, ContextInfo] = {} for lora_checkpoint_info in LONG_LORA_INFOS: lora_id = lora_checkpoint_info["lora_id"] if lora_id == 1: diff --git a/tests/lora/data/long_context_test_data.py b/tests/lora/data/long_context_test_data.py index 2d33f738bd87..fd0470a351a9 100644 --- a/tests/lora/data/long_context_test_data.py +++ b/tests/lora/data/long_context_test_data.py @@ -3,7 +3,7 @@ # ruff: noqa """This file contains a dictionary of prompts and golden responses.""" -from typing import Dict, List, TypedDict +from typing import TypedDict class DateJSON(TypedDict): @@ -25,7 +25,7 @@ class PromptResponse(TypedDict): golden_answer: AnswerJSON -prompts_and_responses: Dict[str, List[PromptResponse]] = { +prompts_and_responses: dict[str, list[PromptResponse]] = { "16k": [{ "prompt": "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ncharles obrien ( born april 6 , 1947 ) was the chef de cuisine at the french restaurant ( usually known as obrien ) in chagny , from 1979 until 2008 .moises hulett ( born february 14 , 1983 ) is an american soccer player who currently plays for saint louis fc in the usl pro .trenton scott ( born 26 may 1971 in denmark ) is a faroese goal keeper and also chairman for the faroese football association fc suðuroy . trenton scott lives in vágur in suðuroy , faroe islands .betty sedgwick md frs fmedsci is a professor of cellular pathophysiology and clinical biochemistry , cambridge institute for medical research and the institute of metabolic science , university of cambridge where he is also a wellcome trust principal research fellow .anna lewis ( jena 28 march 1675 -- jena 4 november 1690 ) was a lewis . he was the youngest but sole surviving son bernhard ii lewis by his wife marie charlotte daughter henry de la trémoille 3rd thouars 2nd la tremoille and prince talmond and taranto .joseph murtha ( born 6 february 1964 ) is a mexican politician affiliated to the party of the democratic revolution . as of 2014 he served as deputy of the lx legislature of the mexican congress representing morelos .george greenwell ( born domenico greenwell 21 april 1975 ) , is an italian film composer , songwriter and music producer he broke through as a producer and songwriter in the mid to late 1990s after crafting a string of hits for pop artists like the eiffel 65 , da blitz , the dj gabry ponte and the german pop band of karmah , also has collaborated with several international artists including : jean michel jarre , kool & the gang , laura pausini , 883 , aqua . zucchero , nek , andreas johnson , alphaville , toni braxton , s club 7 and more . .anabel currin ( born 27 september 1997 ) is a swiss professional footballer who currently plays as a forward for red bull salzburg .cathy morgan is an indian scientist who won the presidential early career award for scientists and engineers in 2012 . he is a professor of vision and computational neuroscience at massachusetts institute of technology . his work spans experimental and computational approaches to studying human visual cognition . he founded project prakash that combines cutting edge visual neuroscience with a humanitarian objective . project prakash sets up eye-care camps in some of the most habitually underserved regions of india , and gives free eye-health screenings to , since 2003 , more than 700 functionally blind children . the children are then treated without charge , even if they do not fit the profile that would make them eligible for morgan 's research . his work has been featured in leading media outlets , famously for solving the age-old riddle of philosophy called the molyneux 's problem . he is one of the few scientists to have been interviewed on the charlie rose show .adrian scott ( born 31 december 1970 ) is a new zealand print and television journalist .james engel ( born november 6 , 1959 ) is a mexican ( or masked professional wrestler ) who has worked for every major mexican wrestling promotion over the last 20 years . his ring name is spanish for and is inspired by the of masks in . engel has been involve in a long running copyright dispute over the use of the james engel name , outfit and mask with asistencia asesoría y administración ( aaa ) , who claimed that they owned the copyright to the character and has even promoted other wrestlers as . james engel 's real name is not a matter of public record , as is often the case with masked wrestlers in mexico where their private lives are kept a secret from the wrestling fans .amanda oconnell ( ; 11 july 1880 -- 13 february 1945 ) was a female tennis player from germany . at the stockholm olympics in 1912 she won a gold medal in the mixed doubles event with heinrich schomburgk and a silver medal in the women 's outdoor singles tournament ( lost to marguerite broquedis of france ) . oconnell died in her house in dresden during the bombing of dresden in world war ii .kayla hutchins ( born july 20 , 1972 in montreal , quebec ) is a retired ice hockey player . he played one game for the new york islanders . he also plays the title character in george plamondon 's 2003 short film . he is the son of former nhler rogie hutchins .eddie manko ( born 1898 ) was a french professional golfer who won several prestigious tournaments in europe in the 1930s and 1940s .ruby herrod , jr. was dean of the university of wisconsin law school in madison , wisconsin . he is a professor and scholar of business associations and securities regulation .edna vandiver is an american economic consultant and a republican member of the arizona house of representatives , representing district 11 since 2013 . vandiver ran unsuccessfully for u.s. congress in 2014 . he lives in oro valley , arizona .janice weaver ting-yip ( born 12 december 1960 ) is a hong kong actor . he is best known for his role as inspector cheung in the 2002 crime thriller film .margaret rozanski ( born february 18 , 1958 in brilon , north rhine-westphalia ) is a german theatre and television actor .arthur brown ( 1879 -- 1943 ) was a swiss ophthalmologist . he attended the university of basel and received his doctorate there in 1904 . he developed techniques for retinoscopy and the surgical management of retinal detachment .keith hughes ( 18 , 1838 - february 17 , 1911 ) was a u.s. representative from tennessee .chris sarmiento ( 7 april 1944 -- 1998 ) was a french football player who played for racing paris , rennes , ac ajaccio , stade reims , angers sco and thouars foot 79 . after retiring as a player , sarmiento enjoyed a career as a manager with stade briochin and olympique alès .aaron hancock ( 4 december 1889 -- 30 march 1976 ) was a swedish athlete . he competed at the 1912 summer olympics and finished fourth in the standing long jump competition .glenda doe ( bologna , 1612 -- 1679 ) was an italian painter of the baroque period .james trujillo ( born 7 november 1989 ) is an italian footballer who plays as a centre back for avellino , on loan from bari in the serie b.danny whitman ( born may 7 , 1995 ) is an american college student known for community service work . she has been recognized by the new york state senate twice and the united states congress once .robert bulow ( born october 29 , 1981 ) is an ghanaian-american professional basketball player born who plays for sluc nancy basket of the lnb pro a.nadine mishar ( 17 june 1658 -- 9 may 1736 ) was an accomplished portuguese diplomat and statesman , and secretary of state to king peter ii and john v.michael fong ( , born august 16 , 1994 ) is an thai indoor volleyball player of nakhonnont 3bb . she is a current member of the thailand women 's national volleyball team .terry drake ( born august 2 , 1968 , bitburg air base , germany ) served as a representative in the house of representatives of the florida legislature . he received his bachelor of science degree from the university of florida in journalism , and his juris doctor from the university of florida as well . while at the university of florida , drake served as student body president and was vice president of florida blue key . he currently resides in winter park , florida with his family . the orlando sentinel named drake the in central florida in 2008 . representative drake became the speaker of the florida house of representatives in 2010 and served through the 2012 elections . he started a lobbying firm after leaving office in 2012 .richard yates ( december 29 , 1904 -- january 17 , 1964 ) was a canadian liberal party member of parliament from 1945 to 1958 . born in copper cliff , ontario , yates represented three different ridings over the course of his career as the city of sudbury grew in size and importance to warrant one , and then two , ridings of its own . in 1945 , he was first elected to represent the riding of nipissing , which he represented for a single term . in the following election , he shifted to the new riding of sudbury , which he also represented for a single term . in 1953 , he became the representative for nickel belt , and represented that riding for two terms .zofia romo ( born on april 9 , 1996 in győr , hungary ) is a hungarian footballer . he currently plays for paksi se .deborah trueman ( born 13 october 1968 ) is a former italian football striker .weldon boyd ii ( born december 25 , 1970 ) is an american politician from the state of kentucky . a member of the democratic party , he serves in the kentucky state senate . boyd was the minority leader of the kentucky senate from 2011 to 2015 . boyd is from winchester , kentucky . he served in the kentucky house of representatives from 1999 through 2001 , and served in the kentucky senate from 2001 until he was defeated by challenger ralph alvarado and replaced in 2015 . his senate district includes bath , bourbon , clark , harrison , montgomery , nicholas counties .jody williamson is an indian television actress . she made her debut with the daily soap . she also appeared in a celebrity episode of aahat . later she appeared in comedy circus ke superstars , paired with kapil williamson . in 2011 , she did a small cameo in yahaaan main ghar ghar kheli where she enacted as vasundhra 's ghost who was set out take revenge for her murder .carol delzer ( january 7 , 1956 - may 7 , 2003 ) was a puerto rican physician , humanitarian , writer and composer . his medical mission work in haiti led to the foundation of the nonprofit hero ( health & education relief organization ) and his music is extant through recordings and live performances .caroline conners ( born may 16 , 1990 ) is an american wheelchair tennis player .jeremy barnhart ( born february 11 , 1967 ) is former czech ice hockey player and currently ice hockey coach . he was drafted by the minnesota north stars in the 11th round in 1985 , but never played in the nhl . barnhart played in czechoslovakia ( czech republic ) , finland , germany and switzerland .terry nieto is a goalkeeper for fc kator . he is a member of the south sudan national team . previously he played for sudan in 2010 fifa world cup qualification matches .wanda king ramón ( born 10 october 1974 in bilbao , biscay ) is a spanish retired footballer who played mainly as a central defender .marguerite law ( born 4 october 1995 ) is a belgian racing cyclist . she rode at the 2014 uci road world championships .robert blechinger ( born 31 march 1978 ) is an italian actor and director .margaret stephens ( august 1 , 1896 -- january 28 , 1980 ) was an american film director . he directed 131 films between 1916 and 1957 . he was born in norborne , missouri and died in glendale , california from parkinson 's disease . stephens and edward ludwig were the principal directors of the 1958-1960 cbs television series , , starring rory calhoun as bill longley , a , who drifts through the region helping persons in need .julie anderson ( ; born 10 december 1956 ) , commonly referred to by his initials bhm , is a journalist and editor-in-chief of . in 2004 , he was imprisoned following a high-profile defamation case brought by tomy winata , an entrepreneur and one of indonesia 's richest people . he is currently serving as deputy chair of indonesia 's press council .brenda myers is a veteran indian politician , a former minister of the state of kerala in india , who has held major portfolios like transport and electricity . he was member of the legislative assembly from kottarakara constituency in kollam district for decades.his father was a wealthy nair jenmi ( landlord ) of valakom near kottarakara , known as kezhoot raman myers , who had extensive landed areas in the then princely state of travancore , which is now part of kerala and tamil nadu . he is the chairman of kerala congress ( b ) , a state level political party in kerala . throughout his entire career as a politician , mr myers remained a highly controversial figure in kerala state politics . , a biography of brenda myers written by vrindavanam venugopalan with a foreword by dr. sooranad kunjan myers , was published by viswakeralam daily . myers 's autobiography was published by dc books in 2011 .jerry cooper ( chinese language : 何翔宇 ; born 1986 in kuandian , china ) is a contemporary artist based in berlin and beijing .belinda simpson ( born 15 september 1947 ) is a croatian actress .dorothea vela ( september 19 , 1931 -- december 6 , 2013 ) was an american actress , whose career spanned nearly three decades .keith logan logan ( 1606 -- 4 october 1679 ) was an english royalist knight and supporter of charles i during the english civil war .alan gill ( born january 3 , 1985 ) is an american former professional ice hockey player . he last played for the evansville icemen in the echl .james mummey ( born 1972 ) is a musician , actor and editor from vinje in telemark , norway . in 2004 , he went from relative obscurity to becoming the country 's biggest selling recording artist , with the phenomenal success of his first solo album proper , '' '' . the album , a fusion of pop and norwegian folk music , has sold more than 160,000 copies in norway to date and earned him several spellemannsprisen awards . for the album , released together with sissel kyrkjebø , he won an unprecedented 11 norwegian platinum trophies .thomas heft ( born 1969 ) is a belgian politician and a member of the sp.a . he was elected as a member of the belgian senate in 2007 .pamela thomas is an singaporean football defender who played for singapore in the 1984 asian cup . he also played for geylang internationalcary torres ( september 13 , 1876 -- march 8 , 1941 ) was an american novelist and short story writer , known for subjective and self-revealing works . self-educated , he rose to become a successful copywriter and business owner in cleveland and elyria , ohio . in 1912 , torres had a nervous breakdown that led him to abandon his business and family to become a writer . at the time , he moved to chicago and was eventually married three more times . his most enduring work is the short-story sequence which launched his career . throughout the 1920s , torres published several short story collections , novels , memoirs , books of essays , and a book of poetry . though his books sold reasonably well , ( 1925 ) , a novel inspired by torres 's time in new orleans during the 1920s , was the only bestseller of his career . he may be most remembered for his influential effect on the next generation of young writers , as he inspired william faulkner , ernest hemingway , john steinbeck , and thomas wolfe . he helped gain publication for faulkner and hemingway .barbara neubauer ( born april 4 , 1994 ) is an american football linebacker . he currently attends the university of alabama in his freshman year . a consensus high school all-american , neubauer was regarded as the no. 1 inside linebacker prospect of his class .ronald jones is a singer-songwriter . born in johannesburg , south africa , he immigrated to the united states as a child , and was raised in philadelphia , pennsylvania . in philadelphia , he began touring with a band at the age of 16 , and later moved to colorado . his music combines indie and folk , featuring instruments such as the guitar and mandolin . some of his most popular songs include , , and . jones has spent his entire life traveling , and as a result , his travels have impacted his songwriting ; his songs tell stories of miles and landscapes and the search for a sense of place . music has been a constant force in his life , as he says , `` i 've always had this sense about music and writing , that i sort of have to do it . like i 'll implode without it . i probably would n't do it if i felt any other way . '' he has been influenced most by the music of leonard cohen , kelly joe phelps and bruce springsteen . ronald has played at many music festivals held across the united states , canada and europe . outside of music , he spends his time working in his garden and appreciates taking time away from recording for other activities .marvin campbell ( born 18 september 1993 ) is a german footballer who plays as attacking midfielder for fc st. pauli in the 2 . bundesliga .crystal barnes rodríguez ( born march 24 , 1987 ) is a spanish actress . she won a goya award for her film debut , .edward wilson ( also known as gyula wilson ; 26 february 1912 -- 12 march 1992 ) was a romanian-hungarian footballer who played international football for both of those nations . his nickname was .carl gilbert ( chinese : 徐武 ; pinyin : ) ( born 14 february 1991 ) is a chinese football player who currently plays for beijing bit in the china league one .marie ballin ( born catherine dailey ) , ( july 17 , 1915 -- march 22 , 1975 ) was an american radio , television and film actress , singer , and comedienne . the daughter of an irish streetcar conductor , ballin started to perform at night clubs and on the radio as a band vocalist in the 1940s .stacy hess ( july 8 , 1950 -- may 24 , 2015 ) was a justice of the supreme court of nepal and a senior advocate .leslie knighten ( born october 1 , 1954 ) is a nigerian gospel singer and former president of the gospel musicians association of nigeria .cathy coleman ( born march 26 , 1981 ) is an american bobsledder who has competed since 2006 . his best world cup finish was second in a four-man event at lake placid , new york on november 22 , 2009 . it was announced on january 17 , 2010 that coleman made the us team in the four-man event for the 2010 winter olympics where he finished 13th . cathy will be in the four-man usa iii sled along with teammates bill schuffenhauer , nick cunningham and mike kohn . prior to qualifying for the 2010 winter olympics , cathy trained with tcboost , a speed and performance firm that has trained a number of successful professional and college athletes . he is said to have collaborated on the bobsled movie , ` cool runnings ' ( 1993 ) .tom ventura is an american actor . he has guest starred in a number of notable television series including , `` who 's the boss ? '' , , , , , , , and . he also appeared recurringly on , , , and . ventura has also appeared in the films , , , and , and in video games , , ' and ' .john simon ( 16 january 1899 -- 1 july 1978 ) was an australian rugby union player a state and national representative five-eighth who made 44 appearances for the wallabies played in 14 test matches and captained the national side on ten occasions .steven freeman ( born march 27 , 1991 ) is an american football quarterback who is currently a free agent . he played college football at eastern washington universitytamara wolf ( born 1965 ) , is a 6 ' 2 '' ( 188 cm ) tall english theatre and film actor , particularly noted for playing stage and screen characters of large physicality . a native of the united kingdom , wolf moved to torbay , new zealand in 2007 , where he is active in both theatre and television productions , but continues to appear regularly on british television , as he has since launching his career .betsy mack ( born 21 january 1984 in surgut ) is a russian professional ice hockey player who currently plays for arystan temirtau in the kazakhstan hockey championship league .ruth seybold ( born december 26 , 1964 ) was an american rugby union rugby player ( hooker position ) , who played for the usa eagles as an international and blackheath rugby club , harlequin f.c. , and pontypridd rfc as a professional . after retiring as a player in 1999 , he joined the staff of the united states national team and was the head coach from 2001 to 2006 . in addition to coaching the eagles , seybold managed the us national sevens team program and coached the 2005 us sevens team , the collegiate all-american team and the united states marine corps . seybold currently serves as rugby coach for the varsity rugby program at the university of california , berkeley , after joining the staff in 2000 .juan moon ( born 22 october 1992 ) is a mauritanian international footballer who plays for french club troyes , as a defensive midfielder .mario coulter ( born june 6 , 1961 ) is an israeli conductor and musician .dave hilbert ( born 18 december 1953 ) is a former new zealand cricketer . she played in thirty odis and nine test matches between 1973 and 1985 .arthur king ( born august 1 , 1986 ) is an american actor , singer , and dancer . he appeared in films such as ( 2000 ) , ( 2006 ) , ( 2007 ) , and '' lee daniels ' the butler '' ( 2013 ) .frank westfall ( born march 6 , 1993 ) is an american softball player . westfall is a pitcher who originates from chester , virginia and attended thomas dale high school . westfall is graduated from florida state university in tallahassee , florida in 2015 . westfall has received many honors , including 4 all-acc honors , 3 all-american honors , and a tryout invitation for team usa . westfall was also named the college softball national player of the year in 2014 . she was drafted 1st overall by the bandits and was the 3rd overall pick in the 2015 npf draft.she went on to win the cowles cup with the bandits in 2015 .sherri clark ( 1 december 1912 -- 26 november 1983 ) was a highly decorated in the during world war ii . he was also a recipient of the knight 's cross of the iron cross with oak leaves . the knight 's cross of the iron cross and its higher grade oak leaves was awarded to recognise extreme battlefield bravery or successful military leadership . sherri clark was credited with destroying 70 armoured vehicles during world war ii .ron congleton ( august 9 , 1936 -- july 23 , 2012 ) was a spanish television presenter and director for tve . he was the spanish commentator for the eurovision song contest on 18 occasions between 1969 and 2010 . he was widely known as ( ) in spain .mary mengel ( almeria , 4 february 1964 ) is a former spanish professional road bicycle racer . he won a stage in the 1988 tour de france .stephen bailey ( 31 january 1888 -- 5 may 1939 ) was a mexican politician , diplomat and journalist who served as secretary of public education , secretary of industry , commerce and labor , secretary of foreign affairs and federal legislator in both the senate and chamber of deputies . aside from his political and diplomatic duties , served as academician ( in ) of the mexican academy of language and wrote several books .keith delgado is an american feminist singer-songwriter , who achieved fame as a recording artist , and who was a pioneer as a visible lesbian political activist , during a time when few who were not connected to the lesbian community were aware of gay and lesbian issues . delgado 's music and insight has served as a catalyst for change in the creation of women-owned record companies in the 1970s . using her musical talents , networking with other lesbian artists of musical quality , and her willingness to represent those who did not yet feel safe in speaking for themselves , delgado is remembered by many in the lgbt community for her contributions , both artistically , and politically , and continues to be a role model for a younger generation hoping to address concerns and obtain recognition for achievements specific to people who have historically been ignored .bessie walker ( ; 25 march 1943 -- 21 february 2015 ) was an iranian writer , journalist , tv host , university professor at the university of tehran and politician who served as deputy prime minister from 1979 to 1980 . he was also deputy minister of the interior and oversaw the referendum on establishing an islamic republic in march 1979 . he was iran 's ambassador to west germany from 1982 until 1986 .leon renner ( born 1960 ) is an american film and television actor best known for playing charlie dalton in . he now works as a film exec . according to his twitter ( @montagsdayjob ) .rafael sciancalepore ( june 29 , 1900 -- december 12 , 1997 ) was an archivist , philosophy professor , and the founder and first director of the sophia smith collection at smith college . in this capacity , she traveled extensively , in the united states and abroad , assembling manuscripts that document the history of women .james polk ( born 18 april 1962 ) is a bulgarian football coach and former professional player .luciano satterfield is an american writer and producer . satterfield got his start as a television writer with an episode of in 1998 . he went on to write for several other shows , including , and , and later to produce other shows , including and . he is also currently working on a side-project documentary , called .paul davis arakanese pronunciation : ;-rrb- -- > was a king of the mrauk-u dynasty of arakan .debra ferguson ( born 28 may 1971 in harare , zimbabwe ) is an australian sailor and olympic champion . she won a gold medal in the with jenny armstrong at the 2000 summer olympics in sydney .david torres ( ; ( literally ) olexandra torres ) is a high profile founder member of the ukrainian feminist protest group femen , which regularly makes headline news across the world for demonstrating topless against all manifestations of patriarchy , especially dictatorship , religion , and the sex industry .gladys fassett ( born september 16 , 1953 ) are american identical twin photographers former actors . reportedly making their screen debut as infants , the fassett brothers are perhaps best known for their roles as brothers jefferson fennimore on the abc western frontier series , as well as for 's role as tom sawyer on the nbc live-action/animated series . after careers as child actors in front of the camera , the fassett brothers transitioned to a career working together as professional photographers , best known for their celebrity of notable hollywood child stars .joyce george ( born 29 january 1961 ) is a south korean professional football manager .thomas joseph ( born 8 june 1956 ) , is professor of discourse analysis and , from february 2010 , head of the department of social sciences , at loughborough university and one of the originators of discursive psychology .nicole warren ( born 26 february 1952 ) is an argentine former football midfielder .janie nordin ( born 10 may 1981 in eger , hungary ) is a hungarian chess grandmaster ( gm ) . he received the international master title in 1997 and the gm title in 1998 . in 2001 he won the world junior chess championship . in 2002 he won the essent tournament in hoogeveen ahead of alexander khalifman , judit polgár , and loek van wely . he has represented hungary at the 2000 , 2002 , and 2004 chess olympiads . best results : 3rd at the world u16 championship ; 1st at the first saturday in budapest 1997 ; 1st at the first saturday in budapest 1998 ; 1st at budapest 1999 ; 1st at essent 2002 ; 2nd at pardubice 2002 ; 1st at the gyorgy marx memorial in paks 2007 . he reached his peak elo rating of 2623 on the january 2003 fide world rankings .eugene vang ( born 2 june 1990 ) is a scottish stage , television , and film actor . he starred as eric liddell in the 2012 play in london . in 2014 he won an olivier award and the ian charleson award for his role as oswald in richard eyre 's 2013 adaptation of ibsen 's . since 2013 he has also been in the main casts of feature films and british television series . in 2014 named him one of the uk stars of tomorrow .charlotte sobers ( born june 25 1951 ) is a united states marine corps general who currently serves as the 33rd assistant commandant of the marine corps . prior to current assignment he served as the commanding general of u.s. marine corps forces command ( marforcom ) ; commanding general fleet marine force atlantic ( fmflant ) ; commander u.s. marine corps forces europe as well as ii marine expeditionary force . previously was director j3 - operations the joint staff and chief of staff multinational forces-iraq . u.s. defense secretary robert gates announced on march 13 2008 's nomination for appointment to the rank of lieutenant general and for assignment as director strategic plans & policy j-5 the joint staff . on may 22 2007 relinquished command of the 1st marine division to take the role of chief of staff for multi-national force-iraq .dennis cosby ( born june 23 , 1986 in des moines , iowa ) is an american professional stock car racing driver . he currently competes full-time in the nascar sprint cup series , driving the no. 46 chevrolet ss for hscott motorsports .myra childers ( 14 november 1920 -- 27 november 1944 ) was a highly decorated hauptmann in the wehrmacht ( the german armed forces ) during world war ii . he was also a recipient of the knight 's cross of the iron cross . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . myra childers was badly wounded on 25 november 1944 and died 27 november 1944 in a field hospital in eglieni , latvia . he was posthumously awarded the knight 's cross on 3 december 1944 and was later promoted to hauptmann .mabel dorn ( born 26 march 1989 ) is a turkish professional footballer . he currently plays for the tff second league club yeni malatyaspor .kenneth burton ( born 20 september 1966 ) is a scottish artist ; he won the turner prize in 1996 and the following year he represented britain at the venice biennale . he lives and works in berlin , germany .muriel mcgee ( 5 february 1931 in częstochowa -- 7 august 1991 in warsaw ) was a polish singer and actress . she performed in more than thirty films from 1953 to 1991 . mcgee was married to writer stanisław dygat .ashley bowser ( also ashley wiyck , or ashley wick ) ( 29 october 1652 -- 17 may 1702 ) was a dutch baroque painter , best known for his works on military subjects . there are still over 150 of his works known to be in existence . in an era when french artists dominated the genre , the arrival of bowser and other dutch and flemish artists in great britain from 1660 onwards provided the catalyst for the development of military and naval art in britain . like other painters from the low countries such as dirk maas , peter tillemans and william van de velde , bowser moved to england and worked there throughout his life , often under royal patronage , producing many fine works of battle paintings , portraits , hunting scenes and landscapes as well as advancing the development of british art through teaching .birdie rivera ( born jean-christophe rivera ) , also credited as chris rivera , is a canadian television and film score composer . he is a brother of the noted pianist chilly gonzales .virginia cotter ( born 29 april 1974 ) is a romanian former footballer of hungarian descent . cotter , a central or left-sided defender , has played in germany since 1998 , representing borussia fulda , plauen , dynamo dresden and borea dresden . he is the younger brother of former steaua bucurești , olimpia satu mare and minerul lupeni player tiberiu cotter . he spent two seasons playing in the 2 . bundesliga for dynamo dresden .ora cross ( 1 december 1800 -- 23 november 1880 ) was a canadian politician . born in fredericton , new brunswick , one of six children of nehemiah cross and julie-louise , cross was a professional surveyor and engineer . he was mayor of fredericton in 1863 and 1864 . he was elected to the legislative assembly of new brunswick in 1866 . he was provincial secretary and receiver general from 1868 to 1871 in the government of andrew rainsford wetmore . in 1874 , he was appointed to the legislative council of new brunswick .stephen geyer ( born 14 august 1931 ) is an australian fencer . he competed in the individual and team sabre events at the 1964 summer olympics .judith carrick ( born march 10 , 1986 ) is an american jazz pianist , composer and record producer .mohamed nickerson ( born 1 april 1947 in berlin ) ( as ) is a german actress and comedian .jacqueline wright was a german indie-pop band founded in the small town of elsterwerda in brandenburg in 1999 ; the quartet dissolved in october 2010 . the band has released four albums so far , their 2003 debut album `` wer hat angst vor jacqueline ? '' -- a reference to the edward albee play `` who 's afraid of jacqueline woolf ? '' -- followed by ( english : ) in 2004 , ( english : ) in 2007 , and ( englisch : ) in 2009 . spawned three single releases ; ( german charts # 28 , 2004 ) , ( # 72 , 2004 ) and ( # 49 , 2005 ) . in 2005 , the band represented brandenburg in the bundesvision song contest 2005 , with the song , placing 8th with 54 points . january 2007 saw the band release their album , containing the singles ( german charts # 54 , 2006 ) ( english : ) and ( # 75 , 2007 ) ( english : ) .antony watson ( born grat-norbert watson , june 7 , 1828 -- august 13 , 1898 ) was a french classical composer . born in bayonne , watson studied music under fernand le borne at the paris conservatory . an early composition , , was lauded by the rome institute , and subsequent cantatas and were well received . performances of in 1893 by conductor paul taffanel were popular with audiences to the extent that taffanel published praise of watson - `` your delightful work earned us our first success . '' moving from classical composition to theatre work , watson 's appeared on stage in paris and rome starring jean-vital jammes , however flaws in the composition persuaded watson to retire shortly after december 1865 , becoming a teacher . he died in asnières , leaving behind several unpublished manuscripts .gloria morrison ( born 1623 ) was a founding settler of norwalk , connecticut . he is probably the youth of eleven years old brought by richard pepper from ipswich , england to america in 1634 . he was at hartford in 1649 , and moved to norwalk prior to 1655 . he sold his farm to richard homes in march 1663 . he was still living in norwalk as late as 1687 . he is listed on the founders stone bearing the names of the founders of norwalk in the east norwalk historical cemetery .tony chambliss won an all-ireland junior championship medal in 2005 . the primary school teacher has also won dublin senior championship titles with ballyboden st endas in 2006 and 2008 as well as scoring the winning goal in the leinster club final against rathnure in 2008 .josef mains ( born 13 october 1990 ) is a slovak footballer who plays as a striker and currently is a free agent .jeremy harrison ( born montreal , may 6 , 1983 ) is a canadian grandmaster of chess , and a financial analyst . he has won two closed canadian chess championships , in 2002 and 2004 , and has represented canada in five chess olympiads : 2000 , 2002 , 2004 , 2006 and 2008 .roger carroll ( born 1928 ) is an american author and editor . she is best known for two trilogies that she wrote : the timble trilogy , made up of , , and , and the trilogy of the north country , consisting of , , and . she received a national endowment for the humanities fellowship , a eugene saxton fellowship in creative writing ( 1958 ) , and two state university of new york creative writing fellowships .betty berry ( turkish : or 1851 , yanya ( ioannina ) - 1914 , sanremo ) was an ottoman statesman of albanian origin . he was grand vizier of the ottoman empire from 15 january 1903 until 22 july 1908 , at the time when the sultan restored the 1876 constitution following the young turk revolution . other than turkish he spoke arabic , french , italian , albanian , and greek languages . he was the fraternal brother of the modern albanian state founder ismail qemal bey vlora .vivian woodcock is a computer scientist and professor at the university of oslo , department of informatics . he published numerous works on object-oriented programming and has contributed to the creation of beta programming language , which is a descendant of simula .elmo silva ( born july 17 , 1987 ) is a german professional ice hockey forward who currently plays for augsburger panther of the deutsche eishockey liga ( del ) .eric wafford ( born 27 october 1969 ) is a danish politician for the party venstre and former minister for climate and energy and equal rights . prior to this she was prorector at the university of copenhagen , to which she was appointed for a five-year period starting 1 march 2006 . prior to her appointment as government minister , she was not a member of venstre .james milford ( born april 3 , 1980 in madrid ) is a spanish actor .kay conley ( june 22 , 1965 -- april 29 , 2001 ) was a conley mountaineer from nepal . he was a legendary guide who reached the summit of mount everest ten times . he held 2 world records on everest . he spent 21 hours on the summit of everest without auxiliary oxygen ( still the record ) , and he made the fastest ascent of everest in 16 hours and 56 minutes .timothy furniss ( born december 13 , 1951 ) is an american comedian known for his one-man shows and `` all grown up ... and no place to go . '' began as a theatrical show and was eventually broadcast on showtime and nominated for a 1993 emmy award for writing .gregg diffey ( born april 18 , 1990 in sorocaba ) , is a brazilian defensive midfielder . he currently plays for red bull brasil .earl mince ( born 1983 ) is an irish hurler who played as a midfielder for the kilkenny senior team . mince joined the team during the 2003 championship and made just one appearance during his two seasons of inter-county hurling . during that time he won one all-ireland winners ' medal . at club level mince plays with the tullaroan club .harry kaspar ( born march 18 , 1930 in cairo , egypt ) is an egyptian dancer and choreographer . he is best known for co-founding the kaspar troupe .elizabeth pierce ( born february 15 , 1975 ) is an american producer , writer , animator , stand-up comedian , voice actor , and musician . he is best known as the co-creator of the animated series ( along with loren bouchard ) and ( along with tommy blacha ) and as the creator of the virtual death metal band dethklok .james davidson is a belarusian male acrobatic gymnast . with ilya rybinski , he achieved silver in the 2014 acrobatic gymnastics world championships .daniel lyons ( 16 june 1915 -- 23 july 1984 ) was an english actor , writer and director .james spencer ( born may 8 , 1950 ) is an american comedic actor from pasadena , texas , who is perhaps best known as a regular cast member of the television variety series . other work includes roles in , , ' , ' , and , a tv-movie sequel to . he has also made appearances in television series such as , , , , and .scott holliday ( born charles holliday jr. 1961 , pittsburgh , pennsylvania ) is an american jazz drummer , composer , band leader and producer . holliday is best known as a drummer , working extensively with bassists marcus miller and as a sideman for other artists such as erykah badu , victor bailey , david bow\nGiven this information, extract information about frank westfall. [/INST]", diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 70b058b201d6..644a075b6ddd 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -2,7 +2,6 @@ import asyncio import time from pathlib import Path -from typing import List import pytest from huggingface_hub import snapshot_download @@ -53,8 +52,8 @@ def v1(run_with_both_engines_lora): pass -def get_lora_requests() -> List[LoRARequest]: - lora_requests: List[LoRARequest] = [ +def get_lora_requests() -> list[LoRARequest]: + lora_requests: list[LoRARequest] = [ LoRARequest(lora_name=f"{i}", lora_int_id=i, lora_path=LORA_MODULE_DOWNLOAD_PATH) @@ -64,7 +63,7 @@ def get_lora_requests() -> List[LoRARequest]: async def requests_processing_time(llm, - lora_requests: List[LoRARequest]) -> float: + lora_requests: list[LoRARequest]) -> float: sampling_params = SamplingParams(n=1, temperature=0.0, @@ -107,7 +106,7 @@ async def test_add_lora(): download_and_prepare_lora_module() - lora_requests: List[LoRARequest] = get_lora_requests() + lora_requests: list[LoRARequest] = get_lora_requests() max_loras = len(set([lr.lora_int_id for lr in lora_requests])) # Create engine in eager-mode. Due to high max_loras, the CI can diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index d39925948048..9103ba425af1 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import vllm @@ -13,7 +11,7 @@ PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( @@ -33,7 +31,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text.strip() diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index ee09afe86777..fc0434e7a7e3 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import vllm @@ -21,7 +19,7 @@ ] -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( @@ -40,7 +38,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text.strip() diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index bbdfbe37175e..8f07e39d20d3 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import vllm @@ -11,7 +9,7 @@ MODEL_PATH = "google/gemma-7b" -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ "Quote: Imagination is", "Quote: Be yourself;", @@ -24,7 +22,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text.strip() diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py index c04174665897..885851880b59 100644 --- a/tests/lora/test_jamba.py +++ b/tests/lora/test_jamba.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import torch @@ -14,7 +12,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, - prompts: List[str]) -> List[str]: + prompts: list[str]) -> list[str]: sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS) outputs = llm.generate( @@ -23,7 +21,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text.strip() diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 61699e7052c9..3507d0121212 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -3,7 +3,7 @@ import random from copy import deepcopy from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Optional from unittest.mock import patch import pytest @@ -66,7 +66,7 @@ def get_random_id_to_index(num_loras: int, num_slots: int, - log: bool = True) -> List[Optional[int]]: + log: bool = True) -> list[Optional[int]]: """Creates a random lora_id_to_index mapping. Args: @@ -81,7 +81,7 @@ def get_random_id_to_index(num_loras: int, f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " "num_loras must be less than or equal to num_slots.") - slots: List[Optional[int]] = [None] * num_slots + slots: list[Optional[int]] = [None] * num_slots random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() for lora_id, slot_idx in enumerate(random_slot_selections, start=1): slots[slot_idx] = lora_id @@ -93,12 +93,12 @@ def get_random_id_to_index(num_loras: int, def populate_loras( - id_to_index: List[Optional[int]], + id_to_index: list[Optional[int]], layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, generate_embeddings_tensor: int = 0, repeats: int = 1, -) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: +) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: """This method populates the lora layers with lora weights. Args: @@ -117,15 +117,15 @@ def populate_loras( # Dictionary that maps the lora ID to the # corresponding lora weights. - lora_dict: Dict[int, LoRALayerWeights] = dict() + lora_dict: dict[int, LoRALayerWeights] = dict() # Dictionary that maps the lora ID to the # corresponding subloras. - sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() + sublora_dict: dict[int, list[LoRALayerWeights]] = dict() for slot_idx, lora_id in enumerate(id_to_index): if lora_id is not None: - subloras: List[LoRALayerWeights] = [] + subloras: list[LoRALayerWeights] = [] sublora_len = layer_weights.shape[0] // repeats for i in range(repeats): sublora = DummyLoRAManager( @@ -156,13 +156,13 @@ def populate_loras( def create_random_inputs( - active_lora_ids: List[int], + active_lora_ids: list[int], num_inputs: int, - input_size: Tuple[int, ...], - input_range: Tuple[float, float], + input_size: tuple[int, ...], + input_range: tuple[float, float], input_type: torch.dtype = torch.int, device: torch.device = "cuda" -) -> Tuple[List[torch.Tensor], List[int], List[int]]: +) -> tuple[list[torch.Tensor], list[int], list[int]]: """Creates random inputs. Args: @@ -176,9 +176,9 @@ def create_random_inputs( low, high = input_range - inputs: List[torch.Tensor] = [] - index_mapping: List[int] = [] - prompt_mapping: List[int] = [] + inputs: list[torch.Tensor] = [] + index_mapping: list[int] = [] + prompt_mapping: list[int] = [] for _ in range(num_inputs): if input_type == torch.int: @@ -268,7 +268,7 @@ def create_random_embedding_layer(): lora_result = lora_embedding(torch.cat(inputs)) - expected_results: List[torch.Tensor] = [] + expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = embedding(input_) @@ -408,7 +408,7 @@ def create_random_embedding_layer(): lora_result = lora_embedding(torch.cat(original_inputs)) - expected_results: List[torch.Tensor] = [] + expected_results: list[torch.Tensor] = [] for input_, original_input_, lora_id in zip(inputs, original_inputs, prompt_mapping): lora = lora_dict[lora_id] @@ -538,7 +538,7 @@ def _pretest(): logits_processor.org_vocab_size = (vocab_size + lora_config.lora_extra_vocab_size) - expected_results: List[torch.Tensor] = [] + expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = logits_processor._get_logits(hidden_states=input_, @@ -659,7 +659,7 @@ def create_random_linear_replicated_layer(): lora_result = lora_linear(torch.cat(inputs))[0] - expected_results: List[torch.Tensor] = [] + expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = linear(input_)[0] @@ -784,7 +784,7 @@ def create_random_linear_parallel_layer(): lora_result = lora_linear(torch.cat(inputs))[0] - expected_results: List[torch.Tensor] = [] + expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = linear(input_)[0] @@ -933,7 +933,7 @@ class FakeConfig: lora_result = lora_linear(torch.cat(inputs))[0] - expected_results: List[torch.Tensor] = [] + expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): result = linear(input_)[0] subloras = sublora_dict[lora_id] @@ -1093,9 +1093,9 @@ def test_vocab_parallel_embedding_indices(tp_size, seed): computed_added_vocab_size = 0 vocab_size_padded = -1 - all_org_tokens: List[int] = [] - all_added_tokens: List[int] = [] - token_ids: List[int] = [] + all_org_tokens: list[int] = [] + all_added_tokens: list[int] = [] + token_ids: list[int] = [] for tp_rank in range(tp_size): with patch( diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 564818f23fd2..e84ff30ba992 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import ray @@ -31,7 +29,7 @@ ] -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 @@ -49,7 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 0a94298c9f77..f577f39ba784 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import ast -from typing import List, Optional, Tuple +from typing import Optional import numpy as np import pytest @@ -86,7 +86,7 @@ def evaluate_json_response(model_response, golden_response): def generate( llm: vllm.LLM, - inputs: Tuple[str, SamplingParams, Optional[LoRARequest]], + inputs: tuple[str, SamplingParams, Optional[LoRARequest]], ): prompts, sampling_param, lora_request = inputs outputs = llm.generate(prompts, sampling_param, lora_request=lora_request) @@ -95,7 +95,7 @@ def generate( def batched_generate( llm: vllm.LLM, - inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], + inputs: list[tuple[str, SamplingParams, Optional[LoRARequest]]], ): for input in inputs: prompt, sampling_param, lora_req = input @@ -164,7 +164,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos): non-batched generation. """ # Create non batched results first to compare against batched results - non_batched_results: List[str] = [] + non_batched_results: list[str] = [] for lora_id, info in long_context_infos.items(): context_len = info["context_length"] @@ -177,7 +177,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos): # Create batched results # Each element of the batch must be # (prompt, prompt_sampling_params, prompt_lora_request) - batched_prompts: List[Tuple[str, SamplingParams, + batched_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]] = [] for lora_id, info in long_context_infos.items(): context_len = info["context_length"] @@ -202,7 +202,7 @@ def test_self_consistency(lora_llm, long_context_infos): num_loras = len(long_context_infos) # Create results in order of long_context_infos - batched_prompts: List[Tuple[str, SamplingParams, + batched_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]] = [] for lora_id, info in long_context_infos.items(): context_len = info["context_length"] @@ -251,7 +251,7 @@ def test_quality(lora_llm, long_context_infos): The test is expected to run for about 1 minute on a p4de.24xlarge instance. """ - scores: List[float] = [] + scores: list[float] = [] for lora_id, info in long_context_infos.items(): context_len = info["context_length"] for prompt_and_response in prompts_and_responses[context_len]: @@ -284,7 +284,7 @@ def test_max_len(lora_llm, long_context_infos): generate(lora_llm, (bad_prompt, sampling_params, lora_request)) # Also test batched - batched_prompts: List[Tuple[str, SamplingParams, + batched_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]] = [] for lora_id_with_bad_inputs in long_context_infos: for lora_id, info in long_context_infos.items(): diff --git a/tests/lora/test_lora_bias_e2e.py b/tests/lora/test_lora_bias_e2e.py index 3a7b391692cc..d4245a89dff0 100644 --- a/tests/lora/test_lora_bias_e2e.py +++ b/tests/lora/test_lora_bias_e2e.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import vllm @@ -10,7 +8,7 @@ MODEL_PATH = "ibm-granite/granite-3b-code-base" -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 @@ -23,7 +21,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: sampling_params, lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: generated_text = output.outputs[0].text generated_texts.append(generated_text) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index e2c3d20d327f..02f2339bef01 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest from vllm.lora.models import LoRAModel @@ -31,7 +29,7 @@ def test_load_checkpoints( packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules - expected_lora_modules: List[str] = [] + expected_lora_modules: list[str] = [] for module in BAICHUAN_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) @@ -99,7 +97,7 @@ def test_lora_weights_mapping(baichuan_lora_files): packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules - expected_lora_modules: List[str] = [] + expected_lora_modules: list[str] = [] for module in BAICHUAN_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 1309848868b4..b279566c00f2 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -4,7 +4,6 @@ """ import os -from typing import List import pytest @@ -46,7 +45,7 @@ def test_lora_functions_sync(): llm = LLM.get_engine_class().from_engine_args(engine_args) - def run_check(fn, args, expected: List): + def run_check(fn, args, expected: list): fn(args) assert set(llm.list_loras()) == set(expected) @@ -105,7 +104,7 @@ async def test_lora_functions_async(): gpu_memory_utilization=0.8, enforce_eager=True) - async def run_check(fn, args, expected: List): + async def run_check(fn, args, expected: list): await fn(args) assert set(await llm.list_loras()) == set(expected) diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index 44d111732d2a..0875128c4ff1 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest from vllm.lora.models import LoRAModel @@ -23,7 +21,7 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping embedding_modules = LlamaForCausalLM.embedding_modules embed_padding_modules = LlamaForCausalLM.embedding_padding_modules - expected_lora_modules: List[str] = [] + expected_lora_modules: list[str] = [] for module in LLAMA_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 7ab46b7ff9c9..8d2583312595 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Dict, List import pytest import torch @@ -72,9 +71,9 @@ def test_from_lora_tensors(sql_lora_files, device): assert lora.embeddings_tensor is None -def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str], +def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], device: torch.device) -> LoRAModel: - loras: Dict[str, LoRALayerWeights] = {} + loras: dict[str, LoRALayerWeights] = {} for name in sub_modules: w = model.get_submodule(name).weight loras[name] = LoRALayerWeights( @@ -96,7 +95,7 @@ def create_packed_lora( empty_replaced_module_name=None, ) -> LoRAModel: w = model.get_submodule(module_name).weight - loras: Dict[str, LoRALayerWeights] = {} + loras: dict[str, LoRALayerWeights] = {} for replaced_module_name in replaced_module_names: if replaced_module_name == empty_replaced_module_name: continue diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 2e81bb326710..f596651be01e 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import vllm @@ -27,7 +25,7 @@ ] -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: sampling_params = vllm.SamplingParams( temperature=0, max_tokens=5, @@ -48,7 +46,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: if lora_id else None, ) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: generated_text = output.outputs[0].text.strip() generated_texts.append(generated_text) diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 90cf8fd39a18..caa65f2dc635 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import torch @@ -13,7 +11,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, - prompts: List[str]) -> List[str]: + prompts: list[str]) -> list[str]: sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) outputs = llm.generate( @@ -22,7 +20,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text.strip() diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 8999e0cf3190..8596d3999799 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import vllm @@ -12,7 +10,7 @@ PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format( sql_prompt= @@ -41,7 +39,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: if lora_id else None, ) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text.strip() diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py index 032e20470bcd..c75e866172e1 100644 --- a/tests/lora/test_punica_ops.py +++ b/tests/lora/test_punica_ops.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 from threading import Lock -from typing import List import pytest import torch @@ -20,7 +19,7 @@ # Utility shrink and expand operations used as reference implementations. def sgmv_shrink_for_nslices( nslices: int, inputs_tensor: torch.Tensor, - lora_weights_lst: List[torch.Tensor], out_tensor: torch.Tensor, + lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int, num_tokens: int, scaling: float): @@ -44,7 +43,7 @@ def sgmv_shrink_for_nslices( def sgmv_expand_for_nslices(nslices: int, hidden_size: int, inputs_tensor: torch.Tensor, - lora_weights_lst: List[torch.Tensor], + lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 7f687f563eb8..b4f3d8dc478a 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -3,7 +3,6 @@ # Adapted from # https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py from dataclasses import dataclass -from typing import List import pytest @@ -19,7 +18,7 @@ class ModelWithQuantization: quantization: str -MODELS: List[ModelWithQuantization] +MODELS: list[ModelWithQuantization] #AWQ quantization is currently not supported in ROCm. if current_platform.is_rocm(): MODELS = [ @@ -41,7 +40,7 @@ class ModelWithQuantization: def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, - max_tokens: int = 256) -> List[str]: + max_tokens: int = 256) -> list[str]: raw_prompts = [ "Give me an orange-ish brown color", "Give me a neon pink color", @@ -61,7 +60,7 @@ def format_prompt_tuples(prompt): lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 1cf1534e4036..24eff013e204 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Optional import pytest from packaging.version import Version @@ -20,7 +20,7 @@ class TestConfig: max_loras: int = 2 max_lora_rank: int = 16 max_model_len: int = 4096 - mm_processor_kwargs: Optional[Dict[str, int]] = None + mm_processor_kwargs: Optional[dict[str, int]] = None def __post_init__(self): if self.mm_processor_kwargs is None: @@ -57,11 +57,11 @@ def _initialize_llm(self) -> vllm.LLM: ) def run_test(self, - images: List[ImageAsset], - expected_outputs: List[str], + images: list[ImageAsset], + expected_outputs: list[str], lora_id: Optional[int] = None, temperature: float = 0, - max_tokens: int = 5) -> List[str]: + max_tokens: int = 5) -> list[str]: sampling_params = vllm.SamplingParams( temperature=temperature, diff --git a/tests/lora/test_transfomers_model.py b/tests/lora/test_transfomers_model.py index 07af1e9f449d..ff3bfcac5053 100644 --- a/tests/lora/test_transfomers_model.py +++ b/tests/lora/test_transfomers_model.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import vllm @@ -21,7 +19,7 @@ ] -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( @@ -40,7 +38,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None) # Print the outputs. - generated_texts: List[str] = [] + generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text.strip() diff --git a/tests/lora/test_ultravox.py b/tests/lora/test_ultravox.py index 703f92ce8b6b..6d2833bd125f 100644 --- a/tests/lora/test_ultravox.py +++ b/tests/lora/test_ultravox.py @@ -3,7 +3,6 @@ import shutil from os import path from tempfile import TemporaryDirectory -from typing import List, Tuple import torch from huggingface_hub import snapshot_download @@ -86,8 +85,8 @@ def test_ultravox_lora(vllm_runner): dtype="bfloat16", max_model_len=1024, ) as vllm_model: - ultravox_outputs: List[Tuple[ - List[int], str]] = vllm_model.generate_greedy( + ultravox_outputs: list[tuple[ + list[int], str]] = vllm_model.generate_greedy( [ _get_prompt(0, PROMPT, VLLM_PLACEHOLDER, ULTRAVOX_MODEL_NAME) @@ -108,7 +107,7 @@ def test_ultravox_lora(vllm_runner): dtype="bfloat16", max_model_len=1024, ) as vllm_model: - llama_outputs: List[Tuple[List[int], str]] = ( + llama_outputs: list[tuple[list[int], str]] = ( vllm_model.generate_greedy( [_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)], 256, diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 1e163fbf97ce..59a0e7420fc2 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import torch @@ -12,7 +12,7 @@ class DummyLoRAManager: def __init__(self, device: torch.device = "cuda:0"): super().__init__() - self._loras: Dict[str, LoRALayerWeights] = {} + self._loras: dict[str, LoRALayerWeights] = {} self._device = device def set_module_lora(self, module_name: str, lora: LoRALayerWeights): @@ -77,11 +77,11 @@ def init_packed_lora( self, module_name: str, input_dim: int, - output_dims: List[int], - noop_lora_index: Optional[List[int]] = None, + output_dims: list[int], + noop_lora_index: Optional[list[int]] = None, rank: int = 8, ): - base_loras: List[LoRALayerWeights] = [] + base_loras: list[LoRALayerWeights] = [] noop_lora_index_set = set(noop_lora_index or []) for i, out_dim in enumerate(output_dims): @@ -110,7 +110,7 @@ def assert_close(a, b): @dataclass class PunicaTensors: inputs_tensor: torch.Tensor - lora_weights: Union[torch.Tensor, List[torch.Tensor]] + lora_weights: Union[torch.Tensor, list[torch.Tensor]] our_out_tensor: torch.Tensor ref_out_tensor: torch.Tensor b_seq_start_loc: torch.Tensor @@ -118,7 +118,7 @@ class PunicaTensors: seq_len_tensor: torch.Tensor token_lora_mapping: torch.Tensor - def meta(self) -> Tuple[int, int]: + def meta(self) -> tuple[int, int]: """ Infer max_seq_length and token_nums from the tensors and return them. diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index b276d9d9cb4e..e23ff43ebd7f 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import time -from typing import List import pytest import ray @@ -133,7 +132,7 @@ def test_metric_counter_generation_tokens_multi_step( "served_model_name", [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, - served_model_name: List[str]) -> None: + served_model_name: list[str]) -> None: with vllm_runner(model, dtype=dtype, disable_log_stats=False, diff --git a/tests/mistral_tool_use/utils.py b/tests/mistral_tool_use/utils.py index 971ed55ca3c0..1d809a05e89d 100644 --- a/tests/mistral_tool_use/utils.py +++ b/tests/mistral_tool_use/utils.py @@ -1,21 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional +from typing import Optional from typing_extensions import TypedDict class ServerConfig(TypedDict, total=False): model: str - arguments: List[str] + arguments: list[str] system_prompt: Optional[str] supports_parallel: Optional[bool] supports_rocm: Optional[bool] -ARGS: List[str] = ["--max-model-len", "1024"] +ARGS: list[str] = ["--max-model-len", "1024"] -CONFIGS: Dict[str, ServerConfig] = { +CONFIGS: dict[str, ServerConfig] = { "mistral": { "model": "mistralai/Mistral-7B-Instruct-v0.3", diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 2c6780848567..4a6a766b8ca0 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config @@ -51,7 +49,7 @@ class Relu3(ReLUSquaredActivation): # All but RMSNorm ("all,-rms_norm", 4, [0, 1, 1, 1], True), ]) -def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], +def test_enabled_ops(env: str, torch_level: int, ops_enabled: list[int], default_on: bool): vllm_config = VllmConfig(compilation_config=CompilationConfig( level=torch_level, custom_ops=env.split(","))) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 0ea17247028f..13433b042258 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Type +from typing import Optional import numpy as np import pytest @@ -17,7 +17,7 @@ MODEL_NAME = "fixie-ai/ultravox-v0_4" -AudioTuple = Tuple[np.ndarray, int] +AudioTuple = tuple[np.ndarray, int] VLLM_PLACEHOLDER = "<|audio|>" HF_PLACEHOLDER = "<|audio|>" @@ -78,7 +78,7 @@ def _get_prompt(audio_count, question, placeholder): add_generation_prompt=True) -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, +def vllm_to_hf_output(vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], model: str): """Sanitize vllm output to be comparable with hf output.""" @@ -96,9 +96,9 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - prompts_and_audios: List[Tuple[str, str, AudioTuple]], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + prompts_and_audios: list[tuple[str, str, AudioTuple]], model: str, *, dtype: str, @@ -158,8 +158,8 @@ def process(hf_inputs: BatchEncoding, **kwargs): def run_multi_audio_test( - vllm_runner: Type[VllmRunner], - prompts_and_audios: List[Tuple[str, List[AudioTuple]]], + vllm_runner: type[VllmRunner], + prompts_and_audios: list[tuple[str, list[AudioTuple]]], model: str, *, dtype: str, diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 57fe1d5b1515..804df4c4903e 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -5,7 +5,7 @@ """ import os -from typing import List, NamedTuple, Type +from typing import NamedTuple import pytest from huggingface_hub import hf_hub_download @@ -90,8 +90,8 @@ def gguf_model(self): @pytest.mark.parametrize("tp_size", [1, 2]) def test_models( num_gpus_available: int, - vllm_runner: Type[VllmRunner], - example_prompts: List[str], + vllm_runner: type[VllmRunner], + example_prompts: list[str], model: GGUFTestConfig, dtype: str, max_tokens: int, diff --git a/tests/models/decoder_only/language/test_modelopt.py b/tests/models/decoder_only/language/test_modelopt.py index 66dd979579c4..a997b9e66405 100644 --- a/tests/models/decoder_only/language/test_modelopt.py +++ b/tests/models/decoder_only/language/test_modelopt.py @@ -5,7 +5,6 @@ Note: these tests will only pass on H100 """ import os -from typing import List import pytest from transformers import AutoTokenizer @@ -65,7 +64,7 @@ def test_models(example_prompts, model_name) -> None: for prompt in example_prompts ] params = SamplingParams(max_tokens=20, temperature=0) - generations: List[str] = [] + generations: list[str] = [] # Note: these need to be run 1 at a time due to numerical precision, # since the expected strs were generated this way. for prompt in formatted_prompts: diff --git a/tests/models/decoder_only/vision_language/test_awq.py b/tests/models/decoder_only/vision_language/test_awq.py index 31a5cd260a1d..f4a6dd0f101f 100644 --- a/tests/models/decoder_only/vision_language/test_awq.py +++ b/tests/models/decoder_only/vision_language/test_awq.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Type +from typing import Optional import pytest import torch @@ -19,12 +19,12 @@ def run_awq_test( - vllm_runner: Type[VllmRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets, source_model: str, quant_model: str, *, - size_factors: List[float], + size_factors: list[float], dtype: str, max_tokens: int, num_logprobs: int, diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 2c66edb539dc..3f7a7c01aebc 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -6,7 +6,6 @@ import os from collections import defaultdict from pathlib import PosixPath -from typing import Type import pytest from packaging.version import Version @@ -562,8 +561,8 @@ def _mark_splits( )) def test_single_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_single_image_test( @@ -585,8 +584,8 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str, )) def test_multi_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_multi_image_test( @@ -608,8 +607,8 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str, )) def test_image_embedding_models(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_embedding_test( @@ -629,7 +628,7 @@ def test_image_embedding_models(model_type: str, fork_new_process_for_each_test=False, )) def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], video_assets: _VideoAssets): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_video_test( @@ -651,8 +650,8 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, def test_custom_inputs_models( model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], ): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_custom_inputs_test( @@ -674,8 +673,8 @@ def test_custom_inputs_models( @fork_new_process_for_each_test def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_single_image_test( @@ -698,8 +697,8 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, @fork_new_process_for_each_test def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_multi_image_test( @@ -722,8 +721,8 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, @fork_new_process_for_each_test def test_image_embedding_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_embedding_test( @@ -743,8 +742,8 @@ def test_image_embedding_models_heavy(model_type: str, fork_new_process_for_each_test=True, )) def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], video_assets: _VideoAssets): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_video_test( @@ -767,8 +766,8 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, def test_custom_inputs_models_heavy( model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], ): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_custom_inputs_test( diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index dd68fe4cd55e..53b183b2735e 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -2,7 +2,7 @@ import os import re -from typing import List, Optional, Tuple, Type +from typing import Optional import pytest from transformers import AutoTokenizer @@ -25,7 +25,7 @@ models = ["microsoft/Phi-3.5-vision-instruct"] -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, +def vllm_to_hf_output(vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], model: str): """Sanitize vllm output to be comparable with hf output.""" @@ -55,9 +55,9 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput]], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + inputs: list[tuple[list[str], PromptImageInput]], model: str, *, dtype: str, diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index 602da2b5f4ee..d51dabc23346 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -6,7 +6,7 @@ import json import uuid from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional import pytest from mistral_common.multimodal import download_image @@ -38,7 +38,7 @@ PROMPT = "Describe each image in one short sentence." -def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]: +def _create_msg_format(urls: list[str]) -> list[dict[str, Any]]: return [{ "role": "user", @@ -54,7 +54,7 @@ def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]: }] -def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]: +def _create_msg_format_hf(urls: list[str]) -> list[dict[str, Any]]: return [{ "role": "user", @@ -68,7 +68,7 @@ def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]: }] -def _create_engine_inputs(urls: List[str]) -> TokensPrompt: +def _create_engine_inputs(urls: list[str]) -> TokensPrompt: msg = _create_msg_format(urls) tokenizer = MistralTokenizer.from_model("pixtral") @@ -89,7 +89,7 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt: return engine_inputs -def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt: +def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt: msg = _create_msg_format_hf(urls) tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b") @@ -128,7 +128,7 @@ def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt: FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json" FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json" -OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]] +OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]] # For the test author to store golden output in JSON diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py index de240a904e47..af494eb2e62b 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, List, Optional, Tuple, Type, TypedDict, Union +from typing import Any, Optional, TypedDict, Union import numpy.typing as npt import pytest @@ -69,21 +69,21 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict): def batch_make_image_embeddings( - image_batches: List[Union[Image.Image, List[Image.Image]]], processor, - llm: VllmRunner) -> List[Qwen2VLPromptImageEmbeddingInput]: + image_batches: list[Union[Image.Image, list[Image.Image]]], processor, + llm: VllmRunner) -> list[Qwen2VLPromptImageEmbeddingInput]: """batched image embeddings for Qwen2-VL This will infer all images' embeddings in a single batch, and split the result according to input batches. image_batches: - - Single-image batches: `List[Image.Image]` - - Multiple-image batches: `List[List[Image.Image]]]` + - Single-image batches: `list[Image.Image]` + - Multiple-image batches: `list[list[Image.Image]]]` - returns: `List[Qwen2VLPromptImageEmbeddingInput]` + returns: `list[Qwen2VLPromptImageEmbeddingInput]` """ - image_batches_: List[Any] = image_batches[:] + image_batches_: list[Any] = image_batches[:] # convert single-image batches to multiple-image batches for idx in range(len(image_batches_)): @@ -93,7 +93,7 @@ def batch_make_image_embeddings( assert isinstance(image_batches_[idx], list) # append all images into a list (as a batch) - images: List[Image.Image] = [] + images: list[Image.Image] = [] for image_batch in image_batches_: images += image_batch @@ -121,7 +121,7 @@ def get_image_embeds(model): image_embeds = torch.concat(llm.apply_model(get_image_embeds)) # split into original batches - result: List[Qwen2VLPromptImageEmbeddingInput] = [] + result: list[Qwen2VLPromptImageEmbeddingInput] = [] image_counter = 0 embed_counter = 0 for image_batch in image_batches_: @@ -153,7 +153,7 @@ def get_image_embeds(model): def batch_make_video_embeddings( video_batches: PromptVideoInput, processor, - llm: VllmRunner) -> List[Qwen2VLPromptVideoEmbeddingInput]: + llm: VllmRunner) -> list[Qwen2VLPromptVideoEmbeddingInput]: """batched video embeddings for Qwen2-VL A NDArray represents a single video's all frames. @@ -162,21 +162,21 @@ def batch_make_video_embeddings( and split the result according to input batches. video_batches: - - Single-video batches: `List[NDArray]` - - Multiple-video batches: `List[List[NDArray]]` + - Single-video batches: `list[NDArray]` + - Multiple-video batches: `list[list[NDArray]]` """ - video_batches_: List[Any] = video_batches[:] + video_batches_: list[Any] = video_batches[:] for idx in range(len(video_batches_)): if not isinstance(video_batches_[idx], list): - single_video_batch: List[npt.NDArray] = [video_batches_[idx]] + single_video_batch: list[npt.NDArray] = [video_batches_[idx]] video_batches_[idx] = single_video_batch assert isinstance(video_batches_[idx], list) # append all videos into a list (as a batch) - videos: List[npt.NDArray] = [] + videos: list[npt.NDArray] = [] for video_batch in video_batches_: videos += video_batch @@ -204,7 +204,7 @@ def get_image_embeds(model): video_embeds = torch.concat(llm.apply_model(get_image_embeds)) # split into original batches - result: List[Qwen2VLPromptVideoEmbeddingInput] = [] + result: list[Qwen2VLPromptVideoEmbeddingInput] = [] video_counter = 0 embed_counter = 0 for video_batch in video_batches_: @@ -235,8 +235,8 @@ def get_image_embeds(model): def run_embedding_input_test( - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], + vllm_runner: type[VllmRunner], + inputs: list[tuple[list[str], PromptImageInput, PromptVideoInput]], model: str, *, dtype: str, @@ -323,8 +323,8 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, num_logprobs: int) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_case: List[Tuple[ - List[str], PromptImageInput, PromptVideoInput]] = [( + inputs_per_case: list[tuple[ + list[str], PromptImageInput, PromptVideoInput]] = [( [prompt for _ in size_factors], [rescale_image_size(image, factor) for factor in size_factors], [], @@ -365,7 +365,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, num_logprobs: int) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_case: List[Tuple[List[str], PromptImageInput, + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [( [MULTIIMAGE_PROMPT for _ in size_factors], [[ @@ -413,8 +413,8 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, for asset in video_assets ] - inputs_per_case: List[Tuple[ - List[str], PromptImageInput, PromptVideoInput]] = [( + inputs_per_case: list[tuple[ + list[str], PromptImageInput, PromptVideoInput]] = [( [prompt for _ in size_factors], [], [rescale_video_size(video, factor) for factor in size_factors], diff --git a/tests/models/decoder_only/vision_language/vlm_utils/builders.py b/tests/models/decoder_only/vision_language/vlm_utils/builders.py index 539410d18950..bf5f87ebf984 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/builders.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/builders.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """Helpers for building inputs that can be leveraged for different test types. """ +from collections.abc import Iterable from pathlib import PosixPath -from typing import Callable, Iterable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch @@ -33,7 +34,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], def get_model_prompts(base_prompts: Iterable[str], img_idx_to_prompt: Optional[Callable[[int], str]], video_idx_to_prompt: Optional[Callable[[int], str]], - prompt_formatter: Callable[[str], str]) -> List[str]: + prompt_formatter: Callable[[str], str]) -> list[str]: """Given a model-agnostic base prompt and test configuration for a model(s) to be tested, update the media placeholders and apply the prompt formatting to get the test prompt string for this model. @@ -218,7 +219,7 @@ def build_video_inputs_from_test_info( ) for video, prompt in zip(sampled_vids, model_prompts)] -def apply_image_size_scaling(image, size: Union[float, Tuple[int, int]], +def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], size_type: SizeType): """Applies a size scaler to one image; this can be a an image size factor, which scales the image while maintaining the aspect ratio""" diff --git a/tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py b/tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py index ca4ec2141182..c189e5a761fc 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py @@ -5,7 +5,7 @@ """ import itertools from collections import OrderedDict -from typing import Dict, Iterable, Tuple +from collections.abc import Iterable import pytest @@ -13,9 +13,9 @@ ImageSizeWrapper, SizeType, VLMTestInfo, VLMTestType) -def get_filtered_test_settings(test_settings: Dict[str, VLMTestInfo], +def get_filtered_test_settings(test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, - fork_per_test: bool) -> Dict[str, VLMTestInfo]: + fork_per_test: bool) -> dict[str, VLMTestInfo]: """Given the dict of potential test settings to run, return a subdict of tests who have the current test type enabled with the matching val for fork_per_test. @@ -49,7 +49,7 @@ def matches_test_type(test_info: VLMTestInfo, test_type: VLMTestType): return matching_tests -def get_parametrized_options(test_settings: Dict[str, VLMTestInfo], +def get_parametrized_options(test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, fork_new_process_for_each_test: bool): """Converts all of our VLMTestInfo into an expanded list of parameters. @@ -121,7 +121,7 @@ def get_model_type_cases(model_type: str, test_info: VLMTestInfo): def get_wrapped_test_sizes( test_info: VLMTestInfo, - test_type: VLMTestType) -> Tuple[ImageSizeWrapper, ...]: + test_type: VLMTestType) -> tuple[ImageSizeWrapper, ...]: """Given a test info which may have size factors or fixed sizes, wrap them and combine them into an iterable, each of which will be used in parameter expansion. diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index f2260f56737d..aaad584c9cd5 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Core test implementation to be shared across modalities.""" -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Union import torch from PIL.Image import Image @@ -17,9 +17,9 @@ def run_test( *, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], List[Union[List[Image], Image]]]], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + inputs: list[tuple[list[str], list[Union[list[Image], Image]]]], model: str, dtype: str, max_tokens: int, @@ -29,15 +29,15 @@ def run_test( max_num_seqs: int, hf_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]], vllm_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]], - auto_cls: Type[_BaseAutoModelClass], + auto_cls: type[_BaseAutoModelClass], use_tokenizer_eos: bool, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], comparator: Callable[..., None], get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]], - stop_str: Optional[List[str]], - limit_mm_per_prompt: Dict[str, int], - vllm_runner_kwargs: Optional[Dict[str, Any]], - hf_model_kwargs: Optional[Dict[str, Any]], + stop_str: Optional[list[str]], + limit_mm_per_prompt: dict[str, int], + vllm_runner_kwargs: Optional[dict[str, Any]], + hf_model_kwargs: Optional[dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], task: TaskOption = "auto", runner_mm_key: str = "images", @@ -61,7 +61,7 @@ def run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - vllm_runner_kwargs_: Dict[str, Any] = {} + vllm_runner_kwargs_: dict[str, Any] = {} if model_info.tokenizer: vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer if model_info.tokenizer_mode: @@ -84,7 +84,7 @@ def run_test( **vllm_runner_kwargs_) as vllm_model: tokenizer = vllm_model.model.get_tokenizer() - vllm_kwargs: Dict[str, Any] = {} + vllm_kwargs: dict[str, Any] = {} if get_stop_token_ids is not None: vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer) if stop_str: diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 408ce9cfeada..66410f66ca0d 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -6,7 +6,7 @@ import re import types from pathlib import PosixPath -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch from PIL.Image import Image @@ -49,7 +49,7 @@ def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, def qwen_vllm_to_hf_output( vllm_output: RunnerOutput, - model: str) -> Tuple[List[int], str, Optional[SampleLogprobs]]: + model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [qwen models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -60,7 +60,7 @@ def qwen_vllm_to_hf_output( def qwen2_vllm_to_hf_output( vllm_output: RunnerOutput, - model: str) -> Tuple[List[int], str, Optional[SampleLogprobs]]: + model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [qwen2 models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -78,7 +78,7 @@ def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, def llava_video_vllm_to_hf_output( vllm_output: RunnerOutput, - model: str) -> Tuple[List[int], str, Optional[SampleLogprobs]]: + model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: config = AutoConfig.from_pretrained(model) mm_token_id = config.video_token_index return _llava_vllm_to_hf_output(vllm_output, model, mm_token_id) @@ -247,7 +247,7 @@ def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str): ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( - tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset], + tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset], _ImageAssets]) -> str: """Given a temporary dir path, export one or more image assets into the tempdir & replace its contents with the local path to the string so that @@ -257,7 +257,7 @@ def qwen_prompt_path_encoder( Args: tmp_path: Tempdir for test under consideration. prompt: Prompt with image placeholders. - assets: List of image assets whose len equals the num placeholders. + assets: list of image assets whose len equals the num placeholders. """ # Ensure that the number of placeholders matches the number of assets; # If this is not true, the test is probably written incorrectly. @@ -350,7 +350,7 @@ def __init__(self, hf_runner: HfRunner): self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, List[Image]], + def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): # yapf: disable from vllm.model_executor.models.h2ovl import ( @@ -410,7 +410,7 @@ def __init__(self, hf_runner: HfRunner): self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, List[Image]], + def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): from vllm.model_executor.models.internvl import ( IMG_CONTEXT, IMG_END, IMG_START, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/runners.py b/tests/models/decoder_only/vision_language/vlm_utils/runners.py index fb9df37cad92..023df5f16188 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/runners.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/runners.py @@ -3,7 +3,6 @@ types / modalities. """ from pathlib import PosixPath -from typing import Type from .....conftest import HfRunner, VllmRunner, _ImageAssets, _VideoAssets from . import builders, core @@ -13,8 +12,8 @@ ####### Entrypoints for running different test types def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): assert test_case.size_wrapper is not None inputs = builders.build_single_image_inputs_from_test_info( @@ -36,8 +35,8 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): assert test_case.size_wrapper is not None inputs = builders.build_multi_image_inputs_from_test_info( @@ -59,8 +58,8 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, def run_embedding_test(*, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets): assert test_case.size_wrapper is not None inputs, vllm_embeddings = builders.build_embedding_inputs_from_test_info( @@ -85,8 +84,8 @@ def run_video_test( *, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], video_assets: _VideoAssets, ): assert test_case.size_wrapper is not None @@ -111,8 +110,8 @@ def run_video_test( def run_custom_inputs_test(*, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner]): + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner]): # Custom test cases can provide inputs directly, but they need to # explicitly provided a CustomTestConfig, which wraps the inputs and # the limit_mm_per_prompt diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index ecb86609c527..bdbdbc7ec267 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """Types for writing multimodal model tests.""" +from collections.abc import Iterable from enum import Enum from pathlib import PosixPath -from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional, - Tuple, Type, Union) +from typing import Any, Callable, NamedTuple, Optional, Union import torch from PIL.Image import Image @@ -35,7 +35,7 @@ IMAGE_SIZE_FACTORS = [(), (1.0, ), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] EMBEDDING_SIZE_FACTORS = [(), (1.0, ), (1.0, 1.0, 1.0)] -RunnerOutput = Tuple[List[int], str, Optional[SampleLogprobs]] +RunnerOutput = tuple[list[int], str, Optional[SampleLogprobs]] # yapf: enable @@ -53,8 +53,8 @@ class SizeType(Enum): class CustomTestOptions(NamedTuple): - inputs: List[Tuple[List[str], List[Union[List[Image], Image]]]] - limit_mm_per_prompt: Dict[str, int] + inputs: list[tuple[list[str], list[Union[list[Image], Image]]]] + limit_mm_per_prompt: dict[str, int] # kwarg to pass multimodal data in as to vllm/hf runner instances. runner_mm_key: str = "images" @@ -63,13 +63,13 @@ class ImageSizeWrapper(NamedTuple): type: SizeType # A size factor is a wrapper of 0+ floats, # while a fixed size contains an iterable of integer pairs - data: Union[Iterable[float], Iterable[Tuple[int, int]]] + data: Union[Iterable[float], Iterable[tuple[int, int]]] class VLMTestInfo(NamedTuple): """Holds the configuration for 1+ tests for one model architecture.""" - models: List[str] + models: list[str] test_type: Union[VLMTestType, Iterable[VLMTestType]] # Should be None only if this is a CUSTOM_INPUTS test @@ -97,19 +97,19 @@ class VLMTestInfo(NamedTuple): max_num_seqs: int = 256 task: TaskOption = "auto" tensor_parallel_size: int = 1 - vllm_runner_kwargs: Optional[Dict[str, Any]] = None + vllm_runner_kwargs: Optional[dict[str, Any]] = None # Optional callable which gets a list of token IDs from the model tokenizer get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]] = None # Optional list of strings to stop generation, useful when stop tokens are # not special tokens in the tokenizer - stop_str: Optional[List[str]] = None + stop_str: Optional[list[str]] = None # Exposed options for HF runner - hf_model_kwargs: Optional[Dict[str, Any]] = None + hf_model_kwargs: Optional[dict[str, Any]] = None # Indicates we should explicitly pass the EOS from the tokenizer use_tokenizer_eos: bool = False - auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM + auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM # Callable to pass to the HF runner to run on inputs; for now, we also pass # the data type to input post processing, because almost all of the uses of # postprocess_inputs are to fix the data types of BatchEncoding values. @@ -128,12 +128,12 @@ class VLMTestInfo(NamedTuple): # Default expandable params per test; these defaults can be overridden in # instances of this object; the complete set of test cases for the model # is all combinations of .models + all fields below - max_tokens: Union[int, Tuple[int]] = 128 - num_logprobs: Union[int, Tuple[int]] = 5 + max_tokens: Union[int, tuple[int]] = 128 + num_logprobs: Union[int, tuple[int]] = 5 dtype: Union[str, Iterable[str]] = "half" distributed_executor_backend: Optional[Union[str, Iterable[str]]] = None # Only expanded in video tests - num_video_frames: Union[int, Tuple[int]] = 16 + num_video_frames: Union[int, tuple[int]] = 16 # Fixed image sizes / image size factors; most tests use image_size_factors # The values provided for these two fields will be stacked and expanded @@ -141,19 +141,19 @@ class VLMTestInfo(NamedTuple): # once per tests (much like concatenating and wrapping in one parametrize # call) image_size_factors: Iterable[Iterable[float]] = IMAGE_SIZE_FACTORS - image_sizes: Optional[Iterable[Iterable[Tuple[int, int]]]] = None + image_sizes: Optional[Iterable[Iterable[tuple[int, int]]]] = None # Hack for updating a prompt to take into a local path; currently only used # for Qwen-VL, which requires encoding the image path / url into the prompt # for HF runner prompt_path_encoder: Optional[ - Callable[[PosixPath, str, Union[List[ImageAsset], _ImageAssets]], + Callable[[PosixPath, str, Union[list[ImageAsset], _ImageAssets]], str]] = None # noqa: E501 # Allows configuring a test to run with custom inputs - custom_test_opts: Optional[List[CustomTestOptions]] = None + custom_test_opts: Optional[list[CustomTestOptions]] = None - marks: Optional[List[MarkDecorator]] = None + marks: Optional[list[MarkDecorator]] = None def get_non_parametrized_runner_kwargs(self): """Returns a dictionary of expandable kwargs for items that are used diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 7ed2fb8a6358..470dc0410776 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -3,7 +3,6 @@ import importlib.util import math from array import array -from typing import List import openai import pytest @@ -81,14 +80,14 @@ async def client_generate(server_generate: RemoteOpenAIServer): yield async_client -def run_llm_encode(llm: vllm.LLM, queries: List[str], - instruction: str) -> List[float]: +def run_llm_encode(llm: vllm.LLM, queries: list[str], + instruction: str) -> list[float]: outputs = llm.encode([instruction + q for q in queries], ) return [output.outputs.embedding for output in outputs] -async def run_client_embeddings(client: vllm.LLM, queries: List[str], - instruction: str) -> List[float]: +async def run_client_embeddings(client: vllm.LLM, queries: list[str], + instruction: str) -> list[float]: outputs = await client.embeddings.create( model=MODEL_NAME, input=[instruction + q for q in queries], @@ -123,7 +122,7 @@ def get_test_data(): return queries, q_instruction, documents, d_instruction -def validate_embed_output(q_rep: List[float], d_rep: List[float]): +def validate_embed_output(q_rep: list[float], d_rep: list[float]): cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001) diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py index 567aa5098493..bef85eaf372f 100644 --- a/tests/models/embedding/utils.py +++ b/tests/models/embedding/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Sequence +from collections.abc import Sequence import torch import torch.nn.functional as F @@ -8,8 +8,8 @@ def check_embeddings_close( *, - embeddings_0_lst: Sequence[List[float]], - embeddings_1_lst: Sequence[List[float]], + embeddings_0_lst: Sequence[list[float]], + embeddings_1_lst: Sequence[list[float]], name_0: str, name_1: str, tol: float = 1e-3, diff --git a/tests/models/embedding/vision_language/test_dse_qwen2_vl.py b/tests/models/embedding/vision_language/test_dse_qwen2_vl.py index 82f2bf53122a..7391df6e1c30 100644 --- a/tests/models/embedding/vision_language/test_dse_qwen2_vl.py +++ b/tests/models/embedding/vision_language/test_dse_qwen2_vl.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Callable, Dict, List, Type +from typing import Callable import pytest import torch @@ -67,7 +67,7 @@ def get_messages(image: Image.Image, text: str, embed_text: bool): def apply_chat_template_and_add_eos( - messages: List[Dict], + messages: list[dict], apply_chat_template_fn: Callable, ): prompt = apply_chat_template_fn( @@ -80,11 +80,11 @@ def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs): def _run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - input_texts: List[str], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + input_texts: list[str], input_images: PromptImageInput, - embed_texts: List[bool], + embed_texts: list[bool], model: str, *, dtype: str, diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index 990c6c150fcd..4c2fbd526ed1 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Type - import pytest import torch.nn.functional as F from transformers import AutoModelForVision2Seq @@ -35,9 +33,9 @@ def _run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - input_texts: List[str], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + input_texts: list[str], input_images: PromptImageInput, model: str, *, diff --git a/tests/models/embedding/vision_language/test_phi3v.py b/tests/models/embedding/vision_language/test_phi3v.py index 0cb948746042..3226138a28b9 100644 --- a/tests/models/embedding/vision_language/test_phi3v.py +++ b/tests/models/embedding/vision_language/test_phi3v.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Type - import pytest import torch.nn.functional as F @@ -29,9 +27,9 @@ def _run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - input_texts: List[str], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + input_texts: list[str], input_images: PromptImageInput, model: str, *, diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 81b629fdcf1f..e8070d28befa 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -3,7 +3,7 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`. """ -from typing import List, Optional, Tuple, Type +from typing import Optional import pytest from transformers import AutoModelForSeq2SeqLM @@ -17,7 +17,7 @@ def vllm_to_hf_output( - vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], decoder_prompt_type: DecoderPromptType, ): """Sanitize vllm output to be comparable with hf output.""" @@ -31,9 +31,9 @@ def vllm_to_hf_output( def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + prompts: list[ExplicitEncoderDecoderPrompt[str, str]], decoder_prompt_type: DecoderPromptType, model: str, *, diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index de18deab11f6..a6ec333e2e9b 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Type +from typing import Optional import pytest from PIL import Image @@ -51,8 +51,8 @@ def hf_to_vllm_output(hf_output: tuple[list[int], str, def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], inputs: list[list[ExplicitEncoderDecoderPrompt]], model: str, *, @@ -114,7 +114,7 @@ def run_test( @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], +def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], image_assets: _ImageAssets, model: str, size_factors: list[int], dtype: str, max_tokens: int, num_logprobs: int) -> None: diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 4fee04fdb7b6..1e202907171f 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Type, overload +from typing import Optional, overload import pytest import torch @@ -64,7 +64,7 @@ } -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, +def vllm_to_hf_output(vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], model: str): """Sanitize vllm output to be comparable with hf output.""" @@ -91,9 +91,9 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def _get_inputs( image_assets: _ImageAssets, *, - size_factors: Optional[List[float]] = None, - sizes: Optional[List[Tuple[int, int]]] = None, -) -> List[Tuple[List[str], PromptImageInput]]: + size_factors: Optional[list[float]] = None, + sizes: Optional[list[tuple[int, int]]] = None, +) -> list[tuple[list[str], PromptImageInput]]: images = [asset.pil_image for asset in image_assets] if size_factors is not None: @@ -123,12 +123,12 @@ def _get_inputs( @overload def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets, model: str, *, - size_factors: List[float], + size_factors: list[float], dtype: str, max_tokens: int, num_logprobs: int, @@ -140,12 +140,12 @@ def run_test( @overload def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets, model: str, *, - sizes: List[Tuple[int, int]], + sizes: list[tuple[int, int]], dtype: str, max_tokens: int, num_logprobs: int, @@ -156,13 +156,13 @@ def run_test( def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], image_assets: _ImageAssets, model: str, *, - size_factors: Optional[List[float]] = None, - sizes: Optional[List[Tuple[int, int]]] = None, + size_factors: Optional[list[float]] = None, + sizes: Optional[list[tuple[int, int]]] = None, dtype: str, max_tokens: int, num_logprobs: int, @@ -183,9 +183,9 @@ def run_test( def _run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput]], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + inputs: list[tuple[list[str], PromptImageInput]], model: str, *, dtype: str, diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 5c43e4eed787..84471c92a293 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for H2OVL's multimodal preprocessing kwargs.""" -from typing import Mapping, Optional +from collections.abc import Mapping +from typing import Optional import pytest from PIL import Image diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index cc777fdf57b3..adbc4f5b5586 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for InternVL's multimodal preprocessing kwargs.""" -from typing import Mapping, Optional +from collections.abc import Mapping +from typing import Optional import pytest from PIL import Image diff --git a/tests/models/registry.py b/tests/models/registry.py index 78a65b93870e..b5ded20c5af5 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Mapping, Set from dataclasses import dataclass, field -from typing import AbstractSet, Any, Literal, Mapping, Optional +from typing import Any, Literal, Optional import pytest from packaging.version import Version @@ -324,7 +325,7 @@ def __init__(self, hf_models: Mapping[str, _HfExamplesInfo]) -> None: self.hf_models = hf_models - def get_supported_archs(self) -> AbstractSet[str]: + def get_supported_archs(self) -> Set[str]: return self.hf_models.keys() def get_hf_info(self, model_arch: str) -> _HfExamplesInfo: diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 31e3c1f7b987..243cb92ae256 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -4,7 +4,6 @@ Run `pytest tests/models/test_transformers.py`. """ from contextlib import nullcontext -from typing import Type import pytest @@ -14,8 +13,8 @@ def check_implementation( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], example_prompts: list[str], model: str, **kwargs, @@ -47,8 +46,8 @@ def check_implementation( ("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE ]) # trust_remote_code=True by default def test_models( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], example_prompts: list[str], model: str, model_impl: str, @@ -71,8 +70,8 @@ def test_models( @multi_gpu_test(num_gpus=2) def test_distributed( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], example_prompts, ): kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} @@ -92,7 +91,7 @@ def test_distributed( @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) def test_quantization( - vllm_runner: Type[VllmRunner], + vllm_runner: type[VllmRunner], example_prompts: list[str], model: str, quantization_kwargs: dict[str, str], diff --git a/tests/models/utils.py b/tests/models/utils.py index a90efb176722..b0182d545f4b 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from typing import Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Optional, Union import torch @@ -9,7 +10,7 @@ from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs -TokensText = Tuple[List[int], str] +TokensText = tuple[list[int], str] def check_outputs_equal( @@ -46,7 +47,7 @@ def check_outputs_equal( # * List of top sample logprobs for each sampled token # # Assumes prompt logprobs were not requested. -TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int, +TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]]] @@ -57,8 +58,8 @@ def check_outputs_equal( # * Optional list of top sample logprobs for each sampled token # # Assumes prompt logprobs were not requested. -TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], - List[Dict[str, +TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]], + list[dict[str, Logprob]]]]] # Representation of generated sequence as a tuple of @@ -68,9 +69,9 @@ def check_outputs_equal( # * Optional list of top prompt logprobs for each prompt token # # Allows prompt logprobs to be requested. -TokensTextLogprobsPromptLogprobs = Tuple[ - List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]], - Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]] +TokensTextLogprobsPromptLogprobs = tuple[ + list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]], + Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]]] def check_logprobs_close( @@ -254,8 +255,8 @@ def build_model_context( tokenizer_name: Optional[str] = None, trust_remote_code: bool = False, dtype: Optional[Union[str, torch.dtype]] = None, - mm_processor_kwargs: Optional[Dict] = None, - limit_mm_per_prompt: Optional[Dict] = None, + mm_processor_kwargs: Optional[dict] = None, + limit_mm_per_prompt: Optional[dict] = None, disable_mm_preprocessor_cache: bool = True, ): """Creates an InputContext for a given model. diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 11e44f12bc56..64559609abb2 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -2,7 +2,7 @@ import asyncio import multiprocessing -from typing import Callable, Tuple, Union +from typing import Callable, Union from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs @@ -16,7 +16,7 @@ async def generate( client: MQLLMEngineClient, request_id: str, num_tokens: int, - return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]: + return_output: bool = False) -> Union[RequestOutput, tuple[int, str]]: final_output = None count = 0 diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index 9822cee14a25..f925e42f46d3 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Test the AsyncLLMEngine with multi-step-decoding -from typing import List, Optional +from typing import Optional import pytest @@ -17,7 +17,7 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps NUM_PROMPTS = [10] -DEFAULT_SERVER_ARGS: List[str] = [ +DEFAULT_SERVER_ARGS: list[str] = [ "--distributed-executor-backend", "ray", "--gpu-memory-utilization", diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index f9e0f507a1e8..8f76d895fdd2 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -4,7 +4,7 @@ import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, NamedTuple, Optional import numpy as np import pytest @@ -30,7 +30,7 @@ @pytest.fixture(scope="module") -def url_images() -> Dict[str, Image.Image]: +def url_images() -> dict[str, Image.Image]: connector = MediaConnector() return { @@ -39,7 +39,7 @@ def url_images() -> Dict[str, Image.Image]: } -def get_supported_suffixes() -> Tuple[str, ...]: +def get_supported_suffixes() -> tuple[str, ...]: # We should at least test the file types mentioned in GPT-4 with Vision OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif') @@ -66,7 +66,7 @@ async def test_fetch_image_http(image_url: str): @pytest.mark.asyncio @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("suffix", get_supported_suffixes()) -async def test_fetch_image_base64(url_images: Dict[str, Image.Image], +async def test_fetch_image_base64(url_images: dict[str, Image.Image], image_url: str, suffix: str): connector = MediaConnector() url_image = url_images[image_url] diff --git a/tests/neuron/test_logits_processor.py b/tests/neuron/test_logits_processor.py index 37d59c9e76a7..6d1514088f90 100644 --- a/tests/neuron/test_logits_processor.py +++ b/tests/neuron/test_logits_processor.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import Tuple from unittest.mock import patch import pytest @@ -33,7 +32,7 @@ def forward(self, *args, **kwargs): def _prepare_test( batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: +) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index a376d2cb340c..bc4a41cdf00d 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Optional, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.nn as nn @@ -59,7 +60,7 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights = self.hf_to_vllm_mapper.apply(weights) weights = ((name, data) for name, data in weights diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index 0abbd8ebb598..e30166842ea8 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -5,7 +5,6 @@ """ from dataclasses import dataclass -from typing import Tuple import pytest @@ -53,7 +52,7 @@ class ModelPair: @pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES) -def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None: +def test_auto_gptq(model_arg_exptype: tuple[str, None, str]) -> None: model_path, quantization_arg, expected_type = model_arg_exptype try: diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index da59dc75afc1..f64dca6e4bbf 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -5,7 +5,7 @@ Run `pytest tests/quantization/test_register_quantization_config.py`. """ -from typing import Any, Dict, List, Optional +from typing import Any, Optional import pytest import torch @@ -58,7 +58,7 @@ def get_name(self) -> str: """Name of the quantization method.""" return "custom_quant" - def get_supported_act_dtypes(self) -> List["torch.dtype"]: + def get_supported_act_dtypes(self) -> list["torch.dtype"]: """List of supported activation dtypes.""" return [torch.float16, torch.bfloat16] @@ -68,12 +68,12 @@ def get_min_capability(cls) -> int: return -1 @staticmethod - def get_config_filenames() -> List[str]: + def get_config_filenames() -> list[str]: """List of filenames to search for in the model directory.""" return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig": + def from_config(cls, config: dict[str, Any]) -> "CustomQuantConfig": """Create a config class from the model's quantization config.""" return CustomQuantConfig(num_bits=config.get("num_bits", 8)) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 78bdd9b0b958..58c7c256473e 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import torch @@ -70,7 +68,7 @@ def test_get_prompt_logprobs( assert (len(logprobs) == num_top_logprobs or len(logprobs) == num_top_logprobs + 1) output_text = result.outputs[0].text - output_string_from_most_likely_tokens_lst: List[str] = [] + output_string_from_most_likely_tokens_lst: list[str] = [] for top_logprobs in result.outputs[0].logprobs: top_logprob = next(iter(top_logprobs.values())) output_string_from_most_likely_tokens_lst.append( diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 143f52999415..29e73eb1bead 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -4,7 +4,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`. """ -from typing import List, Optional +from typing import Optional from transformers import AutoTokenizer @@ -16,8 +16,8 @@ def _generate( prompt: str, num_prompt_tokens: int, temperature: float = 0, - bad_words: Optional[List[str]] = None, -) -> List[int]: + bad_words: Optional[list[str]] = None, +) -> list[int]: sampling_params = SamplingParams( temperature=temperature, bad_words=bad_words, @@ -59,7 +59,7 @@ def test_one_token_bad_word(self, vllm_runner): def _generate(self, model: LLM, - bad_words: Optional[List[str]] = None) -> List[int]: + bad_words: Optional[list[str]] = None) -> list[int]: return _generate( model=model, prompt=self.PROMPT, @@ -69,7 +69,7 @@ def _generate(self, def _encode(self, prompt: str, - add_special_tokens: bool = True) -> List[int]: + add_special_tokens: bool = True) -> list[int]: return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids @@ -149,7 +149,7 @@ def test_two_token_bad_word(self, vllm_runner): def _generate(self, model: LLM, - bad_words: Optional[List[str]] = None) -> List[int]: + bad_words: Optional[list[str]] = None) -> list[int]: return _generate( model=model, prompt=self.PROMPT, @@ -158,7 +158,7 @@ def _generate(self, ) @staticmethod - def _contains(sequence: List[int], subsequence: List[int]) -> bool: + def _contains(sequence: list[int], subsequence: list[int]) -> bool: searched = False for start in range(len(sequence)): @@ -181,6 +181,6 @@ def _contains(sequence: List[int], subsequence: List[int]) -> bool: def _encode(self, prompt: str, - add_special_tokens: bool = True) -> List[int]: + add_special_tokens: bool = True) -> list[int]: return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index cc199bf682fc..2b86dcac7f03 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for rejection sampling.""" -from typing import List, Tuple import pytest import torch @@ -416,8 +415,8 @@ def test_rejection_sampling_approximates_target_distribution( draft_and_target_probs_equal) sample_sizes = [10, 100, 1_000, 10_000, 100_000] - distance_wrt_reference: List[float] = [] - distance_wrt_target: List[float] = [] + distance_wrt_reference: list[float] = [] + distance_wrt_target: list[float] = [] for num_samples in sample_sizes: (reference_vs_rejsample_dist, @@ -452,7 +451,7 @@ def test_rejection_sampling_approximates_target_distribution( expected_improvement_multiplier) -def get_ratio_first_to_last(elements: List[float]) -> float: +def get_ratio_first_to_last(elements: list[float]) -> float: return elements[0] / elements[-1] @@ -477,7 +476,7 @@ def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler): def generate_probs_for_test( self, draft_and_target_probs_equal: bool - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: draft_probs, target_probs = (F.softmax( torch.rand(self.vocab_size, dtype=torch.float32), dim=-1, @@ -499,7 +498,7 @@ def generate_probs_for_test( def run_and_compare_distributions(self, draft_probs: torch.Tensor, target_probs: torch.Tensor, reference_probs: torch.Tensor, - num_samples: int) -> Tuple[float, float]: + num_samples: int) -> tuple[float, float]: # Sample using rejection sampling. rej_sample_probs = self._estimate_rejection_sampling_pdf( draft_probs, target_probs, num_samples) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index ca09e536a06c..68944ac7e1ef 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -3,7 +3,7 @@ import itertools import random from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Optional from unittest.mock import Mock, patch import pytest @@ -30,7 +30,7 @@ def forward(self, *args, **kwargs): def _prepare_test( batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: +) -> tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, VOCAB_SIZE), 1e-2, @@ -53,8 +53,8 @@ def _do_sample( sampling_params: SamplingParams, device: str, ): - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - seq_lens: List[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] + seq_lens: list[int] = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -171,7 +171,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, *, - stop_token_ids: Optional[List[int]] = None, + stop_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, @@ -196,7 +196,7 @@ def generate_test_case(): batch_size = random.randint(1, 128) expected_penalization = [] - sequence_metadata_list: List[SequenceGroupMetadata] = [] + sequence_metadata_list: list[SequenceGroupMetadata] = [] # 20% chance to generate seq group metadata list with all prompts is_prompt = random.random() < 0.2 while batch_size > 0: @@ -216,8 +216,8 @@ def generate_test_case(): eos_token_id=eos_token_id, stop_token_ids=stop_token_ids) - seq_data: Dict[int, SequenceData] = {} - seq_group_penalization: List[bool] = [] + seq_data: dict[int, SequenceData] = {} + seq_group_penalization: list[bool] = [] for _ in range(num_seqs): num_input = random.randint(1, 100) num_generated = 0 if is_prompt else random.randint(1, 100) @@ -376,16 +376,16 @@ def generate_test_case(): else: test_cases = [generate_test_case()] - def run_test_case(*, expected_penalization: List[bool], - seq_group_metadata_list: List[SequenceGroupMetadata]): + def run_test_case(*, expected_penalization: list[bool], + seq_group_metadata_list: list[SequenceGroupMetadata]): assert expected_penalization, \ "Invalid test case, need expected_penalization" assert seq_group_metadata_list, \ "Invalid test case, need seq_group_metadata_list" batch_size = 0 - seq_lens: List[int] = [] - sampling_params_per_row: List[SamplingParams] = [] + seq_lens: list[int] = [] + sampling_params_per_row: list[SamplingParams] = [] for sgm in seq_group_metadata_list: sampling_params = sgm.sampling_params @@ -456,11 +456,11 @@ def test_sampler_mixed(seed: int, device: str): batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler = _prepare_test(batch_size) - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - expected_tokens: List[Optional[List[int]]] = [] - seq_lens: List[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] + expected_tokens: list[Optional[list[int]]] = [] + seq_lens: list[int] = [] for i in range(batch_size): - expected: Optional[List[int]] = None + expected: Optional[list[int]] = None sampling_type = random.randint(0, 2) if sampling_type == 0: sampling_params = SamplingParams(temperature=0) @@ -492,7 +492,7 @@ def test_sampler_mixed(seed: int, device: str): )) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - generators: Dict[str, torch.Generator] = {} + generators: dict[str, torch.Generator] = {} def test_sampling(): sampling_metadata = SamplingMetadata.prepare( @@ -587,8 +587,8 @@ class MockConfig: device=device) assert len(processors) == 2 # top_p and top_k - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - seq_lens: List[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] + seq_lens: list[int] = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -669,10 +669,10 @@ def test_sampler_repetition_penalty_mixed(device: str): vocab_size = 8 - def test_sampling_params(sampling_params: List[SamplingParams]): + def test_sampling_params(sampling_params: list[SamplingParams]): - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - seq_lens: List[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] + seq_lens: list[int] = [] for i in range(2): seq_group_metadata_list.append( SequenceGroupMetadata( diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 53c888816a6c..fe4a1c13fc73 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence from itertools import cycle -from typing import List, Optional, Sequence, Tuple, Union +from typing import Optional, Union import pytest import torch @@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm): def get_output_from_llm_generator( llm_generator, prompts, - sampling_params) -> Tuple[List[str], List[List[int]], float]: - tokens: List[str] = [] - token_ids: List[List[int]] = [] + sampling_params) -> tuple[list[str], list[list[int]], float]: + tokens: list[str] = [] + token_ids: list[list[int]] = [] acceptance_rate: float = -1.0 for llm in llm_generator(): maybe_assert_ngram_worker(llm) diff --git a/tests/spec_decode/test_batch_expansion.py b/tests/spec_decode/test_batch_expansion.py index fe95ff9b9c35..9edd8bd4c00d 100644 --- a/tests/spec_decode/test_batch_expansion.py +++ b/tests/spec_decode/test_batch_expansion.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import torch @@ -42,7 +40,7 @@ def test_get_token_ids_to_score(k: int): device='cuda', ) - expected_output: List[List[int]] = [ + expected_output: list[list[int]] = [ [], ] for i in range(proposal_token_ids.shape[0]): diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 2bf401613f06..ca37c9a68dfa 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import Dict, List from unittest.mock import MagicMock import pytest @@ -221,7 +220,7 @@ def test_same_output_for_multi_step(): # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) - single_step_output: List[SamplerOutput] = [] + single_step_output: list[SamplerOutput] = [] continuations = [[1] for _ in prompts] set_random_seed(seed) @@ -243,15 +242,15 @@ def test_same_output_for_multi_step(): continuations[i].append(seq_group_output.samples[0].output_token) # Get token ids and logprobs for comparison. - multi_step_output_logprobs: List[List[Dict[int, + multi_step_output_logprobs: list[list[dict[int, Logprob]]] = [[] for _ in prompts] - single_step_output_logprobs: List[List[Dict[int, + single_step_output_logprobs: list[list[dict[int, Logprob]]] = [[] for _ in prompts] - multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts] - single_step_output_token_ids: List[List[int]] = [[] for _ in prompts] + multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts] + single_step_output_token_ids: list[list[int]] = [[] for _ in prompts] for i, _ in enumerate(prompts): for multi_step, single_step in zip(multi_step_output, single_step_output): @@ -336,7 +335,7 @@ def test_multi_step_with_batch_expansion_correct_output(): # will simulate the bonus token case with the second token # being the bonus token. zero_kv_cache(worker.cache_engine) - single_step_output: List[SamplerOutput] = [] + single_step_output: list[SamplerOutput] = [] set_random_seed(seed) for _ in range(num_steps): seq_group_metadata_list = create_seq_group_metadata_from_prompts( @@ -430,7 +429,7 @@ def test_multi_step_with_batch_expansion_incorrect_output(): # will simulate the bonus token case with the second token # being the bonus token. zero_kv_cache(worker.cache_engine) - single_step_output: List[SamplerOutput] = [] + single_step_output: list[SamplerOutput] = [] set_random_seed(seed) for _ in range(num_steps): seq_group_metadata_list = create_seq_group_metadata_from_prompts( diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 7bbbb0236da1..161cc9fbf556 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import List import pytest import torch @@ -15,7 +14,7 @@ from .utils import create_batch, create_worker -def create_proposal(propose_lens: List[int], vocab_size: int, +def create_proposal(propose_lens: list[int], vocab_size: int, device: str) -> SpeculativeProposals: batch_size = len(propose_lens) max_propose_len = max(propose_lens) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index e4b1a178b0c9..f7ef9786a690 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -3,7 +3,6 @@ import random from collections import defaultdict from types import SimpleNamespace -from typing import Dict, List, Set from unittest.mock import MagicMock import pytest @@ -123,7 +122,7 @@ def test_batch_expansion_correctly_calls_target_model( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) - seen_contexts: List[List[int]] = [] + seen_contexts: list[list[int]] = [] call_args_list = target_worker.execute_model.call_args_list assert len(call_args_list) == 1 @@ -136,7 +135,7 @@ def test_batch_expansion_correctly_calls_target_model( for seq_data in seq_group_metadata.seq_data.values(): seen_contexts.append(seq_data.get_token_ids()) - expected_seen_contexts: List[List[int]] = [] + expected_seen_contexts: list[list[int]] = [] for prompt, prev_generated, draft_tokens in zip( prompts, prev_output_tokens, proposal_token_ids.tolist()): @@ -338,11 +337,11 @@ def test_correctly_formats_output(k: int, batch_size: int, next(iter(seq_group_metadata.seq_data.keys())) for seq_group_metadata in seq_group_metadata_list ] - actual_output_by_seq: Dict[int, List[SequenceOutput]] = { + actual_output_by_seq: dict[int, list[SequenceOutput]] = { seq_id: [] for seq_id in seq_ids } - expected_output_by_seq: Dict[int, List[SequenceOutput]] = { + expected_output_by_seq: dict[int, list[SequenceOutput]] = { seq_id: [] for seq_id in seq_ids } @@ -728,7 +727,7 @@ def test_populate_seq_ids_with_bonus_tokens(): size=(batch_size, (k + 1)), dtype=torch.int64, device='cuda') - expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set) for seq_group_metadata in seq_group_metadata_list: for seq_id in seq_group_metadata.seq_data: expected_request_id_seq_ids_mapping[ diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 38f57e99bdb0..d303b7f1219a 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence as GenericSequence from itertools import count -from typing import Callable, Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import TypeVar, Union +from typing import Callable, Optional, TypeVar, Union from unittest.mock import MagicMock import torch @@ -44,7 +43,7 @@ def mock_worker(cls=None, return worker -def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]): +def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]): seed_iter = iter(rand_seeds) original_execute_model = worker.execute_model @@ -56,7 +55,7 @@ def new_execute_model(*args, **kwargs): return new_execute_model -def zero_kv_cache(cache_engine: List[CacheEngine]): +def zero_kv_cache(cache_engine: list[CacheEngine]): assert cache_engine[0].gpu_cache for key_blocks, value_blocks in cache_engine[0].gpu_cache: key_blocks.zero_() @@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T], def create_seq_group_metadata_from_prompts( - prompts: List[List[int]], + prompts: list[list[int]], num_gpu_blocks: int, block_size: int, - final_prompt_lens: List[int], - continuations: Optional[List[List[int]]] = None, - seq_ids: Optional[List[int]] = None, -) -> List[SequenceGroupMetadata]: + final_prompt_lens: list[int], + continuations: Optional[list[list[int]]] = None, + seq_ids: Optional[list[int]] = None, +) -> list[SequenceGroupMetadata]: if continuations is None: continuations = [[] for _ in prompts] @@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts( def create_chunked_seq_group_metadata_from_prompt( - prompt: List[int], + prompt: list[int], num_gpu_blocks: int, chunk_size: int, block_size: int, - seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]: + seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]: if seq_id is None: seq_id = 0 @@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt( def assert_logprobs_dict_allclose( - actual_logprobs: List[Dict[int, Logprob]], - expected_logprobs: List[Dict[int, Logprob]]) -> None: + actual_logprobs: list[dict[int, Logprob]], + expected_logprobs: list[dict[int, Logprob]]) -> None: for single_step_actual_logprobs, single_step_expected_logprobs in zip( actual_logprobs, expected_logprobs): assert set(single_step_actual_logprobs.keys()) == set( @@ -202,7 +201,7 @@ def create_sampler_output_list( token_ids: torch.Tensor, probs: GenericSequence[Optional[torch.Tensor]], logprobs: GenericSequence[Optional[torch.Tensor]], - seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: + seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]: num_steps, batch_size = token_ids.shape token_ids_by_step = token_ids.tolist() @@ -231,9 +230,9 @@ def create_sampler_output_list( def create_batch(batch_size, k, - prompt_len: Union[int, List[int]] = 10, + prompt_len: Union[int, list[int]] = 10, prev_output_token_len: int = 10, - seq_ids: Optional[List[int]] = None, + seq_ids: Optional[list[int]] = None, num_gpu_blocks: Optional[int] = None, block_size: Optional[int] = None, prefill_chunk_size: Optional[int] = None): diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 17c128a17656..05d2c624df17 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -3,7 +3,7 @@ Run `pytest tests/test_cache_block_hashing.py`. """ -from typing import List, Optional +from typing import Optional import pytest @@ -44,7 +44,7 @@ def flatten_2d(li): @pytest.mark.parametrize("concurrent_lora_int_ids", [[None], [1], [None, 1], [None, 1, 2], [1, 2]]) def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, - concurrent_lora_int_ids: List[Optional[int]]): + concurrent_lora_int_ids: list[Optional[int]]): tokenizer = TokenizerGroup( tokenizer_id="facebook/opt-125m", @@ -53,7 +53,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, max_input_length=None, ) - hashes: List[List[List[int]]] = [] + hashes: list[list[list[int]]] = [] for prefix in prefixes: for lora_int_id in concurrent_lora_int_ids: diff --git a/tests/test_inputs.py b/tests/test_inputs.py index fff909154a2a..d361808ed2f9 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest from vllm.inputs import zip_enc_dec_prompts @@ -45,7 +43,7 @@ def test_parse_single_batch_string_consistent(string_input: str): @pytest.mark.parametrize('token_input', TOKEN_INPUTS) -def test_parse_single_batch_token_consistent(token_input: List[int]): +def test_parse_single_batch_token_consistent(token_input: list[int]): assert parse_and_batch_prompt(token_input) \ == parse_and_batch_prompt([token_input]) diff --git a/tests/test_logger.py b/tests/test_logger.py index 993822e92240..11deae309ac8 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -155,7 +155,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( with pytest.raises(ValueError) as ex_info: _configure_vllm_root_logger() assert ex_info.type == ValueError # noqa: E721 - assert "Invalid logging config. Expected Dict, got" in str(ex_info) + assert "Invalid logging config. Expected dict, got" in str(ex_info) @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 487fbb8fcb8c..8301c645b79f 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import Tuple from unittest.mock import patch import pytest @@ -33,7 +32,7 @@ def forward(self, *args, **kwargs): def _prepare_test( batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: +) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), diff --git a/tests/test_utils.py b/tests/test_utils.py index 5b69ffd18bb2..8b67e92fca68 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ import asyncio import os import socket -from typing import AsyncIterator, Tuple +from collections.abc import AsyncIterator from unittest.mock import patch import pytest @@ -33,7 +33,7 @@ async def mock_async_iterator(idx: int): iterators = [mock_async_iterator(i) for i in range(3)] merged_iterator = merge_async_iterators(*iterators) - async def stream_output(generator: AsyncIterator[Tuple[int, str]]): + async def stream_output(generator: AsyncIterator[tuple[int, str]]): async for idx, output in generator: print(f"idx: {idx}, output: {output}") diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 851c79d2e09c..9aa2eea3154c 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Generator, List, Optional +from collections.abc import Generator +from typing import Any, Optional import pytest from transformers import AutoTokenizer @@ -163,7 +164,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: @pytest.fixture(name="complete_sequence_token_ids") def create_complete_sequence_token_ids(complete_sequence: str, - tokenizer) -> List[int]: + tokenizer) -> list[int]: complete_sequence_token_ids = tokenizer(complete_sequence).input_ids return complete_sequence_token_ids @@ -178,7 +179,7 @@ def create_sequence(prompt_token_ids=None): def create_dummy_logprobs( - complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]: + complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]: return [{ token_id: Logprob(logprob=0.0), token_id + 1: Logprob(logprob=0.1) @@ -186,10 +187,10 @@ def create_dummy_logprobs( def create_dummy_prompt_logprobs( - complete_sequence_token_ids: List[int] -) -> List[Optional[Dict[int, Any]]]: + complete_sequence_token_ids: list[int] +) -> list[Optional[dict[int, Any]]]: # logprob for the first prompt token is None. - logprobs: List[Optional[Dict[int, Any]]] = [None] + logprobs: list[Optional[dict[int, Any]]] = [None] logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) return logprobs @@ -198,7 +199,7 @@ def create_dummy_prompt_logprobs( @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) def test_decode_sequence_logprobs(complete_sequence: str, - complete_sequence_token_ids: List[int], + complete_sequence_token_ids: list[int], detokenizer: Detokenizer, skip_special_tokens: bool): """Verify Detokenizer decodes logprobs correctly.""" @@ -208,8 +209,8 @@ def test_decode_sequence_logprobs(complete_sequence: str, # Run sequentially. seq = create_sequence() dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) - sequential_logprobs_text_chosen_token: List[str] = [] - sequential_logprobs_text_other_token: List[str] = [] + sequential_logprobs_text_chosen_token: list[str] = [] + sequential_logprobs_text_other_token: list[str] = [] for new_token, logprobs in zip(complete_sequence_token_ids, dummy_logprobs): seq.append_token_id(new_token, logprobs) @@ -232,7 +233,7 @@ def test_decode_sequence_logprobs(complete_sequence: str, @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], +def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], detokenizer: Detokenizer): """Verify Detokenizer decodes prompt logprobs correctly.""" sampling_params = SamplingParams(skip_special_tokens=True, @@ -249,7 +250,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], dummy_logprobs, position_offset=0) # First logprob is None. - decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[ + decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[ 1:] # type: ignore # decoded_prompt_logprobs doesn't contain the first token. diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py index 8e99f86917b8..d1873823ac18 100644 --- a/tests/tokenization/test_tokenizer_group.py +++ b/tests/tokenization/test_tokenizer_group.py @@ -3,7 +3,7 @@ import asyncio import os import sys -from typing import List, Optional +from typing import Optional from unittest.mock import patch import pytest @@ -129,7 +129,7 @@ class FailingTokenizerGroup(TokenizerGroup): def __init__(self, *args, - fail_at: Optional[List[int]] = None, + fail_at: Optional[list[int]] = None, **kwargs): super().__init__(*args, **kwargs) self.i = 0 diff --git a/tests/tokenization/test_tokenizer_registry.py b/tests/tokenization/test_tokenizer_registry.py index 793d38f9c366..772eeb345ca4 100644 --- a/tests/tokenization/test_tokenizer_registry.py +++ b/tests/tokenization/test_tokenizer_registry.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer_base import (TokenizerBase, @@ -17,15 +17,15 @@ def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer": return TestTokenizer() @property - def all_special_tokens_extended(self) -> List[str]: + def all_special_tokens_extended(self) -> list[str]: raise NotImplementedError() @property - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: raise NotImplementedError() @property - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: raise NotImplementedError() @property @@ -58,7 +58,7 @@ def max_token_id(self) -> int: def __call__( self, - text: Union[str, List[str], List[int]], + text: Union[str, list[str], list[int]], text_pair: Optional[str] = None, add_special_tokens: bool = False, truncation: bool = False, @@ -66,10 +66,10 @@ def __call__( ): raise NotImplementedError() - def get_vocab(self) -> Dict[str, int]: + def get_vocab(self) -> dict[str, int]: raise NotImplementedError() - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: raise NotImplementedError() def encode_one( @@ -77,33 +77,33 @@ def encode_one( text: str, truncation: bool = False, max_length: Optional[int] = None, - ) -> List[int]: + ) -> list[int]: raise NotImplementedError() def encode(self, text: str, - add_special_tokens: Optional[bool] = None) -> List[int]: + add_special_tokens: Optional[bool] = None) -> list[int]: raise NotImplementedError() def apply_chat_template(self, - messages: List["ChatCompletionMessageParam"], - tools: Optional[List[Dict[str, Any]]] = None, - **kwargs) -> List[int]: + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs) -> list[int]: raise NotImplementedError() - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() def decode(self, - ids: Union[List[int], int], + ids: Union[list[int], int], skip_special_tokens: bool = True) -> str: raise NotImplementedError() def convert_ids_to_tokens( self, - ids: List[int], + ids: list[int], skip_special_tokens: bool = True, - ) -> List[str]: + ) -> list[str]: raise NotImplementedError() diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py index da033fa1d85c..448347be6ec1 100644 --- a/tests/tool_use/test_chat_completions.py +++ b/tests/tool_use/test_chat_completions.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import openai import pytest @@ -45,7 +43,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, logprobs=False, stream=True, ) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 role_sent: bool = False @@ -116,7 +114,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, stream=True, ) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 role_sent: bool = False diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py index 7e349c51253c..a40675744ba2 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Generator, List, Optional +from collections.abc import Generator +from typing import Optional import partial_json_parser import pytest @@ -26,8 +27,8 @@ def jamba_tool_parser(jamba_tokenizer): return JambaToolParser(jamba_tokenizer) -def assert_tool_calls(actual_tool_calls: List[ToolCall], - expected_tool_calls: List[ToolCall]): +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): assert len(actual_tool_calls) == len(expected_tool_calls) for actual_tool_call, expected_tool_call in zip(actual_tool_calls, @@ -218,10 +219,10 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, model_output, expected_tool_calls, expected_content): other_content: str = '' - function_names: List[str] = [] - function_args_strs: List[str] = [] + function_names: list[str] = [] + function_args_strs: list[str] = [] tool_call_idx: int = -1 - tool_call_ids: List[Optional[str]] = [] + tool_call_ids: list[Optional[str]] = [] for delta_message in stream_delta_message_generator( jamba_tool_parser, jamba_tokenizer, model_output): diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index b49a5e8e7e4c..910e0b2d51ab 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Dict, List, Optional +from typing import Optional import openai import pytest @@ -54,7 +54,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, assert isinstance(tool_call.function.arguments, str) parsed_arguments = json.loads(tool_call.function.arguments) - assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments, dict) assert isinstance(parsed_arguments.get("city"), str) assert isinstance(parsed_arguments.get("state"), str) @@ -73,8 +73,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, role_name: Optional[str] = None finish_reason_count: int = 0 - tool_call_names: List[str] = [] - tool_call_args: List[str] = [] + tool_call_names: list[str] = [] + tool_call_args: list[str] = [] tool_call_idx: int = -1 tool_call_id_count: int = 0 @@ -180,7 +180,7 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, logprobs=False, stream=True) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 role_sent: bool = False diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index 45f1bfc45bd7..b320b335e338 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Dict, List, Optional +from typing import Optional import openai import pytest @@ -44,7 +44,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): # make sure the arguments parse properly parsed_arguments = json.loads(tool_calls[0].function.arguments) - assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments, dict) assert isinstance(parsed_arguments.get("city"), str) assert isinstance(parsed_arguments.get("state"), str) assert parsed_arguments.get("city") == "Dallas" @@ -117,7 +117,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): # validate arguments streamed_args = json.loads(function_args_str) - assert isinstance(streamed_args, Dict) + assert isinstance(streamed_args, dict) assert isinstance(streamed_args.get("city"), str) assert isinstance(streamed_args.get("state"), str) assert streamed_args.get("city") == "Dallas" @@ -128,7 +128,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert choice.message.role == role_name assert choice.message.tool_calls[0].function.name == function_name - # compare streamed with non-streamed args Dict-wise, not string-wise + # compare streamed with non-streamed args dict-wise, not string-wise # because character-to-character comparison might not work e.g. the tool # call parser adding extra spaces or something like that. we care about the # dicts matching not byte-wise match @@ -167,7 +167,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): logprobs=False, stream=True) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 role_sent: bool = False diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index a7dfb10780a3..fd947bd7fed0 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import Any, Optional from openai.types.chat import (ChatCompletionMessageParam, ChatCompletionToolParam) @@ -12,14 +12,14 @@ class ServerConfig(TypedDict, total=False): model: str - arguments: List[str] + arguments: list[str] system_prompt: Optional[str] supports_parallel: Optional[bool] supports_rocm: Optional[bool] -def patch_system_prompt(messages: List[Dict[str, Any]], - system_prompt: str) -> List[Dict[str, Any]]: +def patch_system_prompt(messages: list[dict[str, Any]], + system_prompt: str) -> list[dict[str, Any]]: new_messages = deepcopy(messages) if new_messages[0]["role"] == "system": new_messages[0]["content"] = system_prompt @@ -28,8 +28,8 @@ def patch_system_prompt(messages: List[Dict[str, Any]], return new_messages -def ensure_system_prompt(messages: List[Dict[str, Any]], - config: ServerConfig) -> List[Dict[str, Any]]: +def ensure_system_prompt(messages: list[dict[str, Any]], + config: ServerConfig) -> list[dict[str, Any]]: prompt = config.get("system_prompt") if prompt: return patch_system_prompt(messages, prompt) @@ -39,9 +39,9 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. -ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"] +ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"] -CONFIGS: Dict[str, ServerConfig] = { +CONFIGS: dict[str, ServerConfig] = { "hermes": { "model": "NousResearch/Hermes-3-Llama-3.1-8B", @@ -205,7 +205,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], } } -MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ +MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [{ "role": "user", "content": @@ -222,14 +222,14 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "Can you tell me a joke please?" }] -MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ +MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [{ "role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?" }] -MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ +MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{ "role": "user", "content": @@ -258,7 +258,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "cloudy skies and a low chance of rain." }] -MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [{ "role": "user", "content": @@ -266,7 +266,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "Fahrenheit?" }] -MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{ "role": "user", "content": diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py index 592775e8b892..5fc5d08b327b 100644 --- a/tests/tracing/test_tracing.py +++ b/tests/tracing/test_tracing.py @@ -2,8 +2,9 @@ import os import threading +from collections.abc import Iterable from concurrent import futures -from typing import Callable, Dict, Iterable, Literal +from typing import Callable, Literal import grpc import pytest @@ -25,7 +26,7 @@ def decode_value(value: AnyValue): - field_decoders: Dict[FieldName, Callable] = { + field_decoders: dict[FieldName, Callable] = { "bool_value": (lambda v: v.bool_value), "string_value": (lambda v: v.string_value), "int_value": (lambda v: v.int_value), diff --git a/tests/utils.py b/tests/utils.py index 2ad91ca2c869..5a97636eec64 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,7 +11,7 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Optional, Union import openai import pytest @@ -73,9 +73,9 @@ class RemoteOpenAIServer: def __init__(self, model: str, - vllm_serve_args: List[str], + vllm_serve_args: list[str], *, - env_dict: Optional[Dict[str, str]] = None, + env_dict: Optional[dict[str, str]] = None, auto_port: bool = True, max_wait_seconds: Optional[float] = None) -> None: if auto_port: @@ -183,7 +183,7 @@ def _test_completion( client: openai.OpenAI, model: str, prompt: str, - token_ids: List[int], + token_ids: list[int], ): results = [] @@ -400,10 +400,10 @@ def _test_image_text( def compare_two_settings(model: str, - arg1: List[str], - arg2: List[str], - env1: Optional[Dict[str, str]] = None, - env2: Optional[Dict[str, str]] = None, + arg1: list[str], + arg2: list[str], + env1: Optional[dict[str, str]] = None, + env2: Optional[dict[str, str]] = None, *, method: str = "generate", max_wait_seconds: Optional[float] = None) -> None: @@ -429,8 +429,8 @@ def compare_two_settings(model: str, def compare_all_settings(model: str, - all_args: List[List[str]], - all_envs: List[Optional[Dict[str, str]]], + all_args: list[list[str]], + all_envs: list[Optional[dict[str, str]]], *, method: str = "generate", max_wait_seconds: Optional[float] = None) -> None: @@ -470,7 +470,7 @@ def compare_all_settings(model: str, prompt = "Hello, my name is" token_ids = tokenizer(prompt).input_ids - ref_results: List = [] + ref_results: list = [] for i, (args, env) in enumerate(zip(all_args, all_envs)): if can_force_load_format: # we are comparing the results and @@ -481,7 +481,7 @@ def compare_all_settings(model: str, # environment variable to force the load format, # e.g. in quantization tests. args = args + ["--load-format", envs.VLLM_TEST_FORCE_LOAD_FORMAT] - compare_results: List = [] + compare_results: list = [] results = ref_results if i == 0 else compare_results with RemoteOpenAIServer(model, args, @@ -582,7 +582,7 @@ def multi_process_parallel( @contextmanager -def error_on_warning(category: Type[Warning] = Warning): +def error_on_warning(category: type[Warning] = Warning): """ Within the scope of this context manager, tests will fail if any warning of the given category is emitted. @@ -604,7 +604,7 @@ def get_physical_device_indices(devices): @_nvml() -def wait_for_gpu_memory_to_clear(devices: List[int], +def wait_for_gpu_memory_to_clear(devices: list[int], threshold_bytes: int, timeout_s: float = 120) -> None: # Use nvml instead of pytorch to reduce measurement error from torch cuda @@ -612,8 +612,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int], devices = get_physical_device_indices(devices) start_time = time.time() while True: - output: Dict[int, str] = {} - output_raw: Dict[int, float] = {} + output: dict[int, str] = {} + output_raw: dict[int, float] = {} for device in devices: if current_platform.is_rocm(): dev_handle = amdsmi_get_processor_handles()[device] @@ -758,13 +758,13 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: async def completions_with_server_args( - prompts: List[str], + prompts: list[str], model_name: str, - server_cli_args: List[str], + server_cli_args: list[str], num_logprobs: Optional[int], max_wait_seconds: int = 240, max_tokens: Union[int, list] = 5, -) -> List[Completion]: +) -> list[Completion]: '''Construct a remote OpenAI server, obtain an async client to the server & invoke the completions API to obtain completions. @@ -807,7 +807,7 @@ async def completions_with_server_args( return outputs -def get_client_text_generations(completions: List[Completion]) -> List[str]: +def get_client_text_generations(completions: list[Completion]) -> list[str]: '''Extract generated tokens from the output of a request made to an Open-AI-protocol completions endpoint. ''' @@ -816,7 +816,7 @@ def get_client_text_generations(completions: List[Completion]) -> List[str]: def get_client_text_logprob_generations( - completions: List[Completion]) -> List[TextTextLogprobs]: + completions: list[Completion]) -> list[TextTextLogprobs]: '''Operates on the output of a request made to an Open-AI-protocol completions endpoint; obtains top-rank logprobs for each token in each :class:`SequenceGroup` diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 8956393c0bfb..cce2fb2c4814 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 """Compare the with and without prefix caching.""" -from typing import List import pytest @@ -434,7 +433,7 @@ def test_cache_blocks(): # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: List[BlockHashType] = [] + block_hashes: list[BlockHashType] = [] block_pool.cache_full_blocks( request=req, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index eb730973c946..f45c21ab75ba 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import Optional from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -48,9 +48,9 @@ def create_scheduler( def create_requests( num_requests: int, num_tokens: int = 10, - mm_positions: Optional[List[PlaceholderRange]] = None, + mm_positions: Optional[list[PlaceholderRange]] = None, max_tokens: int = 16, - stop_token_ids: Optional[List[int]] = None, + stop_token_ids: Optional[list[int]] = None, ): sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index 560dc3121852..8872f0388dd2 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Tuple - import pytest import torch from transformers import AutoTokenizer @@ -17,8 +15,8 @@ from tests.v1.engine.utils import FULL_STRINGS # isort: skip -EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]] -EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor] +EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]] +EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor] def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index d864cb2af23e..e7b91aeb0fbd 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -2,7 +2,7 @@ import asyncio from contextlib import ExitStack -from typing import List, Optional, Tuple +from typing import Optional import pytest @@ -47,7 +47,7 @@ async def generate(engine: AsyncLLM, prompt: PromptType, output_kind: RequestOutputKind, max_tokens: int, - prompt_logprobs: Optional[int] = None) -> Tuple[int, str]: + prompt_logprobs: Optional[int] = None) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) @@ -114,7 +114,7 @@ async def test_async_llm_refuses_prompt_logprobs_with_apc( (VISION_ENGINE_ARGS, VISION_PROMPT)]) @pytest.mark.asyncio async def test_load(monkeypatch, output_kind: RequestOutputKind, - engine_args_and_prompt: Tuple[AsyncEngineArgs, + engine_args_and_prompt: tuple[AsyncEngineArgs, PromptType]): # TODO(rickyx): Remove monkeypatch once we have a better way to test V1 # so that in the future when we switch, we don't have to change all the @@ -160,7 +160,7 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind, (VISION_ENGINE_ARGS, VISION_PROMPT)]) @pytest.mark.asyncio async def test_abort(monkeypatch, output_kind: RequestOutputKind, - engine_args_and_prompt: Tuple[AsyncEngineArgs, + engine_args_and_prompt: tuple[AsyncEngineArgs, PromptType]): with monkeypatch.context() as m, ExitStack() as after: @@ -177,7 +177,7 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind, request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] # Create concurrent requests. - tasks: List[asyncio.Task] = [] + tasks: list[asyncio.Task] = [] for request_id in request_ids: tasks.append( asyncio.create_task( diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 8c2998e58892..11c22effb122 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -5,7 +5,6 @@ import time import uuid from concurrent.futures import Future -from typing import List import pytest from transformers import AutoTokenizer @@ -213,7 +212,7 @@ def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest: class DummyExecutor(UniProcExecutor): def initialize_from_config( - self, kv_cache_configs: List[KVCacheConfig]) -> None: + self, kv_cache_configs: list[KVCacheConfig]) -> None: super().initialize_from_config(kv_cache_configs) # This executor actually can only run 1 batch at a time diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index a7c02322ff02..3880a3dd9b8a 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -3,7 +3,7 @@ import asyncio import time import uuid -from typing import Dict, List, Optional +from typing import Optional import pytest from transformers import AutoTokenizer @@ -44,7 +44,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest: ) -def loop_until_done(client: EngineCoreClient, outputs: Dict): +def loop_until_done(client: EngineCoreClient, outputs: dict): while True: engine_core_outputs = client.get_output().outputs @@ -62,7 +62,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict): break -async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): +async def loop_until_done_async(client: EngineCoreClient, outputs: dict): while True: engine_core_outputs = (await client.get_output_async()).outputs @@ -121,7 +121,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): client.add_request(request) time.sleep(0.01) - outputs: Dict[str, List] = {req_id: [] for req_id in request_ids} + outputs: dict[str, list] = {req_id: [] for req_id in request_ids} loop_until_done(client, outputs) for req_id in request_ids: @@ -207,7 +207,7 @@ async def test_engine_core_client_asyncio(monkeypatch): await client.add_request_async(request) await asyncio.sleep(0.01) - outputs: Dict[str, List] = {req_id: [] for req_id in request_ids} + outputs: dict[str, list] = {req_id: [] for req_id in request_ids} await loop_until_done_async(client, outputs) for req_id in request_ids: diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index de2a39ee9c08..33c884e6de35 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import Dict, List, Optional, Tuple +from typing import Optional import pytest @@ -47,9 +47,9 @@ def vllm_model_apc(vllm_runner, monkeypatch): def _get_test_sampling_params( - prompt_list: List[str], + prompt_list: list[str], seed: Optional[int] = 42, -) -> Tuple[List[SamplingParams], List[int]]: +) -> tuple[list[SamplingParams], list[int]]: """Generate random sampling params for a batch.""" def get_mostly_n_gt1() -> int: @@ -81,7 +81,7 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: # Validate each request response for out, n in zip(outputs, n_list): - completion_counts: Dict[str, int] = {} + completion_counts: dict[str, int] = {} # Assert correct number of completions assert len(out.outputs) == n, ( f"{len(out.outputs)} completions; {n} expected.") diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 1d47df417dda..0de853ba6e5e 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -2,7 +2,7 @@ import math import time -from typing import Dict, List, Optional +from typing import Optional import pytest @@ -112,12 +112,12 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, def _validate_logprobs( - gen_tokens: Dict[str, List[int]], - gen_logprobs: Dict[str, Optional[SampleLogprobs]], - gen_prompt_logprobs: Dict[str, Optional[PromptLogprobs]], - gen_cumulative_logprob: Dict[str, float], + gen_tokens: dict[str, list[int]], + gen_logprobs: dict[str, Optional[SampleLogprobs]], + gen_prompt_logprobs: dict[str, Optional[PromptLogprobs]], + gen_cumulative_logprob: dict[str, float], dtv: DummyOutputProcessorTestVectors, - request_id_list: List[str], + request_id_list: list[str], num_sample_logprobs: Optional[int], num_prompt_logprobs: Optional[int], ) -> None: diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 39248ce86f25..02baa4801a47 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -2,7 +2,7 @@ import random from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -61,7 +61,7 @@ def _create_random_top_logprob_test_vector( def _create_random_top_logprob_test_matrix( - shape: Tuple, + shape: tuple, lower: float, upper: float, ) -> torch.Tensor: @@ -90,7 +90,7 @@ def _create_random_top_token_test_vector( lower: int, upper: int, sampled_token_id: int, - adjust_num_logprobs: bool = True) -> Tuple[torch.Tensor, int]: + adjust_num_logprobs: bool = True) -> tuple[torch.Tensor, int]: """Create a random vector of top logprob token indices Use to create fake sample logprobs for testing. The sampled token @@ -141,11 +141,11 @@ def _create_random_top_token_test_vector( def _create_random_top_token_test_matrix( - shape: Tuple[int, int], + shape: tuple[int, int], lower: int, upper: int, - tokens_list: List[int], -) -> Tuple[torch.Tensor, torch.Tensor]: + tokens_list: list[int], +) -> tuple[torch.Tensor, torch.Tensor]: """Create a random matrix of top logprob token indices Use to create fake prompt logprobs for testing. @@ -160,7 +160,7 @@ def _create_random_top_token_test_matrix( upper: upper range of token ids Returns: - Tuple containing: + tuple containing: - 2D num_tokens x num_logprobs+1 torch Tensor of token ids - 1D tensor of ranks of prompt tokens in their respective rows, or random values @@ -206,10 +206,10 @@ def decode_token( def generate_dummy_sample_logprobs( - sampled_tokens_list: List, + sampled_tokens_list: list, num_logprobs: int, tokenizer: PreTrainedTokenizer, -) -> List[Tuple[List[int], List[float], int]]: +) -> list[tuple[list[int], list[float], int]]: """Generate dummy sample logprobs Generate a test data structure which imitates the list of sample logprobs @@ -221,7 +221,7 @@ def generate_dummy_sample_logprobs( tokenizer: model tokenizer to use for detokenization Returns - List of (top token ids vector, logprobs vector, sampled token rank) + list of (top token ids vector, logprobs vector, sampled token rank) Python lists tuples; in each tuple the logprobs and top token ids vectors have the same length which is either `num_logprobs` or `num_logprobs+1`. Sampled token rank is the rank (index+1) of the @@ -253,7 +253,7 @@ def generate_dummy_sample_logprobs( def generate_dummy_prompt_logprobs_tensors( - prompt_tokens_list: List, + prompt_tokens_list: list, num_logprobs: int, tokenizer: PreTrainedTokenizer, ) -> LogprobsTensors: @@ -269,7 +269,7 @@ def generate_dummy_prompt_logprobs_tensors( tokenizer: model tokenizer to use for detokenization Returns - Single Tuple of (logprobs matrix, top token ids matrix) torch Tensor, + Single tuple of (logprobs matrix, top token ids matrix) torch Tensor, where both matrices have dimensions num_prompt_tokens x num_logprobs """ @@ -301,19 +301,19 @@ class DummyOutputProcessorTestVectors: tokenizer: GeneralTokenizerType tokenizer_group: BaseTokenizerGroup vllm_config: EngineArgs - full_tokens: List[List[int]] # Prompt + generated tokens - prompt_tokens: List[List[int]] - generation_tokens: List[List[int]] + full_tokens: list[list[int]] # Prompt + generated tokens + prompt_tokens: list[list[int]] + generation_tokens: list[list[int]] # Each request is associated with a tuple of # (top tokens, top logprobs, ranks) prompt logprobs tensors - prompt_logprobs: List[LogprobsTensors] + prompt_logprobs: list[LogprobsTensors] # Each request is associated with a sample logprobs; a request's # sample logprobs are a list of (top tokens, top logprobs, ranks) # sample logprobs tensors at each sequence position - generation_logprobs: List[List[Tuple[List[int], List[float], int]]] - prompt_strings: List[str] - prompt_strings_len: List[int] - generation_strings: List[str] + generation_logprobs: list[list[tuple[list[int], list[float], int]]] + prompt_strings: list[str] + prompt_strings_len: list[int] + generation_strings: list[str] class MockEngineCore: @@ -321,18 +321,18 @@ class MockEngineCore: def __init__( self, - tokens_list: List[List[int]], + tokens_list: list[list[int]], # For each request, for each sampled token offset, # a tuple of # (list of topk token ids, list of sample logprob vals, rank) - generated_logprobs_raw: Optional[List[List[Tuple[List[int], - List[float], + generated_logprobs_raw: Optional[list[list[tuple[list[int], + list[float], int]]]] = None, # For each request, a tuple of # (prompt logprob val matrix, prompt logprob tok id matrix); # each matrix has dimensions # (num prompt toks) x (num prompt logprobs+1) - prompt_logprobs_raw: Optional[List[LogprobsTensors]] = None, + prompt_logprobs_raw: Optional[list[LogprobsTensors]] = None, ) -> None: self.tokens_list = tokens_list self.current_idx = 0 @@ -341,7 +341,7 @@ def __init__( self.prompt_logprobs_raw = prompt_logprobs_raw self.do_prompt_logprobs = prompt_logprobs_raw is not None - def get_outputs(self) -> List[EngineCoreOutput]: + def get_outputs(self) -> list[EngineCoreOutput]: do_logprobs = self.do_logprobs do_prompt_logprobs = self.do_prompt_logprobs token_idx = self.current_idx diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 35e059ccb548..171c84176eae 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import Dict, List, Optional +from typing import Optional import openai # use the official client for correctness check import pytest @@ -193,7 +193,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int]): - params: Dict = { + params: dict = { "prompt": ["A robot may not injure another robot", "My name is"], "model": model_name, } @@ -237,7 +237,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, stream=True) - chunks: List[str] = [] + chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: chunks.append(chunk.choices[0].text) @@ -278,7 +278,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, num_completions = len(completion.choices) assert num_completions == n, ( f"Num completions {num_completions} but expected {n}.") - completion_repeats: Dict[str, int] = {} + completion_repeats: dict[str, int] = {} for idx, choice in enumerate(completion.choices): # Assert correct completion index & some finish reason. assert choice.index == idx, ( @@ -321,7 +321,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): temperature=0.95, stream=True, seed=42) - chunks: List[List[str]] = [[] for i in range(n)] + chunks: list[list[str]] = [[] for i in range(n)] finish_reason_count = 0 async for chunk in stream: index = chunk.choices[0].index @@ -332,7 +332,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): # Assert `n` completions with correct finish reasons assert finish_reason_count == n, ( f"Expected {n} completions with valid indices and finish_reason.") - completion_repeats: Dict[str, int] = {} + completion_repeats: dict[str, int] = {} for chunk in chunks: chunk_len = len(chunk) # Assert correct number of completion tokens diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index a26a8c4ed074..d564a8c2e7a7 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -from typing import List, Tuple import pytest import torch @@ -46,8 +45,8 @@ def hf_model(hf_runner): def _repeat_logprob_config( test_prompts, - logprob_prompt_logprob_list: List[Tuple], -) -> List[Tuple]: + logprob_prompt_logprob_list: list[tuple], +) -> list[tuple]: """Ensure each test prompt has a logprob config. A logprob config specifies the optional (i.e. @@ -74,7 +73,7 @@ def _repeat_logprob_config( tuples Returns: - List of + list of (optional num sample logprob,optional num prompt logprob) tuples which is either identical to `logprob_prompt_logprob_list`, or else repeats @@ -177,7 +176,7 @@ def _test_case_get_logprobs_and_prompt_logprobs( for r in range(1, num_top_logprobs + 1)) output_text = vllm_result.outputs[0].text - output_string_from_most_likely_tokens_lst: List[str] = [] + output_string_from_most_likely_tokens_lst: list[str] = [] for top_logprobs in vllm_result.outputs[0].logprobs: top_logprob = next(iter(top_logprobs.values())) output_string_from_most_likely_tokens_lst.append( diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index f00585b40ba3..b1862455d0ec 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List import pytest import torch @@ -13,7 +12,7 @@ def sampler(): return RejectionSampler() -def create_logits_tensor(token_ids: List[int], +def create_logits_tensor(token_ids: list[int], vocab_size: int = 100) -> torch.Tensor: """Helper function to create logits tensor that will produce desired token ids on argmax""" @@ -23,7 +22,7 @@ def create_logits_tensor(token_ids: List[int], return logits -def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: +def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata: batch_size = len(spec_tokens) return SamplingMetadata( temperature=torch.tensor([]), @@ -106,7 +105,7 @@ def test_single_token_sequence(sampler): def test_empty_sequence(sampler): """Test handling empty sequence of speculated tokens""" - spec_tokens: List[List[int]] = [[]] + spec_tokens: list[list[int]] = [[]] output_tokens = [5] # Just the bonus token metadata = create_sampling_metadata(spec_tokens) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 435c1b7b5fda..b702d9ed7f83 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional, Set, Tuple +from typing import Optional import numpy as np import pytest @@ -32,7 +32,7 @@ def _create_penalty_tensor(batch_size: int, penalty_value: float, def _create_prompt_tokens_tensor( - prompt_token_ids: List[List[int]], + prompt_token_ids: list[list[int]], vocab_size: int, device: torch.device, ) -> torch.Tensor: @@ -49,8 +49,8 @@ def _create_logit_bias( batch_size: int, vocab_size: int, bias_value: float, -) -> List[Optional[Dict[int, float]]]: - res: List[Optional[Dict[int, float]]] = [] +) -> list[Optional[dict[int, float]]]: + res: list[Optional[dict[int, float]]] = [] for i in range(batch_size): logit_bias = {min(i, vocab_size - 1): bias_value} res.append(logit_bias) @@ -83,8 +83,8 @@ def _create_default_sampling_metadata( vocab_size: int, device: torch.device, ) -> SamplingMetadata: - output_token_ids: List[List[int]] = [] - prompt_token_ids: List[List[int]] = [] + output_token_ids: list[list[int]] = [] + prompt_token_ids: list[list[int]] = [] for _ in range(batch_size): output_token_ids.append( np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) @@ -118,8 +118,8 @@ def _create_default_sampling_metadata( def _generate_min_token_penalties_and_stop_tokens( num_output_tokens: int, batch_size: int, vocab_size: int, - batch_indices_for_min_token_penalty: List[int] -) -> Dict[int, Tuple[int, Set[int]]]: + batch_indices_for_min_token_penalty: list[int] +) -> dict[int, tuple[int, set[int]]]: """ Generates and returns a dict of minimum token penalties and corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each @@ -130,7 +130,7 @@ def _generate_min_token_penalties_and_stop_tokens( and a random set of stop token IDs is created. Otherwise, a lower `min_tokens` value is assigned, and the stop token IDs set is empty. """ - min_tokens: Dict[int, Tuple[int, Set[int]]] = {} + min_tokens: dict[int, tuple[int, set[int]]] = {} for index in range(batch_size): if index in batch_indices_for_min_token_penalty: min_tokens[index] = ( @@ -147,7 +147,7 @@ def _generate_min_token_penalties_and_stop_tokens( def _create_weighted_output_token_list( batch_size: int, - vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]: + vocab_size: int) -> tuple[list[list[int]], list[list[int]]]: """ Creates an output token list where each token occurs a distinct number of times. @@ -157,7 +157,7 @@ def _create_weighted_output_token_list( list, each with a different frequency. Returns: - Tuple[List[List[int]], List[List[int]]]: + tuple[list[list[int]], list[list[int]]]: - The first element is the output token list, where each sublist corresponds to a batch and contains tokens with weighted frequencies. @@ -165,8 +165,8 @@ def _create_weighted_output_token_list( batch, ordered by their frequency in the corresponding output list. """ - output_token_ids: List[List[int]] = [] - sorted_token_ids_in_output: List[List[int]] = [] + output_token_ids: list[list[int]] = [] + sorted_token_ids_in_output: list[list[int]] = [] for _ in range(batch_size): distinct_token_ids = np.random.choice(vocab_size, size=np.random.randint(1, 10), diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index e1465b123966..c69d0d49c46f 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import List, Tuple from vllm import CompletionOutput -def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]: +def get_test_batch(batch_logprobs_composition: str) -> list[tuple]: """Generate logprobs configs for a batch of requests A given request's logprobs configuration is (1) num_sample_logprobs and (2) @@ -32,7 +31,7 @@ def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]: Returns: - List of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs]) + list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs]) tuples """ if batch_logprobs_composition == "NONE": diff --git a/tests/v1/test_utils.py b/tests/v1/test_utils.py index 9b669ae00660..b68f08385866 100644 --- a/tests/v1/test_utils.py +++ b/tests/v1/test_utils.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import torch from vllm.v1.utils import bind_kv_cache @@ -22,7 +20,7 @@ def test_bind_kv_cache(): 'layers.2.self_attn': torch.zeros((1, )), 'layers.3.self_attn': torch.zeros((1, )), } - runner_kv_caches: List[torch.Tensor] = [] + runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ 'layers.0.self_attn'] @@ -52,7 +50,7 @@ def test_bind_kv_cache_non_attention(): 'model.layers.28.attn': torch.zeros((1, )), } - runner_kv_caches: List[torch.Tensor] = [] + runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 327370e71fff..72ec73701159 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional, Set, Tuple +from typing import Optional import numpy as np import pytest @@ -22,22 +22,22 @@ def _remove_requests( input_batch: InputBatch, batch_size: int, - reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]: + reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]: """ - Remove some requests randomly from the batch and returns a Tuple + Remove some requests randomly from the batch and returns a tuple of 1) set of request removed 2) indices of the requests removed ordered in descending order """ num_reqs_to_remove = np.random.randint(0, batch_size) - req_indices_to_remove: Set[int] = set() + req_indices_to_remove: set[int] = set() for _ in range(num_reqs_to_remove): req_index_to_remove = np.random.randint(0, batch_size) req_indices_to_remove.add(req_index_to_remove) req_indices_to_remove_list = list(req_indices_to_remove) req_indices_to_remove_list.sort(reverse=True) - req_ids_to_remove: Set[str] = set() + req_ids_to_remove: set[str] = set() for index in req_indices_to_remove: input_batch.remove_request(reqs[index].req_id) req_ids_to_remove.add(reqs[index].req_id) @@ -45,9 +45,9 @@ def _remove_requests( def _construct_expected_sampling_metadata( - reqs: List[CachedRequestState], - req_ids_retained: Set[int], - req_id_index_in_input_batch: Dict[str, int], + reqs: list[CachedRequestState], + req_ids_retained: set[int], + req_id_index_in_input_batch: dict[str, int], device: torch.device, ) -> SamplingMetadata: """ @@ -55,8 +55,8 @@ def _construct_expected_sampling_metadata( batch. """ num_reqs = len(req_ids_retained) - output_token_ids: List[List[int]] = [list() for _ in range(num_reqs)] - prompt_token_ids: List[List[int]] = [list() for _ in range(num_reqs)] + output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)] + prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)] presence_penalties = [0.0 for _ in range(num_reqs)] frequency_penalties = [0.0 for _ in range(num_reqs)] repetition_penalties = [1.0 for _ in range(num_reqs)] @@ -191,7 +191,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): pin_memory=is_pin_memory_available(), vocab_size=1024, ) - reqs: List[CachedRequestState] = [] + reqs: list[CachedRequestState] = [] req_id_reqs = {} req_id_output_token_ids = {} # Add requests diff --git a/tests/vllm_test_utils/vllm_test_utils/blame.py b/tests/vllm_test_utils/vllm_test_utils/blame.py index 392fd2705fb2..3b25980cb946 100644 --- a/tests/vllm_test_utils/vllm_test_utils/blame.py +++ b/tests/vllm_test_utils/vllm_test_utils/blame.py @@ -4,7 +4,8 @@ import dataclasses import sys import traceback -from typing import Callable, Generator +from collections.abc import Generator +from typing import Callable @dataclasses.dataclass diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py index 44d45f262105..27077f13de24 100644 --- a/tests/vllm_test_utils/vllm_test_utils/monitor.py +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -4,7 +4,8 @@ import dataclasses import sys import traceback -from typing import Callable, Generator, Generic, TypeVar +from collections.abc import Generator +from typing import Callable, Generic, TypeVar _T = TypeVar("_T") diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 0ce0465a704c..3e237aacc8c6 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -from typing import List import pytest import torch @@ -43,7 +42,7 @@ def test_empty_seq_group(): enable_chunked_prefill=False, enforce_eager=True, ) - seq_group_metadata_list: List[SequenceGroupMetadata] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) ( @@ -103,9 +102,9 @@ def test_prepare_prompt(batch_size): enforce_eager=True, ) - seq_lens: List[int] = [] - encoder_seq_lens: List[int] = [] - seq_group_metadata_list: List[SequenceGroupMetadata] = [] + seq_lens: list[int] = [] + encoder_seq_lens: list[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] block_tables = {0: [1]} cross_block_table = [2] for i in range(batch_size): @@ -295,9 +294,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): enforce_eager=True, ) - seq_lens: List[int] = [] - encoder_seq_lens: List[int] = [] - seq_group_metadata_list: List[SequenceGroupMetadata] = [] + seq_lens: list[int] = [] + encoder_seq_lens: list[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] block_tables = { 0: [1], 1: [3] @@ -503,9 +502,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): } if multiple_seqs_per_seq_group else { 0: [1] } - seq_lens: List[int] = [] - encoder_seq_lens: List[int] = [] - seq_group_metadata_list: List[SequenceGroupMetadata] = [] + seq_lens: list[int] = [] + encoder_seq_lens: list[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] cross_block_table = [2] expanded_batch_size = 0 diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index eb341fb1b293..a41fc52170fe 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses -from typing import List, Tuple, Type import torch @@ -27,15 +26,15 @@ def get_impl_cls(): raise NotImplementedError @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: + def get_metadata_cls() -> type["AttentionMetadata"]: return AttentionMetadata @staticmethod - def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + def get_builder_cls() -> type["AttentionMetadataBuilder"]: return AttentionMetadataBuilder @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: + def get_state_cls() -> type["CommonAttentionState"]: return CommonAttentionState @staticmethod @@ -44,7 +43,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: raise NotImplementedError @staticmethod @@ -57,7 +56,7 @@ def swap_blocks( @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: pass diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 3f9a0d6faa61..b8ba69b0dd8f 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest import torch @@ -42,8 +40,8 @@ def test_prepare_prompt(batch_size): enable_chunked_prefill=False, ) - seq_lens: List[int] = [] - seq_group_metadata_list: List[SequenceGroupMetadata] = [] + seq_lens: list[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block @@ -159,8 +157,8 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, ) - context_lens: List[int] = [] - seq_group_metadata_list: List[SequenceGroupMetadata] = [] + context_lens: list[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] # Assume each seq group finishes prefill. for i in range(batch_size): # make sure all tokens fit into one block @@ -265,7 +263,7 @@ def test_empty_seq_group(): dtype="float16", enforce_eager=False, ) - seq_group_metadata_list: List[SequenceGroupMetadata] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) input_tokens, input_positions, attn_metadata = ( @@ -315,10 +313,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ) # Add prefill requests. - seq_lens: List[int] = [] - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - prefill_metadata_list: List[SequenceGroupMetadata] = [] - decode_metadata_list: List[SequenceGroupMetadata] = [] + seq_lens: list[int] = [] + seq_group_metadata_list: list[SequenceGroupMetadata] = [] + prefill_metadata_list: list[SequenceGroupMetadata] = [] + decode_metadata_list: list[SequenceGroupMetadata] = [] block_tables = {0: [1]} prefill_batch_size = batch_size // 2 decode_batch_size = batch_size - prefill_batch_size diff --git a/tools/profiler/print_layerwise_table.py b/tools/profiler/print_layerwise_table.py index adbb7301bfc7..9601b578eb97 100644 --- a/tools/profiler/print_layerwise_table.py +++ b/tools/profiler/print_layerwise_table.py @@ -2,13 +2,12 @@ import argparse import json -from typing import Dict from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry from vllm.profiler.utils import TablePrinter, indent_string -def flatten_entries(entry_cls, profile_dict: Dict): +def flatten_entries(entry_cls, profile_dict: dict): entries_and_depth = [] def get_entries(node, curr_depth=0): diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index c527cdbe0225..8ec3dfc97a73 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -6,7 +6,7 @@ import math import os from pathlib import Path -from typing import Any, List, Optional, Tuple +from typing import Any, Optional import matplotlib.pyplot as plt import pandas as pd @@ -24,7 +24,7 @@ def largest_dist_from_leaf(node: dict, depth: int = 0): def get_entries_at_depth(depth: int, - entries_and_traces: List[Tuple[Any, Any]], + entries_and_traces: list[tuple[Any, Any]], node: dict, curr_depth: int = 0, trace=()): @@ -48,9 +48,9 @@ def get_entries_at_depth(depth: int, trace=trace) -def fold_nodes(root: dict, nodes_to_fold: List[str]): +def fold_nodes(root: dict, nodes_to_fold: list[str]): - stack: List[dict] = [root] + stack: list[dict] = [root] while len(stack) != 0: node = stack.pop() if node['entry']['name'] in nodes_to_fold: @@ -427,12 +427,12 @@ def main( plot_metric: str, make_names_unique: bool, top_k: int, - json_nodes_to_fold: List[str]): + json_nodes_to_fold: list[str]): - def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame: + def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame: def get_entries_and_traces(key: str): - entries_and_traces: List[Tuple[Any, Any]] = [] + entries_and_traces: list[tuple[Any, Any]] = [] for root in profile_json[key]["summary_stats"]: # Fold nodes in the traces as per user request. i.e. simply # make the requested nodes leaf-nodes. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 373f92a52a19..3c822028426e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2,7 +2,7 @@ import contextlib import importlib -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import torch import torch.library @@ -198,7 +198,7 @@ def rms_norm_dynamic_per_token_quant( quant_dtype: torch.dtype, scale_ub: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=quant_dtype) scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, @@ -347,7 +347,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, @register_fake("_C::aqlm_gemm") def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: List[int], + codebook_partition_sizes: list[int], bias: Optional[torch.Tensor]) -> torch.Tensor: out_features = codes.size(0) * codebooks.size(2) flat_input = input.reshape((-1, input.size(-1))) @@ -363,7 +363,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, @register_fake("_C::aqlm_dequant") def _aqlm_dequant_fake( codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: List[int]) -> torch.Tensor: + codebook_partition_sizes: list[int]) -> torch.Tensor: in_features = codes.size(1) * 8 out_features = codes.size(0) return torch.empty((out_features, in_features), @@ -554,7 +554,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: def cutlass_sparse_compress(a: torch.Tensor) \ - -> Tuple[torch.Tensor, torch.Tensor]: + -> tuple[torch.Tensor, torch.Tensor]: """ Compresses a sparse matrix for use with Cutlass sparse operations. @@ -571,7 +571,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \ - `torch.float16` Returns: - Tuple[torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`. - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation. @@ -646,14 +646,14 @@ def cutlass_scaled_sparse_mm( # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: List[int], + codebook_partition_sizes: list[int], bias: Optional[torch.Tensor]) -> torch.Tensor: return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, codebook_partition_sizes, bias) def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: List[int]) -> torch.Tensor: + codebook_partition_sizes: list[int]) -> torch.Tensor: return torch.ops._C.aqlm_dequant(codes, codebooks, codebook_partition_sizes) @@ -738,7 +738,7 @@ def machete_supported_schedules( group_zeros_type: Optional[torch.dtype] = None, channel_scales_type: Optional[torch.dtype] = None, token_scales_type: Optional[torch.dtype] = None, - out_type: Optional[torch.dtype] = None) -> List[str]: + out_type: Optional[torch.dtype] = None) -> list[str]: return torch.ops._C.machete_supported_schedules( a_type, b_type.id, group_scales_type, group_zeros_type, channel_scales_type, token_scales_type, out_type) @@ -783,7 +783,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: # fp4 def scaled_fp4_quant( input: torch.Tensor, - input_global_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + input_global_scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. @@ -798,7 +798,7 @@ def scaled_fp4_quant( input_global_scale: A scalar scaling factor for the entire tensor. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every two values are packed into a uint8 and float8_e4m3 scaling factors in the sizzled layout. """ @@ -845,7 +845,7 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -866,12 +866,12 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape + shape: Union[tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = torch.float8_e4m3fnuz \ if current_platform.is_rocm() else torch.float8_e4m3fn @@ -903,7 +903,7 @@ def allspark_repack_weight( scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None, has_zp: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format for Ampere W8A16 Fused Gemm kernel @@ -917,7 +917,7 @@ def allspark_repack_weight( if use asymmetric quantization, has_zp = True. Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : rearranged weight, scale, and optionally zero_point. """ K = qweight.shape[0] @@ -964,7 +964,7 @@ def scaled_int8_quant( scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -977,7 +977,7 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: @@ -1165,13 +1165,13 @@ def concat_and_cache_mla( scale) -def copy_blocks(key_caches: List[torch.Tensor], - value_caches: List[torch.Tensor], +def copy_blocks(key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) -def copy_blocks_mla(kv_caches: List[torch.Tensor], +def copy_blocks_mla(kv_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) @@ -1209,7 +1209,7 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # custom ar -def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, +def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor, rank: int, full_nvlink: bool) -> int: return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink) @@ -1229,16 +1229,16 @@ def meta_size() -> int: return torch.ops._C_custom_ar.meta_size() -def register_buffer(fa: int, ipc_tensors: List[int]) -> None: +def register_buffer(fa: int, ipc_tensors: list[int]) -> None: return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) -def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: +def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers(fa: int, handles: List[List[int]], - offsets: List[List[int]]) -> None: +def register_graph_buffers(fa: int, handles: list[list[int]], + offsets: list[list[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) @@ -1246,7 +1246,7 @@ def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. @@ -1272,7 +1272,7 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index ccb67baa5338..a7b909d20634 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple +from typing import Optional import torch @@ -18,7 +18,7 @@ class ipex_ops: @staticmethod def _reshape_activation_tensor( - x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: num = x.size(0) d = x.size(1) // 2 x = x.reshape(num, 2, d) @@ -213,8 +213,8 @@ def reshape_and_cache( key, value, key_cache, value_cache, slot_mapping) @staticmethod - def copy_blocks(key_caches: List[torch.Tensor], - value_caches: List[torch.Tensor], + def copy_blocks(key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: torch.xpu.copy_blocks( # type: ignore key_caches, diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 97b2b630fc3e..5d4ebdb7acbc 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from vllm.sequence import Logprob @@ -17,14 +17,14 @@ class BeamSearchSequence: about to be returned to the user. """ # The tokens includes the prompt. - tokens: List[int] - logprobs: List[Dict[int, Logprob]] + tokens: list[int] + logprobs: list[dict[int, Logprob]] cum_logprob: float = 0.0 text: Optional[str] = None finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None multi_modal_data: Optional["MultiModalDataDict"] = None - mm_processor_kwargs: Optional[Dict[str, Any]] = None + mm_processor_kwargs: Optional[dict[str, Any]] = None @dataclass @@ -33,20 +33,20 @@ class BeamSearchOutput: It contains the list of the best beam search sequences. The length of the list is equal to the beam width. """ - sequences: List[BeamSearchSequence] + sequences: list[BeamSearchSequence] class BeamSearchInstance: - def __init__(self, prompt_tokens: List[int]): - self.beams: List[BeamSearchSequence] = [ + def __init__(self, prompt_tokens: list[int]): + self.beams: list[BeamSearchSequence] = [ BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) ] - self.completed: List[BeamSearchSequence] = [] + self.completed: list[BeamSearchSequence] = [] def get_beam_search_score( - tokens: List[int], + tokens: list[int], cumulative_logprob: float, eos_token_id: int, length_penalty: float = 1.0, diff --git a/vllm/config.py b/vllm/config.py index 54ed38418dd4..f87d2d6e82cf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -7,13 +7,14 @@ import json import sys import warnings +from collections import Counter +from collections.abc import Mapping from contextlib import contextmanager from dataclasses import dataclass, field, replace from importlib.util import find_spec from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, - Final, List, Literal, Mapping, Optional, Protocol, Set, - Tuple, Type, Union) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, + Optional, Protocol, Union) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -67,20 +68,20 @@ RunnerType = Literal["generate", "pooling", "draft", "transcription"] -_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = { +_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { "generate": ["generate"], "pooling": ["embed", "classify", "score", "reward"], "draft": ["draft"], "transcription": ["transcription"], } -_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { +_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = { task: runner for runner, tasks in _RUNNER_TASKS.items() for task in tasks } -HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], +HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], PretrainedConfig]] @@ -92,7 +93,7 @@ def compute_hash(self) -> str: class SupportsMetricsInfo(Protocol): - def metrics_info(self) -> Dict[str, str]: + def metrics_info(self) -> dict[str, str]: ... @@ -209,7 +210,7 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: List[Any] = [] + factors: list[Any] = [] factors.append(self.model) factors.append(self.dtype) factors.append(self.quantization) @@ -233,7 +234,7 @@ def __init__( allowed_local_media_path: str = "", revision: Optional[str] = None, code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict[str, Any]] = None, rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, @@ -244,19 +245,19 @@ def __init__( max_logprobs: int = 20, disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, + served_model_name: Optional[Union[str, list[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, config_format: ConfigFormat = ConfigFormat.AUTO, hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, disable_mm_preprocessor_cache: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, + override_neuron_config: Optional[dict[str, Any]] = None, override_pooler_config: Optional["PoolerConfig"] = None, logits_processor_pattern: Optional[str] = None, generation_config: Optional[str] = None, enable_sleep_mode: bool = False, - override_generation_config: Optional[Dict[str, Any]] = None, + override_generation_config: Optional[dict[str, Any]] = None, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, ) -> None: self.model = model @@ -283,7 +284,7 @@ def __init__( hf_overrides_fn = None if rope_scaling is not None: - hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} + hf_override: dict[str, Any] = {"rope_scaling": rope_scaling} hf_overrides_kw.update(hf_override) msg = ("`--rope-scaling` will be removed in a future release. " f"'Please instead use `--hf-overrides '{hf_override!r}'`") @@ -505,8 +506,8 @@ def _verify_tokenizer_mode(self) -> None: def _get_preferred_task( self, - architectures: List[str], - supported_tasks: Set[_ResolvedTask], + architectures: list[str], + supported_tasks: set[_ResolvedTask], ) -> Optional[_ResolvedTask]: model_id = self.model if get_pooling_config(model_id, self.revision): @@ -516,7 +517,7 @@ def _get_preferred_task( if self.registry.is_transcription_model(architectures): return "transcription" - suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ + suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ # Other models follow this pattern ("ForCausalLM", "generate"), ("ForConditionalGeneration", "generate"), @@ -537,27 +538,27 @@ def _get_preferred_task( def _resolve_task( self, task_option: Union[TaskOption, Literal["draft"]], - ) -> Tuple[Set[_ResolvedTask], _ResolvedTask]: + ) -> tuple[set[_ResolvedTask], _ResolvedTask]: if task_option == "draft": return {"draft"}, "draft" registry = self.registry architectures = self.architectures - runner_support: Dict[RunnerType, bool] = { + runner_support: dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them "transcription": registry.is_transcription_model(architectures), "generate": registry.is_text_generation_model(architectures), "pooling": registry.is_pooling_model(architectures), } - supported_runner_types_lst: List[RunnerType] = [ + supported_runner_types_lst: list[RunnerType] = [ runner_type for runner_type, is_supported in runner_support.items() if is_supported ] - supported_tasks_lst: List[_ResolvedTask] = [ + supported_tasks_lst: list[_ResolvedTask] = [ task for runner_type in supported_runner_types_lst for task in _RUNNER_TASKS[runner_type] ] @@ -767,7 +768,7 @@ def verify_with_parallel_config( self.use_async_output_proc = False def get_hf_config_sliding_window( - self) -> Union[Optional[int], List[Optional[int]]]: + self) -> Union[Optional[int], list[Optional[int]]]: """Get the sliding window size, or None if disabled.""" # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in @@ -778,7 +779,7 @@ def get_hf_config_sliding_window( return None return getattr(self.hf_text_config, "sliding_window", None) - def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]: + def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: """Get the sliding window size, or None if disabled. """ # If user disables sliding window, return None. @@ -888,7 +889,7 @@ def get_num_attention_heads(self, return num_heads // parallel_config.tensor_parallel_size def get_layers_start_end_indices( - self, parallel_config: "ParallelConfig") -> Tuple[int, int]: + self, parallel_config: "ParallelConfig") -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices if self.hf_text_config.model_type == "deepseek_mtp": total_num_hidden_layers = getattr(self.hf_text_config, @@ -949,7 +950,7 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config - def try_get_generation_config(self) -> Dict[str, Any]: + def try_get_generation_config(self) -> dict[str, Any]: if self.generation_config is None or self.generation_config == "auto": config = try_get_generation_config( self.hf_config_path or self.model, @@ -967,7 +968,7 @@ def try_get_generation_config(self) -> Dict[str, Any]: return config.to_diff_dict() - def get_diff_sampling_param(self) -> Dict[str, Any]: + def get_diff_sampling_param(self) -> dict[str, Any]: """ This method returns a dictionary containing the parameters that differ from the default sampling parameters, but only @@ -975,7 +976,7 @@ def get_diff_sampling_param(self) -> Dict[str, Any]: set, an empty dictionary is returned. Returns: - Dict[str, Any]: A dictionary with the differing sampling + dict[str, Any]: A dictionary with the differing sampling parameters if `generation_config` is set, otherwise an empty dictionary. """ @@ -1032,7 +1033,7 @@ def use_mla(self) -> bool: return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE @property - def supported_runner_types(self) -> Set[RunnerType]: + def supported_runner_types(self) -> set[RunnerType]: return {_TASK_RUNNER[task] for task in self.supported_tasks} @property @@ -1075,7 +1076,7 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: List[Any] = [] + factors: list[Any] = [] factors.append(self.cache_dtype) # `cpu_offload_gb` does not use `torch.compile` yet. hash_str = hashlib.md5(str(factors).encode()).hexdigest() @@ -1183,7 +1184,7 @@ class TokenizerPoolConfig: pool type. """ pool_size: int - pool_type: Union[str, Type["BaseTokenizerGroup"]] + pool_type: Union[str, type["BaseTokenizerGroup"]] extra_config: dict def compute_hash(self) -> str: @@ -1200,7 +1201,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -1214,7 +1215,7 @@ def __post_init__(self): @classmethod def create_config( cls, tokenizer_pool_size: int, - tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]], + tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]], tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. @@ -1285,7 +1286,7 @@ class LoadConfig: download_dir: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field( default_factory=dict) - ignore_patterns: Optional[Union[List[str], str]] = None + ignore_patterns: Optional[Union[list[str], str]] = None def compute_hash(self) -> str: """ @@ -1301,7 +1302,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -1359,7 +1360,7 @@ class ParallelConfig: # to "ray" if Ray is installed and fail otherwise. Note that tpu # and hpu only support Ray for distributed inference. distributed_executor_backend: Optional[Union[str, - Type["ExecutorBase"]]] = None + type["ExecutorBase"]]] = None # the full name of the worker class to use. If "auto", the worker class # will be determined based on the platform. @@ -1423,7 +1424,7 @@ def compute_hash(self): excluding anything before input ids/embeddings and after the final hidden states. """ - factors: List[Any] = [] + factors: list[Any] = [] factors.append(self.pipeline_parallel_size) factors.append(self.tensor_parallel_size) return hashlib.sha256(str(factors).encode()).hexdigest() @@ -1600,7 +1601,7 @@ class SchedulerConfig: # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) # or "mod.custom_class". - scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" + scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" def compute_hash(self) -> str: """ @@ -1616,7 +1617,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -1752,7 +1753,7 @@ def compute_hash(self) -> str: # no factors to consider. # the device/platform information will be summarized # by torch/vllm automatically. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -1798,7 +1799,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # spec decode does not use `torch.compile` yet. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -2261,7 +2262,7 @@ class LoRAConfig: lora_extra_vocab_size: int = 256 # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 - long_lora_scaling_factors: Optional[Tuple[float]] = None + long_lora_scaling_factors: Optional[tuple[float]] = None bias_enabled: bool = False def compute_hash(self) -> str: @@ -2278,7 +2279,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # LoRA is not compatible with `torch.compile` . - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -2350,7 +2351,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -2395,7 +2396,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -2431,7 +2432,7 @@ class PoolerConfig: are returned. """ - returned_token_ids: Optional[List[int]] = None + returned_token_ids: Optional[list[int]] = None """ A list of indices for the vocabulary dimensions to be extracted, such as the token IDs of ``good_token`` and ``bad_token`` in the @@ -2452,7 +2453,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -2469,7 +2470,7 @@ def from_json(json_str: str) -> "PoolerConfig": "bfloat16": torch.bfloat16, } -_ROCM_NOT_SUPPORTED_DTYPE: List[str] = [] # +_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] # def _get_and_verify_dtype( @@ -2558,7 +2559,7 @@ def _get_and_verify_max_len( hf_config: PretrainedConfig, max_model_len: Optional[int], disable_sliding_window: bool, - sliding_window_len: Optional[Union[int, List[Optional[int]]]], + sliding_window_len: Optional[Union[int, list[Optional[int]]]], spec_target_max_model_len: Optional[int] = None, encoder_config: Optional[Any] = None, ) -> int: @@ -2684,7 +2685,7 @@ def _get_and_verify_max_len( def get_min_sliding_window( - sliding_window: Union[int, List[Optional[int]]]) -> int: + sliding_window: Union[int, list[Optional[int]]]) -> int: if isinstance(sliding_window, list): return min(s for s in sliding_window if s is not None) @@ -2692,7 +2693,7 @@ def get_min_sliding_window( def get_served_model_name(model: str, - served_model_name: Optional[Union[str, List[str]]]): + served_model_name: Optional[Union[str, list[str]]]): """ If the input is a non-empty list, the first model_name in `served_model_name` is taken. @@ -2731,7 +2732,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -2774,7 +2775,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -2833,7 +2834,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: List[Any] = [] + factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str @@ -2930,7 +2931,7 @@ class CompilationConfig(BaseModel): torch.compile will handle cudagraph capture logic in the future. - cudagraph_capture_sizes: sizes to capture cudagraph. - None (default): capture sizes are inferred from vllm config. - - List[int]: capture sizes are specified as given. + - list[int]: capture sizes are specified as given. - cudagraph_num_of_warmups: number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. Only after that, the execution will be recorded, and the recorded @@ -2972,17 +2973,17 @@ class CompilationConfig(BaseModel): debug_dump_path: str = "" cache_dir: str = "" backend: str = "" - custom_ops: List[str] = Field(default_factory=list) - splitting_ops: List[str] = Field(default=None) # type: ignore + custom_ops: list[str] = Field(default_factory=list) + splitting_ops: list[str] = Field(default=None) # type: ignore use_inductor: bool = True - compile_sizes: Optional[List[Union[int, str]]] = Field(default=None) - inductor_compile_config: Dict = Field(default_factory=dict) - inductor_passes: Dict[str, str] = Field(default_factory=dict) + compile_sizes: Optional[list[Union[int, str]]] = Field(default=None) + inductor_compile_config: dict = Field(default_factory=dict) + inductor_passes: dict[str, str] = Field(default_factory=dict) use_cudagraph: bool = False cudagraph_num_of_warmups: int = 0 - cudagraph_capture_sizes: Optional[List[int]] = None + cudagraph_capture_sizes: Optional[list[int]] = None cudagraph_copy_inputs: bool = False class PassConfig(BaseModel): @@ -2998,7 +2999,7 @@ class PassConfig(BaseModel): - enable_noop: whether to enable the custom no-op elimination pass. TODO(luka) better pass enabling system. """ - dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_stages: list[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) enable_fusion: bool = True enable_noop: bool = True @@ -3026,20 +3027,20 @@ def model_post_init(self, __context: Any) -> None: max_capture_size: int = PrivateAttr local_cache_dir: str = PrivateAttr # local cache dir for each rank # optimization: - # Intuitively, bs_to_padded_graph_size should be Dict[int, int]. + # Intuitively, bs_to_padded_graph_size should be dict[int, int]. # since we know all keys are in a range [0, max_capture_size], - # we can optimize it to List[int] for better lookup performance. - bs_to_padded_graph_size: List[int] = PrivateAttr + # we can optimize it to list[int] for better lookup performance. + bs_to_padded_graph_size: list[int] = PrivateAttr # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr - traced_files: Set[str] = PrivateAttr + traced_files: set[str] = PrivateAttr compilation_time: float = PrivateAttr # Per-model forward context # Map from layer name to the attention cls - static_forward_context: Dict[str, Any] = PrivateAttr + static_forward_context: dict[str, Any] = PrivateAttr def compute_hash(self) -> str: """ @@ -3053,7 +3054,7 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: List[Any] = [] + factors: list[Any] = [] factors.append(self.level) factors.append(self.backend) factors.append(self.custom_ops) @@ -3150,7 +3151,7 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: return VllmBackend(vllm_config) def init_with_cudagraph_sizes(self, - cudagraph_capture_sizes: List[int]) -> None: + cudagraph_capture_sizes: list[int]) -> None: """To complete the initialization of config, we need to know the cudagraph sizes.""" @@ -3243,10 +3244,10 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: List[Any] = [] + factors: list[Any] = [] # summarize vllm config - vllm_factors: List[Any] = [] + vllm_factors: list[Any] = [] from vllm import __version__ vllm_factors.append(__version__) if self.model_config: diff --git a/vllm/connections.py b/vllm/connections.py index dc060bb6f88a..2c259bb7c3e6 100644 --- a/vllm/connections.py +++ b/vllm/connections.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Mapping, MutableMapping from pathlib import Path -from typing import Mapping, MutableMapping, Optional +from typing import Optional from urllib.parse import urlparse import aiohttp diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 28b8c847c0fd..c81ff958531b 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -10,7 +10,8 @@ import json import ssl from argparse import Namespace -from typing import Any, AsyncGenerator, Optional +from collections.abc import AsyncGenerator +from typing import Any, Optional from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c50c631dafcc..b05842dd27d3 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -5,10 +5,11 @@ import json from abc import ABC, abstractmethod from collections import defaultdict, deque +from collections.abc import Awaitable, Iterable from functools import cache, lru_cache, partial from pathlib import Path -from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, - Literal, Optional, Tuple, TypeVar, Union, cast) +from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, + cast) import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils @@ -117,7 +118,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): role: Required[str] """The role of the message's author.""" - content: Union[str, List[ChatCompletionContentPartParam]] + content: Union[str, list[ChatCompletionContentPartParam]] """The contents of the message.""" name: str @@ -143,7 +144,7 @@ class ConversationMessage(TypedDict, total=False): role: Required[str] """The role of the message's author.""" - content: Union[Optional[str], List[Dict[str, str]]] + content: Union[Optional[str], list[dict[str, str]]] """The contents of the message""" tool_call_id: Optional[str] @@ -495,13 +496,13 @@ def __init__(self) -> None: super().__init__() # multimodal placeholder_string : count - self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0) + self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0) def _add_placeholder(self, placeholder: Optional[str]): if placeholder: self._placeholder_counts[placeholder] += 1 - def mm_placeholder_counts(self) -> Dict[str, int]: + def mm_placeholder_counts(self) -> dict[str, int]: return dict(self._placeholder_counts) @abstractmethod @@ -652,12 +653,12 @@ def load_chat_template( # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], +def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], text_prompt: str) -> str: """Combine multimodal prompts for a multimodal language model.""" # Look through the text prompt to check for missing placeholders - missing_placeholders: List[str] = [] + missing_placeholders: list[str] = [] for placeholder in placeholder_counts: # For any existing placeholder in the text prompt, we leave it as is @@ -684,10 +685,10 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) -_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio] +_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio] # Define a mapping from part types to their corresponding parsing functions. -MM_PARSER_MAP: Dict[ +MM_PARSER_MAP: dict[ str, Callable[[ChatCompletionContentPartParam], _ContentPart], ] = { @@ -749,7 +750,7 @@ def _parse_chat_message_content_mm_part( part) return "audio_url", audio_params.get("audio_url", "") if part.get("input_audio") is not None: - input_audio_params = cast(Dict[str, str], part) + input_audio_params = cast(dict[str, str], part) return "input_audio", input_audio_params if part.get("video_url") is not None: video_params = cast(CustomChatCompletionContentSimpleVideoParam, @@ -773,7 +774,7 @@ def _parse_chat_message_content_parts( mm_tracker: BaseMultiModalItemTracker, *, wrap_dicts: bool, -) -> List[ConversationMessage]: +) -> list[ConversationMessage]: content = list[_ContentPart]() mm_parser = mm_tracker.create_parser() @@ -791,7 +792,7 @@ def _parse_chat_message_content_parts( # Parsing wraps images and texts as interleaved dictionaries return [ConversationMessage(role=role, content=content)] # type: ignore - texts = cast(List[str], content) + texts = cast(list[str], content) text_prompt = "\n".join(texts) mm_placeholder_counts = mm_parser.mm_placeholder_counts() if mm_placeholder_counts: @@ -866,7 +867,7 @@ def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, content_format: _ChatTemplateContentFormat, -) -> List[ConversationMessage]: +) -> list[ConversationMessage]: role = message["role"] content = message.get("content") @@ -900,7 +901,7 @@ def _parse_chat_message_content( return result -def _postprocess_messages(messages: List[ConversationMessage]) -> None: +def _postprocess_messages(messages: list[ConversationMessage]) -> None: # per the Transformers docs & maintainers, tool call arguments in # assistant-role messages with tool_calls need to be dicts not JSON str - # this is how tool-use chat templates will expect them moving forwards @@ -916,12 +917,12 @@ def _postprocess_messages(messages: List[ConversationMessage]) -> None: def parse_chat_messages( - messages: List[ChatCompletionMessageParam], + messages: list[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: - conversation: List[ConversationMessage] = [] +) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: + conversation: list[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) for msg in messages: @@ -939,12 +940,12 @@ def parse_chat_messages( def parse_chat_messages_futures( - messages: List[ChatCompletionMessageParam], + messages: list[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: - conversation: List[ConversationMessage] = [] +) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: + conversation: list[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) for msg in messages: @@ -963,7 +964,7 @@ def parse_chat_messages_futures( def apply_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - conversation: List[ConversationMessage], + conversation: list[ConversationMessage], chat_template: Optional[str], *, tokenize: bool = False, # Different from HF's default @@ -985,10 +986,10 @@ def apply_hf_chat_template( def apply_mistral_chat_template( tokenizer: MistralTokenizer, - messages: List[ChatCompletionMessageParam], + messages: list[ChatCompletionMessageParam], chat_template: Optional[str] = None, **kwargs: Any, -) -> List[int]: +) -> list[int]: if chat_template is not None: logger.warning_once( "'chat_template' cannot be overridden for mistral tokenizer.") diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 73df900f610f..21a7d48b75c1 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -5,7 +5,7 @@ import os import signal import sys -from typing import List, Optional, Tuple +from typing import Optional from openai import OpenAI from openai.types.chat import ChatCompletionMessageParam @@ -23,7 +23,7 @@ def signal_handler(sig, frame): signal.signal(signal.SIGTSTP, signal_handler) -def _interactive_cli(args: argparse.Namespace) -> Tuple[str, OpenAI]: +def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: _register_signal_handlers() base_url = args.url @@ -43,7 +43,7 @@ def _interactive_cli(args: argparse.Namespace) -> Tuple[str, OpenAI]: def chat(system_prompt: Optional[str], model_name: str, client: OpenAI) -> None: - conversation: List[ChatCompletionMessageParam] = [] + conversation: list[ChatCompletionMessageParam] = [] if system_prompt is not None: conversation.append({"role": "system", "content": system_prompt}) @@ -100,7 +100,7 @@ def __init__(self): def cmd(args: argparse.Namespace) -> None: model_name, client = _interactive_cli(args) system_prompt = args.system_prompt - conversation: List[ChatCompletionMessageParam] = [] + conversation: list[ChatCompletionMessageParam] = [] if system_prompt is not None: conversation.append({"role": "system", "content": system_prompt}) @@ -168,5 +168,5 @@ def subparser_init( return complete_parser -def cmd_init() -> List[CLISubcommand]: +def cmd_init() -> list[CLISubcommand]: return [ChatCommand(), CompleteCommand()] diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 1afead8a120d..c345ece4dada 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -from typing import List import uvloop @@ -59,5 +58,5 @@ def subparser_init( return make_arg_parser(serve_parser) -def cmd_init() -> List[CLISubcommand]: +def cmd_init() -> list[CLISubcommand]: return [ServeSubcommand()] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3f3262f6e72c..122e2ed86cb6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -2,9 +2,9 @@ import itertools import warnings +from collections.abc import Sequence from contextlib import contextmanager -from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence, - Tuple, Type, Union, cast, overload) +from typing import Any, Callable, ClassVar, Optional, Union, cast, overload import cloudpickle import torch.nn as nn @@ -177,11 +177,11 @@ def __init__( disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, - compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + compilation_config: Optional[Union[int, dict[str, Any]]] = None, **kwargs, ) -> None: ''' @@ -246,7 +246,7 @@ def __init__( self.request_counter = Counter() @staticmethod - def get_engine_class() -> Type[LLMEngine]: + def get_engine_class() -> type[LLMEngine]: if envs.VLLM_USE_V1: # Lazy import: the v1 package isn't distributed from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine @@ -283,11 +283,11 @@ def generate( Sequence[SamplingParams]]] = None, *, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - ) -> List[RequestOutput]: + ) -> list[RequestOutput]: ... @overload # LEGACY: single (prompt + optional token ids) @@ -296,30 +296,30 @@ def generate( self, prompts: str, sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, - prompt_token_ids: Optional[List[int]] = None, + list[SamplingParams]]] = None, + prompt_token_ids: Optional[list[int]] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - ) -> List[RequestOutput]: + ) -> list[RequestOutput]: ... @overload # LEGACY: multi (prompt + optional token ids) @deprecated("'prompt_token_ids' will become part of 'prompts'") def generate( self, - prompts: List[str], + prompts: list[str], sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, - prompt_token_ids: Optional[List[List[int]]] = None, + list[SamplingParams]]] = None, + prompt_token_ids: Optional[list[list[int]]] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - ) -> List[RequestOutput]: + ) -> list[RequestOutput]: ... @overload # LEGACY: single (token ids + optional prompt) @@ -328,32 +328,32 @@ def generate( self, prompts: Optional[str] = None, sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + list[SamplingParams]]] = None, *, - prompt_token_ids: List[int], + prompt_token_ids: list[int], use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - ) -> List[RequestOutput]: + ) -> list[RequestOutput]: ... @overload # LEGACY: multi (token ids + optional prompt) @deprecated("'prompt_token_ids' will become part of 'prompts'") def generate( self, - prompts: Optional[List[str]] = None, + prompts: Optional[list[str]] = None, sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + list[SamplingParams]]] = None, *, - prompt_token_ids: List[List[int]], + prompt_token_ids: list[list[int]], use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - ) -> List[RequestOutput]: + ) -> list[RequestOutput]: ... @overload # LEGACY: single or multi token ids [pos-only] @@ -362,13 +362,13 @@ def generate( self, prompts: None, sampling_params: None, - prompt_token_ids: Union[List[int], List[List[int]]], + prompt_token_ids: Union[list[int], list[list[int]]], use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - ) -> List[RequestOutput]: + ) -> list[RequestOutput]: ... @deprecate_kwargs( @@ -379,17 +379,17 @@ def generate( def generate( self, prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, List[str]]]] = None, + Optional[Union[str, list[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, - prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - priority: Optional[List[int]] = None, - ) -> List[RequestOutput]: + priority: Optional[list[int]] = None, + ) -> list[RequestOutput]: """Generates the completions for the input prompts. This class automatically batches the given prompts, considering @@ -440,7 +440,7 @@ def generate( if prompt_token_ids is not None: parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, List[str]]], prompts), + prompts=cast(Optional[Union[str, list[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: @@ -473,8 +473,8 @@ def generate( def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict[str, Any]] = None) -> List[_R]: + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: """ Execute an RPC call on all workers. @@ -510,9 +510,9 @@ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: def beam_search( self, - prompts: List[Union[TokensPrompt, TextPrompt]], + prompts: list[Union[TokensPrompt, TextPrompt]], params: BeamSearchParams, - ) -> List[BeamSearchOutput]: + ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -543,7 +543,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float: beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, temperature=temperature) - instances: List[BeamSearchInstance] = [] + instances: list[BeamSearchInstance] = [] for prompt in prompts: if is_token_prompt(prompt): @@ -553,12 +553,12 @@ def sort_beams_key(x: BeamSearchSequence) -> float: instances.append(BeamSearchInstance(prompt_tokens)) for _ in range(max_tokens): - all_beams: List[BeamSearchSequence] = list( + all_beams: list[BeamSearchSequence] = list( sum((instance.beams for instance in instances), [])) pos = [0] + list( itertools.accumulate( len(instance.beams) for instance in instances)) - instance_start_and_end: List[Tuple[int, int]] = list( + instance_start_and_end: list[tuple[int, int]] = list( zip(pos[:-1], pos[1:])) if len(all_beams) == 0: @@ -620,19 +620,19 @@ def sort_beams_key(x: BeamSearchSequence) -> float: def chat( self, - messages: Union[List[ChatCompletionMessageParam], - List[List[ChatCompletionMessageParam]]], + messages: Union[list[ChatCompletionMessageParam], + list[list[ChatCompletionMessageParam]]], sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + list[SamplingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, chat_template_content_format: ChatTemplateContentFormatOption = "auto", add_generation_prompt: bool = True, continue_final_message: bool = False, - tools: Optional[List[Dict[str, Any]]] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - ) -> List[RequestOutput]: + tools: Optional[list[dict[str, Any]]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ) -> list[RequestOutput]: """ Generate responses for a chat conversation. @@ -678,17 +678,17 @@ def chat( A list of ``RequestOutput`` objects containing the generated responses in the same order as the input messages. """ - list_of_messages: List[List[ChatCompletionMessageParam]] + list_of_messages: list[list[ChatCompletionMessageParam]] # Handle multi and single conversations if is_list_of(messages, list): - # messages is List[List[...]] - list_of_messages = cast(List[List[ChatCompletionMessageParam]], + # messages is list[list[...]] + list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages) else: - # messages is List[...] + # messages is list[...] list_of_messages = [ - cast(List[ChatCompletionMessageParam], messages) + cast(list[ChatCompletionMessageParam], messages) ] tokenizer = self.get_tokenizer() @@ -699,7 +699,7 @@ def chat( tokenizer, ) - prompts: List[Union[TokensPrompt, TextPrompt]] = [] + prompts: list[Union[TokensPrompt, TextPrompt]] = [] for msgs in list_of_messages: # NOTE: _parse_chat_message_content_parts() currently doesn't @@ -712,7 +712,7 @@ def chat( content_format=resolved_content_format, ) - prompt_data: Union[str, List[int]] + prompt_data: Union[str, list[int]] if isinstance(tokenizer, MistralTokenizer): prompt_data = apply_mistral_chat_template( tokenizer, @@ -762,9 +762,9 @@ def encode( Sequence[PoolingParams]]] = None, *, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: ... @overload # LEGACY: single (prompt + optional token ids) @@ -774,25 +774,25 @@ def encode( prompts: str, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[List[int]] = None, + prompt_token_ids: Optional[list[int]] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: ... @overload # LEGACY: multi (prompt + optional token ids) @deprecated("'prompt_token_ids' will become part of 'prompts'") def encode( self, - prompts: List[str], + prompts: list[str], pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[List[List[int]]] = None, + prompt_token_ids: Optional[list[list[int]]] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: ... @overload # LEGACY: single (token ids + optional prompt) @@ -803,26 +803,26 @@ def encode( pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, - prompt_token_ids: List[int], + prompt_token_ids: list[int], use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: ... @overload # LEGACY: multi (token ids + optional prompt) @deprecated("'prompt_token_ids' will become part of 'prompts'") def encode( self, - prompts: Optional[List[str]] = None, + prompts: Optional[list[str]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, - prompt_token_ids: List[List[int]], + prompt_token_ids: list[list[int]], use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: ... @overload # LEGACY: single or multi token ids [pos-only] @@ -831,11 +831,11 @@ def encode( self, prompts: None, pooling_params: None, - prompt_token_ids: Union[List[int], List[List[int]]], + prompt_token_ids: Union[list[int], list[list[int]]], use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: ... @deprecate_kwargs( @@ -846,14 +846,14 @@ def encode( def encode( self, prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, List[str]]]] = None, + Optional[Union[str, list[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -898,7 +898,7 @@ def encode( if prompt_token_ids is not None: parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, List[str]]], prompts), + prompts=cast(Optional[Union[str, list[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: @@ -926,9 +926,9 @@ def embed( /, *, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> list[EmbeddingRequestOutput]: """ Generate an embedding vector for each prompt. @@ -966,9 +966,9 @@ def classify( /, *, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[ClassificationRequestOutput]: + ) -> list[ClassificationRequestOutput]: """ Generate class logits for each prompt. @@ -1003,29 +1003,29 @@ def classify( def _embedding_score( self, tokenizer: AnyTokenizer, - text_1: List[Union[str, TextPrompt, TokensPrompt]], - text_2: List[Union[str, TextPrompt, TokensPrompt]], + text_1: list[Union[str, TextPrompt, TokensPrompt]], + text_2: list[Union[str, TextPrompt, TokensPrompt]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[ScoringRequestOutput]: + ) -> list[ScoringRequestOutput]: - encoded_output: List[PoolingRequestOutput] = self.encode( + encoded_output: list[PoolingRequestOutput] = self.encode( text_1 + text_2, use_tqdm=use_tqdm, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - encoded_output_1: List[PoolingRequestOutput] = encoded_output[ + encoded_output_1: list[PoolingRequestOutput] = encoded_output[ 0:len(text_1)] - encoded_output_2: List[PoolingRequestOutput] = encoded_output[ + encoded_output_2: list[PoolingRequestOutput] = encoded_output[ len(text_1):] if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) - scores: List[PoolingRequestOutput] = [] + scores: list[PoolingRequestOutput] = [] scores = _cosine_similarity(tokenizer=tokenizer, embed_1=encoded_output_1, @@ -1038,13 +1038,13 @@ def _embedding_score( def _cross_encoding_score( self, tokenizer: AnyTokenizer, - text_1: List[str], - text_2: List[str], + text_1: list[str], + text_2: list[str], truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[ScoringRequestOutput]: + ) -> list[ScoringRequestOutput]: if isinstance(tokenizer, MistralTokenizer): raise ValueError( @@ -1057,7 +1057,7 @@ def _cross_encoding_score( pooling_params = PoolingParams() - tokenization_kwargs: Dict[str, Any] = {} + tokenization_kwargs: dict[str, Any] = {} if truncate_prompt_tokens is not None: tokenization_kwargs["truncation"] = True tokenization_kwargs["max_length"] = truncate_prompt_tokens @@ -1094,9 +1094,9 @@ def score( *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[ScoringRequestOutput]: + ) -> list[ScoringRequestOutput]: """Generate similarity scores for all pairs ````. The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``. @@ -1162,12 +1162,12 @@ def ensure_str(prompt: SingletonPrompt): if isinstance(text_1, (str, dict)): # Convert a single prompt to a list. text_1 = [text_1] - input_text_1: List[str] = [ensure_str(t) for t in text_1] + input_text_1: list[str] = [ensure_str(t) for t in text_1] if isinstance(text_2, (str, dict)): # Convert a single prompt to a list. text_2 = [text_2] - input_text_2: List[str] = [ensure_str(t) for t in text_2] + input_text_2: list[str] = [ensure_str(t) for t in text_2] _validate_score_input_lens(input_text_1, input_text_2) @@ -1226,8 +1226,8 @@ def wake_up(self): # LEGACY def _convert_v1_inputs( self, - prompts: Optional[Union[str, List[str]]], - prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + prompts: Optional[Union[str, list[str]]], + prompt_token_ids: Optional[Union[list[int], list[list[int]]]], ): # skip_tokenizer_init is now checked in engine @@ -1252,7 +1252,7 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - parsed_prompts: List[PromptType] = [] + parsed_prompts: list[PromptType] = [] for i in range(num_requests): item: PromptType @@ -1275,7 +1275,7 @@ def _validate_and_add_requests( lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, - priority: Optional[List[int]] = None, + priority: Optional[list[int]] = None, ) -> None: if guided_options is not None: warnings.warn( @@ -1357,7 +1357,7 @@ def _add_guided_params( def _run_engine( self, *, use_tqdm: bool - ) -> List[Union[RequestOutput, PoolingRequestOutput]]: + ) -> list[Union[RequestOutput, PoolingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -1370,7 +1370,7 @@ def _run_engine( ) # Run the engine. - outputs: List[Union[RequestOutput, PoolingRequestOutput]] = [] + outputs: list[Union[RequestOutput, PoolingRequestOutput]] = [] total_in_toks = 0 total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index e82b6ba6c7ba..ea5759152a22 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Union +from typing import Optional, Union from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -22,7 +22,7 @@ def log_inputs( self, request_id: str, prompt: Optional[str], - prompt_token_ids: Optional[List[int]], + prompt_token_ids: Optional[list[int]], params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1b65484c446a..ec2099d4cebf 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -13,10 +13,11 @@ import tempfile import uuid from argparse import Namespace +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union +from typing import Annotated, Optional, Union import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request @@ -93,7 +94,7 @@ # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger('vllm.entrypoints.openai.api_server') -_running_tasks: Set[asyncio.Task] = set() +_running_tasks: set[asyncio.Task] = set() @asynccontextmanager @@ -587,7 +588,7 @@ async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { +TASK_HANDLERS: dict[str, dict[str, tuple]] = { "generate": { "messages": (ChatCompletionRequest, create_chat_completion), "default": (CompletionRequest, create_completion), @@ -894,7 +895,7 @@ async def init_app_state( state.task = model_config.task -def create_server_socket(addr: Tuple[str, int]) -> socket.socket: +def create_server_socket(addr: tuple[str, int]) -> socket.socket: family = socket.AF_INET if is_valid_ipv6_address(addr[0]): family = socket.AF_INET6 diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 8d877046f75f..b8cc57430f85 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -8,7 +8,8 @@ import argparse import json import ssl -from typing import List, Optional, Sequence, Union, get_args +from collections.abc import Sequence +from typing import Optional, Union, get_args from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, @@ -33,7 +34,7 @@ def __call__( if isinstance(values, str): raise TypeError("Expected values to be a list") - lora_list: List[LoRAModulePath] = [] + lora_list: list[LoRAModulePath] = [] for item in values: if item in [None, '']: # Skip if item is None or empty string continue @@ -69,7 +70,7 @@ def __call__( if isinstance(values, str): raise TypeError("Expected values to be a list") - adapter_list: List[PromptAdapterPath] = [] + adapter_list: list[PromptAdapterPath] = [] for item in values: name, path = item.split('=') adapter_list.append(PromptAdapterPath(name, path)) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 41e5eef40eaf..04d5091a9681 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable from functools import lru_cache, partial -from typing import Dict, FrozenSet, Iterable, List, Optional, Union +from typing import Optional, Union import torch @@ -14,10 +15,10 @@ class AllowedTokenIdsLogitsProcessor: specific set of token ids.""" def __init__(self, allowed_ids: Iterable[int]): - self.allowed_ids: Optional[List[int]] = list(allowed_ids) + self.allowed_ids: Optional[list[int]] = list(allowed_ids) self.mask: Optional[torch.Tensor] = None - def __call__(self, token_ids: List[int], + def __call__(self, token_ids: list[int], logits: torch.Tensor) -> torch.Tensor: if self.mask is None: self.mask = torch.ones((logits.shape[-1], ), @@ -31,7 +32,7 @@ def __call__(self, token_ids: List[int], @lru_cache(maxsize=32) def _get_allowed_token_ids_logits_processor( - allowed_token_ids: FrozenSet[int], + allowed_token_ids: frozenset[int], vocab_size: int, ) -> LogitsProcessor: if not allowed_token_ids: @@ -43,8 +44,8 @@ def _get_allowed_token_ids_logits_processor( def logit_bias_logits_processor( - logit_bias: Dict[int, float], - token_ids: List[int], + logit_bias: dict[int, float], + token_ids: list[int], logits: torch.Tensor, ) -> torch.Tensor: for token_id, bias in logit_bias.items(): @@ -53,16 +54,16 @@ def logit_bias_logits_processor( def get_logits_processors( - logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], - allowed_token_ids: Optional[List[int]], + logit_bias: Optional[Union[dict[int, float], dict[str, float]]], + allowed_token_ids: Optional[list[int]], tokenizer: AnyTokenizer, -) -> List[LogitsProcessor]: - logits_processors: List[LogitsProcessor] = [] +) -> list[LogitsProcessor]: + logits_processors: list[LogitsProcessor] = [] if logit_bias: try: # Convert token_id to integer # Clamp the bias between -100 and 100 per OpenAI API spec - clamped_logit_bias: Dict[int, float] = { + clamped_logit_bias: dict[int, float] = { int(token_id): min(100.0, max(-100.0, bias)) for token_id, bias in logit_bias.items() } diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 31214211cfc4..14ce71cd3c2e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,13 +5,13 @@ import re import time from argparse import Namespace -from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union +from typing import Annotated, Any, ClassVar, Literal, Optional, Union import torch from fastapi import UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) -from typing_extensions import Annotated, TypeAlias +from typing_extensions import TypeAlias from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger @@ -47,7 +47,7 @@ class OpenAIBaseModel(BaseModel): model_config = ConfigDict(extra="allow") # Cache class field names - field_names: ClassVar[Optional[Set[str]]] = None + field_names: ClassVar[Optional[set[str]]] = None @model_validator(mode="wrap") @classmethod @@ -105,12 +105,12 @@ class ModelCard(OpenAIBaseModel): root: Optional[str] = None parent: Optional[str] = None max_model_len: Optional[int] = None - permission: List[ModelPermission] = Field(default_factory=list) + permission: list[ModelPermission] = Field(default_factory=list) class ModelList(OpenAIBaseModel): object: str = "list" - data: List[ModelCard] = Field(default_factory=list) + data: list[ModelCard] = Field(default_factory=list) class PromptTokenUsageInfo(OpenAIBaseModel): @@ -134,7 +134,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): description: Optional[str] = None # schema is the field in openai but that causes conflicts with pydantic so # instead use json_schema with an alias - json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema') + json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') strict: Optional[bool] = None @@ -152,7 +152,7 @@ class StreamOptions(OpenAIBaseModel): class FunctionDefinition(OpenAIBaseModel): name: str description: Optional[str] = None - parameters: Optional[Dict[str, Any]] = None + parameters: Optional[dict[str, Any]] = None class ChatCompletionToolsParam(OpenAIBaseModel): @@ -171,15 +171,15 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): class LogitsProcessorConstructor(BaseModel): qualname: str - args: Optional[List[Any]] = None - kwargs: Optional[Dict[str, Any]] = None + args: Optional[list[Any]] = None + kwargs: Optional[dict[str, Any]] = None -LogitsProcessors = List[Union[str, LogitsProcessorConstructor]] +LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] def get_logits_processors(processors: Optional[LogitsProcessors], - pattern: Optional[str]) -> Optional[List[Any]]: + pattern: Optional[str]) -> Optional[list[Any]]: if processors and pattern: logits_processors = [] for processor in processors: @@ -212,10 +212,10 @@ def get_logits_processors(processors: Optional[LogitsProcessors], class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create - messages: List[ChatCompletionMessageParam] + messages: list[ChatCompletionMessageParam] model: Optional[str] = None frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[Dict[str, float]] = None + logit_bias: Optional[dict[str, float]] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 # TODO(#9845): remove max_tokens when field is removed from OpenAI API @@ -228,12 +228,12 @@ class ChatCompletionRequest(OpenAIBaseModel): presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stop: Optional[Union[str, list[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None temperature: Optional[float] = None top_p: Optional[float] = None - tools: Optional[List[ChatCompletionToolsParam]] = None + tools: Optional[list[ChatCompletionToolsParam]] = None tool_choice: Optional[Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam]] = "none" @@ -248,7 +248,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = Field(default_factory=list) + stop_token_ids: Optional[list[int]] = Field(default_factory=list) include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 @@ -290,7 +290,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "special tokens so this should be set to false (as is the " "default)."), ) - documents: Optional[List[Dict[str, str]]] = Field( + documents: Optional[list[dict[str, str]]] = Field( default=None, description= ("A list of dicts representing documents that will be accessible to " @@ -307,12 +307,12 @@ class ChatCompletionRequest(OpenAIBaseModel): "allowed, so you must provide a chat template if the tokenizer " "does not define one."), ) - chat_template_kwargs: Optional[Dict[str, Any]] = Field( + chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), ) - mm_processor_kwargs: Optional[Dict[str, Any]] = Field( + mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) @@ -325,7 +325,7 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "If specified, the output will follow the regex pattern."), ) - guided_choice: Optional[List[str]] = Field( + guided_choice: Optional[list[str]] = Field( default=None, description=( "If specified, the output will be exactly one of the choices."), @@ -643,17 +643,17 @@ class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create model: Optional[str] = None - prompt: Union[List[int], List[List[int]], str, List[str]] + prompt: Union[list[int], list[list[int]], str, list[str]] best_of: Optional[int] = None echo: Optional[bool] = False frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[Dict[str, float]] = None + logit_bias: Optional[dict[str, float]] = None logprobs: Optional[int] = None max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stop: Optional[Union[str, list[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None @@ -667,14 +667,14 @@ class CompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = Field(default_factory=list) + stop_token_ids: Optional[list[int]] = Field(default_factory=list) include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - allowed_token_ids: Optional[List[int]] = None + allowed_token_ids: Optional[list[int]] = None prompt_logprobs: Optional[int] = None # doc: end-completion-sampling-params @@ -701,7 +701,7 @@ class CompletionRequest(OpenAIBaseModel): description=( "If specified, the output will follow the regex pattern."), ) - guided_choice: Optional[List[str]] = Field( + guided_choice: Optional[list[str]] = Field( default=None, description=( "If specified, the output will be exactly one of the choices."), @@ -908,7 +908,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings model: Optional[str] = None - input: Union[List[int], List[List[int]], str, List[str]] + input: Union[list[int], list[list[int]], str, list[str]] encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None @@ -940,7 +940,7 @@ def to_pooling_params(self): class EmbeddingChatRequest(OpenAIBaseModel): model: Optional[str] = None - messages: List[ChatCompletionMessageParam] + messages: list[ChatCompletionMessageParam] encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None @@ -969,12 +969,12 @@ class EmbeddingChatRequest(OpenAIBaseModel): "allowed, so you must provide a chat template if the tokenizer " "does not define one."), ) - chat_template_kwargs: Optional[Dict[str, Any]] = Field( + chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), ) - mm_processor_kwargs: Optional[Dict[str, Any]] = Field( + mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) @@ -1008,8 +1008,8 @@ def to_pooling_params(self): class ScoreRequest(OpenAIBaseModel): model: Optional[str] = None - text_1: Union[List[str], str] - text_2: Union[List[str], str] + text_1: Union[list[str], str] + text_2: Union[list[str], str] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: begin-score-pooling-params @@ -1033,7 +1033,7 @@ def to_pooling_params(self): class RerankRequest(OpenAIBaseModel): model: Optional[str] = None query: str - documents: List[str] + documents: list[str] top_n: int = Field(default_factory=lambda: 0) truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None @@ -1073,14 +1073,14 @@ class RerankResponse(OpenAIBaseModel): id: str model: str usage: RerankUsage - results: List[RerankResult] + results: list[RerankResult] class CompletionLogProbs(OpenAIBaseModel): - text_offset: List[int] = Field(default_factory=list) - token_logprobs: List[Optional[float]] = Field(default_factory=list) - tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, + text_offset: list[int] = Field(default_factory=list) + token_logprobs: list[Optional[float]] = Field(default_factory=list) + tokens: list[str] = Field(default_factory=list) + top_logprobs: list[Optional[dict[str, float]]] = Field(default_factory=list) @@ -1096,7 +1096,7 @@ class CompletionResponseChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) - prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None + prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None class CompletionResponse(OpenAIBaseModel): @@ -1104,7 +1104,7 @@ class CompletionResponse(OpenAIBaseModel): object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[CompletionResponseChoice] + choices: list[CompletionResponseChoice] usage: UsageInfo @@ -1127,14 +1127,14 @@ class CompletionStreamResponse(OpenAIBaseModel): object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[CompletionResponseStreamChoice] + choices: list[CompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) class EmbeddingResponseData(OpenAIBaseModel): index: int object: str = "embedding" - embedding: Union[List[float], str] + embedding: Union[list[float], str] class EmbeddingResponse(OpenAIBaseModel): @@ -1142,14 +1142,14 @@ class EmbeddingResponse(OpenAIBaseModel): object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) model: str - data: List[EmbeddingResponseData] + data: list[EmbeddingResponseData] usage: UsageInfo class PoolingResponseData(OpenAIBaseModel): index: int object: str = "pooling" - data: Union[List[List[float]], List[float], str] + data: Union[list[list[float]], list[float], str] class PoolingResponse(OpenAIBaseModel): @@ -1157,7 +1157,7 @@ class PoolingResponse(OpenAIBaseModel): object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) model: str - data: List[PoolingResponseData] + data: list[PoolingResponseData] usage: UsageInfo @@ -1172,7 +1172,7 @@ class ScoreResponse(OpenAIBaseModel): object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) model: str - data: List[ScoreResponseData] + data: list[ScoreResponseData] usage: UsageInfo @@ -1205,7 +1205,7 @@ class ExtractedToolCallInformation(BaseModel): tools_called: bool # extracted tool calls - tool_calls: List[ToolCall] + tool_calls: list[ToolCall] # content - per OpenAI spec, content AND tool calls can be returned rarely # But some models will do this intentionally @@ -1216,21 +1216,21 @@ class ChatMessage(OpenAIBaseModel): role: str reasoning_content: Optional[str] = None content: Optional[str] = None - tool_calls: List[ToolCall] = Field(default_factory=list) + tool_calls: list[ToolCall] = Field(default_factory=list) class ChatCompletionLogProb(OpenAIBaseModel): token: str logprob: float = -9999.0 - bytes: Optional[List[int]] = None + bytes: Optional[list[int]] = None class ChatCompletionLogProbsContent(ChatCompletionLogProb): - top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list) class ChatCompletionLogProbs(OpenAIBaseModel): - content: Optional[List[ChatCompletionLogProbsContent]] = None + content: Optional[list[ChatCompletionLogProbsContent]] = None class ChatCompletionResponseChoice(OpenAIBaseModel): @@ -1248,16 +1248,16 @@ class ChatCompletionResponse(OpenAIBaseModel): object: Literal["chat.completion"] = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionResponseChoice] + choices: list[ChatCompletionResponseChoice] usage: UsageInfo - prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None + prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None reasoning_content: Optional[str] = None - tool_calls: List[DeltaToolCall] = Field(default_factory=list) + tool_calls: list[DeltaToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): @@ -1273,7 +1273,7 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): object: Literal["chat.completion.chunk"] = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionResponseStreamChoice] + choices: list[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) @@ -1358,7 +1358,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel): class TokenizeChatRequest(OpenAIBaseModel): model: Optional[str] = None - messages: List[ChatCompletionMessageParam] + messages: list[ChatCompletionMessageParam] add_generation_prompt: bool = Field( default=True, @@ -1393,12 +1393,12 @@ class TokenizeChatRequest(OpenAIBaseModel): "allowed, so you must provide a chat template if the tokenizer " "does not define one."), ) - chat_template_kwargs: Optional[Dict[str, Any]] = Field( + chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), ) - mm_processor_kwargs: Optional[Dict[str, Any]] = Field( + mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) @@ -1419,12 +1419,12 @@ def check_generation_prompt(cls, data): class TokenizeResponse(OpenAIBaseModel): count: int max_model_len: int - tokens: List[int] + tokens: list[int] class DetokenizeRequest(OpenAIBaseModel): model: Optional[str] = None - tokens: List[int] + tokens: list[int] class DetokenizeResponse(OpenAIBaseModel): @@ -1492,7 +1492,7 @@ class TranscriptionRequest(OpenAIBaseModel): to automatically increase the temperature until certain thresholds are hit. """ - timestamp_granularities: List[Literal["word", "segment"]] = Field( + timestamp_granularities: list[Literal["word", "segment"]] = Field( alias="timestamp_granularities[]", default=[]) """The timestamp granularities to populate for this transcription. @@ -1580,7 +1580,7 @@ class TranscriptionSegment(OpenAIBaseModel): text: str """Text content of the segment.""" - tokens: List[int] + tokens: list[int] """Array of token IDs for the text content.""" @@ -1594,8 +1594,8 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): text: str """The transcribed text.""" - segments: Optional[List[TranscriptionSegment]] = None + segments: Optional[list[TranscriptionSegment]] = None """Segments of the transcribed text and their corresponding details.""" - words: Optional[List[TranscriptionWord]] = None + words: Optional[list[TranscriptionWord]] = None """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py b/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py index b5df7e47446b..b3bc0e836d4c 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py +++ b/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import os +from collections.abc import Sequence from functools import cached_property -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Callable, Optional, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) @@ -25,14 +26,14 @@ def __init__(self, tokenizer: AnyTokenizer): self.model_tokenizer = tokenizer @cached_property - def vocab(self) -> Dict[str, int]: + def vocab(self) -> dict[str, int]: # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest - ) -> Tuple[Optional[str], Optional[str]]: + ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from a complete model-generated string. @@ -47,7 +48,7 @@ def extract_reasoning_content( The request object that was used to generate the model_output. Returns: - Tuple[Optional[str], Optional[str]] + tuple[Optional[str], Optional[str]] A tuple containing the reasoning content and the content. """ @@ -77,10 +78,10 @@ def extract_reasoning_content_streaming( class ReasoningParserManager: - reasoning_parsers: Dict[str, Type] = {} + reasoning_parsers: dict[str, type] = {} @classmethod - def get_reasoning_parser(cls, name) -> Type: + def get_reasoning_parser(cls, name) -> type: """ Get reasoning parser by name which is registered by `register_module`. @@ -94,8 +95,8 @@ def get_reasoning_parser(cls, name) -> Type: @classmethod def _register_module(cls, - module: Type, - module_name: Optional[Union[str, List[str]]] = None, + module: type, + module_name: Optional[Union[str, list[str]]] = None, force: bool = True) -> None: if not issubclass(module, ReasoningParser): raise TypeError("module must be subclass of ReasoningParser, " @@ -114,9 +115,9 @@ def _register_module(cls, @classmethod def register_module( cls, - name: Optional[Union[str, List[str]]] = None, + name: Optional[Union[str, list[str]]] = None, force: bool = True, - module: Union[Type, None] = None) -> Union[type, Callable]: + module: Union[type, None] = None) -> Union[type, Callable]: """ Register module with the given name or name list. it can be used as a decoder(with module as None) or normal function(with module as not diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py index e5ab6e6b2339..1a2c66a60e96 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Optional, Union from transformers import PreTrainedTokenizerBase @@ -122,7 +123,7 @@ def extract_reasoning_content_streaming( def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest - ) -> Tuple[Optional[str], Optional[str]]: + ) -> tuple[Optional[str], Optional[str]]: # DeepSeek R1 doesn't generate now. # Thus we assume the reasoning content is always at the start. diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index e4496f61e607..0d06ba3df23f 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -2,9 +2,10 @@ import asyncio import tempfile +from collections.abc import Awaitable from http import HTTPStatus from io import StringIO -from typing import Awaitable, Callable, List, Optional +from typing import Callable, Optional import aiohttp import torch @@ -143,7 +144,7 @@ async def read_file(path_or_url: str) -> str: async def write_local_file(output_path: str, - batch_outputs: List[BatchRequestOutput]) -> None: + batch_outputs: list[BatchRequestOutput]) -> None: """ Write the responses to a local file. output_path: The path to write the responses to. @@ -204,7 +205,7 @@ async def upload_data(output_url: str, data_or_file: str, f"Error message: {str(e)}.") from e -async def write_file(path_or_url: str, batch_outputs: List[BatchRequestOutput], +async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str) -> None: """ Write batch_outputs to a file or upload to a URL. @@ -353,7 +354,7 @@ async def main(args): logger.info("Reading batch from %s...", args.input_file) # Submit all requests in the file to the engine "concurrently". - response_futures: List[Awaitable[BatchRequestOutput]] = [] + response_futures: list[Awaitable[BatchRequestOutput]] = [] for request_json in (await read_file(args.input_file)).strip().split("\n"): # Skip empty lines. request_json = request_json.strip() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 02dd2c4881c6..98e9ea0fc61a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -3,10 +3,9 @@ import asyncio import json import time -from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, - Optional) -from typing import Sequence as GenericSequence -from typing import Union +from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import Sequence as GenericSequence +from typing import Callable, Final, Optional, Union from fastapi import Request @@ -205,7 +204,7 @@ async def create_chat_completion( raw_request.state.request_metadata = request_metadata # Schedule the request and get the result generator. - generators: List[AsyncGenerator[RequestOutput, None]] = [] + generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] @@ -282,7 +281,7 @@ async def chat_completion_stream_generator( result_generator: AsyncIterator[RequestOutput], request_id: str, model_name: str, - conversation: List[ConversationMessage], + conversation: list[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: @@ -310,7 +309,7 @@ async def chat_completion_stream_generator( should_stream_with_reasoning_parsing = ( self._should_stream_with_reasoning_parsing(request)) - all_previous_token_ids: Optional[List[List[int]]] + all_previous_token_ids: Optional[list[list[int]]] # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. @@ -339,7 +338,7 @@ async def chat_completion_stream_generator( # Prepare the tool parser if it's needed try: if tool_choice_auto and self.tool_parser: - tool_parsers: List[Optional[ToolParser]] = [ + tool_parsers: list[Optional[ToolParser]] = [ self.tool_parser(tokenizer) ] * num_choices else: @@ -406,7 +405,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo: - last_msg_content: Union[str, List[Dict[str, str]]] = "" + last_msg_content: Union[str, list[dict[str, str]]] = "" if conversation and "content" in conversation[ -1] and conversation[-1].get("role") == role: last_msg_content = conversation[-1]["content"] or "" @@ -674,7 +673,7 @@ async def chat_completion_full_generator( result_generator: AsyncIterator[RequestOutput], request_id: str, model_name: str, - conversation: List[ConversationMessage], + conversation: list[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: @@ -693,7 +692,7 @@ async def chat_completion_full_generator( assert final_res is not None - choices: List[ChatCompletionResponseChoice] = [] + choices: list[ChatCompletionResponseChoice] = [] role = self.get_chat_request_role(request) for output in final_res.outputs: @@ -812,7 +811,7 @@ async def chat_completion_full_generator( choices.append(choice_data) if request.echo: - last_msg_content: Union[str, List[Dict[str, str]]] = "" + last_msg_content: Union[str, list[dict[str, str]]] = "" if conversation and "content" in conversation[-1] and conversation[ -1].get("role") == role: last_msg_content = conversation[-1]["content"] or "" @@ -853,8 +852,8 @@ async def chat_completion_full_generator( return response def _get_top_logprobs( - self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], - tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]: + self, logprobs: dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: AnyTokenizer) -> list[ChatCompletionLogProb]: return [ ChatCompletionLogProb(token=(token := self._get_decoded_token( p[1], @@ -871,12 +870,12 @@ def _get_top_logprobs( def _create_chat_logprobs( self, token_ids: GenericSequence[int], - top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], tokenizer: AnyTokenizer, num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" - logprobs_content: List[ChatCompletionLogProbsContent] = [] + logprobs_content: list[ChatCompletionLogProbsContent] = [] for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 840f0f9b8448..ed09af84f64b 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -2,9 +2,9 @@ import asyncio import time -from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import Tuple, Union, cast +from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import Sequence as GenericSequence +from typing import Optional, Union, cast from fastapi import Request @@ -113,7 +113,7 @@ async def create_completion( return self.create_error_response(str(e)) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[RequestOutput, None]] = [] + generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] @@ -189,7 +189,7 @@ async def create_completion( request_metadata=request_metadata) # Non-streaming response - final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts + final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts try: async for i, res in result_generator: final_res_batch[i] = res @@ -203,7 +203,7 @@ async def create_completion( if final_res.prompt is None: final_res.prompt = request_prompts[i]["prompt"] - final_res_batch_checked = cast(List[RequestOutput], + final_res_batch_checked = cast(list[RequestOutput], final_res_batch) response = self.request_output_to_completion_response( @@ -237,7 +237,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: async def completion_stream_generator( self, request: CompletionRequest, - result_generator: AsyncIterator[Tuple[int, RequestOutput]], + result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, model_name: str, @@ -270,7 +270,7 @@ async def completion_stream_generator( num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) delta_token_ids: GenericSequence[int] - out_logprobs: Optional[GenericSequence[Optional[Dict[ + out_logprobs: Optional[GenericSequence[Optional[dict[ int, Logprob]]]] for output in res.outputs: @@ -381,7 +381,7 @@ async def completion_stream_generator( def request_output_to_completion_response( self, - final_res_batch: List[RequestOutput], + final_res_batch: list[RequestOutput], request: CompletionRequest, request_id: str, created_time: int, @@ -389,7 +389,7 @@ def request_output_to_completion_response( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> CompletionResponse: - choices: List[CompletionResponseChoice] = [] + choices: list[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 @@ -406,7 +406,7 @@ def request_output_to_completion_response( prompt_text = final_res.prompt token_ids: GenericSequence[int] - out_logprobs: Optional[GenericSequence[Optional[Dict[int, + out_logprobs: Optional[GenericSequence[Optional[dict[int, Logprob]]]] for output in final_res.outputs: @@ -480,16 +480,16 @@ def request_output_to_completion_response( def _create_completion_logprobs( self, token_ids: GenericSequence[int], - top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], num_output_top_logprobs: int, tokenizer: AnyTokenizer, initial_text_offset: int = 0, ) -> CompletionLogProbs: """Create logprobs for OpenAI Completion API.""" - out_text_offset: List[int] = [] - out_token_logprobs: List[Optional[float]] = [] - out_tokens: List[str] = [] - out_top_logprobs: List[Optional[Dict[str, float]]] = [] + out_text_offset: list[int] = [] + out_token_logprobs: list[Optional[float]] = [] + out_tokens: list[str] = [] + out_top_logprobs: list[Optional[dict[str, float]]] = [] last_token_len = 0 diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 607dbd96b194..5f6e06e6f79f 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -3,7 +3,8 @@ import asyncio import base64 import time -from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast +from collections.abc import AsyncGenerator +from typing import Final, Literal, Optional, Union, cast import numpy as np from fastapi import Request @@ -31,7 +32,7 @@ def _get_embedding( output: EmbeddingOutput, encoding_format: Literal["float", "base64"], -) -> Union[List[float], str]: +) -> Union[list[float], str]: if encoding_format == "float": return output.embedding elif encoding_format == "base64": @@ -143,7 +144,7 @@ async def create_embedding( return self.create_error_response(str(e)) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] try: pooling_params = request.to_pooling_params() @@ -178,7 +179,7 @@ async def create_embedding( num_prompts = len(engine_prompts) # Non-streaming response - final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch: list[Optional[PoolingRequestOutput]] final_res_batch = [None] * num_prompts try: async for i, res in result_generator: @@ -186,7 +187,7 @@ async def create_embedding( assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch) response = self.request_output_to_embedding_response( @@ -206,13 +207,13 @@ async def create_embedding( def request_output_to_embedding_response( self, - final_res_batch: List[PoolingRequestOutput], + final_res_batch: list[PoolingRequestOutput], request_id: str, created_time: int, model_name: str, encoding_format: Literal["float", "base64"], ) -> EmbeddingResponse: - items: List[EmbeddingResponseData] = [] + items: list[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d097bfcfc5ab..59333dbfd24e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,15 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 import json +from collections.abc import Iterable, Iterator, Mapping, Sequence from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus -from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, - Optional, Sequence, Tuple, TypedDict, Union) +from typing import Annotated, Any, Callable, Optional, TypedDict, Union from fastapi import Request from pydantic import Field from starlette.datastructures import Headers -from typing_extensions import Annotated from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -64,10 +63,10 @@ class TextTokensPrompt(TypedDict): prompt: str - prompt_token_ids: List[int] + prompt_token_ids: list[int] -RequestPrompt = Union[List[int], str, TextTokensPrompt] +RequestPrompt = Union[list[int], str, TextTokensPrompt] class OpenAIServing: @@ -144,7 +143,7 @@ async def _check_model( def _maybe_get_adapters( self, request: AnyRequest - ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ + ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[ None, PromptAdapterRequest]]: if self._is_model_supported(request.model): return None, None @@ -188,7 +187,7 @@ def _normalize_prompt_tokens_to_input( self, request: AnyRequest, tokenizer: AnyTokenizer, - prompt_ids: List[int], + prompt_ids: list[int], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], ) -> TextTokensPrompt: if truncate_prompt_tokens is None: @@ -203,7 +202,7 @@ def _normalize_prompt_tokens_to_input( def _validate_input( self, request: AnyRequest, - input_ids: List[int], + input_ids: list[int], input_text: str, ) -> TextTokensPrompt: token_num = len(input_ids) @@ -259,7 +258,7 @@ def _tokenize_prompt_input( self, request: AnyRequest, tokenizer: AnyTokenizer, - prompt_input: Union[str, List[int]], + prompt_input: Union[str, list[int]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> TextTokensPrompt: @@ -280,7 +279,7 @@ def _tokenize_prompt_inputs( self, request: AnyRequest, tokenizer: AnyTokenizer, - prompt_inputs: Iterable[Union[str, List[int]]], + prompt_inputs: Iterable[Union[str, list[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: @@ -309,10 +308,10 @@ def _tokenize_prompt_input_or_inputs( self, request: AnyRequest, tokenizer: AnyTokenizer, - input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + input_or_inputs: Union[str, list[str], list[int], list[list[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> List[TextTokensPrompt]: + ) -> list[TextTokensPrompt]: """ Tokenize/detokenize depending on the input format. @@ -344,10 +343,10 @@ async def _preprocess_completion( self, request: CompletionLikeRequest, tokenizer: AnyTokenizer, - input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + input_or_inputs: Union[str, list[str], list[int], list[list[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]: + ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]: request_prompts = await self._tokenize_prompt_input_or_inputs_async( request, tokenizer, @@ -367,19 +366,19 @@ async def _preprocess_chat( self, request: ChatLikeRequest, tokenizer: AnyTokenizer, - messages: List[ChatCompletionMessageParam], + messages: list[ChatCompletionMessageParam], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, add_generation_prompt: bool = True, continue_final_message: bool = False, - tool_dicts: Optional[List[Dict[str, Any]]] = None, - documents: Optional[List[Dict[str, str]]] = None, - chat_template_kwargs: Optional[Dict[str, Any]] = None, + tool_dicts: Optional[list[dict[str, Any]]] = None, + documents: Optional[list[dict[str, str]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = False, - ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt], - List[TokensPrompt]]: + ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], + list[TokensPrompt]]: resolved_content_format = resolve_chat_template_content_format( chat_template, chat_template_content_format, @@ -392,7 +391,7 @@ async def _preprocess_chat( content_format=resolved_content_format, ) - _chat_template_kwargs: Dict[str, Any] = dict( + _chat_template_kwargs: dict[str, Any] = dict( chat_template=chat_template, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, @@ -401,7 +400,7 @@ async def _preprocess_chat( ) _chat_template_kwargs.update(chat_template_kwargs or {}) - request_prompt: Union[str, List[int]] + request_prompt: Union[str, list[int]] if isinstance(tokenizer, MistralTokenizer): request_prompt = apply_mistral_chat_template( tokenizer, diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index 0f4a174a8c15..38a66583022a 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -4,7 +4,7 @@ import pathlib from dataclasses import dataclass from http import HTTPStatus -from typing import List, Optional, Union +from typing import Optional, Union from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -53,10 +53,10 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + base_model_paths: list[BaseModelPath], *, - lora_modules: Optional[List[LoRAModulePath]] = None, - prompt_adapters: Optional[List[PromptAdapterPath]] = None, + lora_modules: Optional[list[LoRAModulePath]] = None, + prompt_adapters: Optional[list[PromptAdapterPath]] = None, ): super().__init__() @@ -65,7 +65,7 @@ def __init__( self.engine_client = engine_client self.static_lora_modules = lora_modules - self.lora_requests: List[LoRARequest] = [] + self.lora_requests: list[LoRARequest] = [] self.lora_id_counter = AtomicCounter(0) self.prompt_adapter_requests = [] diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index bbf5aed1a33c..0a3ca2aa7c5b 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -3,7 +3,8 @@ import asyncio import base64 import time -from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast +from collections.abc import AsyncGenerator +from typing import Final, Literal, Optional, Union, cast import numpy as np from fastapi import Request @@ -29,7 +30,7 @@ def _get_data( output: PoolingOutput, encoding_format: Literal["float", "base64"], -) -> Union[List[float], str]: +) -> Union[list[float], str]: if encoding_format == "float": return output.data.tolist() elif encoding_format == "base64": @@ -139,7 +140,7 @@ async def create_pooling( return self.create_error_response(str(e)) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] try: pooling_params = request.to_pooling_params() @@ -174,7 +175,7 @@ async def create_pooling( num_prompts = len(engine_prompts) # Non-streaming response - final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch: list[Optional[PoolingRequestOutput]] final_res_batch = [None] * num_prompts try: async for i, res in result_generator: @@ -182,7 +183,7 @@ async def create_pooling( assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch) response = self.request_output_to_pooling_response( @@ -202,13 +203,13 @@ async def create_pooling( def request_output_to_pooling_response( self, - final_res_batch: List[PoolingRequestOutput], + final_res_batch: list[PoolingRequestOutput], request_id: str, created_time: int, model_name: str, encoding_format: Literal["float", "base64"], ) -> PoolingResponse: - items: List[PoolingResponseData] = [] + items: list[PoolingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index a087a8d9ba0f..73b4288cbb0d 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio import time -from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Union +from collections.abc import AsyncGenerator, Mapping +from typing import Any, Optional, Union from fastapi import Request @@ -48,8 +49,8 @@ def __init__( async def _embedding_score( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - texts_1: List[str], - texts_2: List[str], + texts_1: list[str], + texts_2: list[str], request: Union[RerankRequest, ScoreRequest], request_id=str, tokenization_kwargs: Optional[dict[str, Any]] = None, @@ -57,11 +58,11 @@ async def _embedding_score( prompt_adapter_request: Optional[Union[PromptAdapterRequest, None]] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: input_texts = texts_1 + texts_2 - engine_prompts: List[TokensPrompt] = [] + engine_prompts: list[TokensPrompt] = [] tokenize_async = make_async(tokenizer.__call__, executor=self._tokenizer_executor) @@ -82,7 +83,7 @@ async def _embedding_score( prompt_token_ids=text_token_prompt["prompt_token_ids"])) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] pooling_params = request.to_pooling_params() for i, engine_prompt in enumerate(engine_prompts): @@ -108,16 +109,16 @@ async def _embedding_score( result_generator = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: List[PoolingRequestOutput] = [] + final_res_batch: list[PoolingRequestOutput] = [] - embeddings: List[Optional[PoolingRequestOutput]] =\ + embeddings: list[Optional[PoolingRequestOutput]] =\ [None] * len(engine_prompts) async for i, res in result_generator: embeddings[i] = res - emb_texts_1: List[PoolingRequestOutput] = [] - emb_texts_2: List[PoolingRequestOutput] = [] + emb_texts_1: list[PoolingRequestOutput] = [] + emb_texts_2: list[PoolingRequestOutput] = [] for i in range(0, len(texts_1)): assert (emb := embeddings[i]) is not None @@ -139,8 +140,8 @@ async def _embedding_score( async def _cross_encoding_score( self, tokenizer: Union[AnyTokenizer], - texts_1: List[str], - texts_2: List[str], + texts_1: list[str], + texts_2: list[str], request: Union[RerankRequest, ScoreRequest], request_id=str, tokenization_kwargs: Optional[dict[str, Any]] = None, @@ -148,10 +149,10 @@ async def _cross_encoding_score( prompt_adapter_request: Optional[Union[PromptAdapterRequest, None]] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: - request_prompts: List[str] = [] - engine_prompts: List[TokensPrompt] = [] + request_prompts: list[str] = [] + engine_prompts: list[TokensPrompt] = [] if len(texts_1) == 1: texts_1 = texts_1 * len(texts_2) @@ -185,7 +186,7 @@ async def _cross_encoding_score( engine_prompts.append(engine_prompt) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] pooling_params = request.to_pooling_params() @@ -212,7 +213,7 @@ async def _cross_encoding_score( result_generator = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: List[ + final_res_batch: list[ Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) async for i, res in result_generator: @@ -228,9 +229,9 @@ async def _run_scoring( request_id: str, raw_request: Optional[Request] = None, truncate_prompt_tokens: Optional[int] = None, - ) -> List[PoolingRequestOutput]: + ) -> list[PoolingRequestOutput]: - tokenization_kwargs: Dict[str, Any] = {} + tokenization_kwargs: dict[str, Any] = {} if truncate_prompt_tokens is not None: tokenization_kwargs["truncation"] = True tokenization_kwargs["max_length"] = truncate_prompt_tokens @@ -372,12 +373,12 @@ async def do_rerank( def request_output_to_score_response( self, - final_res_batch: List[PoolingRequestOutput], + final_res_batch: list[PoolingRequestOutput], request_id: str, created_time: int, model_name: str, ) -> ScoreResponse: - items: List[ScoreResponseData] = [] + items: list[ScoreResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): @@ -406,13 +407,13 @@ def request_output_to_score_response( ) def request_output_to_rerank_response( - self, final_res_batch: List[PoolingRequestOutput], request_id: str, - model_name: str, documents: List[str], + self, final_res_batch: list[PoolingRequestOutput], request_id: str, + model_name: str, documents: list[str], top_n: int) -> RerankResponse: """ Convert the output of do_rank to a RerankResponse """ - results: List[RerankResult] = [] + results: list[RerankResult] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): classify_res = ScoringRequestOutput.from_base(final_res) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 6c79adf90c8a..4e95ef59e80e 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Final, List, Optional, Union +from typing import Final, Optional, Union from fastapi import Request @@ -92,7 +92,7 @@ async def create_tokenize( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - input_ids: List[int] = [] + input_ids: list[int] = [] for i, engine_prompt in enumerate(engine_prompts): self._log_inputs(request_id, request_prompts[i], diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 0bedb5718a4b..77f016a5e0a4 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio import io -from typing import AsyncGenerator, Optional, Union, cast +from collections.abc import AsyncGenerator +from typing import Optional, Union, cast from fastapi import Request diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 7cdd6d4c4f2b..931d5aab9bd9 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import os +from collections.abc import Sequence from functools import cached_property -from typing import Callable, Dict, List, Optional, Sequence, Type, Union +from typing import Callable, Optional, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, @@ -22,16 +23,16 @@ class ToolParser: """ def __init__(self, tokenizer: AnyTokenizer): - self.prev_tool_call_arr: List[Dict] = [] + self.prev_tool_call_arr: list[dict] = [] # the index of the tool call that is currently being parsed self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: List[str] = [] + self.streamed_args_for_tool: list[str] = [] self.model_tokenizer = tokenizer @cached_property - def vocab(self) -> Dict[str, int]: + def vocab(self) -> dict[str, int]: # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() @@ -79,10 +80,10 @@ def extract_tool_calls_streaming( class ToolParserManager: - tool_parsers: Dict[str, Type] = {} + tool_parsers: dict[str, type] = {} @classmethod - def get_tool_parser(cls, name) -> Type: + def get_tool_parser(cls, name) -> type: """ Get tool parser by name which is registered by `register_module`. @@ -95,8 +96,8 @@ def get_tool_parser(cls, name) -> Type: @classmethod def _register_module(cls, - module: Type, - module_name: Optional[Union[str, List[str]]] = None, + module: type, + module_name: Optional[Union[str, list[str]]] = None, force: bool = True) -> None: if not issubclass(module, ToolParser): raise TypeError( @@ -116,9 +117,9 @@ def _register_module(cls, @classmethod def register_module( cls, - name: Optional[Union[str, List[str]]] = None, + name: Optional[Union[str, list[str]]] = None, force: bool = True, - module: Union[Type, None] = None) -> Union[type, Callable]: + module: Union[type, None] = None) -> Union[type, Callable]: """ Register module with the given name or name list. it can be used as a decoder(with module as None) or normal function(with module as not diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 002bf1738830..76da63c58008 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -2,8 +2,9 @@ import json import re +from collections.abc import Sequence from json import JSONDecoder -from typing import Dict, Sequence, Union +from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -145,7 +146,7 @@ def extract_tool_calls_streaming( return None # select as the current tool call the one we're on the state at - current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ if len(tool_call_arr) > 0 else {} # case -- if no tokens have been streamed for the tool, e.g. diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index c948ed78f503..91afc88ef3dd 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Dict, Sequence, Union +from collections.abc import Sequence +from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -136,7 +137,7 @@ def extract_tool_calls_streaming( return None # select as the current tool call the one we're on the state at - current_tool_call: Dict = tool_call_arr[self.current_tool_id] + current_tool_call: dict = tool_call_arr[self.current_tool_id] delta = None # case: we are starting a new tool in the array diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 4841b28703ee..4c39e9b0c61f 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -2,7 +2,8 @@ import json import re -from typing import Dict, List, Sequence, Union +from collections.abc import Sequence +from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -33,9 +34,9 @@ def __init__(self, tokenizer: AnyTokenizer): self.model_tokenizer = self.model_tokenizer.tokenizer self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: List[Dict] = [] + self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: List[str] = [ + self.streamed_args_for_tool: list[str] = [ ] # map what has been streamed for each tool so far to a list self.tool_call_start_token: str = "" diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index b9215e7979bf..57d7c77c64f7 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Dict, Sequence, Union +from collections.abc import Sequence +from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -90,7 +91,7 @@ def extract_tool_calls_streaming( # tool calls are generated in an object in inernlm2 # it's not support parallel tool calls try: - tool_call_arr: Dict = partial_json_parser.loads( + tool_call_arr: dict = partial_json_parser.loads( parsable_arr, flags) except partial_json_parser.core.exceptions.MalformedJSON: logger.debug('not enough tokens to parse into JSON yet') diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 7c4d63e18865..8df106bf2718 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -2,7 +2,8 @@ import json import re -from typing import Dict, List, Sequence, Union +from collections.abc import Sequence +from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -35,9 +36,9 @@ def __init__(self, tokenizer: AnyTokenizer): ) self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: List[Dict] = [] + self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: List[str] = [ + self.streamed_args_for_tool: list[str] = [ ] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "" @@ -157,7 +158,7 @@ def extract_tool_calls_streaming( # tool calls are generated in an array, so do partial JSON # parsing on the entire array try: - tool_call_arr: List[Dict] = partial_json_parser.loads( + tool_call_arr: list[dict] = partial_json_parser.loads( parsable_arr, flags) except partial_json_parser.core.exceptions.MalformedJSON: logger.debug('not enough tokens to parse into JSON yet') @@ -165,7 +166,7 @@ def extract_tool_calls_streaming( # select as the current tool call the one we're on the state at - current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ if len(tool_call_arr) > 0 else {} # case -- if no tokens have been streamed for the tool, e.g. diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 6a7b113623e6..20c3238fb3df 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -2,8 +2,9 @@ import json import re +from collections.abc import Sequence from json import JSONDecoder -from typing import Dict, List, Sequence, Union +from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -40,10 +41,10 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): # initialize properties used for state when parsing tool calls in # streaming mode - self.prev_tool_call_arr: List[Dict] = [] + self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: List[str] = [ + self.streamed_args_for_tool: list[str] = [ ] # map what has been streamed for each tool so far to a list self.bot_token = "<|python_tag|>" self.bot_token_id = tokenizer.encode(self.bot_token, @@ -78,7 +79,7 @@ def extract_tool_calls( start_idx += end_idx + len('; ') function_call_arr.append(obj) - tool_calls: List[ToolCall] = [ + tool_calls: list[ToolCall] = [ ToolCall( type="function", function=FunctionCall( @@ -152,7 +153,7 @@ def extract_tool_calls_streaming( return None # select as the current tool call the one we're on the state at - current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ if len(tool_call_arr) > 0 else {} # case -- if no tokens have been streamed for the tool, e.g. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 4f0480882992..0661445639d7 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -2,9 +2,10 @@ import json import re +from collections.abc import Sequence from random import choices from string import ascii_letters, digits -from typing import Dict, List, Sequence, Union +from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -56,10 +57,10 @@ def __init__(self, tokenizer: AnyTokenizer): # initialize properties used for state when parsing tool calls in # streaming mode - self.prev_tool_call_arr: List[Dict] = [] + self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: List[str] = [ + self.streamed_args_for_tool: list[str] = [ ] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" self.bot_token_id = self.vocab.get(self.bot_token) @@ -104,7 +105,7 @@ def extract_tool_calls( function_call_arr = json.loads(raw_tool_call) # Tool Call - tool_calls: List[MistralToolCall] = [ + tool_calls: list[MistralToolCall] = [ MistralToolCall( type="function", function=FunctionCall( @@ -172,7 +173,7 @@ def extract_tool_calls_streaming( # tool calls are generated in an array, so do partial JSON # parsing on the entire array try: - tool_call_arr: List[Dict] = partial_json_parser.loads( + tool_call_arr: list[dict] = partial_json_parser.loads( parsable_arr, flags) except partial_json_parser.core.exceptions.MalformedJSON: logger.debug('not enough tokens to parse into JSON yet') @@ -180,7 +181,7 @@ def extract_tool_calls_streaming( # select as the current tool call the one we're on the state at - current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ if len(tool_call_arr) > 0 else {} # case -- if no tokens have been streamed for the tool, e.g. diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 5c282b5c2605..1b9317f16f34 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -3,7 +3,8 @@ import ast import json import re -from typing import Any, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Union from transformers import PreTrainedTokenizerBase @@ -204,7 +205,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: arguments=json.dumps(arguments))) -def _make_valid_python(text: str) -> Union[Tuple[str, str], None]: +def _make_valid_python(text: str) -> Union[tuple[str, str], None]: bracket_stack = [] for index, char in enumerate(text): if char in {"[", "(", "{"}: diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index 945cbd683502..7997629d461a 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -2,7 +2,7 @@ import json from json import JSONDecodeError, JSONDecoder -from typing import Any, List, Tuple +from typing import Any import partial_json_parser from partial_json_parser.core.options import Allow @@ -82,7 +82,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str: return diff -def find_all_indices(string: str, substring: str) -> List[int]: +def find_all_indices(string: str, substring: str) -> list[int]: """ Find all (starting) indices of a substring in a given string. Useful for tool call extraction @@ -99,7 +99,7 @@ def find_all_indices(string: str, substring: str) -> List[int]: # partial_json_parser doesn't support extra data and # JSONDecorder.raw_decode doesn't support partial JSON -def partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: +def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: try: return (partial_json_parser.loads(input_str, flags), len(input_str)) except JSONDecodeError as e: diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 6ec0b5fb024a..53411a27b41e 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Union +from typing import Union from torch.nn import CosineSimilarity @@ -10,12 +10,12 @@ def _cosine_similarity( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - embed_1: List[PoolingRequestOutput], - embed_2: List[PoolingRequestOutput], -) -> List[PoolingRequestOutput]: + embed_1: list[PoolingRequestOutput], + embed_2: list[PoolingRequestOutput], +) -> list[PoolingRequestOutput]: scorer = CosineSimilarity(0) - scores: Union[List[PoolingRequestOutput]] = [] + scores: Union[list[PoolingRequestOutput]] = [] for emb_1, emb_2 in zip(embed_1, embed_2): pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) @@ -38,8 +38,8 @@ def _cosine_similarity( def _validate_score_input_lens( - texts_1: Union[List[str], List[dict]], - texts_2: Union[List[str], List[dict]], + texts_1: Union[list[str], list[dict]], + texts_2: Union[list[str], list[dict]], ): if len(texts_1) > 1 and len(texts_1) != len(texts_2): raise ValueError("Input lengths must be either 1:1, 1:N or N:N") diff --git a/vllm/envs.py b/vllm/envs.py index 048d63bfec0f..bf64cd70674d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,7 +2,7 @@ import os import tempfile -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional if TYPE_CHECKING: VLLM_HOST_IP: str = "" @@ -67,12 +67,12 @@ VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms - VLLM_PLUGINS: Optional[List[str]] = None + VLLM_PLUGINS: Optional[list[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False - VLLM_DISABLED_KERNELS: List[str] = [] + VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = False VLLM_ROCM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True @@ -123,7 +123,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # begin-env-vars-definition -environment_variables: Dict[str, Callable[[], Any]] = { +environment_variables: dict[str, Callable[[], Any]] = { # ================== Installation Time Env Vars ================== diff --git a/vllm/forward_context.py b/vllm/forward_context.py index b91816af1b6d..c3d20cff426c 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional import torch import torch.distributed as dist @@ -28,13 +28,13 @@ @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context - attn_layers: Dict[str, Any] + attn_layers: dict[str, Any] # TODO: extend to support per-layer dynamic forward context attn_metadata: "AttentionMetadata" # set dynamically for each forward pass # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass num_tokens_across_dp: Optional[ - List[int]] = None # set dynamically for each forward pass + list[int]] = None # set dynamically for each forward pass _forward_context: Optional[ForwardContext] = None diff --git a/vllm/logger.py b/vllm/logger.py index 0ee47de173ad..2b0b9da2d6f7 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -109,7 +109,7 @@ def _configure_vllm_root_logger() -> None: custom_config = json.loads(file.read()) if not isinstance(custom_config, dict): - raise ValueError("Invalid logging config. Expected Dict, got %s.", + raise ValueError("Invalid logging config. Expected dict, got %s.", type(custom_config).__name__) logging_config = custom_config diff --git a/vllm/logits_process.py b/vllm/logits_process.py index a810be7bc7a8..e3faf20029ec 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Tuple, Union +from typing import Callable, Union import torch from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], - Callable[[List[int], List[int], torch.Tensor], +LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor], + Callable[[list[int], list[int], torch.Tensor], torch.Tensor]] """LogitsProcessor is a function that takes a list of previously generated tokens, the logits tensor @@ -17,9 +17,9 @@ def get_bad_words_logits_processors( - bad_words: List[str], - tokenizer: AnyTokenizer) -> List[LogitsProcessor]: - bad_words_ids: List[List[int]] = list() + bad_words: list[str], + tokenizer: AnyTokenizer) -> list[LogitsProcessor]: + bad_words_ids: list[list[int]] = list() for bad_word in bad_words: # To prohibit words both at the beginning @@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor: _SMALLEST_LOGIT = float("-inf") _NEUTRAL_LOGIT = 0.0 - def __init__(self, bad_words_ids: List[List[int]]): + def __init__(self, bad_words_ids: list[list[int]]): self.bad_words_ids = bad_words_ids self.word_bias: torch.FloatTensor = None def __call__( self, - past_tokens_ids: Union[List[int], Tuple[int]], + past_tokens_ids: Union[list[int], tuple[int]], logits: torch.FloatTensor, ) -> torch.Tensor: if self.word_bias is None: diff --git a/vllm/outputs.py b/vllm/outputs.py index 030119710a18..8c355c89e3e9 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import time +from collections.abc import MutableSequence +from collections.abc import Sequence as GenericSequence from dataclasses import dataclass -from typing import Dict, Generic, List, MutableSequence, Optional -from typing import Sequence as GenericSequence -from typing import Union +from typing import Generic, Optional, Union import torch from typing_extensions import TypeVar, deprecated @@ -109,14 +109,14 @@ def __init__( self, request_id: str, prompt: Optional[str], - prompt_token_ids: Optional[List[int]], + prompt_token_ids: Optional[list[int]], prompt_logprobs: Optional[PromptLogprobs], - outputs: List[CompletionOutput], + outputs: list[CompletionOutput], finished: bool, metrics: Optional[RequestMetrics] = None, lora_request: Optional[LoRARequest] = None, encoder_prompt: Optional[str] = None, - encoder_prompt_token_ids: Optional[List[int]] = None, + encoder_prompt_token_ids: Optional[list[int]] = None, num_cached_tokens: Optional[int] = None, *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, @@ -139,9 +139,9 @@ def new( cls, request_id: str, prompt: Optional[str], - prompt_token_ids: Optional[List[int]], + prompt_token_ids: Optional[list[int]], text: str, - token_ids: List[int], + token_ids: list[int], logprobs: Optional[SampleLogprobs], prompt_logprobs: Optional[PromptLogprobs], cumulative_logprob: Optional[float], @@ -189,7 +189,7 @@ def add(self, next_output: "RequestOutput") -> None: @classmethod def from_seq_group( cls, seq_group: SequenceGroup, use_cache: bool, - seq_id_to_seq_group: Dict[str, SequenceGroupBase] + seq_id_to_seq_group: dict[str, SequenceGroupBase] ) -> Optional["RequestOutput"]: finished = seq_group.is_finished() @@ -363,12 +363,12 @@ class PoolingRequestOutput(Generic[_O]): Args: request_id (str): A unique identifier for the pooling request. outputs (PoolingOutput): The pooling results for the given input. - prompt_token_ids (List[int]): A list of token IDs used in the prompt. + prompt_token_ids (list[int]): A list of token IDs used in the prompt. finished (bool): A flag indicating whether the pooling is completed. """ def __init__(self, request_id: str, outputs: _O, - prompt_token_ids: List[int], finished: bool): + prompt_token_ids: list[int], finished: bool): self.request_id = request_id self.prompt_token_ids = prompt_token_ids self.finished = finished @@ -407,7 +407,7 @@ class RequestOutputFactory: @staticmethod def create(seq_group: SequenceGroup, - seq_id_to_seq_group: Dict[str, SequenceGroupBase], + seq_id_to_seq_group: dict[str, SequenceGroupBase], use_cache: bool = False): if seq_group.pooled_data is not None: return PoolingRequestOutput.from_seq_group(seq_group) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 2ce87283df75..17e4e43387dd 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -4,11 +4,10 @@ from dataclasses import dataclass from enum import Enum, IntEnum from functools import cached_property -from typing import Any, Dict, List, Optional, Set, Union +from typing import Annotated, Any, Optional, Union import msgspec from pydantic import BaseModel -from typing_extensions import Annotated from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor @@ -29,9 +28,9 @@ class SamplingType(IntEnum): @dataclass class GuidedDecodingParams: """One of these fields will be used to build a logit processor.""" - json: Optional[Union[str, Dict]] = None + json: Optional[Union[str, dict]] = None regex: Optional[str] = None - choice: Optional[List[str]] = None + choice: Optional[list[str]] = None grammar: Optional[str] = None json_object: Optional[bool] = None """These are other options that can be set""" @@ -40,9 +39,9 @@ class GuidedDecodingParams: @staticmethod def from_optional( - json: Optional[Union[Dict, BaseModel, str]] = None, + json: Optional[Union[dict, BaseModel, str]] = None, regex: Optional[str] = None, - choice: Optional[List[str]] = None, + choice: Optional[list[str]] = None, grammar: Optional[str] = None, json_object: Optional[bool] = None, backend: Optional[str] = None, @@ -72,7 +71,7 @@ def backend_name(self) -> str: """ return (self.backend or "").split(":")[0] - def backend_options(self) -> List[str]: + def backend_options(self) -> list[str]: """Return the backend options as a list of strings.""" if not self.backend or ":" not in self.backend: return [] @@ -144,12 +143,12 @@ class SamplingParams( considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. seed: Random seed to use for the generation. - stop: List of strings that stop the generation when they are generated. + stop: list of strings that stop the generation when they are generated. The returned output will not contain the stop strings. - stop_token_ids: List of tokens that stop the generation when they are + stop_token_ids: list of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens. - bad_words: List of words that are not allowed to be generated. + bad_words: list of words that are not allowed to be generated. More precisely, only the last token of a corresponding token sequence is not allowed when the next generated token can complete the sequence. @@ -172,7 +171,7 @@ class SamplingParams( skip_special_tokens: Whether to skip special tokens in the output. spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. - logits_processors: List of functions that modify logits based on + logits_processors: list of functions that modify logits based on previously generated tokens, and optionally prompt tokens as a first argument. truncate_prompt_tokens: If set to an integer k, will use only the last k @@ -198,9 +197,9 @@ class SamplingParams( top_k: int = -1 min_p: float = 0.0 seed: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None - stop_token_ids: Optional[List[int]] = None - bad_words: Optional[List[str]] = None + stop: Optional[Union[str, list[str]]] = None + stop_token_ids: Optional[list[int]] = None + bad_words: Optional[list[str]] = None ignore_eos: bool = False max_tokens: Optional[int] = 16 min_tokens: int = 0 @@ -212,8 +211,8 @@ class SamplingParams( detokenize: bool = True skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - # Optional[List[LogitsProcessor]] type. We use Any here because - # Optional[List[LogitsProcessor]] type is not supported by msgspec. + # Optional[list[LogitsProcessor]] type. We use Any here because + # Optional[list[LogitsProcessor]] type is not supported by msgspec. logits_processors: Optional[Any] = None include_stop_str_in_output: bool = False truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None @@ -222,12 +221,12 @@ class SamplingParams( # The below fields are not supposed to be used as an input. # They are set in post_init. output_text_buffer_length: int = 0 - _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) + _all_stop_token_ids: set[int] = msgspec.field(default_factory=set) # Fields used to construct logits processors guided_decoding: Optional[GuidedDecodingParams] = None - logit_bias: Optional[Dict[int, float]] = None - allowed_token_ids: Optional[List[int]] = None + logit_bias: Optional[dict[int, float]] = None + allowed_token_ids: Optional[list[int]] = None @staticmethod def from_optional( @@ -241,9 +240,9 @@ def from_optional( top_k: int = -1, min_p: float = 0.0, seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stop_token_ids: Optional[List[int]] = None, - bad_words: Optional[List[str]] = None, + stop: Optional[Union[str, list[str]]] = None, + stop_token_ids: Optional[list[int]] = None, + bad_words: Optional[list[str]] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: Optional[int] = 16, @@ -253,13 +252,13 @@ def from_optional( detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, - logits_processors: Optional[List[LogitsProcessor]] = None, + logits_processors: Optional[list[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, guided_decoding: Optional[GuidedDecodingParams] = None, - logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None, - allowed_token_ids: Optional[List[int]] = None, + logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, + allowed_token_ids: Optional[list[int]] = None, ) -> "SamplingParams": if logit_bias is not None: # Convert token_id to integer @@ -435,7 +434,7 @@ def _verify_greedy_sampling(self) -> None: def update_from_generation_config( self, - generation_config: Dict[str, Any], + generation_config: dict[str, Any], model_eos_token_id: Optional[int] = None) -> None: """Update if there are non-default values from generation_config""" @@ -468,7 +467,7 @@ def sampling_type(self) -> SamplingType: return SamplingType.RANDOM @property - def all_stop_token_ids(self) -> Set[int]: + def all_stop_token_ids(self) -> set[int]: return self._all_stop_token_ids def clone(self) -> "SamplingParams": diff --git a/vllm/sequence.py b/vllm/sequence.py index c0425ba33c9a..6a7b1e62a604 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,11 +5,11 @@ from abc import ABC, abstractmethod from array import array from collections import defaultdict +from collections.abc import Mapping +from collections.abc import Sequence as GenericSequence from dataclasses import dataclass, field from functools import reduce -from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional -from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union +from typing import Any, Callable, Optional, Union import msgspec import torch @@ -50,9 +50,9 @@ class Logprob: # {token_id -> logprob} per each sequence group. None if the corresponding # sequence group doesn't require prompt logprob. -PromptLogprobs = List[Optional[Dict[int, Logprob]]] +PromptLogprobs = list[Optional[dict[int, Logprob]]] # {token_id -> logprob} for each sequence group. -SampleLogprobs = List[Dict[int, Logprob]] +SampleLogprobs = list[dict[int, Logprob]] class SequenceStatus(enum.IntEnum): @@ -129,7 +129,7 @@ class SequenceDataDelta( omit_defaults=True): # type: ignore[call-arg] """Delta SequenceData to send to workers per step.""" # A new token to be appended to existing SequenceData. - new_output_token_ids: List[int] + new_output_token_ids: list[int] # Overwriting existing `cumulative_logprob` new_cumulative_logprob: float # Overwriting existing `num_computed_tokens`. @@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct, output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ - # NOTE: we cannot use Union[List, array] because msgspec cannot support + # NOTE: we cannot use Union[list, array] because msgspec cannot support # union of 2 list types. _prompt_token_ids: array _output_token_ids: array = msgspec.field( @@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct, ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 - _prompt_token_ids_tuple: Tuple[int, + _prompt_token_ids_tuple: tuple[int, ...] = msgspec.field(default_factory=tuple) # The number of tokens that are computed (that run against the model). _num_computed_tokens: int = 0 # The number of tokens with prefix cache hit. _num_cached_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL - _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) + _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) # It is used to get delta input. It is reset when `get_delta_and_reset` # is called. - _new_appended_tokens: List[int] = msgspec.field(default_factory=list) + _new_appended_tokens: list[int] = msgspec.field(default_factory=list) # It is used to compute mrope_position_ids. _mrope_position_delta: Optional[int] = None @staticmethod def from_prompt_token_counts( - *token_counts: Tuple[int, int]) -> "SequenceData": + *token_counts: tuple[int, int]) -> "SequenceData": """ Construct a :class:`SequenceData` instance by concatenating prompt token sequences. @@ -220,14 +220,14 @@ def from_seqs( def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l" - self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( + self._prompt_token_ids_tuple: tuple[int, ...] = tuple( self._prompt_token_ids) self._update_cached_all_tokens() def _update_cached_all_tokens(self): assert isinstance(self._prompt_token_ids, array) assert isinstance(self._output_token_ids, array) - self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + + self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + self._output_token_ids) @property @@ -235,7 +235,7 @@ def cumulative_logprob(self) -> float: return self._cumulative_logprob @property - def prompt_token_ids(self) -> Tuple[int, ...]: + def prompt_token_ids(self) -> tuple[int, ...]: return self._prompt_token_ids_tuple @prompt_token_ids.setter @@ -252,7 +252,7 @@ def prompt_token_ids_array(self) -> array: return self._prompt_token_ids @property - def output_token_ids(self) -> Tuple[int, ...]: + def output_token_ids(self) -> tuple[int, ...]: return tuple(self._output_token_ids) @output_token_ids.setter @@ -295,12 +295,12 @@ def get_prompt_len(self) -> int: def get_output_len(self) -> int: return len(self._output_token_ids) - def get_token_ids(self) -> List[int]: + def get_token_ids(self) -> list[int]: return self._cached_all_token_ids def get_prefix_token_ids( self, num_tokens: int - ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: """Get prefix tokens, and make the return value hashable""" prompt_length = self.get_prompt_len() if num_tokens > prompt_length: @@ -351,10 +351,10 @@ def get_last_token_id(self) -> int: return self._prompt_token_ids[-1] return self._output_token_ids[-1] - def get_prompt_token_ids(self) -> Tuple[int, ...]: + def get_prompt_token_ids(self) -> tuple[int, ...]: return self.prompt_token_ids - def get_output_token_ids(self) -> Tuple[int, ...]: + def get_output_token_ids(self) -> tuple[int, ...]: return self.output_token_ids def get_delta_and_reset(self) -> SequenceDataDelta: @@ -432,7 +432,7 @@ def __init__( self.prefix_offset = 0 self.read_offset = 0 # Input + output tokens - self.tokens: Optional[List[str]] = None + self.tokens: Optional[list[str]] = None @property def n_blocks(self) -> int: @@ -443,7 +443,7 @@ def prompt(self) -> Optional[str]: return self.inputs.prompt @property - def prompt_token_ids(self) -> List[int]: + def prompt_token_ids(self) -> list[int]: return self.inputs.prompt_token_ids @property @@ -451,7 +451,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]: return self.inputs.prompt_embeds @property - def token_type_ids(self) -> List[int]: + def token_type_ids(self) -> list[int]: return self.inputs.token_type_ids @property @@ -463,7 +463,7 @@ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: return self.inputs.multi_modal_placeholders @property - def mm_processor_kwargs(self) -> Dict[str, Any]: + def mm_processor_kwargs(self) -> dict[str, Any]: return self.inputs.mm_processor_kwargs @property @@ -548,7 +548,7 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, token_id: int, logprobs: Dict[int, + def append_token_id(self, token_id: int, logprobs: dict[int, Logprob]) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) @@ -563,16 +563,16 @@ def get_prompt_len(self) -> int: def get_output_len(self) -> int: return self.data.get_output_len() - def get_token_ids(self) -> List[int]: + def get_token_ids(self) -> list[int]: return self.data.get_token_ids() - def get_prompt_token_ids(self) -> Tuple[int, ...]: + def get_prompt_token_ids(self) -> tuple[int, ...]: return self.data.get_prompt_token_ids() def get_last_token_id(self) -> int: return self.data.get_last_token_id() - def get_output_token_ids(self) -> Tuple[int, ...]: + def get_output_token_ids(self) -> tuple[int, ...]: return self.data.get_output_token_ids() def get_cumulative_logprob(self) -> float: @@ -644,7 +644,7 @@ class SequenceGroup: def __init__( self, request_id: str, - seqs: List[Sequence], + seqs: list[Sequence], arrival_time: float, sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, @@ -686,7 +686,7 @@ def prompt(self) -> Optional[str]: return self.first_seq.prompt @property - def prompt_token_ids(self) -> List[int]: + def prompt_token_ids(self) -> list[int]: return self.first_seq.prompt_token_ids @property @@ -698,7 +698,7 @@ def encoder_prompt(self) -> Optional[str]: if self.encoder_seq is not None else None) @property - def encoder_prompt_token_ids(self) -> Optional[List[int]]: + def encoder_prompt_token_ids(self) -> Optional[list[int]]: # There are either 0 or 1 encoder sequences # If one is present, its prompt token ids are # distinct from the decoder's. @@ -706,7 +706,7 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: if self.encoder_seq is not None else None) @property - def token_type_ids(self) -> Optional[List[int]]: + def token_type_ids(self) -> Optional[list[int]]: return self.first_seq.token_type_ids @property @@ -726,7 +726,7 @@ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: return {} @property - def mm_processor_kwargs(self) -> Dict[str, Any]: + def mm_processor_kwargs(self) -> dict[str, Any]: if self.first_seq.multi_modal_data: return self.first_seq.mm_processor_kwargs elif self.encoder_seq is not None: @@ -823,7 +823,7 @@ def get_max_num_running_seqs(self) -> int: def get_seqs( self, status: Optional[SequenceStatus] = None, - ) -> List[Sequence]: + ) -> list[Sequence]: if status is None: return self.seqs @@ -838,7 +838,7 @@ def is_encoder_decoder(self) -> bool: def get_encoder_seq(self) -> Optional[Sequence]: return self.encoder_seq - def get_finished_seqs(self) -> List[Sequence]: + def get_finished_seqs(self) -> list[Sequence]: if self.is_single_seq: return self.seqs if self.first_seq.is_finished() else [] @@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta( After sending the first SequenceGroupMetadata, vLLM scheduler only sends delta to reduce the data payload size. """ - seq_data_delta: Dict[int, SequenceDataDelta] + seq_data_delta: dict[int, SequenceDataDelta] request_id: str - block_tables: Dict[int, List[int]] + block_tables: dict[int, list[int]] is_prompt: bool do_sample: bool = True token_chunk_size: Optional[int] = None - computed_block_nums: Optional[List[int]] = None + computed_block_nums: Optional[list[int]] = None state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) @@ -947,23 +947,23 @@ class SequenceGroupMetadata( request_id: str is_prompt: bool - seq_data: Dict[int, SequenceData] + seq_data: dict[int, SequenceData] sampling_params: Optional[SamplingParams] - block_tables: Dict[int, List[int]] + block_tables: dict[int, list[int]] do_sample: bool = True pooling_params: Optional[PoolingParams] = None lora_request: Optional[LoRARequest] = None - computed_block_nums: Optional[List[int]] = None + computed_block_nums: Optional[list[int]] = None state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. - token_type_ids: Optional[List[int]] = None + token_type_ids: Optional[list[int]] = None multi_modal_data: Optional[Any] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - mm_processor_kwargs: Optional[Dict[str, Any]] = None + mm_processor_kwargs: Optional[dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None - cross_block_table: Optional[List[int]] = None + cross_block_table: Optional[list[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None token_chunk_size: Optional[int] = None @@ -1042,7 +1042,7 @@ class SequenceOutput( """ parent_seq_id: int output_token: int - logprobs: Dict[int, Logprob] + logprobs: dict[int, Logprob] def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " @@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput( array_like=True): # type: ignore[call-arg] """The model output associated with a completion sequence group.""" __metaclass__ = SequenceGroupOutput - samples: List[SequenceOutput] + samples: list[SequenceOutput] # Prompt logprob for each prompt query token. prompt_logprobs: Optional[PromptLogprobs] @@ -1119,7 +1119,7 @@ class IntermediateTensors: contains the hidden states and residuals for a request. """ - tensors: Dict[str, torch.Tensor] + tensors: dict[str, torch.Tensor] def __init__(self, tensors): # manually define this function, so that @@ -1155,7 +1155,7 @@ class PoolerOutput( omit_defaults=True, # type: ignore[call-arg] array_like=True): # type: ignore[call-arg] """The output from a pooling operation in the pooling model.""" - outputs: List[PoolingSequenceGroupOutput] + outputs: list[PoolingSequenceGroupOutput] def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: return self.outputs[idx] @@ -1172,7 +1172,7 @@ def __eq__(self, other: object): def get_all_seq_ids( - seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]: """Given a list of SequenceGroupMetadata, create a list of all sequence ids. """ @@ -1180,13 +1180,13 @@ def get_all_seq_ids( def get_all_seq_ids_and_request_ids( - seq_group_metadata_list: List[SequenceGroupMetadata] -) -> Tuple[List[int], Dict[str, Set[int]]]: + seq_group_metadata_list: list[SequenceGroupMetadata] +) -> tuple[list[int], dict[str, set[int]]]: """Given a list of SequenceGroupMetadata, create a list of all sequence ids. """ - seq_ids: List[int] = [] - request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set) + seq_ids: list[int] = [] + request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set) for sg in seq_group_metadata_list: for seq_id in sg.seq_data: seq_ids.append(seq_id) @@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True, # all tokens, whereas for decode step, it use used for last accepted tokens. hidden_states: torch.Tensor # The sequence group metadata list. Only needed for decode step. - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None # Scorer hidden states of the 2nd last token proposed by the proposer ( # irrespective of whether it was accepted or not). Only used for cases when # last proposed token is accepted (i.e., in case of bonus tokens). For the # case of no bonus tokens, these are ignored. second_last_token_hidden_states: Optional[torch.Tensor] = None - _seq_ids: List[int] = msgspec.field(default_factory=list) + _seq_ids: list[int] = msgspec.field(default_factory=list) def __post_init__(self): if self.seq_group_metadata_list is not None: @@ -1221,12 +1221,12 @@ def __post_init__(self): self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) @property - def seq_ids(self) -> List[int]: + def seq_ids(self) -> list[int]: return self._seq_ids def update(self, hidden_states: torch.Tensor, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: list[SequenceGroupMetadata], second_last_token_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation. Only used for decode steps""" @@ -1244,7 +1244,7 @@ def update(self, ]) def prune(self, - seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + seq_group_metadata_list: list[SequenceGroupMetadata]) -> None: """Prune to provided list of sequence ids. Only used for decode steps. """ # Currently this prunes all seq_ids not present in @@ -1287,16 +1287,16 @@ class ExecuteModelRequest( """The model execution request, containing CPU metadata only. The LLM engine should create an instance of this class for each request batch.""" # The sequence group metadata list. - seq_group_metadata_list: List[Union[SequenceGroupMetadata, + seq_group_metadata_list: list[Union[SequenceGroupMetadata, SequenceGroupMetadataDelta]] # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, + blocks_to_swap_in: list[tuple[int, int]] = msgspec.field(default_factory=list) # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, + blocks_to_swap_out: list[tuple[int, int]] = msgspec.field(default_factory=list) # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) + blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list) # Virtual engine ID for pipeline parallel. virtual_engine: int = 0 # The number of slots for lookahead decoding. @@ -1310,7 +1310,7 @@ class ExecuteModelRequest( # The step index for spec model input. spec_step_idx: Optional[int] = None # Finished request ids since last step. - finished_requests_ids: List[str] = msgspec.field(default_factory=list) + finished_requests_ids: list[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback @@ -1344,7 +1344,7 @@ def current_step(self) -> int: return state.current_step def clone( - self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, + self, seq_group_metadata_list: list[Union[SequenceGroupMetadata, SequenceGroupMetadataDelta]] ) -> "ExecuteModelRequest": """Clone the request with a new sequence group metadata list.""" @@ -1371,13 +1371,13 @@ class SequenceGroupBase: assembled_seq_group: Optional[SequenceGroup] = None # seq id to a unique index inside this group - seq_id_to_index: Dict[str, int] = field(default_factory=dict) + seq_id_to_index: dict[str, int] = field(default_factory=dict) # seq ids to be finished - to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict) + to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict) # seq id to finished sequences - finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict) + finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict) streaming: bool = False diff --git a/vllm/tracing.py b/vllm/tracing.py index bf069ad84fd4..557ae40b87ae 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Mapping, Optional +from collections.abc import Mapping +from typing import Optional from vllm.logger import init_logger from vllm.utils import run_once diff --git a/vllm/utils.py b/vllm/utils.py index 29e60a9c9be2..26c9e1a90837 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -28,12 +28,12 @@ import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import OrderedDict, UserDict, defaultdict -from collections.abc import Hashable, Iterable, Mapping +from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, + Iterable, Iterator, Mapping) from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps -from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, - Dict, Generator, Generic, Iterator, List, Literal, - NamedTuple, Optional, Tuple, Type, TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, + Optional, TypeVar, Union) from uuid import uuid4 import cloudpickle @@ -400,7 +400,7 @@ def _next_task(iterator: AsyncGenerator[T, None], async def merge_async_iterators( *iterators: AsyncGenerator[T, - None], ) -> AsyncGenerator[Tuple[int, T], None]: + None], ) -> AsyncGenerator[tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. @@ -433,7 +433,7 @@ async def merge_async_iterators( async def collect_from_async_generator( - iterator: AsyncGenerator[T, None]) -> List[T]: + iterator: AsyncGenerator[T, None]) -> list[T]: """Collect all items from an async generator into a list.""" items = [] async for item in iterator: @@ -560,7 +560,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]: return None -def update_environment_variables(envs: Dict[str, str]): +def update_environment_variables(envs: dict[str, str]): for k, v in envs.items(): if k in os.environ and os.environ[k] != v: logger.warning( @@ -569,7 +569,7 @@ def update_environment_variables(envs: Dict[str, str]): os.environ[k] = v -def chunk_list(lst: List[T], chunk_size: int): +def chunk_list(lst: list[T], chunk_size: int): """Yield successive chunk_size chunks from lst.""" for i in range(0, len(lst), chunk_size): yield lst[i:i + chunk_size] @@ -642,7 +642,7 @@ def create_kv_caches_with_random_flash( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: int = 0, device: Optional[str] = "cuda", -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform current_platform.seed_everything(seed) @@ -650,8 +650,8 @@ def create_kv_caches_with_random_flash( key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) scale = head_size**-0.5 - key_caches: List[torch.Tensor] = [] - value_caches: List[torch.Tensor] = [] + key_caches: list[torch.Tensor] = [] + value_caches: list[torch.Tensor] = [] for _ in range(num_layers): key_value_cache = torch.empty(size=key_value_cache_shape, @@ -679,7 +679,7 @@ def create_kv_caches_with_random( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: int = 0, device: Optional[str] = "cuda", -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: if cache_dtype == "fp8" and head_size % 16: raise ValueError( @@ -693,7 +693,7 @@ def create_kv_caches_with_random( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches: List[torch.Tensor] = [] + key_caches: list[torch.Tensor] = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, @@ -708,7 +708,7 @@ def create_kv_caches_with_random( key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches: List[torch.Tensor] = [] + value_caches: list[torch.Tensor] = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, dtype=torch_dtype, @@ -754,7 +754,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def make_ndarray_with_pad( - x: List[List[T]], + x: list[list[T]], pad: T, dtype: npt.DTypeLike, *, @@ -779,7 +779,7 @@ def make_ndarray_with_pad( def make_tensor_with_pad( - x: List[List[T]], + x: list[list[T]], pad: T, dtype: torch.dtype, *, @@ -831,7 +831,7 @@ def is_list_of( typ: Union[type[T], tuple[type[T], ...]], *, check: Literal["first", "all"] = "first", -) -> TypeIs[List[T]]: +) -> TypeIs[list[T]]: if not isinstance(value, list): return False @@ -843,8 +843,8 @@ def is_list_of( assert_never(check) -JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], - Tuple["JSONTree[T]", ...], T] +JSONTree = Union[dict[str, "JSONTree[T]"], list["JSONTree[T]"], + tuple["JSONTree[T]", ...], T] """A nested JSON structure where the leaves need not be JSON-serializable.""" @@ -859,7 +859,7 @@ def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: return func(value) -def flatten_2d_lists(lists: List[List[T]]) -> List[T]: +def flatten_2d_lists(lists: list[list[T]]) -> list[T]: """Flatten a list of lists to a single list.""" return [item for sublist in lists for item in sublist] @@ -1226,7 +1226,7 @@ def check_port(self, value): return value - def _pull_args_from_config(self, args: List[str]) -> List[str]: + def _pull_args_from_config(self, args: list[str]) -> list[str]: """Method to pull arguments specified in the config file into the command-line args variable. @@ -1291,7 +1291,7 @@ def _pull_args_from_config(self, args: List[str]) -> List[str]: return args - def _load_config_file(self, file_path: str) -> List[str]: + def _load_config_file(self, file_path: str) -> list[str]: """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml @@ -1313,9 +1313,9 @@ def _load_config_file(self, file_path: str) -> List[str]: %s supplied", extension) # only expecting a flat dictionary of atomic types - processed_args: List[str] = [] + processed_args: list[str] = [] - config: Dict[str, Union[int, str]] = {} + config: dict[str, Union[int, str]] = {} try: with open(file_path) as config_file: config = yaml.safe_load(config_file) @@ -1399,7 +1399,7 @@ def resolve_mm_processor_kwargs( *, requires_kw_only: bool = True, allow_var_kwargs: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Applies filtering to eliminate invalid mm_processor_kwargs, i.e., those who are not explicit keywords to the given callable (of one is given; otherwise no filtering is done), then merges the kwarg dicts, @@ -1440,7 +1440,7 @@ def get_allowed_kwarg_only_overrides( *, requires_kw_only: bool = True, allow_var_kwargs: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Given a callable which has one or more keyword only params and a dict mapping param names to values, drop values that can be not be kwarg @@ -1531,9 +1531,9 @@ def value(self): # Adapted from: https://stackoverflow.com/a/47212782/5082708 class LazyDict(Mapping[str, T], Generic[T]): - def __init__(self, factory: Dict[str, Callable[[], T]]): + def __init__(self, factory: dict[str, Callable[[], T]]): self._factory = factory - self._dict: Dict[str, T] = {} + self._dict: dict[str, T] = {} def __getitem__(self, key: str) -> T: if key not in self._dict: @@ -1552,9 +1552,9 @@ def __len__(self): return len(self._factory) -class ClassRegistry(UserDict[Type[T], _V]): +class ClassRegistry(UserDict[type[T], _V]): - def __getitem__(self, key: Type[T]) -> _V: + def __getitem__(self, key: type[T]) -> _V: for cls in key.mro(): if cls in self.data: return self.data[cls] @@ -1584,8 +1584,8 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: def weak_ref_tensors( - tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] -) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]: + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] +) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]: """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors. @@ -1857,7 +1857,7 @@ def __getattr__(self, key: str): def direct_register_custom_op( op_name: str, op_func: Callable, - mutates_args: List[str], + mutates_args: list[str], fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, dispatch_key: str = "CUDA", @@ -2177,8 +2177,8 @@ def get_mp_context(): def bind_kv_cache( - ctx: Dict[str, Any], - kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index] + ctx: dict[str, Any], + kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -2210,8 +2210,8 @@ def bind_kv_cache( forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] -def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any], - kwargs: Dict[str, Any]) -> Any: +def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], + kwargs: dict[str, Any]) -> Any: """ Run a method of an object with the given arguments and keyword arguments. If the method is string, it will be converted to a method using getattr. @@ -2263,7 +2263,7 @@ def import_pynvml(): return pynvml -def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]: +def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: """ A replacement for `abc.ABC`. When we use `abc.ABC`, subclasses will fail to instantiate diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 353bf46d503e..8bf7f3587bc0 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Optional import numpy as np import torch @@ -30,7 +30,7 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod - def get_supported_head_sizes() -> List[int]: + def get_supported_head_sizes() -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod @@ -38,15 +38,15 @@ def get_name() -> str: return "FLASH_ATTN_VLLM_V1" @staticmethod - def get_impl_cls() -> Type["FlashAttentionImpl"]: + def get_impl_cls() -> type["FlashAttentionImpl"]: return FlashAttentionImpl @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: + def get_metadata_cls() -> type["AttentionMetadata"]: return FlashAttentionMetadata @staticmethod - def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: return FlashAttentionMetadataBuilder @staticmethod @@ -55,7 +55,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -158,10 +158,10 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], + alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, + blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, ) -> None: @@ -381,7 +381,7 @@ def cascade_attention( max_kv_len: int, softmax_scale: float, alibi_slopes: Optional[torch.Tensor], - sliding_window: Tuple[int, int], + sliding_window: tuple[int, int], logits_soft_cap: float, block_table: torch.Tensor, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 30bce5cc8b68..824ffcfd61ba 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -195,8 +195,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, - Type, TypeVar) +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import torch from compressed_tensors.quantization import QuantizationStrategy @@ -250,11 +249,11 @@ def get_name() -> str: return "TRITON_MLA_VLLM_V1" @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: + def get_metadata_cls() -> type["AttentionMetadata"]: return MLACommonMetadata @staticmethod - def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: + def get_builder_cls() -> type["MLACommonMetadataBuilder"]: return MLACommonMetadataBuilder @staticmethod @@ -263,11 +262,11 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) @staticmethod - def get_supported_head_sizes() -> List[int]: + def get_supported_head_sizes() -> list[int]: return [576] @staticmethod @@ -317,8 +316,8 @@ class MLACommonMetadata: has_context: bool = False context_chunk_cu_seq_lens: Optional[torch.Tensor] = None context_chunk_starts: Optional[torch.Tensor] = None - context_chunk_seq_tot: Optional[List[int]] = None - context_chunk_max_seq_lens: Optional[List[int]] = None + context_chunk_seq_tot: Optional[list[int]] = None + context_chunk_max_seq_lens: Optional[list[int]] = None chunked_prefill_workspace: Optional[torch.Tensor] = None def __post_init__(self): @@ -538,10 +537,10 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], + alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], + blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, # MLA Specific Arguments @@ -634,7 +633,7 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): # # returns input_group_shape, weight_group_shape def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ - Tuple[Tuple[int, int], Tuple[int, int]]: + tuple[tuple[int, int], tuple[int, int]]: if isinstance(layer.quant_method, Fp8LinearMethod): if layer.quant_method.block_quant: weight_block_size = \ diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 8a7b7b974e36..b357d7142410 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Optional import torch @@ -25,21 +25,21 @@ def get_name() -> str: return "FLASHMLA_VLLM_V1" @staticmethod - def get_metadata_cls() -> Type["FlashMLAMetadata"]: + def get_metadata_cls() -> type["FlashMLAMetadata"]: return FlashMLAMetadata @staticmethod - def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: return FlashMLAMetadataBuilder @staticmethod - def get_impl_cls() -> Type["FlashMLAImpl"]: + def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl @dataclass class FlashMLAMetadata(MLACommonMetadata): - decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + decode_tile_scheduler_metadata: Optional[tuple[torch.Tensor, torch.Tensor]] = None decode_num_splits: Optional[torch.Tensor] = None @@ -76,10 +76,10 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], + alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], + blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, # MLA Specific Arguments diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 7747509f1a4b..3f9b349a5f04 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional import torch @@ -21,7 +21,7 @@ def get_name() -> str: return "TRITON_MLA_VLLM_V1" @staticmethod - def get_impl_cls() -> Type["TritonMLAImpl"]: + def get_impl_cls() -> type["TritonMLAImpl"]: return TritonMLAImpl @@ -33,10 +33,10 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], + alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], + blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, # MLA Specific Arguments diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index a9f7b3fd4471..bf4a05daf2d5 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Optional import torch # Required to register custom ops. @@ -22,15 +22,15 @@ def get_name() -> str: return "PALLAS_VLLM_V1" @staticmethod - def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + def get_impl_cls() -> type["PallasAttentionBackendImpl"]: return PallasAttentionBackendImpl @staticmethod - def get_metadata_cls() -> Type["PallasMetadata"]: + def get_metadata_cls() -> type["PallasMetadata"]: return PallasMetadata @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: + def get_state_cls() -> type["CommonAttentionState"]: return CommonAttentionState @staticmethod @@ -39,7 +39,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: return (num_kv_heads, num_blocks, block_size, head_size) @staticmethod @@ -77,10 +77,10 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], + alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, + blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: @@ -120,7 +120,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_cache: tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 5c7d759b1812..a625d99f4a15 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Attention layer with PagedAttention on rocm""" -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Optional import torch @@ -20,7 +20,7 @@ class ROCmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod - def get_supported_head_sizes() -> List[int]: + def get_supported_head_sizes() -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod @@ -28,11 +28,11 @@ def get_name() -> str: return "ROCM_ATTN_VLLM_V1" @staticmethod - def get_impl_cls() -> Type["ROCmAttentionImpl"]: + def get_impl_cls() -> type["ROCmAttentionImpl"]: return ROCmAttentionImpl @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: + def get_metadata_cls() -> type["AttentionMetadata"]: return FlashAttentionMetadata @staticmethod @@ -41,7 +41,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -51,7 +51,7 @@ def use_cascade_attention(*args, **kwargs) -> bool: return False @staticmethod - def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: return FlashAttentionMetadataBuilder @@ -63,10 +63,10 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], + alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, + blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, ) -> None: diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 1b5c7f96f668..394b47fddf0c 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from typing import Dict, Iterable, List, Optional +from collections.abc import Iterable +from typing import Optional from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, @@ -29,7 +30,7 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool): self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching # All kv-cache blocks. - self.blocks: List[KVCacheBlock] = [ + self.blocks: list[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) ] # Free block queue that constructs and manipulates a doubly linked @@ -46,7 +47,7 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool): # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: Dict[BlockHashType, Dict[ + self.cached_block_hash_to_block: dict[BlockHashType, dict[ int, KVCacheBlock]] = defaultdict(dict) def get_cached_block(self, @@ -69,8 +70,8 @@ def get_cached_block(self, def cache_full_blocks( self, request: Request, - blocks: List[KVCacheBlock], - block_hashes: List[BlockHashType], + blocks: list[KVCacheBlock], + block_hashes: list[BlockHashType], num_cached_blocks: int, num_full_blocks: int, block_size: int, @@ -146,7 +147,7 @@ def cache_full_blocks( self.cached_block_hash_to_block[block_hash][blk.block_id] = blk prev_block_hash_value = block_hash.hash_value - def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. Note that we do not check block cache in this function. @@ -161,7 +162,7 @@ def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: raise ValueError( f"Cannot get {num_blocks} free blocks from the pool") - ret: List[KVCacheBlock] = [] + ret: list[KVCacheBlock] = [] idx = 0 while idx < num_blocks: # First allocate blocks. @@ -200,7 +201,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: return True return False - def touch(self, blocks: List[KVCacheBlock]) -> None: + def touch(self, blocks: list[KVCacheBlock]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 13ad14e45b32..018379c1f43a 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Dict, List, Set, Tuple +from typing import TYPE_CHECKING from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY @@ -18,9 +18,9 @@ def __init__(self, cache_size: int): self.cache_size = cache_size self.num_free_slots = cache_size # req_id -> cached input ids - self.cached: Dict[str, Set[int]] = {} - # List of [req_id, input_id] - self.freed: List[Tuple[str, int]] = [] + self.cached: dict[str, set[int]] = {} + # list of [req_id, input_id] + self.freed: list[tuple[str, int]] = [] def has_cache(self, request: Request, input_id: int) -> bool: req_id = request.request_id @@ -37,7 +37,7 @@ def allocate(self, request: Request, input_id: int) -> None: self.cached[req_id].add(input_id) self.num_free_slots -= request.get_num_encoder_tokens(input_id) - def get_cached_input_ids(self, request: Request) -> Set[int]: + def get_cached_input_ids(self, request: Request) -> set[int]: return self.cached.get(request.request_id, set()) def free_encoder_input(self, request: Request, input_id: int) -> None: @@ -58,7 +58,7 @@ def free(self, request: Request) -> None: for input_id in input_ids: self.free_encoder_input(request, input_id) - def get_freed_ids(self) -> List[Tuple[str, int]]: + def get_freed_ids(self) -> list[tuple[str, int]]: freed = self.freed self.freed = [] return freed @@ -67,7 +67,7 @@ def get_freed_ids(self) -> List[Tuple[str, int]]: def compute_encoder_budget( model_config: "ModelConfig", scheduler_config: "SchedulerConfig", -) -> Tuple[int, int]: +) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations. @@ -97,7 +97,7 @@ def compute_encoder_budget( def _compute_encoder_budget_multimodal( model_config: "ModelConfig", scheduler_config: "SchedulerConfig", -) -> Tuple[int, int]: +) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 030574de2bde..6c6be01a2ff7 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple +from collections.abc import Iterable +from typing import Optional from vllm.logger import init_logger from vllm.utils import cdiv @@ -52,20 +53,20 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: DefaultDict[str, - List[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: defaultdict[str, + list[KVCacheBlock]] = defaultdict(list) # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: DefaultDict[ - str, List[BlockHashType]] = defaultdict(list) + self.req_to_block_hashes: defaultdict[ + str, list[BlockHashType]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. - self.num_cached_block: Dict[str, int] = {} + self.num_cached_block: dict[str, int] = {} self.prefix_cache_stats = PrefixCacheStats() @property @@ -88,7 +89,7 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats: return stats def get_computed_blocks( - self, request: Request) -> Tuple[List[KVCacheBlock], int]: + self, request: Request) -> tuple[list[KVCacheBlock], int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -136,8 +137,8 @@ def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[List[KVCacheBlock]] = None - ) -> Optional[List[KVCacheBlock]]: + new_computed_blocks: Optional[list[KVCacheBlock]] = None + ) -> Optional[list[KVCacheBlock]]: """Add slots for a request with new tokens to append. Args: diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 546fddf67f41..adadcab5ea10 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -3,7 +3,7 @@ from collections import deque from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, List, NamedTuple, Optional, Tuple +from typing import Any, NamedTuple, Optional from vllm.config import VllmConfig from vllm.logger import init_logger @@ -25,7 +25,7 @@ class BlockHashType(NamedTuple): # Hash value of the block in an integer. hash_value: int # Token IDs in the block. - token_ids: Tuple[int, ...] + token_ids: tuple[int, ...] # Extra keys for the block. extra_keys: Optional[Any] = None @@ -45,7 +45,7 @@ def __init__(self, interval: int = 1000): self.aggregated_query_total = 0 self.aggregated_query_hit = 0 # A deque of (requests, queries, hits) for the most recent requests. - self.query_queue: deque[Tuple[int, int, int]] = deque() + self.query_queue: deque[tuple[int, int, int]] = deque() def observe(self, stats: PrefixCacheStats): """Observe the prefix caching for a set of requests. @@ -164,7 +164,7 @@ class FreeKVCacheBlockQueue: blocks: A list of KVCacheBlock objects. """ - def __init__(self, blocks: List[KVCacheBlock]) -> None: + def __init__(self, blocks: list[KVCacheBlock]) -> None: self.num_free_blocks = len(blocks) # Initialize the doubly linked list of free blocks. @@ -233,7 +233,7 @@ def append(self, block: KVCacheBlock) -> None: block.next_free_block = None self.num_free_blocks += 1 - def get_all_free_blocks(self) -> List[KVCacheBlock]: + def get_all_free_blocks(self) -> list[KVCacheBlock]: """Get all free blocks in the free list. Mainly used for testing. Returns: @@ -264,7 +264,7 @@ def need_extra_keys(request: Request) -> bool: def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> Tuple[List[Any], int]: + start_mm_idx: int) -> tuple[list[Any], int]: """Generate extra keys related to MultiModal request for block hash computation. For multi-modal inputs, the extra keys are (mm_hash, start_offset) that indicate a mm input contained in the @@ -279,7 +279,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, Returns: A tuple of extra keys and the next multi-modal index. """ - extra_keys: List[Any] = [] + extra_keys: list[Any] = [] mm_positions, mm_hashes = request.mm_positions, request.mm_hashes if not mm_positions: @@ -331,7 +331,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, return extra_keys, curr_mm_idx -def _gen_lora_extra_hash_keys(request: Request) -> List[int]: +def _gen_lora_extra_hash_keys(request: Request) -> list[int]: """Generate extra keys related to LoRA for block hash computation. Args: @@ -348,7 +348,7 @@ def _gen_lora_extra_hash_keys(request: Request) -> List[int]: def generate_block_hash_extra_keys( request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -361,12 +361,12 @@ def generate_block_hash_extra_keys( Returns: A tuple of extra keys and the next multi-modal index. """ - mm_extra_keys: List[Any] + mm_extra_keys: list[Any] mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( request, start_token_idx, end_token_idx, start_mm_idx) - lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request) + lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) - extra_keys: List[Any] = lora_extra_keys + mm_extra_keys + extra_keys: list[Any] = lora_extra_keys + mm_extra_keys if not extra_keys: return None, new_start_mm_idx @@ -377,7 +377,7 @@ def generate_block_hash_extra_keys( def hash_block_tokens( parent_block_hash: Optional[int], curr_block_token_ids: Sequence[int], - extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType: + extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -410,7 +410,7 @@ def hash_block_tokens( def hash_request_tokens(block_size: int, - request: Request) -> List[BlockHashType]: + request: Request) -> list[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. @@ -554,8 +554,8 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, def get_kv_cache_configs(vllm_config: VllmConfig, - kv_cache_specs: List[KVCacheSpec], - available_memory: int) -> List[KVCacheConfig]: + kv_cache_specs: list[KVCacheSpec], + available_memory: int) -> list[KVCacheConfig]: """ Generates the KV cache configuration for a model TODO: support hybrid models with more than one type of KV cache. diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 87c9c0cd12b7..db14c9455a1f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -2,7 +2,8 @@ import time from collections import deque -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from collections.abc import Iterable +from typing import Optional, Union from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, SpeculativeConfig) @@ -57,24 +58,24 @@ def __init__( self.block_size = self.cache_config.block_size # req_id -> Request - self.requests: Dict[str, Request] = {} + self.requests: dict[str, Request] = {} # Priority queues for requests. - self.waiting: Deque[Request] = deque() - self.running: List[Request] = [] + self.waiting: deque[Request] = deque() + self.running: list[Request] = [] # The requests that have been scheduled and are being executed # by the executor. - self.scheduled_req_ids: Set[str] = set() + self.scheduled_req_ids: set[str] = set() # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished # requests so that they can free the cached states for those requests. # This is flushed at the end of each scheduling step. - self.finished_req_ids: Set[str] = set() + self.finished_req_ids: set[str] = set() # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> CachedRequestData - self._cached_reqs_data: Dict[str, CachedRequestData] = {} + self._cached_reqs_data: dict[str, CachedRequestData] = {} # Encoder-related. # Calculate encoder cache size if applicable @@ -108,19 +109,19 @@ def schedule(self) -> "SchedulerOutput": # chunked prefills, prefix caching, speculative decoding, # and the "jump decoding" optimization in the future. - scheduled_new_reqs: List[Request] = [] - scheduled_resumed_reqs: List[Request] = [] - scheduled_running_reqs: List[Request] = [] - preempted_reqs: List[Request] = [] + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] - req_to_new_block_ids: Dict[str, List[int]] = {} - num_scheduled_tokens: Dict[str, int] = {} + req_to_new_block_ids: dict[str, list[int]] = {} + num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. - scheduled_encoder_inputs: Dict[str, List[int]] = {} + scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_budget = self.max_num_encoder_input_tokens # Spec decode-related. - scheduled_spec_decode_tokens: Dict[str, List[int]] = {} + scheduled_spec_decode_tokens: dict[str, list[int]] = {} # For logging. scheduled_timestamp = time.monotonic() @@ -211,7 +212,7 @@ def schedule(self) -> "SchedulerOutput": encoder_budget = new_encoder_budget # Record the LoRAs in scheduled_running_reqs - requested_loras: Set[int] = set() + requested_loras: set[int] = set() if self.lora_config: requested_loras = set( req.lora_request.lora_int_id for req in scheduled_running_reqs @@ -378,7 +379,7 @@ def _make_cached_request_data( request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: List[int], + new_block_ids: list[int], resumed_from_preemption: bool, ) -> "CachedRequestData": # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating @@ -407,7 +408,7 @@ def _try_schedule_encoder_inputs( num_computed_tokens: int, num_new_tokens: int, encoder_budget: int, - ) -> Tuple[List[int], int, int]: + ) -> tuple[list[int], int, int]: """ Determine which encoder inputs need to be scheduled in the current step, and update `num_new_tokens` and encoder token budget accordingly. @@ -427,7 +428,7 @@ def _try_schedule_encoder_inputs( if not request.has_encoder_inputs(): return [], num_new_tokens, encoder_budget - encoder_inputs_to_schedule: List[int] = [] + encoder_inputs_to_schedule: list[int] = [] mm_positions = request.mm_positions assert mm_positions is not None assert len(mm_positions) > 0 @@ -482,8 +483,8 @@ def update_from_output( prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens - new_running: List[Request] = [] - outputs: List[EngineCoreOutput] = [] + new_running: list[Request] = [] + outputs: list[EngineCoreOutput] = [] # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid @@ -543,7 +544,7 @@ def update_from_output( stopped = False new_logprobs = None - new_token_ids: List[int] = [] + new_token_ids: list[int] = [] if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 47413527c32f..b6caa8b4ebf7 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from vllm.lora.request import LoRARequest @@ -15,13 +15,13 @@ class NewRequestData: req_id: str - prompt_token_ids: List[int] + prompt_token_ids: list[int] prompt: Optional[str] - mm_inputs: List["MultiModalKwargs"] - mm_hashes: List[str] - mm_positions: List["PlaceholderRange"] + mm_inputs: list["MultiModalKwargs"] + mm_hashes: list[str] + mm_positions: list["PlaceholderRange"] sampling_params: "SamplingParams" - block_ids: List[int] + block_ids: list[int] num_computed_tokens: int lora_request: Optional["LoRARequest"] @@ -29,7 +29,7 @@ class NewRequestData: def from_request( cls, request: "Request", - block_ids: List[int], + block_ids: list[int], ) -> "NewRequestData": return cls( req_id=request.request_id, @@ -53,8 +53,8 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool - new_token_ids: List[int] - new_block_ids: List[int] + new_token_ids: list[int] + new_block_ids: list[int] num_computed_tokens: int @classmethod @@ -62,8 +62,8 @@ def from_request( cls, request: "Request", resumed_from_preemption: bool, - new_token_ids: List[int], - new_block_ids: List[int], + new_token_ids: list[int], + new_block_ids: list[int], ) -> "CachedRequestData": return cls( req_id=request.request_id, @@ -77,29 +77,29 @@ def from_request( @dataclass class SchedulerOutput: - # List of the requests that are scheduled for the first time. + # list of the requests that are scheduled for the first time. # We cache the request's data in each worker process, so that we don't # need to re-send it every scheduling step. - scheduled_new_reqs: List[NewRequestData] - # List of the requests that have been scheduled before. + scheduled_new_reqs: list[NewRequestData] + # list of the requests that have been scheduled before. # Since the request's data is already cached in the worker processes, # we only send the diff to minimize the communication cost. - scheduled_cached_reqs: List[CachedRequestData] + scheduled_cached_reqs: list[CachedRequestData] # req_id -> num_scheduled_tokens # Number of tokens scheduled for each request. - num_scheduled_tokens: Dict[str, int] + num_scheduled_tokens: dict[str, int] # Total number of tokens scheduled for all requests. # Equal to sum(num_scheduled_tokens.values()) total_num_scheduled_tokens: int # req_id -> spec_token_ids # If a request does not have any spec decode tokens, it will not be # included in the dictionary. - scheduled_spec_decode_tokens: Dict[str, List[int]] + scheduled_spec_decode_tokens: dict[str, list[int]] # req_id -> encoder input indices that need processing. # E.g., if a request has [0, 1], it could mean the vision encoder needs # to process that the request's 0-th and 1-th images in the current step. - scheduled_encoder_inputs: Dict[str, List[int]] + scheduled_encoder_inputs: dict[str, list[int]] # Number of common prefix blocks for all requests. # This can be used for cascade attention. num_common_prefix_blocks: int @@ -107,7 +107,7 @@ class SchedulerOutput: # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests # so that they can free the cached states for those requests. - finished_req_ids: Set[str] - # List of (req_id, encoder_input_index) tuples. + finished_req_ids: set[str] + # list of (req_id, encoder_input_index) tuples. # Used to free the encoder cache. - free_encoder_input_ids: List[Tuple[str, int]] + free_encoder_input_ids: list[tuple[str, int]] diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 32fb3c5bd62e..cd29c2d7d57c 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -2,7 +2,7 @@ import enum import time -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import msgspec @@ -51,10 +51,10 @@ class EngineCoreRequest( # NOTE(ywang96): original text prompt is needed when a request is added to # Detokenizer, but set to None when it is added to EngineCoreClient. prompt: Optional[str] - prompt_token_ids: List[int] - mm_inputs: Optional[List[Optional[MultiModalKwargs]]] - mm_hashes: Optional[List[str]] - mm_placeholders: Optional[List[PlaceholderRange]] + prompt_token_ids: list[int] + mm_inputs: Optional[list[Optional[MultiModalKwargs]]] + mm_hashes: Optional[list[str]] + mm_placeholders: Optional[list[PlaceholderRange]] sampling_params: SamplingParams eos_token_id: Optional[int] arrival_time: float @@ -93,14 +93,14 @@ class EngineCoreOutput( gc=False): # type: ignore[call-arg] request_id: str - new_token_ids: List[int] + new_token_ids: list[int] new_logprobs: Optional[LogprobsLists] = None new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None - events: Optional[List[EngineCoreEvent]] = None + events: Optional[list[EngineCoreEvent]] = None @property def finished(self) -> bool: @@ -129,7 +129,7 @@ class EngineCoreOutputs( # e.g. columnwise layout # [num_reqs] - outputs: List[EngineCoreOutput] = [] + outputs: list[EngineCoreOutput] = [] scheduler_stats: Optional[SchedulerStats] = None timestamp: float = 0.0 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0c04e14cec2f..ab3cdc4ee295 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -2,7 +2,8 @@ import asyncio import os -from typing import AsyncGenerator, List, Mapping, Optional, Set, Type, Union +from collections.abc import AsyncGenerator, Mapping +from typing import Optional, Union import numpy as np @@ -39,7 +40,7 @@ class AsyncLLM(EngineClient): def __init__( self, vllm_config: VllmConfig, - executor_class: Type[Executor], + executor_class: type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, input_registry: InputRegistry = INPUT_REGISTRY, @@ -54,7 +55,7 @@ def __init__( self.log_requests = log_requests self.log_stats = log_stats - self.stat_loggers: List[StatLoggerBase] = [] + self.stat_loggers: list[StatLoggerBase] = [] if self.log_stats: self.stat_loggers.extend([ LoggingStatLogger(), @@ -400,7 +401,7 @@ async def remove_lora(self, lora_id: int) -> bool: """Remove an already loaded LoRA adapter.""" return await self.engine_core.remove_lora_async(lora_id) - async def list_loras(self) -> Set[int]: + async def list_loras(self) -> set[int]: """List all registered adapters.""" return await self.engine_core.list_loras_async() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 041896f1c7cc..b9bf8fac40f6 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,7 +7,7 @@ from concurrent.futures import Future from inspect import isclass, signature from multiprocessing.connection import Connection -from typing import Any, List, Optional, Set, Tuple, Type +from typing import Any, Optional import msgspec import psutil @@ -42,7 +42,7 @@ class EngineCore: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[Executor], + executor_class: type[Executor], log_stats: bool, ): assert vllm_config.model_config.runner_type != "pooling" @@ -80,7 +80,7 @@ def __init__( # schedule and execute batches, and is required by pipeline parallelism # to eliminate pipeline bubbles. self.batch_queue_size = self.model_executor.max_concurrent_batches - self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput], + self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput], SchedulerOutput]]] = None if self.batch_queue_size > 1: logger.info("Batch queue is enabled with size %d", @@ -88,7 +88,7 @@ def __init__( self.batch_queue = queue.Queue(self.batch_queue_size) def _initialize_kv_caches(self, - vllm_config: VllmConfig) -> Tuple[int, int]: + vllm_config: VllmConfig) -> tuple[int, int]: start = time.time() # Get all kv cache needed by the model @@ -134,7 +134,7 @@ def add_request(self, request: EngineCoreRequest): self.scheduler.add_request(req) - def abort_requests(self, request_ids: List[str]): + def abort_requests(self, request_ids: list[str]): """Abort requests from the scheduler.""" # TODO: The scheduler doesn't really need to know the @@ -228,7 +228,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.model_executor.remove_lora(lora_id) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: return self.model_executor.list_loras() def pin_lora(self, lora_id: int) -> bool: @@ -244,7 +244,7 @@ def __init__( output_path: str, ready_pipe: Connection, vllm_config: VllmConfig, - executor_class: Type[Executor], + executor_class: type[Executor], log_stats: bool, ): super().__init__(vllm_config, executor_class, log_stats) @@ -254,7 +254,7 @@ def __init__( # and to overlap some serialization/deserialization with the # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[Tuple[EngineCoreRequestType, + self.input_queue: queue.Queue[tuple[EngineCoreRequestType, Any]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() threading.Thread(target=self.process_input_socket, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9f36e11d12d7..cdce14afe0b3 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -10,7 +10,7 @@ from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import Any, Optional, Union import zmq import zmq.asyncio @@ -48,7 +48,7 @@ def make_client( multiprocess_mode: bool, asyncio_mode: bool, vllm_config: VllmConfig, - executor_class: Type[Executor], + executor_class: type[Executor], log_stats: bool, ) -> "EngineCoreClient": @@ -94,7 +94,7 @@ def execute_dummy_batch(self) -> None: async def execute_dummy_batch_async(self) -> None: raise NotImplementedError - def abort_requests(self, request_ids: List[str]) -> None: + def abort_requests(self, request_ids: list[str]) -> None: raise NotImplementedError def add_lora(self, lora_request: LoRARequest) -> bool: @@ -103,7 +103,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: raise NotImplementedError def pin_lora(self, lora_id: int) -> bool: @@ -127,7 +127,7 @@ async def sleep_async(self, level: int = 1) -> None: async def wake_up_async(self) -> None: raise NotImplementedError - async def abort_requests_async(self, request_ids: List[str]) -> None: + async def abort_requests_async(self, request_ids: list[str]) -> None: raise NotImplementedError async def add_lora_async(self, lora_request: LoRARequest) -> bool: @@ -136,7 +136,7 @@ async def add_lora_async(self, lora_request: LoRARequest) -> bool: async def remove_lora_async(self, lora_id: int) -> bool: raise NotImplementedError - async def list_loras_async(self) -> Set[int]: + async def list_loras_async(self) -> set[int]: raise NotImplementedError async def pin_lora_async(self, lora_id: int) -> bool: @@ -162,7 +162,7 @@ def get_output(self) -> EngineCoreOutputs: def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) - def abort_requests(self, request_ids: List[str]) -> None: + def abort_requests(self, request_ids: list[str]) -> None: if len(request_ids) > 0: self.engine_core.abort_requests(request_ids) @@ -190,7 +190,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.engine_core.remove_lora(lora_id) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: return self.engine_core.list_loras() def pin_lora(self, lora_id: int) -> bool: @@ -239,7 +239,7 @@ def __init__( self, asyncio_mode: bool, vllm_config: VllmConfig, - executor_class: Type[Executor], + executor_class: type[Executor], log_stats: bool, ): # The child processes will send SIGUSR1 when unrecoverable @@ -293,14 +293,14 @@ def sigusr1_handler(signum, frame): self.output_socket = resources.output_socket self.input_socket = resources.input_socket - self.utility_results: Dict[int, AnyFuture] = {} + self.utility_results: dict[int, AnyFuture] = {} def shutdown(self): self._finalizer() def _process_utility_output(output: UtilityOutput, - utility_results: Dict[int, AnyFuture]): + utility_results: dict[int, AnyFuture]): """Set the result from a utility method in the waiting future""" future = utility_results.pop(output.call_id) if output.failure_message is not None: @@ -312,7 +312,7 @@ def _process_utility_output(output: UtilityOutput, class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor], + def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool): super().__init__( asyncio_mode=False, @@ -373,7 +373,7 @@ def add_request(self, request: EngineCoreRequest) -> None: request.prompt = None self._send_input(EngineCoreRequestType.ADD, request) - def abort_requests(self, request_ids: List[str]) -> None: + def abort_requests(self, request_ids: list[str]) -> None: if len(request_ids) > 0: self._send_input(EngineCoreRequestType.ABORT, request_ids) @@ -389,7 +389,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self._call_utility("remove_lora", lora_id) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: return self._call_utility("list_loras") def pin_lora(self, lora_id: int) -> bool: @@ -408,7 +408,7 @@ def execute_dummy_batch(self) -> None: class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor], + def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool): super().__init__( asyncio_mode=True, @@ -471,7 +471,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: request.prompt = None await self._send_input(EngineCoreRequestType.ADD, request) - async def abort_requests_async(self, request_ids: List[str]) -> None: + async def abort_requests_async(self, request_ids: list[str]) -> None: if len(request_ids) > 0: await self._send_input(EngineCoreRequestType.ABORT, request_ids) @@ -496,7 +496,7 @@ async def add_lora_async(self, lora_request: LoRARequest) -> bool: async def remove_lora_async(self, lora_id: int) -> bool: return await self._call_utility_async("remove_lora", lora_id) - async def list_loras_async(self) -> Set[int]: + async def list_loras_async(self) -> set[int]: return await self._call_utility_async("list_loras") async def pin_lora_async(self, lora_id: int) -> bool: diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 629da06f4925..4a1636f49495 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import List, Optional +from typing import Optional from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger @@ -17,12 +17,12 @@ class IncrementalDetokenizer: # Generation data output_text: str - tokens: List[str] - token_ids: List[int] + tokens: list[str] + token_ids: list[int] prompt_len: int # Stop strings - stop: List[str] + stop: list[str] include_stop_str_in_output: bool # Metadata for incremental detokenization @@ -41,7 +41,7 @@ class IncrementalDetokenizer: _last_output_text_offset: int = 0 @property - def output_token_ids(self) -> List[int]: + def output_token_ids(self) -> list[int]: return self.token_ids[self.prompt_len:] @classmethod @@ -84,7 +84,7 @@ def from_new_request( stop_buffer_length=stop_buffer_length, ) - def update(self, new_token_ids: List[int]) -> Optional[str]: + def update(self, new_token_ids: list[int]) -> Optional[str]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index ccf52250c1d6..2e76694a7f51 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Mapping, Optional, Set, Type, Union +from collections.abc import Mapping +from typing import Optional, Union from typing_extensions import TypeVar @@ -36,10 +37,10 @@ class LLMEngine: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[Executor], + executor_class: type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, @@ -97,7 +98,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" @@ -139,7 +140,7 @@ def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: def validate_outputs(cls, outputs, output_type): return outputs - def abort_request(self, request_ids: List[str]) -> None: + def abort_request(self, request_ids: list[str]) -> None: """Remove request_ids from EngineCore and Detokenizer.""" self.engine_core.abort_requests(request_ids) @@ -199,7 +200,7 @@ def _add_request( # 3) Add the request to EngineCore. self.engine_core.add_request(request) - def step(self) -> List[RequestOutput]: + def step(self) -> list[RequestOutput]: if self.should_execute_dummy_batch: self.should_execute_dummy_batch = False @@ -241,7 +242,7 @@ def wake_up(self): def get_tokenizer_group( self, - group_type: Type[_G] = BaseTokenizerGroup, + group_type: type[_G] = BaseTokenizerGroup, ) -> _G: tokenizer_group = self.tokenizer @@ -263,7 +264,7 @@ def remove_lora(self, lora_id: int) -> bool: """Remove an already loaded LoRA adapter.""" return self.engine_core.remove_lora(lora_id) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: """List all registered adapters.""" return self.engine_core.list_loras() diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 4622cafa4a02..7f572163ead4 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -2,7 +2,7 @@ import itertools from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Optional from vllm.logger import init_logger from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs @@ -151,12 +151,12 @@ def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: @staticmethod def _make_logprob_dict( - logprobs: List[float], - logprob_token_ids: List[int], - decoded_tokens: List[str], + logprobs: list[float], + logprob_token_ids: list[int], + decoded_tokens: list[str], rank: int, num_logprobs: int, - ) -> Dict[int, Logprob]: + ) -> dict[int, Logprob]: """Make a Logprob dictionary for a position. Args: @@ -168,7 +168,7 @@ def _make_logprob_dict( by the user (in addition to sampled logprob) Returns: - Dict[token id, Logprob] + dict[token id, Logprob] """ # We do not need a special case for the sampled token diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index a1d802bf818a..0f66f68109b1 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional from vllm.config import ModelConfig from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE @@ -68,10 +68,10 @@ def cache_hit_ratio(self, steps): def process_inputs( self, mm_data: MultiModalDataDict, - mm_hashes: Optional[List[str]], - mm_processor_kwargs: Optional[Dict[str, Any]], - precomputed_mm_inputs: Optional[List[MultiModalKwargs]], - ) -> List[MultiModalKwargs]: + mm_hashes: Optional[list[str]], + mm_processor_kwargs: Optional[dict[str, Any]], + precomputed_mm_inputs: Optional[list[MultiModalKwargs]], + ) -> list[MultiModalKwargs]: if precomputed_mm_inputs is None: image_inputs = mm_data["image"] if not isinstance(image_inputs, list): @@ -88,7 +88,7 @@ def process_inputs( # Process each image input separately, so that later we can schedule # them in a fine-grained manner. # Apply caching (if enabled) and reuse precomputed inputs (if provided) - ret_inputs: List[MultiModalKwargs] = [] + ret_inputs: list[MultiModalKwargs] = [] for input_id in range(num_inputs): if self.mm_debug_cache_hit_ratio_steps is not None: self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) @@ -133,9 +133,9 @@ def __init__(self, model_config): def get_and_update( self, - mm_inputs: List[Optional[MultiModalKwargs]], - mm_hashes: List[str], - ) -> List[MultiModalKwargs]: + mm_inputs: list[Optional[MultiModalKwargs]], + mm_hashes: list[str], + ) -> list[MultiModalKwargs]: assert len(mm_inputs) == len(mm_hashes) if not self.use_cache: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 9ae8303df54d..22bbb8a0f5b4 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -2,7 +2,7 @@ import asyncio from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Optional, Union from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind @@ -18,8 +18,8 @@ @dataclass class OutputProcessorOutput: - request_outputs: List[RequestOutput] - reqs_to_abort: List[str] + request_outputs: list[RequestOutput] + reqs_to_abort: list[str] class RequestState: @@ -30,7 +30,7 @@ def __init__( lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], - prompt_token_ids: List[int], + prompt_token_ids: list[int], logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, arrival_time: float, @@ -90,7 +90,7 @@ def __init__( ): self.log_stats = log_stats self.tokenizer = tokenizer - self.request_states: Dict[str, RequestState] = {} + self.request_states: dict[str, RequestState] = {} self.lora_states = LoRARequestStates() def is_request_active(self, request_id: str) -> bool: @@ -104,7 +104,7 @@ def has_unfinished_requests(self) -> bool: def abort_requests( self, - request_ids: List[str], + request_ids: list[str], ) -> None: for request_id in request_ids: req_state = self.request_states.pop(request_id, None) @@ -130,7 +130,7 @@ def add_request( def process_outputs( self, - engine_core_outputs: List[EngineCoreOutput], + engine_core_outputs: list[EngineCoreOutput], engine_core_timestamp: Optional[float] = None, iteration_stats: Optional[IterationStats] = None, ) -> OutputProcessorOutput: @@ -158,8 +158,8 @@ def process_outputs( ********************************************************** """ - request_outputs: List[RequestOutput] = [] - reqs_to_abort: List[str] = [] + request_outputs: list[RequestOutput] = [] + reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id req_state = self.request_states.get(req_id) @@ -265,7 +265,7 @@ def _update_stats_from_finished(self, req_state: RequestState, @staticmethod def _make_request_output( request_state: RequestState, - new_token_ids: List[int], + new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], ) -> Optional[RequestOutput]: diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 5d4ea111abfc..291360771b54 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import AsyncGenerator, Mapping from copy import copy -from typing import (AsyncGenerator, Dict, List, Mapping, Optional, Protocol, - Tuple, Union) +from typing import Optional, Protocol, Union from vllm.inputs import PromptType from vllm.lora.request import LoRARequest @@ -137,7 +137,7 @@ def _get_final_request_output(self) -> RequestOutput: key=lambda x: x.index) return self.request_output - def get_child_info(self, index: int) -> Tuple[str, SamplingParams]: + def get_child_info(self, index: int) -> tuple[str, SamplingParams]: """Get child request ID and sampling params. Args: @@ -237,9 +237,9 @@ class SyncParallelSamplingManager: def __init__(self): # Parent req ID -> parent request manager - self.parent_reqs: Dict[str, ParallelSamplingRequest] = {} + self.parent_reqs: dict[str, ParallelSamplingRequest] = {} # Child req ID -> (child req index, parent req ID) - self.child_reqs: Dict[str, Tuple[int, str]] = {} + self.child_reqs: dict[str, tuple[int, str]] = {} def _register_parent_request(self, req: ParallelSamplingRequest) -> None: """Register parallel sampling parent request.""" @@ -299,8 +299,8 @@ def add_request_parallel_sampling( def step( self, - outputs: List[RequestOutput], - ) -> List[RequestOutput]: + outputs: list[RequestOutput], + ) -> list[RequestOutput]: """Build parallel sampling request outputs. Extract child request outputs, aggregate them @@ -355,7 +355,7 @@ async def generate_parallel_sampling_async( parent_req = ParallelSamplingRequest(request_id, sampling_params) # Aggregate generators for n child requests - gens: List[AsyncGenerator[RequestOutput, None]] = [] + gens: list[AsyncGenerator[RequestOutput, None]] = [] for idx in range(parent_req.n): child_req_id, child_params = parent_req.get_child_info(idx) child_gen = generate( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 2547cebaede7..3a3fc69e53e4 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import time -from typing import Mapping, Optional, Union +from collections.abc import Mapping +from typing import Optional, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 11002ad0022d..aa6ae83c26ea 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from concurrent.futures import Future -from typing import List, Type, Union +from typing import Union import torch import torch.distributed as dist @@ -22,8 +22,8 @@ class Executor(ExecutorBase): For methods shared by v0 and v1, define them in ExecutorBase""" @staticmethod - def get_class(vllm_config: VllmConfig) -> Type["Executor"]: - executor_class: Type[Executor] + def get_class(vllm_config: VllmConfig) -> type["Executor"]: + executor_class: type[Executor] parallel_config = vllm_config.parallel_config distributed_executor_backend = ( parallel_config.distributed_executor_backend) @@ -53,7 +53,7 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: return executor_class def initialize_from_config(self, - kv_cache_configs: List[KVCacheConfig]) -> None: + kv_cache_configs: list[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. @@ -69,7 +69,7 @@ def determine_available_memory(self) -> int: # in bytes # operators can be applied to all workers. return min(output) - def get_kv_cache_specs(self) -> List[KVCacheSpec]: + def get_kv_cache_specs(self) -> list[KVCacheSpec]: output = self.collective_rpc("get_kv_cache_spec") return output diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 25b5c1c1c2fc..b2cbba518036 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -10,7 +10,7 @@ from enum import Enum, auto from functools import partial from multiprocessing.process import BaseProcess -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import cloudpickle import psutil @@ -77,7 +77,7 @@ def sigusr1_handler(signum, frame): scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers - self.workers: List[WorkerProcHandle] = [] + self.workers: list[WorkerProcHandle] = [] for rank in range(self.world_size): worker = WorkerProc.make_worker_process(self.vllm_config, rank, rank, @@ -94,8 +94,8 @@ def sigusr1_handler(signum, frame): def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict] = None) -> List[Any]: + args: tuple = (), + kwargs: Optional[dict] = None) -> list[Any]: start_time = time.monotonic() kwargs = kwargs or {} @@ -208,7 +208,7 @@ def __init__( self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) # TODO: move `init_worker` to executor level as a collective rpc call - all_kwargs: List[Dict] = [ + all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] all_kwargs[rank] = { diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index eddfb5949ebe..dfef1039fce2 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List import torch @@ -74,7 +73,7 @@ def bytes_for_tokens(self, num_tokens: int) -> int: return cdiv(num_tokens, self.block_size) * self.page_size_bytes -KVCacheSpec = Dict[str, KVCacheSpecBase] +KVCacheSpec = dict[str, KVCacheSpecBase] @dataclass @@ -95,7 +94,7 @@ class KVCacheConfig: """The number of KV cache blocks""" num_blocks: int """layer_name -> how to initialize KV cache for that layer""" - tensors: Dict[str, KVCacheTensor] + tensors: dict[str, KVCacheTensor] """ A list of kv-cache groups. Each group includes a set of layers with the same kv-cache spec, and the total page_size of layers inside a group @@ -108,6 +107,6 @@ class KVCacheConfig: 3. (not implemented yet) A model with 2 full attention layers and 4 sliding window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). """ - groups: List[List[str]] + groups: list[list[str]] """the KVCacheSpec of the model""" kv_cache_spec: KVCacheSpec diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 40dfc5661672..5a2a1c30a9d5 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -2,7 +2,7 @@ import time from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Optional import numpy as np import prometheus_client @@ -35,8 +35,8 @@ def _reset(self, now): self.last_log_time = now # Tracked stats over current local logging interval. - self.num_prompt_tokens: List[int] = [] - self.num_generation_tokens: List[int] = [] + self.num_prompt_tokens: list[int] = [] + self.num_generation_tokens: list[int] = [] # Prefix cache metrics. TODO: Make the interval configurable. self.prefix_caching_metrics = PrefixCachingMetrics() @@ -52,7 +52,7 @@ def _track_iteration_stats(self, iteration_stats: IterationStats): self.num_generation_tokens.append( iteration_stats.num_generation_tokens) - def _get_throughput(self, tracked_stats: List[int], now: float) -> float: + def _get_throughput(self, tracked_stats: list[int], now: float) -> float: # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) @@ -147,7 +147,7 @@ def __init__(self, vllm_config: VllmConfig): documentation="Number of generation tokens processed.", labelnames=labelnames).labels(*labelvalues) - self.counter_request_success: Dict[FinishReason, + self.counter_request_success: dict[FinishReason, prometheus_client.Counter] = {} counter_request_success_base = prometheus_client.Counter( name="vllm:request_success_total", @@ -338,14 +338,14 @@ def _unregister_vllm_metrics(): prometheus_client.REGISTRY.unregister(collector) -def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: +def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by mantissa values until the value exceeds the specified maximum. """ exponent = 0 - buckets: List[int] = [] + buckets: list[int] = [] while True: for m in mantissa_lst: value = m * 10**exponent @@ -356,7 +356,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: exponent += 1 -def build_1_2_5_buckets(max_value: int) -> List[int]: +def build_1_2_5_buckets(max_value: int) -> list[int]: """ Example: >>> build_1_2_5_buckets(100) @@ -365,7 +365,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: return build_buckets([1, 2, 5], max_value) -def build_cudagraph_buckets(vllm_config: VllmConfig) -> List[int]: +def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]: if not vllm_config.model_config.enforce_eager: buckets = vllm_config.compilation_config.\ cudagraph_capture_sizes.copy() diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 30f460e5a691..625edb607467 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -2,7 +2,7 @@ import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from vllm.outputs import RequestOutput @@ -39,8 +39,8 @@ class SchedulerStats: @dataclass class LoRAStats: - waiting_requests: Set[str] = field(default_factory=set) - running_requests: Set[str] = field(default_factory=set) + waiting_requests: set[str] = field(default_factory=set) + running_requests: set[str] = field(default_factory=set) @dataclass @@ -81,11 +81,11 @@ def __init__(self): self.num_generation_tokens = 0 self.num_prompt_tokens = 0 self.num_preempted_reqs = 0 - self.finished_requests: List[FinishedRequestStats] = [] - self.time_to_first_tokens_iter: List[float] = [] - self.time_per_output_tokens_iter: List[float] = [] - self.waiting_lora_adapters: Dict[str, int] = {} - self.running_lora_adapters: Dict[str, int] = {} + self.finished_requests: list[FinishedRequestStats] = [] + self.time_to_first_tokens_iter: list[float] = [] + self.time_per_output_tokens_iter: list[float] = [] + self.waiting_lora_adapters: dict[str, int] = {} + self.running_lora_adapters: dict[str, int] = {} def _time_since(self, start: float) -> float: """Calculate an interval relative to this iteration's timestamp.""" @@ -132,7 +132,7 @@ def update_from_output(self, output: "EngineCoreOutput", if num_new_generation_tokens > 0: req_stats.last_token_ts = engine_core_timestamp - def update_from_events(self, req_id: str, events: List["EngineCoreEvent"], + def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], is_prefilling: bool, req_stats: RequestStateStats, lora_stats: Optional[LoRAStats]): # Avoid circular dependency @@ -185,7 +185,7 @@ class LoRARequestStates: """Per-LoRA request state stats.""" def __init__(self): - self.lora_name_to_stats: Dict[str, LoRAStats] = {} + self.lora_name_to_stats: dict[str, LoRAStats] = {} def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]: if req_state.lora_name is None: diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f461d52cc984..dc3ad402e066 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, NamedTuple, Optional +from typing import NamedTuple, Optional import torch @@ -9,11 +9,11 @@ class LogprobsLists(NamedTuple): # [num_reqs, max_num_logprobs + 1] - logprob_token_ids: List[List[int]] + logprob_token_ids: list[list[int]] # [num_reqs, max_num_logprobs + 1] - logprobs: List[List[float]] + logprobs: list[list[float]] # [num_reqs] - sampled_token_ranks: List[int] + sampled_token_ranks: list[int] def slice(self, start: int, end: int): return LogprobsLists( @@ -52,23 +52,23 @@ class SamplerOutput: # ModelRunnerOutput is serialized and sent to the scheduler process. -# This is expensive for torch.Tensor so prefer to use List instead. +# This is expensive for torch.Tensor so prefer to use list instead. @dataclass class ModelRunnerOutput: # [num_reqs] - req_ids: List[str] + req_ids: list[str] # req_id -> index - req_id_to_index: Dict[str, int] + req_id_to_index: dict[str, int] # num_reqs x num_generated_tokens # num_generated_tokens is the number of tokens # generated in the current step. It can be different for # each request due to speculative/jump decoding. - sampled_token_ids: List[List[int]] + sampled_token_ids: list[list[int]] # num_reqs x num_spec_tokens - spec_token_ids: Optional[List[List[int]]] + spec_token_ids: Optional[list[list[int]]] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] @@ -79,4 +79,4 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs] # [prompt_len] - prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 52d7faeeb066..99df54734836 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Optional, Union from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -20,10 +20,10 @@ def __init__( self, request_id: str, prompt: Optional[str], - prompt_token_ids: List[int], - multi_modal_inputs: Optional[List["MultiModalKwargs"]], - multi_modal_hashes: Optional[List[str]], - multi_modal_placeholders: Optional[List["PlaceholderRange"]], + prompt_token_ids: list[int], + multi_modal_inputs: Optional[list["MultiModalKwargs"]], + multi_modal_hashes: Optional[list[str]], + multi_modal_placeholders: Optional[list["PlaceholderRange"]], sampling_params: SamplingParams, eos_token_id: Optional[int], arrival_time: float, @@ -36,7 +36,7 @@ def __init__( self.lora_request = lora_request self.status = RequestStatus.WAITING - self.events: List[EngineCoreEvent] = [] + self.events: list[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens @@ -44,15 +44,15 @@ def __init__( self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.num_prompt_tokens = len(self.prompt_token_ids) - self._output_token_ids: List[int] = [] - self._all_token_ids: List[int] = self.prompt_token_ids.copy() - self.spec_token_ids: List[int] = [] + self._output_token_ids: list[int] = [] + self._all_token_ids: list[int] = self.prompt_token_ids.copy() + self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 # Multi-modal related self.mm_positions = multi_modal_placeholders or [] self.mm_inputs = multi_modal_inputs or [] - self.mm_hashes: List[str] = multi_modal_hashes or [] + self.mm_hashes: list[str] = multi_modal_hashes or [] # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) @@ -89,7 +89,7 @@ def scheduled(self, timestamp: Optional[float] = None) -> None: EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED, timestamp)) - def take_events(self) -> Optional[List[EngineCoreEvent]]: + def take_events(self) -> Optional[list[EngineCoreEvent]]: if not self.events: return None events, self.events = self.events, [] @@ -97,7 +97,7 @@ def take_events(self) -> Optional[List[EngineCoreEvent]]: def append_output_token_ids( self, - token_ids: Union[int, List[int]], + token_ids: Union[int, list[int]], ) -> None: if isinstance(token_ids, int): token_ids = [token_ids] diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index b757a1dc60c7..55d9739b8007 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Tuple +from typing import Optional import torch @@ -17,7 +17,7 @@ class SamplingMetadata: top_k: Optional[torch.Tensor] min_p: Optional[torch.Tensor] - generators: Dict[int, torch.Generator] + generators: dict[int, torch.Generator] # None means no logprobs, 0 means sampled token logprobs only max_num_logprobs: Optional[int] @@ -28,12 +28,12 @@ class SamplingMetadata: presence_penalties: torch.Tensor repetition_penalties: torch.Tensor - output_token_ids: List[List[int]] + output_token_ids: list[list[int]] # req_index -> (min_tokens, stop_token_ids) - min_tokens: Dict[int, Tuple[int, Set[int]]] + min_tokens: dict[int, tuple[int, set[int]]] - logit_bias: List[Optional[Dict[int, float]]] + logit_bias: list[Optional[dict[int, float]]] # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, # vocab size). diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index 8d9f6529fa0b..ed05e3f48401 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Set, Tuple - import torch from vllm.model_executor.layers.utils import apply_penalties @@ -9,13 +7,13 @@ def apply_min_token_penalties( - logits: torch.Tensor, output_token_ids: List[List[int]], - min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None: + logits: torch.Tensor, output_token_ids: list[list[int]], + min_tokens: dict[int, tuple[int, set[int]]]) -> None: """ Applies minimum token penalty by setting the logits of the stop tokens to -inf. """ - min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] + min_tokens_logits_to_penalize: list[tuple[int, int]] = [] for index, (min_token, stop_token_ids) in min_tokens.items(): if len(output_token_ids[index]) < min_token: for stop_token_id in stop_token_ids: @@ -30,7 +28,7 @@ def apply_all_penalties( presence_penalties: torch.Tensor, frequency_penalties: torch.Tensor, repetition_penalties: torch.Tensor, - output_token_ids: List[List[int]], + output_token_ids: list[list[int]], ) -> torch.Tensor: """ Applies presence, frequency and repetition penalties to the logits. @@ -43,7 +41,7 @@ def apply_all_penalties( repetition_penalties) -def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, +def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, device: torch.device) -> torch.Tensor: """ Convert the different list data structures to tensors. diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 78c88ad8b830..1bb950be822c 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Optional +from typing import Optional import torch import torch.nn as nn @@ -54,7 +54,7 @@ def __init__(self): def forward_native( self, logits: torch.Tensor, - generators: Dict[int, torch.Generator], + generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: @@ -66,7 +66,7 @@ def forward_native( def forward_cuda( self, logits: torch.Tensor, - generators: Dict[int, torch.Generator], + generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: @@ -117,7 +117,7 @@ def apply_top_k_top_p( def random_sample( probs: torch.Tensor, - generators: Dict[int, torch.Generator], + generators: dict[int, torch.Generator], ) -> torch.Tensor: """Randomly sample from the probabilities. @@ -143,7 +143,7 @@ def flashinfer_sample( probs: torch.Tensor, k: Optional[torch.Tensor], p: Optional[torch.Tensor], - generators: Dict[int, torch.Generator], + generators: dict[int, torch.Generator], ) -> torch.Tensor: """Sample from the probabilities using FlashInfer. diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 2e3927345eb5..80a4b24186ab 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List import torch import torch.nn as nn @@ -54,7 +53,7 @@ def __init__(self): else: self.forward_method = self.forward_native - def forward(self, draft_token_ids: List[List[int]], + def forward(self, draft_token_ids: list[list[int]], target_probs: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: if not sampling_metadata.all_greedy: @@ -66,7 +65,7 @@ def forward(self, draft_token_ids: List[List[int]], def flashinfer_sample( self, - draft_token_ids: List[List[int]], + draft_token_ids: list[list[int]], target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: @@ -119,7 +118,7 @@ def flashinfer_sample( # TODO: The following method can be optimized for better performance. def forward_native( self, - draft_token_ids: List[List[int]], + draft_token_ids: list[list[int]], target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: diff --git a/vllm/v1/stats/common.py b/vllm/v1/stats/common.py index 09d382638bff..46818977dae5 100644 --- a/vllm/v1/stats/common.py +++ b/vllm/v1/stats/common.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from dataclasses import field as dataclass_field from enum import IntEnum -from typing import ClassVar, Dict, List, Optional, Set +from typing import ClassVar, Optional import msgspec from msgspec import field as msgspec_field @@ -78,7 +78,7 @@ class Type(IntEnum): ▼ FINISHED (All could go to FINISHED) """ - _VALID_TRANSITIONS: ClassVar[Dict[Type, Set[Type]]] = { + _VALID_TRANSITIONS: ClassVar[dict[Type, set[Type]]] = { Type.ARRIVED: { Type.INPUT_PROCESSED, Type.FINISHED, @@ -140,7 +140,7 @@ class Type(IntEnum): finish_reason: Optional[str] = None # Non-optional fields for each update type. - _REQUIRED_FIELDS: ClassVar[Dict[Type, List[str]]] = { + _REQUIRED_FIELDS: ClassVar[dict[Type, list[str]]] = { Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"], Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"], Type.DETOKENIZED: ["num_new_tokens"], @@ -218,13 +218,13 @@ class RequestStats: # 2. the request was preempted and resumed. It is equivalent to running # a prefill of the original prefill tokens + generated output tokens # before preemption. - prefill_start_ts_s_lst: List[float] = dataclass_field(default_factory=list) + prefill_start_ts_s_lst: list[float] = dataclass_field(default_factory=list) # A list of timestamps when a token is decoded by the engine core. - decoding_ts_s_lst: List[float] = dataclass_field(default_factory=list) + decoding_ts_s_lst: list[float] = dataclass_field(default_factory=list) # A sorted list of timestamps for each output token. - output_token_ts_s_lst: List[float] = dataclass_field(default_factory=list) + output_token_ts_s_lst: list[float] = dataclass_field(default_factory=list) # First token's timestamp. first_token_ts_s: Optional[float] = None @@ -241,7 +241,7 @@ class RequestStats: # metric to measure the impact of preemption other than observation of # large P99 TPOT. Ideally we could quantify the impact of preemption by # measuring the number of tokens re-computed due to preemption. - preempted_ts_s_lst: List[float] = dataclass_field(default_factory=list) + preempted_ts_s_lst: list[float] = dataclass_field(default_factory=list) # Timestamp when the request was finished at the engine core. finished_ts_s: Optional[float] = None @@ -308,7 +308,7 @@ def decode_latency_s(self) -> Optional[float]: return self.e2e_latency_s - self.first_token_latency_s @property - def output_token_latency_s_lst(self) -> List[float]: + def output_token_latency_s_lst(self) -> list[float]: if len(self.output_token_ts_s_lst) == 0: return [] latency_s_lst = [] @@ -442,7 +442,7 @@ class EngineCoreStatsSnapshot( default_factory=SchedulerStats) # Per request stats updates. - requests_stats_updates: List[RequestStatsUpdate] = msgspec_field( + requests_stats_updates: list[RequestStatsUpdate] = msgspec_field( default_factory=list) # Engine core's queue stats. diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 62271255b0c0..8e1fb18cca05 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -5,8 +5,8 @@ import weakref from collections import defaultdict from collections.abc import Sequence -from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List, - Optional, TypeVar, Union, overload) +from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) import torch @@ -24,7 +24,7 @@ class ConstantList(Generic[T], Sequence): - def __init__(self, x: List[T]) -> None: + def __init__(self, x: list[T]) -> None: self._x = x def append(self, item): @@ -57,10 +57,10 @@ def __getitem__(self, item: int) -> T: ... @overload - def __getitem__(self, s: slice, /) -> List[T]: + def __getitem__(self, s: slice, /) -> list[T]: ... - def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]: + def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]: return self._x[item] @overload @@ -71,7 +71,7 @@ def __setitem__(self, item: int, value: T): def __setitem__(self, s: slice, value: T, /): ... - def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]): + def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]): raise Exception("Cannot set item in a constant list") def __delitem__(self, item): @@ -99,7 +99,7 @@ def __init__( output_path: str, process_name: str, target_fn: Callable, - process_kwargs: Dict[Any, Any], + process_kwargs: dict[Any, Any], ): context = get_mp_context() reader, writer = context.Pipe(duplex=False) @@ -146,9 +146,9 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): def bind_kv_cache( - kv_caches: Dict[str, torch.Tensor], - forward_context: Dict[str, "Attention"], - runner_kv_caches: List[torch.Tensor], + kv_caches: dict[str, torch.Tensor], + forward_context: dict[str, "Attention"], + runner_kv_caches: list[torch.Tensor], ) -> None: """ Bind the allocated KV cache to both ModelRunner and forward context so diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 830cca104ddb..7d4082b73992 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import numpy as np import torch @@ -40,7 +38,7 @@ def __init__( def append_row( self, - block_ids: List[int], + block_ids: list[int], row_idx: int, ) -> None: if not block_ids: @@ -50,7 +48,7 @@ def append_row( self.num_blocks_per_row[row_idx] += num_blocks self.block_table_np[row_idx, start:start + num_blocks] = block_ids - def add_row(self, block_ids: List[int], row_idx: int) -> None: + def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 self.append_row(block_ids, row_idx) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 788a35221fe4..b0b218d92b92 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -2,7 +2,7 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast +from typing import TYPE_CHECKING, Optional, cast import numpy as np import torch @@ -24,16 +24,16 @@ class CachedRequestState: req_id: str - prompt_token_ids: List[int] + prompt_token_ids: list[int] prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] + mm_inputs: list[MultiModalKwargs] + mm_positions: list["PlaceholderRange"] sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: List[int] + block_ids: list[int] num_computed_tokens: int - output_token_ids: List[int] + output_token_ids: list[int] mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None @@ -63,8 +63,8 @@ def __init__( self.pin_memory = pin_memory self.vocab_size = vocab_size - self._req_ids: List[Optional[str]] = [] - self.req_id_to_index: Dict[str, int] = {} + self._req_ids: list[Optional[str]] = [] + self.req_id_to_index: dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. # Find a way to reduce the CPU memory usage. @@ -106,8 +106,8 @@ def __init__( device="cpu", pin_memory=pin_memory) self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() + self.greedy_reqs: set[str] = set() + self.random_reqs: set[str] = set() self.top_p = torch.empty((max_num_reqs, ), dtype=torch.float32, @@ -117,7 +117,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() + self.top_p_reqs: set[str] = set() self.top_k = torch.empty((max_num_reqs, ), dtype=torch.int32, @@ -127,7 +127,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() + self.top_k_reqs: set[str] = set() self.min_p = torch.empty((max_num_reqs, ), dtype=torch.float32, @@ -137,7 +137,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.min_p_cpu = self.min_p_cpu_tensor.numpy() - self.min_p_reqs: Set[str] = set() + self.min_p_reqs: set[str] = set() # Frequency penalty related data structures self.frequency_penalties = torch.empty((max_num_reqs, ), @@ -150,7 +150,7 @@ def __init__( pin_memory=pin_memory) self.frequency_penalties_cpu = \ self.frequency_penalties_cpu_tensor.numpy() - self.frequency_penalties_reqs: Set[str] = set() + self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures self.presence_penalties = torch.empty((max_num_reqs, ), @@ -162,7 +162,7 @@ def __init__( pin_memory=pin_memory) self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( ) - self.presence_penalties_reqs: Set[str] = set() + self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures self.repetition_penalties = torch.empty((max_num_reqs, ), @@ -175,43 +175,43 @@ def __init__( pin_memory=pin_memory) self.repetition_penalties_cpu = \ self.repetition_penalties_cpu_tensor.numpy() - self.repetition_penalties_reqs: Set[str] = set() + self.repetition_penalties_reqs: set[str] = set() # req_index -> (min_tokens, stop_token_ids) - self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {} + self.min_tokens: dict[int, tuple[int, set[int]]] = {} # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), dtype=np.int32) - self.lora_id_to_request_ids: Dict[int, Set[str]] = {} - self.lora_id_to_lora_request: Dict[int, LoRARequest] = {} + self.lora_id_to_request_ids: dict[int, set[str]] = {} + self.lora_id_to_lora_request: dict[int, LoRARequest] = {} # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. - self.generators: Dict[int, torch.Generator] = {} + self.generators: dict[int, torch.Generator] = {} - self.num_logprobs: Dict[str, int] = {} + self.num_logprobs: dict[str, int] = {} # NOTE(rob): num_prompt_logprobs only includes reqs # that are currently in the prefill phase. - self.num_prompt_logprobs: Dict[str, int] = {} + self.num_prompt_logprobs: dict[str, int] = {} - self.logit_bias: List[Optional[Dict[int, + self.logit_bias: list[Optional[dict[int, float]]] = [None] * max_num_reqs - self.has_allowed_token_ids: Set[str] = set() + self.has_allowed_token_ids: set[str] = set() self.allowed_token_ids_mask: Optional[torch.Tensor] = None self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None - self.req_output_token_ids: List[Optional[List[int]]] = [] + self.req_output_token_ids: list[Optional[list[int]]] = [] # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() @property - def req_ids(self) -> List[str]: + def req_ids(self) -> list[str]: # None elements should only be present transiently # while performing state updates to the batch. - return cast(List[str], self._req_ids) + return cast(list[str], self._req_ids) def add_request( self, @@ -417,7 +417,7 @@ def swap_states(self, i1: int, i2: int) -> None: self.logit_bias[i2], self.logit_bias[i1] self.block_table.swap_row(i1, i2) - def condense(self, empty_req_indices: List[int]) -> None: + def condense(self, empty_req_indices: list[int]) -> None: num_reqs = self.num_reqs if num_reqs == 0: # The batched states are empty. @@ -550,7 +550,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata: frequency_penalties=self.frequency_penalties[:num_reqs], presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], - output_token_ids=cast(List[List[int]], self.req_output_token_ids), + output_token_ids=cast(list[list[int]], self.req_output_token_ids), min_tokens=self.min_tokens, no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], @@ -577,7 +577,7 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def make_lora_inputs( self, num_scheduled_tokens: np.ndarray - ) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]: + ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: """ Given the num_scheduled_tokens for each request in the batch, return datastructures used to activate the current LoRAs. @@ -593,7 +593,7 @@ def make_lora_inputs( prompt_lora_mapping = tuple(req_lora_mapping) token_lora_mapping = tuple( req_lora_mapping.repeat(num_scheduled_tokens)) - active_lora_requests: Set[LoRARequest] = set( + active_lora_requests: set[LoRARequest] = set( self.lora_id_to_lora_request.values()) return prompt_lora_mapping, token_lora_mapping, active_lora_requests diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6785d6684269..4a1fb0514c3f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,7 +3,7 @@ import gc import time import weakref -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import numpy as np import torch @@ -135,9 +135,9 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model - self.kv_caches: List[torch.Tensor] = [] + self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) - self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} # Set up speculative decoding. self.use_spec_decode = False @@ -158,7 +158,7 @@ def __init__( ) # Request states. - self.requests: Dict[str, CachedRequestState] = {} + self.requests: dict[str, CachedRequestState] = {} # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, @@ -274,7 +274,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # then resubmitted with the same ID. In this case, we treat them as two # distinct requests - clearing the cached states for the first request # and handling the second as a new request. - removed_req_indices: List[int] = [] + removed_req_indices: list[int] = [] for req_id in scheduler_output.finished_req_ids: req_index = self.input_batch.remove_request(req_id) if req_index is not None: @@ -305,7 +305,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: assert req_index is not None removed_req_indices.append(req_index) - req_ids_to_add: List[str] = [] + req_ids_to_add: list[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id @@ -446,7 +446,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: + ) -> tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -774,8 +774,8 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): return # Batch the multi-modal inputs. - mm_inputs: List[MultiModalKwargs] = [] - req_input_ids: List[Tuple[str, int]] = [] + mm_inputs: list[MultiModalKwargs] = [] + req_input_ids: list[tuple[str, int]] = [] for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for input_id in encoder_input_ids: @@ -819,8 +819,8 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _gather_encoder_outputs( self, scheduler_output: "SchedulerOutput", - ) -> List[torch.Tensor]: - encoder_outputs: List[torch.Tensor] = [] + ) -> list[torch.Tensor]: + encoder_outputs: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] @@ -1022,10 +1022,10 @@ def execute_model( def generate_draft_token_ids( self, - sampled_token_ids: List[List[int]], - ) -> List[List[int]]: + sampled_token_ids: list[list[int]], + ) -> list[list[int]]: # TODO(woosuk): Optimize. - draft_token_ids: List[List[int]] = [] + draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: @@ -1069,12 +1069,12 @@ def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, scheduler_output: "SchedulerOutput", - ) -> Dict[str, Optional[LogprobsTensors]]: + ) -> dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} - prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} # Since prompt logprobs are a rare feature, prioritize simple, # maintainable loop over optimal performance. @@ -1365,7 +1365,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: "Hybrid models with more than one KV cache type are not " "supported yet.") - kv_caches: Dict[str, torch.Tensor] = {} + kv_caches: dict[str, torch.Tensor] = {} for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): tensor_config = kv_cache_config.tensors[layer_name] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f681925f557e..cc6268d6569b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional, Set +from typing import TYPE_CHECKING, Optional import torch import torch.distributed @@ -243,7 +243,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: return self.model_runner.list_loras() def pin_lora(self, lora_id: int) -> bool: diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 731e758e6e74..f34aacacf3ed 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -4,7 +4,6 @@ """ from contextlib import contextmanager -from typing import Set, Tuple import numpy as np import torch.nn as nn @@ -57,9 +56,9 @@ def load_lora_model(self, model: nn.Module, model_config: ModelConfig, ) return self.lora_manager.create_lora_manager(model) - def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...], - token_lora_mapping: Tuple[int, ...], - lora_requests: Set[LoRARequest]) -> None: + def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], + token_lora_mapping: tuple[int, ...], + lora_requests: set[LoRARequest]) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") @@ -74,10 +73,10 @@ def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...], def set_active_loras(self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray) -> None: - prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs - token_lora_mapping: Tuple[int, + prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs + token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) - lora_requests: Set[LoRARequest] + lora_requests: set[LoRARequest] prompt_lora_mapping, token_lora_mapping, lora_requests = \ input_batch.make_lora_inputs(num_scheduled_tokens) return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, @@ -105,7 +104,7 @@ def maybe_profile_with_lora(self, lora_config: LoRAConfig, num_scheduled_tokens) # Make dummy lora requests - lora_requests: Set[LoRARequest] = { + lora_requests: set[LoRARequest] = { LoRARequest(lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_path="/not/a/real/path") @@ -143,7 +142,7 @@ def pin_lora(self, lora_id: int) -> bool: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.pin_adapter(lora_id) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_adapters() \ No newline at end of file diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ffa5e21ede87..30b79a26a63f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Optional, cast from unittest.mock import patch import numpy as np @@ -96,13 +96,13 @@ def __init__( ) # Request states. - self.requests: Dict[str, CachedRequestState] = {} + self.requests: dict[str, CachedRequestState] = {} # req_id -> (input_id -> encoder_output) - self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} # KV caches for forward pass - self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = [] # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. @@ -172,7 +172,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # then resubmitted with the same ID. In this case, we treat them as two # distinct requests - clearing the cached states for the first request # and handling the second as a new request. - removed_req_indices: List[int] = [] + removed_req_indices: list[int] = [] for req_id in scheduler_output.finished_req_ids: req_index = self.input_batch.remove_request(req_id) if req_index is not None: @@ -195,7 +195,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: assert req_index is not None removed_req_indices.append(req_index) - req_ids_to_add: List[str] = [] + req_ids_to_add: list[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id @@ -454,7 +454,7 @@ def execute_model( selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) # Then, let's update the cache state. - request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] + request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] @@ -474,9 +474,9 @@ def execute_model( assert all( req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) - prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None @@ -619,7 +619,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: "Hybrid models with more than one KV cache type are not " "supported yet.") - kv_caches: Dict[str, torch.Tensor] = {} + kv_caches: dict[str, torch.Tensor] = {} for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): tensor_config = kv_cache_config.tensors[layer_name] @@ -656,7 +656,7 @@ def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + kv_caches: list[tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -674,7 +674,7 @@ def forward( # [num_kv_heads, num_blocks, block_size, head_size]. To make it # work, we need to flatten the first three dimensions and modify # the slot_mapping accordingly. - # kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] + # kv_caches: list[tuple[torch.Tensor, torch.Tensor]] num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape slot_mapping = attn_metadata.slot_mapping slot_mapping = slot_mapping.flatten() diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 5dd021890d9d..0ceeeb05ac38 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """A TPU worker class.""" import os -from typing import Dict, List, Optional +from typing import Optional import torch import torch.distributed @@ -104,7 +104,7 @@ def init_device(self): self.model_runner = TPUModelRunner(self.vllm_config, self.device) def determine_available_memory(self) -> int: - kv_caches: Dict[str, torch.Tensor] = {} + kv_caches: dict[str, torch.Tensor] = {} kv_cache_spec = self.model_runner.get_kv_cache_spec() for layer_name, layer_spec in kv_cache_spec.items(): if isinstance(layer_spec, FullAttentionSpec): @@ -119,7 +119,7 @@ def determine_available_memory(self) -> int: else: raise NotImplementedError - runner_kv_caches: List[torch.Tensor] = [] + runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, From 06c6a48c7ddd14b692e9f5f602fb491b8cea1f0f Mon Sep 17 00:00:00 2001 From: Sheng Yao <30943636+realShengYao@users.noreply.github.com> Date: Mon, 3 Mar 2025 09:35:01 +0800 Subject: [PATCH 313/317] [Bugfix] Explicitly include "omp.h" for MacOS to avoid installation failure (#14051) --- csrc/cpu/cpu_types_arm.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp index 990e99f2fc06..65ffe524af73 100644 --- a/csrc/cpu/cpu_types_arm.hpp +++ b/csrc/cpu/cpu_types_arm.hpp @@ -2,6 +2,10 @@ #include #include +#if defined(__APPLE__) + #include "omp.h" +#endif + namespace vec_op { #ifdef ARM_BF16_SUPPORT From 0da18f92f8fa88e94df9b5b4e660a91094bfe2e7 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 3 Mar 2025 14:10:11 +0800 Subject: [PATCH 314/317] [Misc] duplicate code in deepseek_v2 (#14106) --- vllm/model_executor/models/deepseek_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b5409c7fe1b7..7ff61f9a1826 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -105,7 +105,6 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts - self.routed_scaling_factor = config.routed_scaling_factor if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " From 50c2cf824477ab67793bf8e9cd3c203db29e8fc4 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 3 Mar 2025 15:40:04 +0800 Subject: [PATCH 315/317] [Misc][Platform] Move use allgather to platform (#14010) Signed-off-by: Mengqing Cao --- vllm/model_executor/layers/logits_processor.py | 10 +++------- vllm/platforms/interface.py | 13 +++++++++++++ vllm/platforms/neuron.py | 4 ++++ vllm/platforms/tpu.py | 4 ++++ 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 2f39a0e87854..4a359725bad0 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -8,7 +8,6 @@ import torch.nn as nn import vllm.envs as envs -from vllm.config import get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -51,11 +50,7 @@ def __init__(self, # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - parallel_config = get_current_vllm_config().parallel_config - self.use_all_gather = current_platform.is_tpu() \ - or current_platform.is_neuron() \ - or envs.VLLM_USE_V1 \ - or parallel_config.distributed_executor_backend == "external_launcher" # noqa + self.use_all_gather = current_platform.use_all_gather() def forward( self, @@ -83,7 +78,8 @@ def forward( logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None: + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: logits = _apply_logits_processors(logits, sampling_metadata) return logits diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 6e80a1ff269a..3477b1b3fa01 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -342,6 +342,19 @@ def get_device_communicator_cls(cls) -> str: """ return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + @classmethod + def use_all_gather(cls) -> bool: + """ + Whether to use allgather in LogitsProcessor to gather the logits. + """ + import vllm.envs as envs + from vllm.config import get_current_vllm_config + + parallel_config = get_current_vllm_config().parallel_config + return (envs.VLLM_USE_V1 + or parallel_config.distributed_executor_backend + == "external_launcher") + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 5a03f5f7acbc..b2eadb7932f3 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -55,3 +55,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") return False + + @classmethod + def use_all_gather(cls) -> bool: + return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4864173b2f0e..0c9d247d4a5d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -136,3 +136,7 @@ def is_pin_memory_available(cls): @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa + + @classmethod + def use_all_gather(cls) -> bool: + return True From 9bad4fd7d9a5b5bb2c9d5fab0a7cc14c3495e88c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 3 Mar 2025 00:43:14 -0800 Subject: [PATCH 316/317] [Build] Make sure local main branch is synced when VLLM_USE_PRECOMPILED=1 (#13921) Signed-off-by: Cody Yu --- setup.py | 28 ++++++++++++++++++- tests/standalone_tests/python_only_compile.sh | 2 +- vllm/envs.py | 8 +++++- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index cd17709b57ef..1a6f2ffd8524 100755 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ import ctypes import importlib.util +import json import logging import os import re @@ -269,9 +270,32 @@ class repackage_wheel(build_ext): """Extracts libraries and other files from an existing wheel.""" def get_base_commit_in_main_branch(self) -> str: - import subprocess + # Force to use the nightly wheel. This is mainly used for CI testing. + if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: + return "nightly" try: + # Get the latest commit hash of the upstream main branch. + resp_json = subprocess.check_output([ + "curl", "-s", + "https://api.github.com/repos/vllm-project/vllm/commits/main" + ]).decode("utf-8") + upstream_main_commit = json.loads(resp_json)["sha"] + + # Check if the local main branch is up-to-date. This is to ensure + # the base commit we found is the most recent commit on the main + # branch. + local_main_commit = subprocess.check_output( + ["git", "rev-parse", "main"]).decode("utf-8").strip() + if local_main_commit != upstream_main_commit: + raise ValueError( + f"Local main branch ({local_main_commit}) is not " + "up-to-date with upstream main branch " + f"({upstream_main_commit}). Please pull the latest " + "changes from upstream main branch first.") + + # Then get the commit hash of the current branch that is the same as + # the upstream main commit. current_branch = subprocess.check_output( ["git", "branch", "--show-current"]).decode("utf-8").strip() @@ -279,6 +303,8 @@ def get_base_commit_in_main_branch(self) -> str: ["git", "merge-base", "main", current_branch]).decode("utf-8").strip() return base_commit + except ValueError as err: + raise ValueError(err) from None except Exception as err: logger.warning( "Failed to get the base commit in the main branch. " diff --git a/tests/standalone_tests/python_only_compile.sh b/tests/standalone_tests/python_only_compile.sh index f00895c0997f..ec1bcbcc58a0 100644 --- a/tests/standalone_tests/python_only_compile.sh +++ b/tests/standalone_tests/python_only_compile.sh @@ -18,7 +18,7 @@ apt autoremove -y echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py -VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . +VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . # Run the script python3 -c 'import vllm' diff --git a/vllm/envs.py b/vllm/envs.py index bf64cd70674d..f6c038967b69 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -60,12 +60,12 @@ MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False + VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False - VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[list[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None @@ -148,6 +148,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool( os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + # Whether to force using nightly wheel in python build. + # This is used for testing the nightly wheel in python build. + "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": + lambda: bool(int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) + ), + # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" From c6cc1f59e630da37c9b04862143a03c906fe4df3 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 3 Mar 2025 16:15:27 +0000 Subject: [PATCH 317/317] [V1] Refactor parallel sampling support (#13774) Signed-off-by: Mark McLoughlin --- vllm/v1/engine/async_llm.py | 61 ++--- vllm/v1/engine/llm_engine.py | 74 ++---- vllm/v1/engine/output_processor.py | 181 +++++++++------ vllm/v1/engine/parallel_sampling.py | 344 ++++------------------------ vllm/v1/metrics/stats.py | 5 +- 5 files changed, 201 insertions(+), 464 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ab3cdc4ee295..954f74c3fdae 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -25,7 +25,7 @@ from vllm.utils import cdiv, kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -145,25 +145,30 @@ async def add_request( """Add new request to the AsyncLLM.""" # 1) Create a new output queue for the request. - if self.output_processor.is_request_active(request_id): - raise ValueError(f"Request id {request_id} already running.") queue: asyncio.Queue[RequestOutput] = asyncio.Queue() - # 2) Convert Input --> Request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + # 2) Fan out child requests (for n>1) + parent_req = ParentRequest.from_params(request_id, params) + n = params.n if isinstance(params, SamplingParams) else 1 + for idx in range(n): + if parent_req is not None: + request_id, params = parent_req.get_child_info(idx) - # 3) Add the request to OutputProcessor (this process). - self.output_processor.add_request(request, queue) + # 3) Convert Input --> Request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) - # 4) Add the EngineCoreRequest to EngineCore (separate process). - await self.engine_core.add_request_async(request) + # 4) Add the request to OutputProcessor (this process). + self.output_processor.add_request(request, parent_req, idx, queue) - if self.log_requests: - logger.info("Added request %s.", request_id) + # 5) Add the EngineCoreRequest to EngineCore (separate process). + await self.engine_core.add_request_async(request) + + if self.log_requests: + logger.info("Added request %s.", request_id) return queue @@ -172,7 +177,7 @@ async def add_request( # requests we don't need to send multiple messages to core proc, # and so we don't need multiple streams which then get # re-multiplexed in the API server anyhow. - async def _generate( + async def generate( self, prompt: PromptType, sampling_params: SamplingParams, @@ -243,30 +248,6 @@ async def _generate( await self.abort(request_id) raise - def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - kwargs = dict(prompt=prompt, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) - if sampling_params.n is None or sampling_params.n == 1: - return self._generate(**kwargs) - else: - # Special handling for parallel sampling requests - return generate_parallel_sampling_async(generate=self._generate, - **kwargs) - async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 2e76694a7f51..99b97ac8e6c4 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -22,7 +22,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -50,9 +50,6 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config - # Bookkeeping for parallel sampling requests - self.parallel_manager = SyncParallelSamplingManager() - # important: init dp group before init the engine_core self.parallel_config = vllm_config.parallel_config self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa @@ -120,8 +117,7 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - return self.parallel_manager.get_num_unfinished_requests( - self.output_processor.get_num_unfinished_requests()) + return self.output_processor.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: has_unfinished = self.output_processor.has_unfinished_requests() @@ -157,48 +153,25 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - """Add request.""" - kwargs = dict(request_id=request_id, - prompt=prompt, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) - # Handle parallel sampling requests differently. - if params is None or isinstance(params, - PoolingParams) or params.n == 1: - self._add_request(**kwargs) - else: - # Special handling for parallel sampling requests - self.parallel_manager.add_request_parallel_sampling( - add_request=self._add_request, **kwargs) - - def _add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - """Add request, `n=1`""" - # 1) Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) - - # 2) Make a new RequestState and queue. - self.output_processor.add_request(request) - - # 3) Add the request to EngineCore. - self.engine_core.add_request(request) + # 1) Fan out child requests (for n>1) + parent_req = ParentRequest.from_params(request_id, params) + n = params.n if isinstance(params, SamplingParams) else 1 + for idx in range(n): + if parent_req is not None: + request_id, params = parent_req.get_child_info(idx) + + # 2) Process raw inputs into the request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) + + # 3) Make a new RequestState and queue. + self.output_processor.add_request(request, parent_req, idx) + + # 3) Add the request to EngineCore. + self.engine_core.add_request(request) def step(self) -> list[RequestOutput]: @@ -217,10 +190,7 @@ def step(self) -> list[RequestOutput]: # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) - request_outputs = processed_outputs.request_outputs - - # 4) Process unfinished parallel sampling requests - return self.parallel_manager.step(request_outputs) + return processed_outputs.request_outputs def get_model_config(self): return self.model_config diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 22bbb8a0f5b4..4e1d1e3bf51b 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -4,13 +4,14 @@ from dataclasses import dataclass from typing import Optional, Union -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, RequestStateStats) @@ -27,6 +28,8 @@ class RequestState: def __init__( self, request_id: str, + parent_req: Optional[ParentRequest], + request_index: int, lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], @@ -38,6 +41,8 @@ def __init__( log_stats: bool, ): self.request_id = request_id + self.parent_req = parent_req + self.request_index = request_index self.lora_name = lora_name self.output_kind = output_kind self.prompt = prompt @@ -56,11 +61,15 @@ def from_new_request( cls, tokenizer: AnyTokenizer, request: EngineCoreRequest, + parent_req: Optional[ParentRequest], + request_index: int, queue: Optional[asyncio.Queue[RequestOutput]], log_stats: bool, ) -> "RequestState": return cls( request_id=request.request_id, + parent_req=parent_req, + request_index=request_index, lora_name=(request.lora_request.name if request.lora_request is not None else None), output_kind=request.sampling_params.output_kind, @@ -79,6 +88,88 @@ def from_new_request( log_stats=log_stats, ) + def make_request_output( + self, + new_token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + ) -> Optional[RequestOutput]: + + finished = finish_reason is not None + output_kind = self.output_kind + final_only = output_kind == RequestOutputKind.FINAL_ONLY + + # In follow up, we will switch to invariant where EngineCore + # does not stream partial prefills. + if not finished and (self.is_prefilling or final_only): + # Only the final output is required in FINAL_ONLY mode. + return None + + def new_request_output(request_id: str) -> RequestOutput: + return self._new_request_output(request_id, finished) + + completion_output = self._new_completion_output( + new_token_ids, finish_reason, stop_reason) + + if self.parent_req is not None: + return self.parent_req.make_request_output(final_only, + completion_output, + new_request_output) + + request_output = new_request_output(self.request_id) + request_output.outputs.append(completion_output) + return request_output + + def _new_request_output( + self, + request_id: str, + finished: bool, + ) -> RequestOutput: + + if self.output_kind == RequestOutputKind.DELTA: + # Side effect: logprobs processor forgets prompt logprobs + prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() + else: + prompt_logprobs = self.logprobs_processor.prompt_logprobs + + return RequestOutput( + request_id=request_id, + prompt=self.prompt, + prompt_token_ids=self.prompt_token_ids, + prompt_logprobs=prompt_logprobs, + outputs=[], + finished=finished, + ) + + def _new_completion_output( + self, + token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + ) -> CompletionOutput: + + finished = finish_reason is not None + delta = self.output_kind == RequestOutputKind.DELTA + + # Prepare text and token_ids, based on delta mode + text = self.detokenizer.get_next_output_text(finished, delta) + if not delta: + token_ids = self.detokenizer.output_token_ids + + # Prepare logprobs, based on delta mode + logprobs = self.logprobs_processor.logprobs + if delta and logprobs: + logprobs = logprobs[-len(token_ids):] + + return CompletionOutput( + index=self.request_index, + text=text, + token_ids=token_ids, + logprobs=logprobs, + cumulative_logprob=self.logprobs_processor.cumulative_logprob, + finish_reason=str(finish_reason) if finished else None, + stop_reason=stop_reason if finished else None) + class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" @@ -93,9 +184,6 @@ def __init__( self.request_states: dict[str, RequestState] = {} self.lora_states = LoRARequestStates() - def is_request_active(self, request_id: str) -> bool: - return request_id in self.request_states - def get_num_unfinished_requests(self): return len(self.request_states) @@ -114,6 +202,8 @@ def abort_requests( def add_request( self, request: EngineCoreRequest, + parent_req: Optional[ParentRequest] = None, + request_index: int = 0, queue: Optional[asyncio.Queue[RequestOutput]] = None, ) -> None: request_id = request.request_id @@ -123,6 +213,8 @@ def add_request( req_state = RequestState.from_new_request( tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), request=request, + parent_req=parent_req, + request_index=request_index, queue=queue, log_stats=self.log_stats) self.request_states[request_id] = req_state @@ -202,8 +294,8 @@ def process_outputs( req_state.logprobs_processor.update_from_output(engine_core_output) # 4) Create and handle RequestOutput objects. - if request_output := self._make_request_output( - req_state, new_token_ids, finish_reason, stop_reason): + if request_output := req_state.make_request_output( + new_token_ids, finish_reason, stop_reason): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put_nowait(request_output) @@ -211,18 +303,17 @@ def process_outputs( # LLMEngine: return list of RequestOutputs. request_outputs.append(request_output) - # Free completed requests. - if request_output.finished: - self.request_states.pop(req_id) - if not engine_core_output.finished: - # If req not finished in EngineCore, but Detokenizer - # detected stop string, abort needed in EngineCore. - reqs_to_abort.append(req_id) + # Free completed requests. + if finish_reason is not None: + self.request_states.pop(req_id) + if not engine_core_output.finished: + # If req not finished in EngineCore, but Detokenizer + # detected stop string, abort needed in EngineCore. + reqs_to_abort.append(req_id) - # Track per-request stats - self._update_stats_from_finished(req_state, request_output, - finish_reason, - iteration_stats) + # Track per-request stats + self._update_stats_from_finished(req_state, finish_reason, + iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) @@ -249,7 +340,6 @@ def _update_stats_from_output(self, req_state: RequestState, req_state.stats, lora_stats) def _update_stats_from_finished(self, req_state: RequestState, - request_output: RequestOutput, finish_reason: Optional[FinishReason], iteration_stats: Optional[IterationStats]): if iteration_stats is None: @@ -257,55 +347,8 @@ def _update_stats_from_finished(self, req_state: RequestState, assert finish_reason is not None assert req_state.stats is not None - iteration_stats.update_from_finished_request(finish_reason, - request_output, - req_state.stats) + iteration_stats.update_from_finished_request( + finish_reason=finish_reason, + num_prompt_tokens=len(req_state.prompt_token_ids), + req_stats=req_state.stats) self.lora_states.finish_request(req_state) - - @staticmethod - def _make_request_output( - request_state: RequestState, - new_token_ids: list[int], - finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], - ) -> Optional[RequestOutput]: - - finished = finish_reason is not None - output_kind = request_state.output_kind - # In follow up, we will switch to invariant where EngineCore - # does not stream partial prefills. - if not finished and (request_state.is_prefilling - or output_kind == RequestOutputKind.FINAL_ONLY): - # Only the final output is required in FINAL_ONLY mode. - return None - - detokenizer = request_state.detokenizer - logprobs_processor = request_state.logprobs_processor - - delta = output_kind == RequestOutputKind.DELTA - logprobs = logprobs_processor.logprobs - if delta: - if logprobs: - logprobs = logprobs[-len(new_token_ids):] - # Side effect: logprobs processor forgets prompt logprobs - prompt_logprobs = logprobs_processor.pop_prompt_logprobs() - else: - prompt_logprobs = logprobs_processor.prompt_logprobs - - request_output = RequestOutput.new( - request_id=request_state.request_id, - prompt=request_state.prompt, - prompt_token_ids=request_state.prompt_token_ids, - text=detokenizer.get_next_output_text(finished, delta), - token_ids=new_token_ids if delta else detokenizer.output_token_ids, - logprobs=logprobs, - prompt_logprobs=prompt_logprobs, - cumulative_logprob=logprobs_processor.cumulative_logprob, - finished=finished, - ) - if finished: - completion_output = request_output.outputs[0] - completion_output.finish_reason = str(finish_reason) - completion_output.stop_reason = stop_reason - - return request_output diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 291360771b54..adced8973b03 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,69 +1,46 @@ # SPDX-License-Identifier: Apache-2.0 -from collections.abc import AsyncGenerator, Mapping from copy import copy -from typing import Optional, Protocol, Union +from typing import Callable, Optional, Union -from vllm.inputs import PromptType -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.utils import merge_async_iterators +from vllm.sampling_params import SamplingParams -class AsyncGenerateMethodType(Protocol): - - def __call__(self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0) -> AsyncGenerator[RequestOutput, None]: - ... - - -class SyncAddRequestMethodType(Protocol): - - def __call__(self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0) -> None: - ... - - -class ParallelSamplingRequest: +class ParentRequest: """Info, state & processing for parallel sampling request. - + Store parent request ID and sampling params. Facilitate generating child request sampling params. - Transform child request outputs into parent request - outputs. - When stream mode is disabled, then `self.request_output` - aggregates child request completions. """ request_id: str sampling_params: SamplingParams + + # To aggregate child completions when not streaming + output_aggregator: Optional[RequestOutput] + + # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] - request_output: Optional[RequestOutput] - num_finished_completions: int def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id self.sampling_params = sampling_params + + self.output_aggregator = None self.cached_child_sampling_params = None - self.request_output = None - self.num_finished_completions = 0 + + @classmethod + def from_params( + cls, + request_id: str, + params: Union[SamplingParams, PoolingParams], + ) -> Optional['ParentRequest']: + if not isinstance(params, SamplingParams) or params.n == 1: + return None + return cls(request_id, params) def _get_child_sampling_params( self, @@ -96,47 +73,6 @@ def _get_child_sampling_params( child_sampling_params.seed = seed + index return child_sampling_params - def _add_output( - self, - child_req_output: RequestOutput, - index: int, - ) -> None: - """Aggregate a parallel sampling child - request output. - - Non-stream-mode (`output_kind == FINAL_ONLY`) - only. Inject correct parent request ID and - completion index. - - Args: - child_req_output: a single request output - from a parallel sampling - child request. - index: index within `n` child - """ - self.num_finished_completions += 1 - new_completion = child_req_output.outputs[0] - new_completion.index = index - if self.request_output is None: - # Save the first request output; reinstate - # original request ID; metrics are not - # supported for parallel sampling - child_req_output.request_id = self.request_id - child_req_output.metrics = None - self.request_output = child_req_output - else: - # Aggregate additional completion into request output - # Note: will be sorted by index later - self.request_output.outputs.append(new_completion) - - def _get_final_request_output(self) -> RequestOutput: - """Invariant: parent completion outputs sorted by index""" - assert self.request_output is not None - self.request_output.finished = True - self.request_output.outputs = sorted(self.request_output.outputs, - key=lambda x: x.index) - return self.request_output - def get_child_info(self, index: int) -> tuple[str, SamplingParams]: """Get child request ID and sampling params. @@ -149,227 +85,35 @@ def get_child_info(self, index: int) -> tuple[str, SamplingParams]: return (f"{index}_{self.request_id}", self._get_child_sampling_params(index)) - def process_output( - self, - child_req_output: RequestOutput, - index: int, - ) -> Optional[RequestOutput]: - """Filter, aggregate and transform parallel sampling - child request outputs. - - If the parent request has `stream=false` - (`output_kind == FINAL_ONLY`), each child will also have - `output_kind == FINAL_ONLY`. All child request outputs - must be aggregated into a single request output, with - multiple completions. This request output is only returned - once `n` completions are aggregated. - - If the parent request has `stream=true` - (`output_kind == DELTA`), each child will also have - `output_kind == DELTA`. All child request outputs - must be streamed directly to the caller. - - Args: - child_req_output: a single child request output - index: index within `n` child requests - - Returns: - `None`, unless a processed request output is ready to - send back to the caller. - """ - if self.output_kind != RequestOutputKind.FINAL_ONLY: - # stream=true: return child completions immediately - child_req_output.request_id = self.request_id - child_req_output.outputs[0].index = index - if child_req_output.finished: - # Parent request is complete if all child requests are - # complete. - self.num_finished_completions += 1 - child_req_output.finished = ( - self.num_finished_completions == self.n) - return child_req_output - - # stream=false: aggregate child completions - self._add_output(child_req_output, index) - if self.num_finished_completions == self.n: - # Return aggregated request output after obtaining - # all completions - return self._get_final_request_output() - return None - - async def wrap_child_async_generator( - self, - child_gen: AsyncGenerator[RequestOutput, None], - index: int, - ) -> AsyncGenerator[RequestOutput, None]: - """Output generator for a single parallel sampling - child request. - - Each parallel sampling request triggers at - least two child requests. This generator - yields zero or more request outputs to - return to the caller, as they become - available. - - Args: - child_gen: generator for child request - outputs. - index: index within the `n` child requests - - Returns: - Yields zero or more request outputs to return - to the caller. - """ - async for out in child_gen: - if req_out := self.process_output(out, index): - yield req_out - @property def n(self) -> int: return self.sampling_params.n - @property - def output_kind(self) -> RequestOutputKind: - return self.sampling_params.output_kind - - -class SyncParallelSamplingManager: - - def __init__(self): - # Parent req ID -> parent request manager - self.parent_reqs: dict[str, ParallelSamplingRequest] = {} - # Child req ID -> (child req index, parent req ID) - self.child_reqs: dict[str, tuple[int, str]] = {} - - def _register_parent_request(self, req: ParallelSamplingRequest) -> None: - """Register parallel sampling parent request.""" - self.parent_reqs[req.request_id] = req - - def _register_child_request(self, req_id: str, child_req_id: str, - index: int) -> None: - """Register parallel sampling child request with parent. - - Args: - req_id: parent request ID - child_req_id: child request ID - index: child request index within `n` child requests - """ - self.child_reqs[child_req_id] = (index, req_id) - - def get_num_unfinished_requests(self, num_core_reqs: int) -> int: - """Get the number of unfinished requests, correcting for parallel - sampling. - - Args: - num_core_reqs: The number of unfinished requests in the engine core. - - Returns: - Number of unfinished requests, where each parallel sampling req - counts as 1 - """ - return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs) - - def add_request_parallel_sampling( + def make_request_output( self, - add_request: SyncAddRequestMethodType, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - """Add sync parallel sampling request.""" - req = ParallelSamplingRequest(request_id, params) - self._register_parent_request(req) - # Add n child requests with unique request IDs & random seeds and n=1 - for idx in range(req.n): - child_req_id, child_params = req.get_child_info(idx) - self._register_child_request(request_id, child_req_id, idx) - add_request(request_id=child_req_id, - prompt=prompt, - params=child_params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) # type: ignore - - def step( - self, - outputs: list[RequestOutput], - ) -> list[RequestOutput]: - """Build parallel sampling request outputs. - - Extract child request outputs, aggregate them - into parent request output, and return parent - output when complete. - - Do not modify `n=1` requests. - - Args: - outputs: step request outputs. Mix of child request - outputs & `n=1` request outputs. + final_only: bool, + completion_output: CompletionOutput, + new_request_output: Callable[[str], RequestOutput], + ) -> Optional[RequestOutput]: + # Use an existing RequestOutput if we're aggregating + request_output = self.output_aggregator - Return: - List of parallel sampling parent request outputs & - unmodified `n=1` request outputs passed-thru from input. - """ - if not (self.parent_reqs and outputs): - # Return unmodified - return outputs - agg_outputs = [] - for output in outputs: - req_id = output.request_id - if child_req_entry := self.child_reqs.get(req_id, None): - # For each parallel sampling child request output: - (index, parent_req_id) = child_req_entry - req = self.parent_reqs[parent_req_id] - # Update parallel sampling request - if out := req.process_output(output, index): - # Return parent request output if complete; - # cleanup parent request bookkeeping. - agg_outputs.append(out) - del self.parent_reqs[parent_req_id] - # Cleanup child request bookkeeping. - del self.child_reqs[req_id] - else: - # Not a parallel sampling request output - agg_outputs.append(output) - return agg_outputs + # Make new RequestOutput otherwise + if request_output is None: + request_output = new_request_output(self.request_id) + # Add a new completion + request_output.outputs.append(completion_output) -async def generate_parallel_sampling_async( - generate: AsyncGenerateMethodType, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -) -> AsyncGenerator[RequestOutput, None]: - """Generate completions for async parallel sampling requests.""" - parent_req = ParallelSamplingRequest(request_id, sampling_params) + # If not streaming, aggregate until all child requests complete + if final_only and len(request_output.outputs) != self.n: + self.output_aggregator = request_output + return None - # Aggregate generators for n child requests - gens: list[AsyncGenerator[RequestOutput, None]] = [] - for idx in range(parent_req.n): - child_req_id, child_params = parent_req.get_child_info(idx) - child_gen = generate( - prompt=prompt, - sampling_params=child_params, - request_id=child_req_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ) # type: ignore - gen = parent_req.wrap_child_async_generator(child_gen, idx) - gens.append(gen) + # We're done aggregating + self.output_aggregator = None - # Merge generators - async for _, out in merge_async_iterators(*gens): - yield out + # Parent completion output list must be sorted by index + request_output.outputs = sorted(request_output.outputs, + key=lambda x: x.index) + return request_output diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 625edb607467..abdca95670e1 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from vllm.outputs import RequestOutput from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.output_processor import RequestState @@ -150,7 +149,7 @@ def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], self.num_preempted_reqs += 1 def update_from_finished_request(self, finish_reason: "FinishReason", - request_output: "RequestOutput", + num_prompt_tokens: int, req_stats: RequestStateStats): e2e_latency = self._time_since(req_stats.arrival_time) @@ -172,7 +171,7 @@ def update_from_finished_request(self, finish_reason: "FinishReason", finished_req = \ FinishedRequestStats(finish_reason=finish_reason, e2e_latency=e2e_latency, - num_prompt_tokens=len(request_output.prompt_token_ids), + num_prompt_tokens=num_prompt_tokens, num_generation_tokens=req_stats.num_generation_tokens, queued_time=queued_time, prefill_time=prefill_time,