@@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
9999 token_idx, query_stride, key_stride, head_stride);
100100}
101101
102- template <typename scalar_t , bool IS_NEOX>
103- __global__ void batched_rotary_embedding_kernel (
104- const int64_t * __restrict__ positions, // [batch_size, seq_len] or
105- // [num_tokens]
106- scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads,
107- // head_size] or [num_tokens, num_heads,
108- // head_size]
109- scalar_t * __restrict__ key, // nullptr or
110- // [batch_size, seq_len, num_kv_heads,
111- // head_size] or [num_tokens, num_kv_heads,
112- // head_size]
113- const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
114- // 2]
115- const int64_t * __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
116- const int rot_dim, const int64_t query_stride, const int64_t key_stride,
117- const int64_t head_stride, const int num_heads, const int num_kv_heads,
118- const int head_size) {
119- // Each thread block is responsible for one token.
120- const int token_idx = blockIdx .x ;
121- int64_t pos = positions[token_idx];
122- int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
123- const scalar_t * cache_ptr =
124- cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
125-
126- apply_rotary_embedding<scalar_t , IS_NEOX>(
127- query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
128- token_idx, query_stride, key_stride, head_stride);
129- }
130-
131102} // namespace vllm
132103
133104void rotary_embedding (
@@ -211,96 +182,3 @@ void rotary_embedding(
211182 }
212183 });
213184}
214-
215- /*
216- Batched version of rotary embedding, pack multiple LoRAs together
217- and process in batched manner.
218- */
219- void batched_rotary_embedding (
220- torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
221- torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
222- // [num_tokens, num_heads * head_size] or
223- // [batch_size, seq_len, num_heads, head_size] or
224- // [num_tokens, num_heads, head_size]
225- std::optional<torch::Tensor>
226- key, // null or
227- // [batch_size, seq_len, num_kv_heads * head_size] or
228- // [num_tokens, num_kv_heads * head_size] or
229- // [batch_size, seq_len, num_heads, head_size] or
230- // [num_tokens, num_heads, head_size]
231- int64_t head_size,
232- torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
233- bool is_neox, int64_t rot_dim,
234- torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
235- ) {
236- // num_tokens = batch_size * seq_len
237- int64_t num_tokens = cos_sin_cache_offsets.size (0 );
238- TORCH_CHECK (
239- positions.size (0 ) == num_tokens || positions.numel () == num_tokens,
240- " positions must have the same num_tokens or batch_size as "
241- " cos_sin_cache_offsets" );
242-
243- int positions_ndim = positions.dim ();
244- // Make sure num_tokens dim is consistent across positions, query, and key
245- TORCH_CHECK (
246- positions_ndim == 1 || positions_ndim == 2 ,
247- " positions must have shape [num_tokens] or [batch_size, seq_len]" );
248- if (positions_ndim == 1 ) {
249- TORCH_CHECK (query.size (0 ) == positions.size (0 ) &&
250- (!key.has_value () || key->size (0 ) == positions.size (0 )),
251- " query, key and positions must have the same number of tokens" );
252- }
253- if (positions_ndim == 2 ) {
254- TORCH_CHECK (
255- query.size (0 ) == positions.size (0 ) &&
256- (!key.has_value () || key->size (0 ) == positions.size (0 )) &&
257- query.size (1 ) == positions.size (1 ) &&
258- (!key.has_value () || key->size (1 ) == positions.size (1 )),
259- " query, key and positions must have the same batch_size and seq_len" );
260- }
261-
262- // Make sure head_size is valid for query and key
263- int query_hidden_size = query.numel () / num_tokens;
264- int key_hidden_size = key.has_value () ? key->numel () / num_tokens : 0 ;
265- TORCH_CHECK (query_hidden_size % head_size == 0 );
266- TORCH_CHECK (key_hidden_size % head_size == 0 );
267-
268- // Make sure query and key have concistent number of heads
269- int num_heads = query_hidden_size / head_size;
270- int num_kv_heads = key.has_value () ? key_hidden_size / head_size : num_heads;
271- TORCH_CHECK (num_heads % num_kv_heads == 0 );
272-
273- int seq_dim_idx = positions_ndim - 1 ;
274- int64_t query_stride = query.stride (seq_dim_idx);
275- int64_t key_stride = key.has_value () ? key->stride (seq_dim_idx) : 0 ;
276- // Determine head stride: for [*, heads, head_size] use stride of last dim;
277- // for flat [*, heads*head_size], heads blocks are contiguous of size
278- // head_size
279- int query_ndim = query.dim ();
280- int64_t head_stride =
281- (query_ndim == positions_ndim + 2 ) ? query.stride (-2 ) : head_size;
282-
283- dim3 grid (num_tokens);
284- dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
285- const at::cuda::OptionalCUDAGuard device_guard (device_of (query));
286- const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
287- VLLM_DISPATCH_FLOATING_TYPES (query.scalar_type (), " rotary_embedding" , [&] {
288- if (is_neox) {
289- vllm::batched_rotary_embedding_kernel<scalar_t , true >
290- <<<grid, block, 0 , stream>>> (
291- positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
292- key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
293- cos_sin_cache.data_ptr <scalar_t >(),
294- cos_sin_cache_offsets.data_ptr <int64_t >(), rot_dim, query_stride,
295- key_stride, head_stride, num_heads, num_kv_heads, head_size);
296- } else {
297- vllm::batched_rotary_embedding_kernel<scalar_t , false >
298- <<<grid, block, 0 , stream>>> (
299- positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
300- key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
301- cos_sin_cache.data_ptr <scalar_t >(),
302- cos_sin_cache_offsets.data_ptr <int64_t >(), rot_dim, query_stride,
303- key_stride, head_stride, num_heads, num_kv_heads, head_size);
304- }
305- });
306- }
0 commit comments