Skip to content

Commit

Permalink
rebase on main
Browse files Browse the repository at this point in the history
  • Loading branch information
gnovack committed Jan 27, 2025
1 parent 9552c94 commit 23afd2e
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 79 deletions.
2 changes: 0 additions & 2 deletions vllm/attention/ops/nki_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def flash_paged_attention(
B_F_SIZE = 512
b, h, d, seqlen_q = query.shape
B_P_SIZE = 128
# B_P_SIZE = min(seqlen_q, 128)
B_D_SIZE = d
LARGE_TILE_SZ = config.seq_tile_size

Expand All @@ -379,7 +378,6 @@ def flash_paged_attention(
assert tuple(value_cache.shape) == (
# TODO(gnovack) - hacky padding block
num_blocks + 1,
# num_blocks,
block_size,
k_h,
d,
Expand Down
7 changes: 2 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,10 +1308,7 @@ def __post_init__(self) -> None:
from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray_is_available()
if current_platform.is_neuron():
# neuron uses single process to control multiple devices
backend = "uni"
elif (current_platform.is_cuda()
if (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if not ray_found:
raise ValueError("Unable to load Ray which is "
Expand Down Expand Up @@ -3277,7 +3274,7 @@ def _set_cudagraph_sizes(self):
batch_size_capture_list = []
if current_platform.is_neuron():
# TODO(gnovack) - choose a proper list of batch sizes
batch_size_capture_list = [128]
batch_size_capture_list = [128, self.scheduler_config.max_num_batched_tokens]
elif self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
# TODO(gnovack) - clean this up
d = x.shape[-1] // 2
# s = F.silu(x[:, :, :d])
if len(x.shape) == 3:
s = x[:, :, :d] * torch.nn.functional.sigmoid(x[:, :, :d])
return s * x[:, :, d:]
Expand Down
18 changes: 7 additions & 11 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
Expand Down Expand Up @@ -195,19 +196,14 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:

qkv, _ = self.qkv_proj(hidden_states)

# TODO(gnovack) - Figure out a better way to streamline QKV splitting
if current_platform.is_neuron():
# TODO(gnovack) - Figure out a better way to streamline QKV computation
w_q = self.qkv_proj.weight.t()[:, :self.q_size]
q = torch.einsum('bsh,hq->bsq', hidden_states, w_q)

w_k = self.qkv_proj.weight.t()[:, self.q_size:self.q_size+self.kv_size]
k = torch.einsum('bsh,hk->bsk', hidden_states, w_k)

w_v = self.qkv_proj.weight.t()[:, self.q_size+self.kv_size:self.q_size + (2*self.kv_size)]
v = torch.einsum('bsh,hk->bsk', hidden_states, w_v)
q = qkv[:, :, :self.q_size]
k = qkv[:, :, self.q_size:self.q_size+self.kv_size]
v = qkv[:, :, self.q_size+self.kv_size:self.q_size + (2*self.kv_size)]
else:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

q, k = self.rotary_emb(positions, q, k)
Expand Down
9 changes: 1 addition & 8 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,6 @@ def hpu_platform_plugin() -> Optional[str]:

return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None

is_neuron = False
try:
import neuronx_distributed # noqa: F401
is_neuron = True
except ImportError:
pass

def xpu_platform_plugin() -> Optional[str]:
is_xpu = False

Expand Down Expand Up @@ -120,7 +113,7 @@ def cpu_platform_plugin() -> Optional[str]:
def neuron_platform_plugin() -> Optional[str]:
is_neuron = False
try:
import transformers_neuronx # noqa: F401
import neuronx_distributed # noqa: F401
is_neuron = True
except ImportError:
pass
Expand Down
8 changes: 1 addition & 7 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class NeuronPlatform(Platform):
device_name: str = "neuron"
device_type: str = "neuron"
ray_device_key: str = "neuron_cores"
dispatch_key: str = "XLA"
supported_quantization: list[str] = ["neuron_quant"]
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"

Expand Down Expand Up @@ -53,19 +54,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"

if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "uni"

assert (vllm_config.lora_config is
None), "LoRA is not supported for Neuron backend."
assert (not vllm_config.speculative_config
), "Speculative decoding not yet supported for Neuron backend."

cache_config = vllm_config.cache_config
if cache_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len

@classmethod
def is_pin_memory_available(cls) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/neuron_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -138,6 +139,7 @@ def __init__(
@torch.inference_mode()
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

logger = init_logger(__name__)

Expand Down Expand Up @@ -126,6 +125,7 @@ def init_device(self):
set_random_seed(self.model_config.seed)

# Construct the model runner
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
self.model_runner = GPUModelRunner(self.vllm_config, self.device)

def load_model(self) -> None:
Expand Down
102 changes: 70 additions & 32 deletions vllm/v1/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch.distributed
import torch.nn as nn

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
Expand All @@ -20,7 +22,10 @@
from vllm.v1.attention.backends.neuron_attn import NeuronAttentionBackend, NeuronAttentionMetadata
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)

if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
Expand Down Expand Up @@ -96,6 +101,7 @@ def __init__(
max_num_blocks_per_req=self.max_num_blocks_per_req,
device="cpu",
pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
)

self.input_ids = torch.zeros(self.max_num_tokens,
Expand All @@ -109,7 +115,8 @@ def __init__(
dtype=self.dtype,
device="cpu")

self.neuron_compilation_batch_sizes = list(reversed(self.vllm_config.compilation_config.capture_sizes))
# TODO(gnovack) - use compile sizes...
self.neuron_compilation_batch_sizes = list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
Expand Down Expand Up @@ -155,8 +162,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
start_index = len(req_state.block_ids)
end_index = start_index + num_new_blocks
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table_cpu[
req_index, start_index:end_index] = req_data.new_block_ids
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)

req_ids_to_add: List[str] = []
# Add new requests to the cached states.
Expand Down Expand Up @@ -215,11 +222,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0

# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table[:num_reqs].copy_(
self.input_batch.block_table_cpu_tensor[:num_reqs],
non_blocking=True)
self.input_batch.block_table.commit(num_reqs)

# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
Expand Down Expand Up @@ -278,7 +281,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[
block_numbers = self.input_batch.block_table.get_cpu_tensor().flatten()[
req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size]
block_offsets = torch.from_numpy(positions_np % self.block_size)
Expand Down Expand Up @@ -333,14 +336,12 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
)
num_active_blocks_factor = max(LARGE_TILE_SZ // self.block_size // num_active_blocks_shifted, 1)
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor

# assert (num_active_blocks * self.block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
assert (num_active_blocks * self.block_size) % LARGE_TILE_SZ == 0, "invalid {num_active_blocks=}"

context_kv_len = num_active_blocks * self.block_size
# assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"


block_table = self.input_batch.block_table[:num_reqs]
block_table = self.input_batch.block_table.get_cpu_tensor()[:num_reqs]
active_block_table = get_active_block_tables(
block_table,
torch.tensor(query_lens),
Expand All @@ -363,8 +364,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
0,
max(context_kv_len, LARGE_TILE_SZ) - prior_mask.shape[1],
0,
# B_P_SIZE - prior_mask.shape[0],
padded_num_tokens - prior_mask.shape[0],
B_P_SIZE - prior_mask.shape[0],
),
"constant",
0,
Expand All @@ -375,8 +375,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
0,
padded_num_tokens - active_mask.shape[1],
0,
# B_P_SIZE - active_mask.shape[0],
padded_num_tokens - active_mask.shape[0],
B_P_SIZE - active_mask.shape[0],
),
"constant",
0,
Expand All @@ -397,7 +396,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
block_table=self.input_batch.block_table.get_device_tensor()[:num_reqs],
slot_mapping=slot_mapping,
num_active_blocks=num_active_blocks,
active_block_table=active_block_table,
Expand All @@ -422,7 +421,11 @@ def _prepare_sampling(
or scheduler_output.scheduled_resumed_reqs):
skip_copy = False
# Create the sampling metadata.
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
req_id_output_token_ids: Dict[str, List[int]] = \
{req_id: req.output_token_ids \
for req_id, req in self.requests.items()}

sampling_metadata = self.input_batch.make_sampling_metadata(req_id_output_token_ids, skip_copy)
return sampling_metadata

def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
Expand Down Expand Up @@ -539,13 +542,14 @@ def execute_model(

# Run the decoder.
# Use persistent buffers for CUDA graphs.
hidden_states = self.model(
input_ids=input_ids.unsqueeze(0).to(self.device),
positions=self.positions[:num_input_tokens].unsqueeze(0).to(self.device),
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds.to(self.device) if inputs_embeds is not None else None,
).cpu()
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids.unsqueeze(0).to(self.device),
positions=self.positions[:num_input_tokens].unsqueeze(0).to(self.device),
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds.to(self.device) if inputs_embeds is not None else None,
).cpu()
hidden_states = hidden_states[0, :num_scheduled_tokens]
hidden_states = hidden_states[logits_indices.cpu()]
logits = self.model.compute_logits(hidden_states, None)
Expand Down Expand Up @@ -660,7 +664,7 @@ def _dummy_run(
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
with set_forward_context(None, self.vllm_config):
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = model(
input_ids=input_ids.unsqueeze(0).to(self.device),
positions=self.positions[:num_tokens].unsqueeze(0).to(self.device),
Expand All @@ -686,18 +690,25 @@ def capture_model(self) -> None:
elapsed_time = end_time - start_time
logger.info("Neuron compilation finished in %.0f secs", elapsed_time)

def initialize_kv_cache(self, num_blocks: int) -> None:
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
assert len(self.kv_caches) == 0
self.num_blocks = num_blocks
self.num_blocks = kv_cache_config.num_blocks

kv_caches: Dict[str, torch.Tensor] = {}

with torch.inference_mode():
kv_cache_shape = NeuronAttentionBackend.get_kv_cache_shape(
num_blocks + 1, self.block_size, self.num_kv_heads, self.head_size)
for _ in range(self.num_attn_layers):
self.num_blocks + 1, self.block_size, self.num_kv_heads, self.head_size)
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
cache = torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device='cpu')
self.kv_caches.append(cache.to(self.device))
kv_caches[layer_name] = cache.to(self.device)

bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)

def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# TODO: Optimize this?
Expand All @@ -706,6 +717,33 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
return size
return None

def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""

forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
)
else:
raise NotImplementedError
return kv_cache_spec


def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
Expand Down
Loading

0 comments on commit 23afd2e

Please sign in to comment.