Skip to content

Commit 4445dca

Browse files
authored
Modify RobertaEmbedding forward as custom op method (#1049)
Same PR as #996. Just for v1.21.0_next branch.
1 parent 2edff28 commit 4445dca

File tree

3 files changed

+120
-4
lines changed

3 files changed

+120
-4
lines changed

README_GAUDI.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM
386386

387387
- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
388388
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
389-
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava and qwen models.
389+
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava, qwen and roberta models.
390390
- `VLLM_PROMPT_USE_FLEX_ATTENTION` is enabled only for llama model, and allows to use torch.nn.attention.flex_attention instead of FusedSDPA. Note, this requires `VLLM_PROMPT_USE_FUSEDSDPA=0`
391391

392392
# Quantization, FP8 Inference and Model Calibration Process

docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM
361361

362362
- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
363363
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
364-
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava model.
364+
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava, qwen and roberta models.
365365

366366
## Quantization, FP8 Inference and Model Calibration Process
367367

vllm/model_executor/models/roberta.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import itertools
4+
import os
45
from typing import Iterable, List, Optional, Tuple
56

67
import torch
@@ -9,6 +10,7 @@
910

1011
from vllm.attention import AttentionMetadata
1112
from vllm.config import VllmConfig
13+
from vllm.model_executor.custom_op import CustomOp
1214
from vllm.model_executor.layers.pooler import CrossEncodingPooler
1315
from vllm.model_executor.layers.vocab_parallel_embedding import (
1416
VocabParallelEmbedding)
@@ -47,7 +49,8 @@ def encoder_decoder_weights():
4749
if not n.startswith("roberta."))
4850

4951

50-
class RobertaEmbedding(nn.Module):
52+
@CustomOp.register("roberta_embedding")
53+
class RobertaEmbedding(CustomOp):
5154

5255
def __init__(self, config: RobertaConfig):
5356
super().__init__()
@@ -71,7 +74,80 @@ def __init__(self, config: RobertaConfig):
7174
raise ValueError("Only 'absolute' position_embedding_type" +
7275
" is supported")
7376

74-
def forward(
77+
self.use_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL',
78+
'false').lower() == 'true'
79+
80+
def forward_hpu(
81+
self,
82+
input_ids: torch.Tensor,
83+
seq_lens: torch.Tensor,
84+
position_ids: torch.Tensor,
85+
token_type_ids: Optional[torch.Tensor] = None,
86+
) -> torch.Tensor:
87+
input_shape = input_ids.size()
88+
inputs_embeds = self.word_embeddings(input_ids)
89+
90+
# Replace position ids because in RoBERTa models
91+
# they have to start at padding_idx + 1 and ignore
92+
# existing padding tokens
93+
# Modified replace position ids
94+
# for HPU set position_ids and input_ids as [batch_size, bucket_size]
95+
# References:
96+
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
97+
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
98+
pos_list = []
99+
token_list = []
100+
if self.use_merged_prefill:
101+
offset = 0
102+
for seq_len in seq_lens:
103+
pos_list.append(position_ids[0][offset:offset + seq_len])
104+
token_list.append(input_ids[0][offset:offset + seq_len])
105+
offset += seq_len
106+
107+
offset = 0
108+
for positions, tokens, seq_len in zip(pos_list, token_list,
109+
seq_lens):
110+
# Verify assumption that incoming position are
111+
# always a sequence from 0 to N.
112+
expected_pos = torch.arange(positions.size()[0],
113+
dtype=torch.long,
114+
device=inputs_embeds.device)
115+
assert torch.equal(positions, expected_pos)
116+
position_ids[0][offset:offset +
117+
seq_len] = create_position_ids_from_input_ids(
118+
tokens, self.padding_idx)
119+
offset += seq_len
120+
else:
121+
for offset in range(position_ids.size()[0]):
122+
pos_list.append(position_ids[offset])
123+
token_list.append(input_ids[offset])
124+
125+
for index, (positions, tokens, seq_len) in enumerate(
126+
zip(pos_list, token_list, seq_lens)):
127+
# Verify assumption that incoming position are
128+
# always a sequence from 0 to N.
129+
expected_pos = torch.arange(positions.size()[0],
130+
dtype=torch.long,
131+
device=inputs_embeds.device)
132+
valid_input_mask = expected_pos < seq_len
133+
expected_pos = expected_pos * valid_input_mask
134+
assert torch.equal(positions, expected_pos)
135+
position_ids[index] = create_position_ids_from_input_ids_hpu(
136+
tokens, self.padding_idx, seq_len)
137+
138+
# Position embeddings.
139+
position_embeddings = self.position_embeddings(position_ids)
140+
if token_type_ids is None:
141+
token_type_ids = torch.zeros(input_shape,
142+
dtype=torch.long,
143+
device=inputs_embeds.device)
144+
145+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
146+
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
147+
embeddings = self.LayerNorm(embeddings)
148+
return embeddings
149+
150+
def forward_native(
75151
self,
76152
input_ids: torch.Tensor,
77153
seq_lens: torch.Tensor,
@@ -119,6 +195,46 @@ def forward(
119195
embeddings = self.LayerNorm(embeddings)
120196
return embeddings
121197

198+
def forward_cuda(
199+
self,
200+
input_ids: torch.Tensor,
201+
seq_lens: torch.Tensor,
202+
position_ids: torch.Tensor,
203+
token_type_ids: Optional[torch.Tensor] = None,
204+
) -> torch.Tensor:
205+
return self.forward_native(input_ids, seq_lens, position_ids,
206+
token_type_ids)
207+
208+
209+
# Adapted from transformers
210+
def create_position_ids_from_input_ids_hpu(input_ids,
211+
padding_idx,
212+
seq_len,
213+
past_key_values_length=0):
214+
"""
215+
Replace non-padding symbols with their position numbers.
216+
Position numbers begin at padding_idx+1. Padding symbols
217+
are ignored. This is modified from fairseq's `utils.make_positions`.
218+
219+
Args:
220+
x: torch.Tensor x:
221+
222+
Returns: torch.Tensor
223+
"""
224+
# The series of casts and type-conversions here are carefully
225+
# balanced to both work with ONNX export and XLA.
226+
valid_input_mask = torch.arange(input_ids.size()[0],
227+
dtype=torch.int,
228+
device=input_ids.device)
229+
valid_input_mask = valid_input_mask < seq_len
230+
231+
mask = input_ids.ne(padding_idx).int()
232+
233+
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
234+
past_key_values_length) * mask
235+
236+
return (incremental_indices.long() + padding_idx) * valid_input_mask
237+
122238

123239
# Adapted from transformers
124240
def create_position_ids_from_input_ids(input_ids,

0 commit comments

Comments
 (0)