diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index a6e96c0bb433..1a9af1e4425f 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -10,10 +10,14 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams() #temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +# llm = LLM(model="facebook/opt-125m") +llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=16, + max_model_len=128, + enforce_eager=True) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -21,4 +25,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 540a35e1ecb9..3174decc17a1 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -33,14 +34,16 @@ class DPMetadata: @dataclass class ForwardContext: - # copy from vllm_config.compilation_config.static_forward_context + # Copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] - # TODO: extend to support per-layer dynamic forward context + # TODO: Extend to support per-layer dynamic forward context attn_metadata: "AttentionMetadata" # set dynamically for each forward pass - # TODO: remove after making all virtual_engines share the same kv cache + # TODO: Remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass - # set dynamically for each forward pass + # Set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None + # Whether this is a profile run (before KV cache init) + is_profile_run: bool = False, _forward_context: Optional[ForwardContext] = None @@ -58,7 +61,8 @@ def get_forward_context() -> ForwardContext: def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0): + num_tokens: int = 0, + is_profile_run: bool = False): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -93,12 +97,15 @@ def set_forward_context(attn_metadata: Any, global _forward_context prev_context = _forward_context + _forward_context = ForwardContext( no_compile_layers=vllm_config.compilation_config. static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, - dp_metadata=dp_metadata) + dp_metadata=dp_metadata, + is_profile_run=is_profile_run) + try: yield finally: @@ -111,10 +118,17 @@ def set_forward_context(attn_metadata: Any, else: # for v1 attention backends batchsize = attn_metadata.num_input_tokens + # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch - torch.cuda.synchronize() + if current_platform.is_tpu(): + import torch_xla.core.xla_model as xm + xm.mark_step() + xm.wait_device_ops() + else: + torch.cuda.synchronize() + now = time.perf_counter() # time measurement is in milliseconds batchsize_forward_time[batchsize].append( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f9a3217fbef3..47ba560ee7b5 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -30,7 +30,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -104,9 +103,6 @@ def __init__( self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size - # Lazy initialization - # self.model: nn.Module # Set after load_model - self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -582,7 +578,6 @@ def execute_model( hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, - kv_caches=self.kv_caches, inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:total_num_scheduled_tokens] @@ -680,8 +675,8 @@ def load_model(self) -> None: def _dummy_run( self, - kv_caches, num_tokens: int, + is_profile_run: bool, ) -> None: if self.is_multimodal_model: input_ids = None @@ -728,15 +723,28 @@ def _dummy_run( torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - with set_forward_context(attn_metadata, self.vllm_config, 0): + with set_forward_context(attn_metadata, + self.vllm_config, + 0, + is_profile_run=is_profile_run): assert self.model is not None self.model( input_ids=input_ids, positions=position_ids, - kv_caches=kv_caches, inputs_embeds=inputs_embeds, ) + # This is used before KV cache init + def profile_run(self, num_tokens) -> None: + self._dummy_run(num_tokens=num_tokens, is_profile_run=True) + + # This is used after KV cache init + def dummy_run( + self, + num_tokens: int, + ) -> None: + self._dummy_run(num_tokens=num_tokens, is_profile_run=False) + def capture_model(self) -> None: """Compile the model.""" @@ -745,7 +753,7 @@ def capture_model(self) -> None: start = time.perf_counter() num_tokens = 16 while True: - self._dummy_run(self.kv_caches, num_tokens) + self.dummy_run(num_tokens) logger.info(" -- num_tokens: %d", num_tokens) xm.mark_step() xm.wait_device_ops() @@ -769,6 +777,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_caches: dict[str, torch.Tensor] = {} + kv_cache_shape_prev = None for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % layer_spec.page_size_bytes == 0 @@ -779,6 +788,12 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: layer_spec.head_size) dtype = layer_spec.dtype + # Ensure all "kv_cache_shape" are the same across the model + if kv_cache_shape_prev is None: + kv_cache_shape_prev = kv_cache_shape + else: + assert kv_cache_shape == kv_cache_shape_prev + tpu_k_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) @@ -788,10 +803,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: else: raise NotImplementedError - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + # ModelWrapperV1 needs to know the KV cache shape + self.model.set_kv_cache_shape(kv_cache_shape_prev) + + # Associates each attention layer in the `forward_context` with the + # initialized KV cache. + forward_context = self.vllm_config.compilation_config \ + .static_forward_context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] class ModelWrapperV1(nn.Module): @@ -799,12 +820,15 @@ class ModelWrapperV1(nn.Module): def __init__(self, model: nn.Module): super().__init__() self.model = model + self.kv_cache_shape = None + + def set_kv_cache_shape(self, kv_cache_shape): + self.kv_cache_shape = kv_cache_shape def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: list[tuple[torch.Tensor, torch.Tensor]], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -817,16 +841,20 @@ def forward( inputs_embeds: The input embeddings of shape [num_tokens, hidden_size]. It is used for multimodal models. """ - # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: - attn_metadata = get_forward_context().attn_metadata - # index_copy_(slot_mapping) only works when the inserted dimension - # is 0. However, the KV cache in the Pallas backend has the shape - # [num_kv_heads, num_blocks, block_size, head_size]. To make it - # work, we need to flatten the first three dimensions and modify - # the slot_mapping accordingly. - # kv_caches: list[tuple[torch.Tensor, torch.Tensor]] - num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + # + # Note: We skip this step during first profiling run (before KV init) + if not forward_context.is_profile_run: + assert self.kv_cache_shape # Ensure initialized + num_kv_heads, num_blocks, block_size, _ = self.kv_cache_shape + slot_mapping = attn_metadata.slot_mapping slot_mapping = slot_mapping.flatten() head_indicies = torch.arange(0, diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index d09f5dd84007..ddbe552c4d6a 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -21,7 +21,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.tpu_model_runner import TPUModelRunner logger = init_logger(__name__) @@ -128,18 +127,19 @@ def determine_available_memory(self) -> int: else: raise NotImplementedError - runner_kv_caches: list[torch.Tensor] = [] - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - runner_kv_caches) + # Associates each attention layer in the `forward_context` with the + # initialized KV cache. + forward_context = self.vllm_config.compilation_config \ + .static_forward_context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] - self.model_runner._dummy_run( - runner_kv_caches, - num_tokens=self.scheduler_config.max_num_batched_tokens, - ) + self.model_runner.profile_run( + num_tokens=self.scheduler_config.max_num_batched_tokens) # Synchronize before measuring the memory usage. + xm.mark_step() xm.wait_device_ops() # Get the maximum amount of memory used by the model weights and