Skip to content

Commit

Permalink
support pp virtual engine
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
  • Loading branch information
heheda12345 committed Jan 8, 2025
1 parent ffe8cdd commit 2cb84f2
Show file tree
Hide file tree
Showing 13 changed files with 56 additions and 39 deletions.
12 changes: 9 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def __init__(
self.attn_type = attn_type
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
self.kv_cache = torch.tensor([])
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]

def forward(
self,
Expand Down Expand Up @@ -238,7 +242,8 @@ def unified_attention(
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
return self.impl.forward(query, key, value, self.kv_cache, attn_metadata,
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._k_scale, self._v_scale)


Expand Down Expand Up @@ -270,10 +275,11 @@ def unified_attention_with_output(
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(query,
key,
value,
self.kv_cache,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
Expand Down
7 changes: 6 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class ForwardContext:
attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass


_forward_context: Optional[ForwardContext] = None
Expand All @@ -42,7 +44,9 @@ def get_forward_context() -> ForwardContext:


@contextmanager
def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig):
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
Expand All @@ -55,6 +59,7 @@ def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig):
prev_context = _forward_context
_forward_context = ForwardContext(
attn_layers=vllm_config.compilation_config.static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata)
try:
yield
Expand Down
15 changes: 11 additions & 4 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,16 +1949,20 @@ def get_mp_context():
return multiprocessing.get_context(mp_method)


def bind_kv_cache(ctx: Dict[str, Any], kv_cache: List[torch.Tensor]) -> None:
def bind_kv_cache(
ctx: Dict[str, Any],
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache = kv_cache[extract_layer_index(layer_name)]
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn is mapped to the same kv cache tensor
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
Expand All @@ -1974,4 +1978,7 @@ def bind_kv_cache(ctx: Dict[str, Any], kv_cache: List[torch.Tensor]) -> None:
kv_cache_idx = layer_index_sorted.index(
extract_layer_index(layer_name))
forward_ctx = ctx[layer_name]
forward_ctx.kv_cache = kv_cache[kv_cache_idx]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
assert forward_ctx.kv_cache[ve].numel() == 0
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,4 +863,4 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
device=self.device))
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
[self.kv_caches])
18 changes: 5 additions & 13 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import torch

from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
ModelConfig, ParallelConfig)
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
bind_kv_cache, get_dtype_size, is_pin_memory_available)
get_dtype_size, is_pin_memory_available)

logger = init_logger(__name__)

Expand All @@ -21,14 +20,9 @@ class CacheEngine:
as swapping and copying.
"""

def __init__(
self,
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
compilation_config: CompilationConfig,
) -> None:
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
Expand Down Expand Up @@ -64,8 +58,6 @@ def __init__(
self.gpu_cache = self._allocate_kv_cache(
self.num_gpu_blocks, self.device_config.device_type)
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
bind_kv_cache(compilation_config.static_forward_context,
self.gpu_cache)

def _allocate_kv_cache(
self,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/cpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def execute_model(
intermediate_tensors,
}

with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)

# Compute the logits.
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ def execute_model(
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})

with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/cpu_pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def execute_model(
intermediate_tensors,
}

with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)

# Only perform pooling in the driver worker.
Expand Down
14 changes: 7 additions & 7 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
ModelConfig, ParallelConfig, VllmConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
Expand All @@ -33,8 +33,8 @@ class CPUCacheEngine:
"""

def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, device_config: DeviceConfig,
compilation_config: CompilationConfig) -> None:
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
assert device_config.device_type == "cpu"
self.cache_config = cache_config
self.model_config = model_config
Expand Down Expand Up @@ -66,8 +66,6 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,

# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
bind_kv_cache(compilation_config.static_forward_context,
self.cpu_cache)

def _allocate_kv_cache(
self,
Expand Down Expand Up @@ -292,13 +290,15 @@ def _init_cache_engine(self) -> None:
self.model_config,
self.parallel_config,
self.device_config,
self.compilation_config,
) for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.cpu_cache = [
self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
for ve in range(self.parallel_config.pipeline_parallel_size):
bind_kv_cache(self.compilation_config.static_forward_context,
self.cpu_cache[ve], ve)
self.model_runner.block_size = self.cache_engine[0].block_size

assert all(
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def execute_model(
} if self.has_inner_state else {}

multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
Expand Down
5 changes: 3 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs)

with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
Expand Down Expand Up @@ -1695,7 +1696,7 @@ def execute_model(

if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def execute_model(
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types

with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
Expand Down
7 changes: 4 additions & 3 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, memory_profiling
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
Expand Down Expand Up @@ -278,14 +278,15 @@ def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
CacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config,
self.compilation_config)
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.gpu_cache)

def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
Expand Down

0 comments on commit 2cb84f2

Please sign in to comment.