Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM

- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava and qwen models.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava, qwen and roberta models.
- `VLLM_PROMPT_USE_FLEX_ATTENTION` is enabled only for llama model, and allows to use torch.nn.attention.flex_attention instead of FusedSDPA. Note, this requires `VLLM_PROMPT_USE_FUSEDSDPA=0`

# Quantization, FP8 Inference and Model Calibration Process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM

- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava model.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava, qwen and roberta models.

## Quantization, FP8 Inference and Model Calibration Process

Expand Down
120 changes: 118 additions & 2 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import itertools
import os
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -9,6 +10,7 @@

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand Down Expand Up @@ -47,7 +49,8 @@ def encoder_decoder_weights():
if not n.startswith("roberta."))


class RobertaEmbedding(nn.Module):
@CustomOp.register("roberta_embedding")
class RobertaEmbedding(CustomOp):

def __init__(self, config: RobertaConfig):
super().__init__()
Expand All @@ -71,7 +74,80 @@ def __init__(self, config: RobertaConfig):
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")

def forward(
self.use_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL',
'false').lower() == 'true'

def forward_hpu(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# Modified replace position ids
# for HPU set position_ids and input_ids as [batch_size, bucket_size]
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
pos_list = []
token_list = []
if self.use_merged_prefill:
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[0][offset:offset + seq_len])
token_list.append(input_ids[0][offset:offset + seq_len])
offset += seq_len

offset = 0
for positions, tokens, seq_len in zip(pos_list, token_list,
seq_lens):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
position_ids[0][offset:offset +
seq_len] = create_position_ids_from_input_ids(
tokens, self.padding_idx)
offset += seq_len
else:
for offset in range(position_ids.size()[0]):
pos_list.append(position_ids[offset])
token_list.append(input_ids[offset])

for index, (positions, tokens, seq_len) in enumerate(
zip(pos_list, token_list, seq_lens)):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
valid_input_mask = expected_pos < seq_len
expected_pos = expected_pos * valid_input_mask
assert torch.equal(positions, expected_pos)
position_ids[index] = create_position_ids_from_input_ids_hpu(
tokens, self.padding_idx, seq_len)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)

token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings

def forward_native(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
Expand Down Expand Up @@ -119,6 +195,46 @@ def forward(
embeddings = self.LayerNorm(embeddings)
return embeddings

def forward_cuda(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.forward_native(input_ids, seq_lens, position_ids,
token_type_ids)


# Adapted from transformers
def create_position_ids_from_input_ids_hpu(input_ids,
padding_idx,
seq_len,
past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.

Args:
x: torch.Tensor x:

Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
valid_input_mask = torch.arange(input_ids.size()[0],
dtype=torch.int,
device=input_ids.device)
valid_input_mask = valid_input_mask < seq_len

mask = input_ids.ne(padding_idx).int()

incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
past_key_values_length) * mask

return (incremental_indices.long() + padding_idx) * valid_input_mask


# Adapted from transformers
def create_position_ids_from_input_ids(input_ids,
Expand Down