Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions examples/offline_inference/basic/basic.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this - we should make a examples/offline_inference/tpu/ folder to keep this

Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=16,
max_model_len=128,
enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
28 changes: 21 additions & 7 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
Expand All @@ -33,14 +34,16 @@

@dataclass
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
# Copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
# TODO: Extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
# TODO: Remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
# Set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
# Whether this is a profile run (before KV cache init)
is_profile_run: bool = False,

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]

Check failure on line 46 in vllm/forward_context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[bool]", variable has type "bool") [assignment]


_forward_context: Optional[ForwardContext] = None
Expand All @@ -58,7 +61,8 @@
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0):
num_tokens: int = 0,
is_profile_run: bool = False):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
Expand Down Expand Up @@ -93,12 +97,15 @@

global _forward_context
prev_context = _forward_context

_forward_context = ForwardContext(
no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
dp_metadata=dp_metadata,
is_profile_run=is_profile_run)

try:
yield
finally:
Expand All @@ -111,10 +118,17 @@
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens

# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
torch.cuda.synchronize()
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
xm.mark_step()
xm.wait_device_ops()
else:
torch.cuda.synchronize()

now = time.perf_counter()
# time measurement is in milliseconds
batchsize_forward_time[batchsize].append(
Expand Down
76 changes: 52 additions & 24 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -104,9 +103,6 @@ def __init__(
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size

# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

Expand Down Expand Up @@ -582,7 +578,6 @@ def execute_model(
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:total_num_scheduled_tokens]
Expand Down Expand Up @@ -680,8 +675,8 @@ def load_model(self) -> None:

def _dummy_run(
self,
kv_caches,
num_tokens: int,
is_profile_run: bool,
) -> None:
if self.is_multimodal_model:
input_ids = None
Expand Down Expand Up @@ -728,15 +723,28 @@ def _dummy_run(
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)

with set_forward_context(attn_metadata, self.vllm_config, 0):
with set_forward_context(attn_metadata,
self.vllm_config,
0,
is_profile_run=is_profile_run):
assert self.model is not None
self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)

# This is used before KV cache init
def profile_run(self, num_tokens) -> None:
self._dummy_run(num_tokens=num_tokens, is_profile_run=True)

# This is used after KV cache init
def dummy_run(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_dummy_run and dummy_run look confusing, do we really need this overload?

self,
num_tokens: int,
) -> None:
self._dummy_run(num_tokens=num_tokens, is_profile_run=False)

def capture_model(self) -> None:
"""Compile the model."""

Expand All @@ -745,7 +753,7 @@ def capture_model(self) -> None:
start = time.perf_counter()
num_tokens = 16
while True:
self._dummy_run(self.kv_caches, num_tokens)
self.dummy_run(num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
xm.mark_step()
xm.wait_device_ops()
Expand All @@ -769,6 +777,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:

kv_caches: dict[str, torch.Tensor] = {}

kv_cache_shape_prev = None
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
Expand All @@ -779,6 +788,12 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
layer_spec.head_size)
dtype = layer_spec.dtype

# Ensure all "kv_cache_shape" are the same across the model
if kv_cache_shape_prev is None:
kv_cache_shape_prev = kv_cache_shape
else:
assert kv_cache_shape == kv_cache_shape_prev
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: is this for ruling out some model architecture?


tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
Expand All @@ -788,23 +803,32 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
else:
raise NotImplementedError

bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
# ModelWrapperV1 needs to know the KV cache shape
self.model.set_kv_cache_shape(kv_cache_shape_prev)

# Associates each attention layer in the `forward_context` with the
# initialized KV cache.
forward_context = self.vllm_config.compilation_config \
.static_forward_context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
Comment on lines +811 to +815
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you see any use in having this bit factored as a util, similarly to bind_kv_cache? We could re-use at least in tpu_worker



class ModelWrapperV1(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to implement ModelWrapperV1 like this?

class ModelWrapperV1(nn.Module):

     def __init__(self, model: nn.Module, num_kv_heads, num_blocks, block_size):
         super().__init__()
         self.model = model
         self.num_kv_heads = num_kv_heads
         ...

     def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
         inputs_embeds: Optional[torch.Tensor] = None,
         is_profile_run: bool,
     ) -> torch.Tensor:
         if not is_profile_run:
              num_kv_heads = self.num_kv_heads
              ...

class TPUModelRunner:
    def _dummy_run(
           self,
           num_tokens: int,
           is_profile_run: bool,
     ) -> None:
         self.model.forward(..., is_profile_run=is_profile_run)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 this is not possible because num_blocks is not known until determine_num_available_blocks is done and initialize_kv_cache is executed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass a fake value first and update it after determine_num_available_blocks?


def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.kv_cache_shape = None

def set_kv_cache_shape(self, kv_cache_shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can probably get away without setters as long as we keep the class and the logic lean

self.kv_cache_shape = kv_cache_shape

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Expand All @@ -817,16 +841,20 @@ def forward(
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
attn_metadata = get_forward_context().attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata

# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
#
# Note: We skip this step during first profiling run (before KV init)
if not forward_context.is_profile_run:
assert self.kv_cache_shape # Ensure initialized
num_kv_heads, num_blocks, block_size, _ = self.kv_cache_shape

slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indicies = torch.arange(0,
Expand Down
20 changes: 10 additions & 10 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.tpu_model_runner import TPUModelRunner

logger = init_logger(__name__)
Expand Down Expand Up @@ -128,18 +127,19 @@ def determine_available_memory(self) -> int:
else:
raise NotImplementedError

runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
# Associates each attention layer in the `forward_context` with the
# initialized KV cache.
forward_context = self.vllm_config.compilation_config \
.static_forward_context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]

self.model_runner._dummy_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
self.model_runner.profile_run(
num_tokens=self.scheduler_config.max_num_batched_tokens)

# Synchronize before measuring the memory usage.
xm.mark_step()
xm.wait_device_ops()

# Get the maximum amount of memory used by the model weights and
Expand Down