11# SPDX-License-Identifier: Apache-2.0
22
33import itertools
4+ import os
45from typing import Iterable , List , Optional , Tuple
56
67import torch
910
1011from vllm .attention import AttentionMetadata
1112from vllm .config import VllmConfig
13+ from vllm .model_executor .custom_op import CustomOp
1214from vllm .model_executor .layers .pooler import CrossEncodingPooler
1315from vllm .model_executor .layers .vocab_parallel_embedding import (
1416 VocabParallelEmbedding )
@@ -47,7 +49,8 @@ def encoder_decoder_weights():
4749 if not n .startswith ("roberta." ))
4850
4951
50- class RobertaEmbedding (nn .Module ):
52+ @CustomOp .register ("roberta_embedding" )
53+ class RobertaEmbedding (CustomOp ):
5154
5255 def __init__ (self , config : RobertaConfig ):
5356 super ().__init__ ()
@@ -71,7 +74,80 @@ def __init__(self, config: RobertaConfig):
7174 raise ValueError ("Only 'absolute' position_embedding_type" +
7275 " is supported" )
7376
74- def forward (
77+ self .use_merged_prefill = os .environ .get ('VLLM_MERGED_PREFILL' ,
78+ 'false' ).lower () == 'true'
79+
80+ def forward_hpu (
81+ self ,
82+ input_ids : torch .Tensor ,
83+ seq_lens : torch .Tensor ,
84+ position_ids : torch .Tensor ,
85+ token_type_ids : Optional [torch .Tensor ] = None ,
86+ ) -> torch .Tensor :
87+ input_shape = input_ids .size ()
88+ inputs_embeds = self .word_embeddings (input_ids )
89+
90+ # Replace position ids because in RoBERTa models
91+ # they have to start at padding_idx + 1 and ignore
92+ # existing padding tokens
93+ # Modified replace position ids
94+ # for HPU set position_ids and input_ids as [batch_size, bucket_size]
95+ # References:
96+ # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
97+ # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
98+ pos_list = []
99+ token_list = []
100+ if self .use_merged_prefill :
101+ offset = 0
102+ for seq_len in seq_lens :
103+ pos_list .append (position_ids [0 ][offset :offset + seq_len ])
104+ token_list .append (input_ids [0 ][offset :offset + seq_len ])
105+ offset += seq_len
106+
107+ offset = 0
108+ for positions , tokens , seq_len in zip (pos_list , token_list ,
109+ seq_lens ):
110+ # Verify assumption that incoming position are
111+ # always a sequence from 0 to N.
112+ expected_pos = torch .arange (positions .size ()[0 ],
113+ dtype = torch .long ,
114+ device = inputs_embeds .device )
115+ assert torch .equal (positions , expected_pos )
116+ position_ids [0 ][offset :offset +
117+ seq_len ] = create_position_ids_from_input_ids (
118+ tokens , self .padding_idx )
119+ offset += seq_len
120+ else :
121+ for offset in range (position_ids .size ()[0 ]):
122+ pos_list .append (position_ids [offset ])
123+ token_list .append (input_ids [offset ])
124+
125+ for index , (positions , tokens , seq_len ) in enumerate (
126+ zip (pos_list , token_list , seq_lens )):
127+ # Verify assumption that incoming position are
128+ # always a sequence from 0 to N.
129+ expected_pos = torch .arange (positions .size ()[0 ],
130+ dtype = torch .long ,
131+ device = inputs_embeds .device )
132+ valid_input_mask = expected_pos < seq_len
133+ expected_pos = expected_pos * valid_input_mask
134+ assert torch .equal (positions , expected_pos )
135+ position_ids [index ] = create_position_ids_from_input_ids_hpu (
136+ tokens , self .padding_idx , seq_len )
137+
138+ # Position embeddings.
139+ position_embeddings = self .position_embeddings (position_ids )
140+ if token_type_ids is None :
141+ token_type_ids = torch .zeros (input_shape ,
142+ dtype = torch .long ,
143+ device = inputs_embeds .device )
144+
145+ token_type_embeddings = self .token_type_embeddings (token_type_ids )
146+ embeddings = inputs_embeds + token_type_embeddings + position_embeddings
147+ embeddings = self .LayerNorm (embeddings )
148+ return embeddings
149+
150+ def forward_native (
75151 self ,
76152 input_ids : torch .Tensor ,
77153 seq_lens : torch .Tensor ,
@@ -119,6 +195,46 @@ def forward(
119195 embeddings = self .LayerNorm (embeddings )
120196 return embeddings
121197
198+ def forward_cuda (
199+ self ,
200+ input_ids : torch .Tensor ,
201+ seq_lens : torch .Tensor ,
202+ position_ids : torch .Tensor ,
203+ token_type_ids : Optional [torch .Tensor ] = None ,
204+ ) -> torch .Tensor :
205+ return self .forward_native (input_ids , seq_lens , position_ids ,
206+ token_type_ids )
207+
208+
209+ # Adapted from transformers
210+ def create_position_ids_from_input_ids_hpu (input_ids ,
211+ padding_idx ,
212+ seq_len ,
213+ past_key_values_length = 0 ):
214+ """
215+ Replace non-padding symbols with their position numbers.
216+ Position numbers begin at padding_idx+1. Padding symbols
217+ are ignored. This is modified from fairseq's `utils.make_positions`.
218+
219+ Args:
220+ x: torch.Tensor x:
221+
222+ Returns: torch.Tensor
223+ """
224+ # The series of casts and type-conversions here are carefully
225+ # balanced to both work with ONNX export and XLA.
226+ valid_input_mask = torch .arange (input_ids .size ()[0 ],
227+ dtype = torch .int ,
228+ device = input_ids .device )
229+ valid_input_mask = valid_input_mask < seq_len
230+
231+ mask = input_ids .ne (padding_idx ).int ()
232+
233+ incremental_indices = (torch .cumsum (mask , dim = 0 ).type_as (mask ) +
234+ past_key_values_length ) * mask
235+
236+ return (incremental_indices .long () + padding_idx ) * valid_input_mask
237+
122238
123239# Adapted from transformers
124240def create_position_ids_from_input_ids (input_ids ,
0 commit comments