diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 909c2ad955f25..45fe1989f9bff 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -2,14 +2,13 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### -import importlib from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type import torch +import math import vllm.hpu.xops as xops from vllm.hpu.attn_bias import (AttentionBias, - BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -18,7 +17,6 @@ from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) from vllm.logger import init_logger -from vllm.utils import is_hip logger = init_logger(__name__) @@ -119,11 +117,11 @@ def __post_init__(self): class HabanaAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| + |<--------------- num_prefill_tokens ----------------->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. @@ -196,48 +194,37 @@ def forward( HabanaPagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, + attn_metadata.kv_cache_dtype, attn_metadata.prefill_metadata is not None) if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if kv_cache is None or prefill_meta.block_tables.numel() == 0: - # normal attention. - # block tables are empty if the prompt does not have a cached - # prefix. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - + # TODO: move this outside of model if prefill_meta.attn_bias is None: if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) + lens = torch.tensor(attn_metadata.prefill_metadata.seq_lens, device=query.device, dtype=torch.int32) + len_mask = (torch.arange(0, seq_len, device=query.device, dtype=torch.int32) + .view(1, seq_len) + .ge(lens.unsqueeze(-1)) + .view(batch_size, 1, 1, seq_len)) + causal_mask = torch.triu( + torch.ones((batch_size, 1, seq_len, seq_len), device=query.device, dtype=torch.bool), + diagonal=1 + ) + mask = causal_mask.logical_or(len_mask) + attn_bias = (torch.zeros_like(mask, dtype=query.dtype) + .masked_fill_(mask, -math.inf)) if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) + raise NotImplementedError("Sliding window is not supported on HPU") prefill_meta.attn_bias = attn_bias else: prefill_meta.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, batch_size, seq_len, query.dtype) - query_shape = (batch_size, seq_len, self.num_kv_heads, self.num_queries_per_kv, self.head_size) if self.num_kv_heads != self.num_heads else (batch_size, seq_len, self.num_heads, self.head_size) - kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.num_queries_per_kv, self.head_size) if self.num_kv_heads != self.num_heads else (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) - out = xops.memory_efficient_attention_forward( + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) + out = xops.prompt_attention( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), diff --git a/vllm/hpu/xops.py b/vllm/hpu/xops.py index c9d237744a917..d6404a4872c0d 100644 --- a/vllm/hpu/xops.py +++ b/vllm/hpu/xops.py @@ -5,62 +5,37 @@ # LICENSE file in the root directory of this source tree. ############################################################################### -import habana_frameworks.torch as htorch import torch -import torch.nn.functional as F -from typing import List, Optional, Tuple, Union -from .attn_bias import AttentionBias, BlockDiagonalCausalMask +from typing import Optional -try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA -except ImportError: - print("Not using HPU fused scaled dot-product attention kernel.") - FusedSDPA = None +import vllm.hpu.utils -def memory_efficient_attention_forward( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, - p: float = 0.0, - scale: Optional[float] = None, -) -> torch.Tensor: - assert attn_bias is not None, "Attention mask is required for prompt processing" - dim = query.dim() - is_causal = isinstance(attn_bias, BlockDiagonalCausalMask) - if FusedSDPA and (is_causal or attn_bias is None): - bs = query.shape[0] - seq_len_q = query.shape[1] - seq_len_kv = key.shape[1] - heads = query.shape[-2] if dim != 5 else query.shape[-3] - attn_groups = 1 if dim != 5 else query.shape[-2] - head_dim = query.shape[-1] - if dim == 4: - # [bs, seq_len, 1, heads, head_dim] -> [bs, heads, seq_len, head_dim] - query = query.reshape(bs, seq_len_q, heads, head_dim).permute(0, 2, 1, 3) - key = key.reshape(bs, seq_len_kv, heads, head_dim).permute(0, 2, 1, 3) - value = value.reshape(bs, seq_len_kv, heads, head_dim).permute(0, 2, 1, 3) - elif dim == 5: - # [bs, seq_len, heads, attn_groups, head_dim] -> [bs, heads, attn_groups, seq_len, head_dim] - query = query.reshape(bs, seq_len_q, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4) - key = key.reshape(bs, seq_len_kv, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4) - value = value.reshape(bs, seq_len_kv, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4) - else: - raise ValueError(f"Unsupported attention dimension: {dim}") - - import habana_frameworks.torch.hpu as ht - with ht.sdp_kernel(enable_recompute=False): # (flash_attention_recompute and q_len == 1)): - out = FusedSDPA.apply( - query, key, value, None, p, is_causal, scale - ) - htorch.core.mark_step() - if dim == 4: - # [bs, heads, seq_len, head_dim] -> [bs, seq_len, heads, head_dim] - out = out.permute(0, 2, 1, 3).reshape(bs, seq_len_q, heads, head_dim) - elif dim == 5: - # [bs, heads, attn_groups, seq_len, head_dim] -> [bs, seq_len, heads, attn_groups, head_dim] - out = out.permute(0, 3, 1, 2, 4).reshape(bs, seq_len_q, heads, attn_groups, head_dim) - else: - raise NotImplementedError(f'Only FusedSDPA causal or non-masked attention is supported.\nFusedSDPA support: {FusedSDPA is not None}\nis_causal: {is_causal}\nmask_present: {attn_bias is not None}') - return out +@vllm.hpu.utils.with_mark_steps +def prompt_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + p: float = 0.0, + scale: Optional[float] = None, +) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + query_heads = query.size(1) + kv_heads = key.size(1) + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + attn_bias = attn_bias.unsqueeze(2) + attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) + if attn_bias is not None: + attn_weights.add_(attn_bias) + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_weights = torch.matmul(attn_weights, value) + if query_heads != kv_heads: + attn_weights = attn_weights.flatten(1, 2) + attn_weights = attn_weights.transpose(1, 2) + return attn_weights diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index e2076018b5609..78b3e6417366e 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -192,12 +192,6 @@ def _prepare_seq_groups( # Total number of prompts from given sequence groups. num_prompts = 0 - # FIXME: On HPU prompts are right-padded. We need to take that into account - # when updating model_output_idx - if is_hpu() and len(seq_lens) > 0: - assert seq_lens == query_lens, 'Prompt chunking is not yet supported on HPU!' - max_seq_len = max(seq_lens) - for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -225,12 +219,10 @@ def _prepare_seq_groups( prompt_logprob_len = (query_len - num_prefill_sample if do_sample else query_len) sample_len = num_prefill_sample if do_sample else 0 - padding_len = 0 if not is_hpu() else max_seq_len - seq_len else: # Decode prompt_logprob_len = 0 sample_len = len(seq_ids) if do_sample else 0 - padding_len = 0 # Update indices to select from the model output. """ @@ -249,7 +241,6 @@ def _prepare_seq_groups( selected_token_indices.extend( range(model_output_idx, model_output_idx + sample_len)) model_output_idx += sample_len - model_output_idx += padding_len # We now find indices for logprob computation and sampling. """ diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index e306ef0ae12cb..995864e3f81e7 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -2,60 +2,77 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### -import contextlib import time from enum import IntEnum -from typing import Dict, List, NamedTuple, Optional, Set, Tuple - -# for logging hpugraph capture -import tqdm -import pandas as pd -import tabulate +from typing import List, NamedTuple, Optional, Set, Tuple, Dict import os -import contextlib import math import itertools -import numpy as np +import operator import torch -import torch.nn as nn import habana_frameworks.torch as htorch -from habana_frameworks.torch.hpu.metrics import metric_localcontext from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, +from vllm.config import (DeviceConfig, LoadConfig, CacheConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.device_communicators import custom_all_reduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import (HabanaMemoryProfiler, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim, pad_to_max_length, format_bytes) +from vllm.utils import (HabanaMemoryProfiler, is_pin_memory_available, + make_tensor_with_pad, format_bytes) logger = init_logger(__name__) -_PAD_SLOT_ID = -1 +_PAD_SLOT_ID = 0 LORA_WARMUP_RANK = 8 -_BATCH_SIZE_ALIGNMENT = 16 -# Capture graphs for token size 1, 2, 4, 8, 16, 32, 48, ..., 512. -# NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4, 8] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) -] -# Capture graphs for token size 1, 32, 64, 128, 256, 512, 768 ... 2048 -_MAX_SEQ_LEN_ALIGNMENT = 256 -_MAX_SEQ_LENS_TO_CAPTURE = [1, 32, 64, 128] + [ - _MAX_SEQ_LEN_ALIGNMENT * i for i in range(1, 9) -] + +# Read bucketing configuration from env variables +# phase is either 'prompt' or 'decode' +# dim is either 'bs' or 'seq' +# example env variable: VLLM_DECODE_BS_STEP=128 +def read_bucket_settings(phase: str, dim: str, **defaults: Dict): + params = ['min', 'step', 'max'] + values = [os.environ.get(f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper(), defaults[p]) for p in params] + return values + + +def warmup_buckets(config: Tuple[int, int, int]): + bmin, bstep, bmax = config + base = itertools.repeat(2) + ramp_up = itertools.accumulate(base, func=operator.mul, initial=bmin) + ramp_up = itertools.takewhile(lambda x: x < bstep and x <= bmax, ramp_up) + stable = range(bstep, bmax + 1, bstep) + return list(ramp_up) + list(stable) + + +def next_pow2(value: int): + res = 1 + while value > 1: + value = (value + 1) // 2 + res *= 2 + return res + + +def round_up(value: int, k: int): + return (value + k - 1) // k * k + + +def find_bucket(value: int, config: Tuple[int, int, int]): + bmin, bstep, bmax = config + if value < bstep: + result = min(next_pow2(value), bstep) + else: + result = round_up(value, bstep) + return result class PreparePromptMetadata(NamedTuple): @@ -127,6 +144,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, load_config: LoadConfig, + cache_config: CacheConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, @@ -139,22 +157,16 @@ def __init__( self.load_config = load_config self.is_driver_worker = is_driver_worker - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device - # Set after load_model. - self.lora_manager: LRUCacheWorkerLoRAManager = None - - self.graph_runner_class = HPUGraphRunner - self.graph_runners: Dict[Tuple[int, int], self.graph_runner_class] = {} - - self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture - if self.model_config is not None else 0) + self.max_num_seqs = self.scheduler_config.max_num_seqs + self.max_model_len = self.scheduler_config.max_model_len + self.max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + self.block_size = cache_config.block_size self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype @@ -164,16 +176,11 @@ def __init__( self.model_config.dtype if model_config is not None else None) # Lazy initialization - self.model: torch.nn.Module # Set after load_model - self.block_size: int # Set after initial profiling. - # When using CUDA graph, the input block tables must be padded to - # max_seq_len_to_capture. However, creating the block table in - # Python can be expensive. To optimize this, we cache the block table - # in numpy and only copy the actual input content at every iteration. - # The shape of the cached block table will be - # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables: torch.Tensor # Set after initial profiling. + self.lora_manager: LRUCacheWorkerLoRAManager = None + self.model: torch.nn.Module = None + self.excluded_from_warmup = [] + self._setup_buckets() def load_model(self) -> None: with HabanaMemoryProfiler() as m: @@ -207,16 +214,18 @@ def load_model(self) -> None: self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) - def set_block_size(self, block_size: int) -> None: - self.block_size = block_size - - self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), - dtype=np.int32) - - def get_max_block_per_batch(self) -> int: - block_size = self.block_size - return (self.max_seq_len_to_capture + block_size - 1) // block_size + def _setup_buckets(self) -> None: + self.prompt_bs_bucket_cfg = read_bucket_settings('prompt', 'bs', min=1, step=32, max=min(self.max_num_seqs, 64)) + self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', min=1, step=128, max=self.max_num_seqs) + self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', 'seq', min=self.block_size, step=self.block_size, max=1024) + self.decode_seq_bucket_cfg = read_bucket_settings('decode', 'seq', min=self.block_size, step=self.block_size, max=2048) + logger.info(f"Prompt bucket config (min, step, max_warmup) bs:{self.prompt_bs_bucket_cfg}, seq:{self.prompt_seq_bucket_cfg}") + logger.info(f"Decode bucket config (min, step, max_warmup) bs:{self.decode_bs_bucket_cfg}, seq:{self.decode_seq_bucket_cfg}") + + # FIXME: exclude from warmup as it causes OOM on llama-70b + self.excluded_from_warmup = [ + (64, 1024, True) + ] def _prepare_prompt( self, @@ -285,8 +294,6 @@ def _prepare_prompt( # actual prompt lens context_lens.append(context_len) - if context_len != 0: - import pdb; pdb.set_trace() # what happens if we hit that path?? query_lens.append(seq_len - context_len) input_tokens.append(prompt_tokens) @@ -357,34 +364,31 @@ def _prepare_prompt( multi_modal_input = None max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - max_prompt_len = max(seq_lens) + max_prompt_len = max(find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) + input_tokens = make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - + max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - + max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + slot_mapping = make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + block_tables = make_tensor_with_pad(prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) # Query length can be shorter than key (i.e., prompt) when prefill # is chunked or prefix cached. @@ -394,7 +398,6 @@ def _prepare_prompt( subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device=self.device) @@ -426,6 +429,7 @@ def _prepare_prompt( multi_modal_input=multi_modal_input, slot_mapping=slot_mapping, ) + def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -479,28 +483,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - # vLLM uses cuda graph only for decoding requests. - # See `capture_model` API for more details. - # For decoding requests, batch_size == input_tokens. - batch_size = len(input_tokens) max_seq_len = max(seq_lens) - use_captured_graph = ( - not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seq_len <= _MAX_SEQ_LENS_TO_CAPTURE[-1] - and max_seq_len <= self.max_seq_len_to_capture) - if use_captured_graph: - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - for _ in range(graph_batch_size - batch_size): - input_tokens.append([0]) - input_positions.append([0]) - slot_mapping.append([_PAD_SLOT_ID]) - seq_lens.append(1) - block_tables.append([]) - lora_index_mapping.append(0) - batch_size = graph_batch_size - input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) @@ -514,33 +497,15 @@ def _prepare_decode( dtype=torch.int, device=self.device) - if use_captured_graph: - # When using cuda-graph all these tensors should be - # padded. - assert seq_lens_tensor.shape[0] == len(input_tokens) - assert seq_lens_tensor.shape[0] == len(input_positions) - assert seq_lens_tensor.shape[0] == len(slot_mapping) - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - graph_max_seq_len = _get_graph_max_seq_len(max_seq_len) - assert graph_max_seq_len >= max_seq_len - graph_block_count = math.ceil(graph_max_seq_len / self.block_size) - input_block_tables = self.graph_block_tables[:batch_size, :graph_block_count] - for i, block_table in enumerate(block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=self.device) - else: - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, seq_lens=None, @@ -551,7 +516,7 @@ def _prepare_decode( seq_start_loc=None, context_lens_tensor=None, block_tables=block_tables, - use_cuda_graph=use_captured_graph, + use_cuda_graph=False, ) return PrepareDecodeMetadata( input_tokens=input_tokens, @@ -563,7 +528,6 @@ def _prepare_decode( slot_mapping=slot_mapping, ) - def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -611,7 +575,7 @@ def prepare_input_tensors( num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) - # NOTE(kzawora): Here we diverge from GPU code - we don't support mixed batches, so we either use decode or prefill inputs, without coalescing. + # NOTE(kzawora): Here we diverge from GPU code - we don't support mixed batches, so we either use decode or prefill inputs, without coalescing. assert (num_prefills == 0 and num_decode_tokens > 0) or (num_prefills > 0 and num_decode_tokens == 0), "HPU does not support mixed batches!" if num_decode_tokens > 0: input_tokens = decode_input_tokens @@ -621,6 +585,14 @@ def prepare_input_tensors( lora_prompt_mapping = decode_lora_prompt_mapping lora_requests = decode_lora_requests + # FIXME: We need to adjust selected_token_indices to accomodate for padding + max_len = input_tokens.size(1) + paddings = [max_len - s for s in seq_lens] + paddings = [0] + paddings[:-1] + paddings = list(itertools.accumulate(paddings)) + paddings = torch.tensor(paddings, dtype=sampling_metadata.selected_token_indices.dtype, device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices.add_(paddings) + if self.lora_config: lora_mapping = LoRAMapping( lora_index_mapping, @@ -629,9 +601,6 @@ def prepare_input_tensors( else: lora_mapping = None - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. if (prefill_attn_metadata is not None and decode_attn_metadata is not None): batch_type = BatchType.MIXED @@ -721,13 +690,19 @@ def prepare_input_tensors( sampling_metadata, lora_requests, lora_mapping, multi_modal_input) - @torch.inference_mode() def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + is_prompt = seq_group_metadata_list[0].is_prompt + real_batch_size = len(seq_group_metadata_list) + bucket_cfg = self.prompt_bs_bucket_cfg if is_prompt else self.decode_bs_bucket_cfg + batch_size_padding = find_bucket(real_batch_size, bucket_cfg) - real_batch_size + seq_group_metadata_list = seq_group_metadata_list.copy() + seq_group_metadata_list.extend(seq_group_metadata_list[0] for _ in range(batch_size_padding)) (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input ) = self.prepare_input_tensors(seq_group_metadata_list) @@ -735,17 +710,6 @@ def execute_model( if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) - # Currently HPU graph is only supported by the decode phase. - prefill_meta = attn_metadata.prefill_metadata - decode_meta = attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: - graph_batch_size = input_tokens.shape[0] - graph_block_count = decode_meta.block_tables.shape[1] - graph_runner_key = (graph_batch_size, graph_block_count) - model_executable = self.graph_runners[graph_runner_key] - logger.info(f"Executing {self.graph_runner_class.__name__} with batch {graph_batch_size}, block_count {graph_block_count} (context_len up to {graph_block_count*self.block_size}, currently {torch.max(decode_meta.seq_lens_tensor).item()})") - else: - model_executable = self.model execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, @@ -754,11 +718,14 @@ def execute_model( } if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) - hidden_states = model_executable(**execute_model_kwargs) + + htorch.core.mark_step() + hidden_states = self.model(**execute_model_kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) + htorch.core.mark_step() # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -769,520 +736,63 @@ def execute_model( logits=logits, sampling_metadata=sampling_metadata, ) - + output.outputs = output.outputs[:real_batch_size] + htorch.core.mark_step() return output - @torch.inference_mode() - def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests = [] - dummy_lora_requests_per_seq = [] - if self.lora_config: - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_local_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for vision encoding, which needs - # to be accounted for when calculating the GPU blocks for - # vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - if self.vision_language_config: - max_num_seqs = min( - max_num_seqs, - int(max_num_batched_tokens / - self.vision_language_config.image_feature_size)) - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data = SequenceData([0] * seq_len) - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - ) - seqs.append(seq) + def create_dummy_seq_group_metadata(self, group_id, seq_len, is_prompt): + sampling_params = SamplingParams(temperature=0) + num_blocks = math.ceil(seq_len / self.block_size) + if is_prompt: + input_len = seq_len + output_len = 0 + block_tables = None + else: + input_len = seq_len - 1 + output_len = 1 + block_tables = {group_id: [0] * num_blocks} + prompt_token_ids = [0] * input_len + output_token_ids = [1] * output_len + seq_data = SequenceData(prompt_token_ids) + seq_data.output_token_ids = output_token_ids + return SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=(output_len == 0), + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + ) - # Run the model with the dummy inputs. + def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - self.execute_model(seqs, kv_caches) + seq_len = self.max_model_len // self.max_num_seqs + self.warmup_scenario(self.max_num_seqs, seq_len, True, kv_caches) + + def warmup_scenario(self, batch_size, seq_len, is_prompt, kv_caches) -> None: + seqs = [self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) for i in range(batch_size)] + _ = self.execute_model(seqs, kv_caches) torch.hpu.synchronize() - return - - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_loras() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_loras(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_lora(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_loras() @torch.inference_mode() - def capture_model(self, kv_caches: List[torch.Tensor]) -> None: - """Cuda graph capture a model. - - Note that CUDA graph's performance gain is negligible if number - of batched tokens are larger than 200. And since CUDA graph - requires fixed sized tensors, supporting large/variable batch - size requires high GPU memory overhead. Thus, vLLM only captures - decoding requests. Mixed batch (chunked prefill + decoding) or - prefill requests are not captured. - - Since it is used for decoding-only, it assumes there's only 1 token - per sequence in the batch. - """ - # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never - # deleted before the CUDA graphs. - - assert not self.model_config.enforce_eager - logger.info("Capturing the model for HPUGraphs. This may lead to " - "unexpected consequences if the model is not static. To " - "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI.") - logger.info("HPUGraphs can take additional ~10 GiB memory per HPU. " - "If you are running out of memory, consider decreasing " - "`gpu_memory_utilization` or enforcing eager mode. " - "You can also reduce the `max_num_seqs` as needed " - "to decrease memory usage.") + def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + times = 1 # TODO: this is will be updated once HPU graphs are reintroduced + scenarios = [] + scenarios.extend(itertools.product(warmup_buckets(self.decode_bs_bucket_cfg), warmup_buckets(self.decode_seq_bucket_cfg), [False])) + scenarios.extend(itertools.product(warmup_buckets(self.prompt_bs_bucket_cfg), warmup_buckets(self.prompt_seq_bucket_cfg), [True])) + scenarios = [scenario for scenario in reversed(scenarios) for _ in range(times) if scenario not in self.excluded_from_warmup] + + start_mem = HabanaMemoryProfiler.current_memory_usage() start_time = time.perf_counter() - - # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).to('hpu') - input_positions = torch.zeros(max_batch_size, 1, dtype=torch.long).to('hpu') - slot_mapping = torch.zeros(max_batch_size, 1, dtype=torch.long).to('hpu') # TODO(kzawora): when using torch.empty, following occurs: RuntimeError: Error when trying to cast Long to Int, Input values range [0, 139632108750000] exceeds Int range [-2147483648, 2147483647] - slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).to('hpu') - block_tables = torch.from_numpy(self.graph_block_tables).to('hpu') - - graph_batch_size = _get_graph_batch_size( - self.scheduler_config.max_num_seqs) - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - - # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce - # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use - # either custom all-reduce kernel or CuPy NCCL. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or CuPy NCCL if it is disabled or not supported. - with custom_all_reduce.capture(): - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - valid_combinations = [] - total_combinations = len(_BATCH_SIZES_TO_CAPTURE)*len(_MAX_SEQ_LENS_TO_CAPTURE) - import pandas as pd - df = pd.DataFrame(index=_BATCH_SIZES_TO_CAPTURE, columns=_MAX_SEQ_LENS_TO_CAPTURE) - for idx, (batch_size, max_seq_len) in enumerate(itertools.product(reversed(_BATCH_SIZES_TO_CAPTURE), reversed(_MAX_SEQ_LENS_TO_CAPTURE))): - block_count = math.ceil(max_seq_len / self.block_size) - # Skip capture of "out-of-bound" batch sizes and context lengths - if batch_size > self.scheduler_config.max_num_seqs: - logger.debug(f"[{idx}/{total_combinations}] Skipping capture for batch {batch_size}, max_seq_len {max_seq_len}, block_count {block_count}. Reason: Batch out of bound.") - df[max_seq_len][batch_size] = 'batch OoB' - continue - if max_seq_len > self.max_seq_len_to_capture: - logger.debug(f"[{idx}/{total_combinations}] Skipping capture for batch {batch_size}, max_seq_len {max_seq_len}, block_count {block_count}. Reason: Nax context length out of bound.") - df[max_seq_len][batch_size] = 'ctx OoB' - continue - block_count = math.ceil(max_seq_len / self.block_size) - captured_block_counts = [math.ceil(cl / self.block_size) for (n, cl) in valid_combinations if n == batch_size] - if block_count in captured_block_counts: - logger.debug(f"[{idx}/{total_combinations}] Skipping capture for batch {batch_size}, max_seq_len {max_seq_len}, block_count {block_count}. Reason: Block size already captured.") - df[max_seq_len][batch_size] = 'redundant' - continue - logger.debug(f"[{idx}/{total_combinations}] Will capture for batch {batch_size}, max_seq_len {max_seq_len}, block_count {block_count}. Constraints met.") - df[max_seq_len][batch_size] = 'VALID' - valid_combinations.append((batch_size, max_seq_len)) - - total_valid_hpugraphs = len(valid_combinations) - logger.info(f"Starting capture {total_valid_hpugraphs} valid HPUGraphs. Skipping capture of {total_combinations-total_valid_hpugraphs}/{total_combinations} graphs due to batch/context constraints.") - logger.debug(f"Capture summary (row: batch_size; col: max_seq_len):") - logger.debug(tabulate.tabulate(df, tablefmt='mixed_outline', headers='keys', showindex="always")) - - graph_runner_name = self.graph_runner_class.__name__ - graph_mem_usage_df = pd.DataFrame(index=list(reversed(sorted({b for b,c in valid_combinations}))), columns=list(reversed(sorted({c for b,c in valid_combinations})))) - pbar = tqdm.tqdm(valid_combinations) - start_mem = HabanaMemoryProfiler.current_memory_usage() - log_graph_compilation_all = os.environ.get('VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0' - log_graph_compilation = os.environ.get('VLLM_HPU_LOG_STEP_GRAPH_COMPILATION', '0') != '0' or log_graph_compilation_all - - for idx, (batch_size, max_seq_len) in enumerate(pbar): - block_count = math.ceil(max_seq_len / self.block_size) - # Create dummy attn_metadata. - decode_metadata = self.attn_backend.make_metadata( - is_prompt=False, - seq_lens=None, - seq_lens_tensor=context_lens[:batch_size], - max_query_len=None, - max_seq_len=block_count*self.block_size, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, # NOTE(kzawora): this seems sus, shoudn't we have seq_lens tensor here? - block_tables=block_tables[:batch_size, :block_count], - use_cuda_graph=True, - ) - attn_metadata = AttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - prefill_metadata=None, - decode_metadata=decode_metadata, - kv_cache_dtype=self.kv_cache_dtype, - ) - - if self.lora_config: - lora_mapping = LoRAMapping( - [0] * batch_size, - [0] * batch_size, - ) - self.set_active_loras(set(), lora_mapping) - graph_runner = self.graph_runner_class(self.model) - local_start_mem = HabanaMemoryProfiler.current_memory_usage() - capture_start = time.time() - desc = f'Capturing {graph_runner_name} for batch {batch_size}, max_seq_len {max_seq_len}, block_count {block_count}, allocated {format_bytes(local_start_mem - start_mem)} device memory in total ({format_bytes(HabanaMemoryProfiler.current_memory_usage())}/{format_bytes(HabanaMemoryProfiler.total_memory())} used)' - pbar.set_description(desc) - logger.debug(f"[{idx}/{total_valid_hpugraphs}] {desc}...") - profiling_ctx = contextlib.nullcontext() if not (log_graph_compilation_all or log_graph_compilation) else metric_localcontext("graph_compilation") - with profiling_ctx as gc_local_metric: - graph_runner.capture( - input_tokens[:batch_size], - input_positions[:batch_size], - kv_caches, - attn_metadata, - ) - if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0) or log_graph_compilation_all: - logger.info(f"VLLM_HPU_STEP_GRAPH_COMPILATION: {gc_local_metric.stats()}, {graph_runner_name}; batch {batch_size}, max_seq_len {max_seq_len}, block_count {block_count}") - self.graph_runners[(batch_size, block_count)] = graph_runner - capture_end = time.time() - local_end_mem = HabanaMemoryProfiler.current_memory_usage() - mem_usage_str = format_bytes(local_end_mem - local_start_mem) - graph_mem_usage_df[max_seq_len][batch_size] = mem_usage_str - logger.debug(f"[{idx}/{total_valid_hpugraphs}] {desc}... done in {capture_end-capture_start:.2f} seconds! Took {mem_usage_str} of device memory ({format_bytes(HabanaMemoryProfiler.current_memory_usage())}/{format_bytes(HabanaMemoryProfiler.total_memory())} used)") - + for i, (batch_size, seq_len, is_prompt) in enumerate(scenarios): + mem_usage = 100.0 * HabanaMemoryProfiler.current_memory_usage() / HabanaMemoryProfiler.total_memory() + logger.info(f"[Warmup][{i+1}/{len(scenarios)}] batch_size:{batch_size} seq_len:{seq_len} is_prompt:{is_prompt} mem_usage:{mem_usage:0.1f}%") + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) end_time = time.perf_counter() - elapsed_time = end_time - start_time - # This usually takes < 10 seconds. end_mem = HabanaMemoryProfiler.current_memory_usage() - logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs, allocated {format_bytes(end_mem - start_mem)} of device memory ({format_bytes(HabanaMemoryProfiler.current_memory_usage())}/{format_bytes(HabanaMemoryProfiler.total_memory())} used)") - logger.info(f"Graph memory allocation summary (row: batch_size; col: max_seq_len):") - logger.info(tabulate.tabulate(graph_mem_usage_df, tablefmt='mixed_outline', headers='keys', showindex="always")) - - def __del__(self) -> None: - # Delete the CUDA graphs before deleting the CuPy NCCL communicator. - # NOTE(woosuk): This is necessary because otherwise deadlocks can - # happen. - # FIXME(woosuk): This is a bit hacky. Find a more robust solution. - self.graph_runners.clear() + elapsed_time = end_time - start_time + logger.info(f"Warmup finished in {elapsed_time:.0f} secs, allocated {format_bytes(end_mem - start_mem)} of device memory") @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() - - -class FakeHPUGraphRunner: - - def __init__(self, model: nn.Module): - self.model = model - - def capture( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> None: - return - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - return self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - ) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - -class FakeHPUGraphRunnerWithWarmup: - - def __init__(self, model: nn.Module): - self.model = model - - def capture( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> None: - htorch.core.mark_step() - out = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - ) - htorch.core.mark_step() - htorch.hpu.synchronize() - return - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - htorch.core.mark_step() - out = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - ) - htorch.core.mark_step() - return out - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) -class HPUGraphRunner: - - def __init__(self, model: nn.Module): - self.model = model - self.graph = None - self.input_buffers: Dict[str, torch.Tensor] = {} - self.output_buffers: Dict[str, torch.Tensor] = {} - - def capture( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> None: - assert self.graph is None - # Run the model once without capturing the graph. - # This is to make sure that the captured graph does not include the - # kernel launches for initial benchmarking (e.g., Triton autotune). - self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - ) - htorch.hpu.synchronize() - - # Capture the graph. - # NOTE(woosuk): Python 3.8 does not support multi-line with statements. - # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement - self.graph = htorch.hpu.HPUGraph() - with htorch.hpu.graph(self.graph): # noqa: SIM117 - hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - ) - torch.hpu.synchronize() - - # Save the input and output buffers. - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - } - self.output_buffers = {"hidden_states": hidden_states} - return - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - **kwargs, - ) -> torch.Tensor: - # KV caches are fixed tensors, so we don't need to copy them. - del kv_caches - - # Copy the input tensors to the input buffers. - self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) - self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - # Run the graph. - self.graph.replay() - - # Return the output tensor. - return self.output_buffers["hidden_states"] - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - -class ExperimentalHPUGraphRunner: - def __init__(self, model: nn.Module): - self.model = model - - def capture( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> None: - class ModelWrapper(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - self.attn_backend = get_attn_backend(torch.bfloat16) - def forward(self, input_ids, positions, kv_caches, slot_mapping, context_lens, block_tables): - wrapper_attn_metadata = self.attn_backend.make_metadata( - is_prompt=attn_metadata.is_prompt, - seq_lens=None, - seq_lens_tensor=None, - num_prefill_tokens=0, - num_generation_tokens=attn_metadata.num_generation_tokens, - max_subquery_len=None, - max_seq_len=attn_metadata.max_seq_len, - max_prompt_len=None, - subquery_start_loc=None, - seq_start_loc=None, - context_lens=context_lens, - block_tables=block_tables, - use_cuda_graph=True, - kv_cache_dtype=attn_metadata.kv_cache_dtype, - ) - return self.model( - input_ids, - positions, - kv_caches, - wrapper_attn_metadata - ) - self.graph_model = htorch.hpu.wrap_in_hpu_graph(ModelWrapper(self.model)) - out = self.graph_model( - input_ids, - positions, - kv_caches, - attn_metadata.slot_mapping, - attn_metadata.context_lens, - attn_metadata.block_tables, - ) - htorch.hpu.synchronize() - return - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - out = self.graph_model( - input_ids, - positions, - kv_caches, - attn_metadata.slot_mapping, - attn_metadata.context_lens, - attn_metadata.block_tables, - ) - return out - - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - -def _get_graph_batch_size(batch_size: int) -> int: - """Returns the padded batch size given actual batch size. - - Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, - 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... - """ - if batch_size <= 2: - return batch_size - elif batch_size <= 4: - return 4 - elif batch_size <= 8: - return 8 - else: - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - -def _get_graph_max_seq_len(max_seq_len: int) -> int: - """Returns the padded batch size given actual batch size. - - Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, - 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... - """ - if max_seq_len <= 32: - return 32 - elif max_seq_len <= 64: - return 64 - elif max_seq_len <= 128: - return 128 - else: - return ((max_seq_len + _MAX_SEQ_LEN_ALIGNMENT - 1) // - _MAX_SEQ_LEN_ALIGNMENT * _MAX_SEQ_LEN_ALIGNMENT) diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 43ccd235c174f..eeba9e5c4adba 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -73,13 +73,14 @@ def __init__( assert False, "To be tested: vision language model on HPU" self.model_runner = HabanaModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - load_config=load_config, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker) + parallel_config, + scheduler_config, + device_config, + load_config=load_config, + cache_config=cache_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: CacheEngine @@ -168,12 +169,10 @@ def _init_cache_engine(self) -> None: self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.hpu_cache = self.cache_engine.gpu_cache - self.model_runner.set_block_size(self.cache_engine.block_size) htorch.hpu.synchronize() # we want to materialize cache tensors before we proceed with graph capture/execution def _warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.hpu_cache) + self.model_runner.warmup_model(self.hpu_cache) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed)