@@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding(
3838 scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads,
3939 // head_size] or [num_tokens, num_heads,
4040 // head_size]
41- scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads,
41+ scalar_t * __restrict__ key, // nullptr or
42+ // [batch_size, seq_len, num_kv_heads,
4243 // head_size] or [num_tokens, num_kv_heads,
4344 // head_size]
4445 const scalar_t * cache_ptr, const int head_size, const int num_heads,
@@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding(
5758 query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
5859 }
5960
60- const int nk = num_kv_heads * embed_dim;
61- for (int i = threadIdx .x ; i < nk; i += blockDim .x ) {
62- const int head_idx = i / embed_dim;
63- const int64_t token_head = token_idx * key_stride + head_idx * head_size;
64- const int rot_offset = i % embed_dim;
65- apply_token_rotary_embedding<scalar_t , IS_NEOX>(
66- key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
61+ if (key != nullptr ) {
62+ const int nk = num_kv_heads * embed_dim;
63+ for (int i = threadIdx .x ; i < nk; i += blockDim .x ) {
64+ const int head_idx = i / embed_dim;
65+ const int64_t token_head = token_idx * key_stride + head_idx * head_size;
66+ const int rot_offset = i % embed_dim;
67+ apply_token_rotary_embedding<scalar_t , IS_NEOX>(
68+ key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
69+ }
6770 }
6871}
6972
@@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel(
7477 scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads,
7578 // head_size] or [num_tokens, num_heads,
7679 // head_size]
77- scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads,
80+ scalar_t * __restrict__ key, // nullptr or
81+ // [batch_size, seq_len, num_kv_heads,
7882 // head_size] or [num_tokens, num_kv_heads,
7983 // head_size]
8084 const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel(
98102 scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads,
99103 // head_size] or [num_tokens, num_heads,
100104 // head_size]
101- scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads,
105+ scalar_t * __restrict__ key, // nullptr or
106+ // [batch_size, seq_len, num_kv_heads,
102107 // head_size] or [num_tokens, num_kv_heads,
103108 // head_size]
104109 const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -127,51 +132,53 @@ void rotary_embedding(
127132 // [num_tokens, num_heads * head_size] or
128133 // [batch_size, seq_len, num_heads, head_size] or
129134 // [num_tokens, num_heads, head_size]
130- torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
131- // [num_tokens, num_kv_heads * head_size] or
132- // [batch_size, seq_len, num_heads, head_size] or
133- // [num_tokens, num_heads, head_size]
135+ std::optional<torch::Tensor> key,
136+ // null or
137+ // [batch_size, seq_len, num_kv_heads * head_size] or
138+ // [num_tokens, num_kv_heads * head_size] or
139+ // [batch_size, seq_len, num_heads, head_size] or
140+ // [num_tokens, num_heads, head_size]
134141 int64_t head_size,
135142 torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
136143 bool is_neox) {
137144 // num_tokens = batch_size * seq_len
138145 int64_t num_tokens = positions.numel ();
139146 int positions_ndim = positions.dim ();
140147
141- // Make sure num_tokens dim is consistent across positions, query, and key.
148+ // Make sure num_tokens dim is consistent across positions, query, and key
142149 TORCH_CHECK (
143150 positions_ndim == 1 || positions_ndim == 2 ,
144151 " positions must have shape [num_tokens] or [batch_size, seq_len]" );
145152 if (positions_ndim == 1 ) {
146- TORCH_CHECK (
147- query. size ( 0 ) == positions. size ( 0 ) && key. size (0 ) == positions.size (0 ),
148- " query, key and positions must have the same number of tokens" );
153+ TORCH_CHECK (query. size ( 0 ) == positions. size ( 0 ) &&
154+ (!key. has_value () || key-> size (0 ) == positions.size (0 ) ),
155+ " query, key and positions must have the same number of tokens" );
149156 }
150157 if (positions_ndim == 2 ) {
151158 TORCH_CHECK (
152159 query.size (0 ) == positions.size (0 ) &&
153- key.size (0 ) == positions.size (0 ) &&
160+ (! key.has_value () || key-> size (0 ) == positions.size (0 ) ) &&
154161 query.size (1 ) == positions.size (1 ) &&
155- key.size (1 ) == positions.size (1 ),
162+ (! key.has_value () || key-> size (1 ) == positions.size (1 ) ),
156163 " query, key and positions must have the same batch_size and seq_len" );
157164 }
158165
159166 // Make sure head_size is valid for query and key
160167 // hidden_size = num_heads * head_size
161168 int query_hidden_size = query.numel () / num_tokens;
162- int key_hidden_size = key.numel () / num_tokens;
169+ int key_hidden_size = key.has_value () ? key-> numel () / num_tokens : 0 ;
163170 TORCH_CHECK (query_hidden_size % head_size == 0 );
164171 TORCH_CHECK (key_hidden_size % head_size == 0 );
165172
166173 // Make sure query and key have consistent number of heads
167174 int num_heads = query_hidden_size / head_size;
168- int num_kv_heads = key_hidden_size / head_size;
175+ int num_kv_heads = key. has_value () ? key_hidden_size / head_size : num_heads ;
169176 TORCH_CHECK (num_heads % num_kv_heads == 0 );
170177
171178 int rot_dim = cos_sin_cache.size (1 );
172179 int seq_dim_idx = positions_ndim - 1 ;
173180 int64_t query_stride = query.stride (seq_dim_idx);
174- int64_t key_stride = key.stride (seq_dim_idx);
181+ int64_t key_stride = key.has_value () ? key-> stride (seq_dim_idx) : 0 ;
175182
176183 dim3 grid (num_tokens);
177184 dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
@@ -181,15 +188,16 @@ void rotary_embedding(
181188 if (is_neox) {
182189 vllm::rotary_embedding_kernel<scalar_t , true ><<<grid, block, 0 , stream>>> (
183190 positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
184- key.data_ptr <scalar_t >(), cos_sin_cache.data_ptr <scalar_t >(), rot_dim,
185- query_stride, key_stride, num_heads, num_kv_heads, head_size);
191+ key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
192+ cos_sin_cache.data_ptr <scalar_t >(), rot_dim, query_stride, key_stride,
193+ num_heads, num_kv_heads, head_size);
186194 } else {
187195 vllm::rotary_embedding_kernel<scalar_t , false >
188196 <<<grid, block, 0 , stream>>> (
189197 positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
190- key.data_ptr < scalar_t >(), cos_sin_cache. data_ptr <scalar_t >(),
191- rot_dim, query_stride, key_stride, num_heads, num_kv_heads ,
192- head_size);
198+ key.has_value () ? key-> data_ptr <scalar_t >() : nullptr ,
199+ cos_sin_cache. data_ptr < scalar_t >(), rot_dim, query_stride ,
200+ key_stride, num_heads, num_kv_heads, head_size);
193201 }
194202 });
195203}
@@ -204,10 +212,12 @@ void batched_rotary_embedding(
204212 // [num_tokens, num_heads * head_size] or
205213 // [batch_size, seq_len, num_heads, head_size] or
206214 // [num_tokens, num_heads, head_size]
207- torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
208- // [num_tokens, num_kv_heads * head_size] or
209- // [batch_size, seq_len, num_heads, head_size] or
210- // [num_tokens, num_heads, head_size]
215+ std::optional<torch::Tensor>
216+ key, // null or
217+ // [batch_size, seq_len, num_kv_heads * head_size] or
218+ // [num_tokens, num_kv_heads * head_size] or
219+ // [batch_size, seq_len, num_heads, head_size] or
220+ // [num_tokens, num_heads, head_size]
211221 int64_t head_size,
212222 torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
213223 bool is_neox, int64_t rot_dim,
@@ -221,38 +231,38 @@ void batched_rotary_embedding(
221231 " cos_sin_cache_offsets" );
222232
223233 int positions_ndim = positions.dim ();
224- // Make sure num_tokens dim is consistent across positions, query, and key.
234+ // Make sure num_tokens dim is consistent across positions, query, and key
225235 TORCH_CHECK (
226236 positions_ndim == 1 || positions_ndim == 2 ,
227237 " positions must have shape [num_tokens] or [batch_size, seq_len]" );
228238 if (positions_ndim == 1 ) {
229- TORCH_CHECK (
230- query. size ( 0 ) == positions. size ( 0 ) && key. size (0 ) == positions.size (0 ),
231- " query, key and positions must have the same number of tokens" );
239+ TORCH_CHECK (query. size ( 0 ) == positions. size ( 0 ) &&
240+ (!key. has_value () || key-> size (0 ) == positions.size (0 ) ),
241+ " query, key and positions must have the same number of tokens" );
232242 }
233243 if (positions_ndim == 2 ) {
234244 TORCH_CHECK (
235245 query.size (0 ) == positions.size (0 ) &&
236- key.size (0 ) == positions.size (0 ) &&
246+ (! key.has_value () || key-> size (0 ) == positions.size (0 ) ) &&
237247 query.size (1 ) == positions.size (1 ) &&
238- key.size (1 ) == positions.size (1 ),
248+ (! key.has_value () || key-> size (1 ) == positions.size (1 ) ),
239249 " query, key and positions must have the same batch_size and seq_len" );
240250 }
241251
242252 // Make sure head_size is valid for query and key
243253 int query_hidden_size = query.numel () / num_tokens;
244- int key_hidden_size = key.numel () / num_tokens;
254+ int key_hidden_size = key.has_value () ? key-> numel () / num_tokens : 0 ;
245255 TORCH_CHECK (query_hidden_size % head_size == 0 );
246256 TORCH_CHECK (key_hidden_size % head_size == 0 );
247257
248258 // Make sure query and key have concistent number of heads
249259 int num_heads = query_hidden_size / head_size;
250- int num_kv_heads = key_hidden_size / head_size;
260+ int num_kv_heads = key. has_value () ? key_hidden_size / head_size : num_heads ;
251261 TORCH_CHECK (num_heads % num_kv_heads == 0 );
252262
253263 int seq_dim_idx = positions_ndim - 1 ;
254264 int64_t query_stride = query.stride (seq_dim_idx);
255- int64_t key_stride = key.stride (seq_dim_idx);
265+ int64_t key_stride = key.has_value () ? key-> stride (seq_dim_idx) : 0 ;
256266
257267 dim3 grid (num_tokens);
258268 dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
@@ -263,14 +273,16 @@ void batched_rotary_embedding(
263273 vllm::batched_rotary_embedding_kernel<scalar_t , true >
264274 <<<grid, block, 0 , stream>>> (
265275 positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
266- key.data_ptr <scalar_t >(), cos_sin_cache.data_ptr <scalar_t >(),
276+ key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
277+ cos_sin_cache.data_ptr <scalar_t >(),
267278 cos_sin_cache_offsets.data_ptr <int64_t >(), rot_dim, query_stride,
268279 key_stride, num_heads, num_kv_heads, head_size);
269280 } else {
270281 vllm::batched_rotary_embedding_kernel<scalar_t , false >
271282 <<<grid, block, 0 , stream>>> (
272283 positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
273- key.data_ptr <scalar_t >(), cos_sin_cache.data_ptr <scalar_t >(),
284+ key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
285+ cos_sin_cache.data_ptr <scalar_t >(),
274286 cos_sin_cache_offsets.data_ptr <int64_t >(), rot_dim, query_stride,
275287 key_stride, num_heads, num_kv_heads, head_size);
276288 }
0 commit comments