Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,15 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

def make_empty_intermediate_tensors(
batch_size: int,
context_size: int,
dtype: torch.dtype,
device: torch.device,
) -> IntermediateTensors:
return IntermediateTensors({
key:
torch.zeros((batch_size, hidden_size), dtype=dtype, device=device)
torch.zeros((batch_size, context_size, hidden_size),
dtype=dtype,
device=device)
for key in keys
})

Expand Down
3 changes: 3 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,9 @@ def __eq__(self, other: object):
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"

def __iter__(self):
return iter(self.tensors)


class PoolerOutput(
msgspec.Struct,
Expand Down
27 changes: 24 additions & 3 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.hpu_attn import HPUAttentionImpl
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
Expand Down Expand Up @@ -421,6 +421,8 @@ def forward(self, *args, **kwargs):
with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
virtual_engine):
hidden_states = self.model(*args, **kwargs)
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if selected_token_indices is not None:
hidden_states = hidden_states.index_select(
Expand All @@ -433,6 +435,9 @@ def compute_logits(self, *args, **kwargs):
def sample(self, *args, **kwargs):
return self.model.sample(*args, **kwargs)

def make_empty_intermediate_tensors(self, *args, **kwargs):
return self.model.make_empty_intermediate_tensors(*args, **kwargs)

def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)

Expand Down Expand Up @@ -1949,7 +1954,7 @@ def profile_run(self) -> None:
kv_caches = [None] * num_layers
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[kv_caches])
[kv_caches] * self.parallel_config.pipeline_parallel_size)
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
max_batch_size = min(self.max_num_seqs,
self.max_num_batched_tokens // max_seq_len)
Expand Down Expand Up @@ -2030,7 +2035,18 @@ def warmup_scenario(self,
is_single_step = \
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_prompt or is_single_step:
self.execute_model(inputs, kv_caches, warmup_mode=True)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = \
self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
context_size=seq_len if is_prompt else 1,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(inputs,
kv_caches,
intermediate_tensors=intermediate_tensors,
warmup_mode=True)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
Expand Down Expand Up @@ -2528,6 +2544,9 @@ def execute_model(
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
assert not (use_delayed_sampling and num_steps != 1), \
'Delayed sampling is not compatible with MSS!'
assert not (use_delayed_sampling and
self.parallel_config.pipeline_parallel_size != 1), \
'Delayed sampling is not compatible with Pipeline Parallelism!'
assert model_input.input_tokens is not None
if use_delayed_sampling and not model_input.is_prompt and \
self.is_driver_worker:
Expand Down Expand Up @@ -2684,6 +2703,8 @@ def try_revert_dummy_output_tokens():
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))
if not get_pp_group().is_last_rank:
return hidden_states

# Compute the logits.
with self.profiler.record_event(
Expand Down
11 changes: 8 additions & 3 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.distributed import (ensure_model_parallel_initialized, get_pp_group,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -63,8 +63,9 @@ def __init__(
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.parallel_config and self.is_driver_worker:
assert self.rank % self.parallel_config.tensor_parallel_size == 0, \
"The driver worker must have rank 0."

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
Expand Down Expand Up @@ -526,6 +527,10 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

if parallel_config.pipeline_parallel_size > 1:
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group().all_reduce(torch.zeros(1).to('hpu'))
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
Expand Down