From 512c7bfaca6cf847418bd94f98c22ba8ae25ea1b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 07:01:13 +0000 Subject: [PATCH 1/9] Minor --- cacheflow/models/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 6c7dcbedd3b26..87eeccb61634b 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -12,7 +12,7 @@ class Sampler(nn.Module): def __init__(self) -> None: - super(Sampler, self).__init__() + super().__init__() def forward( self, From 56674f4b1d35cad1161f9da0d4cbb349cddf85c0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 07:29:18 +0000 Subject: [PATCH 2/9] Add test code for rotary embedding --- tests/kernels/pos_encoding.py | 61 +++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/kernels/pos_encoding.py diff --git a/tests/kernels/pos_encoding.py b/tests/kernels/pos_encoding.py new file mode 100644 index 0000000000000..d60115f6c237d --- /dev/null +++ b/tests/kernels/pos_encoding.py @@ -0,0 +1,61 @@ +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RefRotaryEmbedding(nn.Module): + """Reference implementation of the GPT-NeoX style rotary embedding.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + ) -> None: + super().__init__() + self.max_position_embeddings = max_position_embeddings + + # Create cos and sin embeddings. + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype=inv_freq.dtype) + sin = emb.sin().to(dtype=inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, + positions: torch.LongTensor, + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + query = query.transpose(0, 1) + key = key.transpose(0, 1) + query, key = apply_rotary_pos_emb(query, key, cos, sin) + query = query.transpose(0, 1).contiguous() + key = key.transpose(0, 1).contiguous() + # Output query/key shape: [num_tokens, num_tokens, head_size] + return query, key From 3533de0fa1e71332d0fa1ae50255ec537260678a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 07:29:31 +0000 Subject: [PATCH 3/9] Minor --- cacheflow/models/opt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 3a7e6a1103855..a74c6a97100b3 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -51,7 +51,7 @@ def __init__( assert num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = embed_dim // total_num_heads - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 # TODO(woosuk): Fuse the three linear layers into one QKV linear layer. self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, @@ -66,7 +66,6 @@ def __init__( self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, perform_initialization=False) - self.attn = OPTCacheFlowAttention(scale=self.scaling) def forward( From 8e0e6a43a046f85201be4fbb6fb1e53461066e1f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 08:53:07 +0000 Subject: [PATCH 4/9] Minor --- csrc/cache_kernels.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8d69e4dc6459b..d7a0faa814108 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -122,13 +122,13 @@ void reshape_and_cache( torch::Tensor& value_cache, torch::Tensor& slot_mapping) { int num_tokens = key.size(0); - int head_num = key.size(1); + int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(3); int x = key_cache.size(4); dim3 grid(num_tokens); - dim3 block(std::min(head_num * head_size, 512)); + dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( key.scalar_type(), @@ -140,7 +140,7 @@ void reshape_and_cache( key_cache.data_ptr(), value_cache.data_ptr(), slot_mapping.data_ptr(), - head_num, + num_heads, head_size, block_size, x); From 3b6652add29d9928da36e087f0c55d4e42a6b090 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 10:21:34 +0000 Subject: [PATCH 5/9] Add rotary embedding kernel --- csrc/pos_encoding.cpp | 16 +++++++ csrc/pos_encoding_kernels.cu | 83 ++++++++++++++++++++++++++++++++++++ setup.py | 8 ++++ 3 files changed, 107 insertions(+) create mode 100644 csrc/pos_encoding.cpp create mode 100644 csrc/pos_encoding_kernels.cu diff --git a/csrc/pos_encoding.cpp b/csrc/pos_encoding.cpp new file mode 100644 index 0000000000000..a10bec85a98a7 --- /dev/null +++ b/csrc/pos_encoding.cpp @@ -0,0 +1,16 @@ +#include + +void rotary_embedding_neox( + torch::Tensor& out_query, + torch::Tensor& out_key, + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& cos_sin_cache); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "rotary_embedding_neox", + &rotary_embedding_neox, + "Apply GPT-NeoX style rotary embedding to query and key"); +} diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu new file mode 100644 index 0000000000000..50cf209fb200d --- /dev/null +++ b/csrc/pos_encoding_kernels.cu @@ -0,0 +1,83 @@ +#include +#include + +namespace cacheflow { + +template +__global__ void rotary_embedding_neox_kernel( + scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size] + const int64_t* __restrict__ positions, // [num_tokens] + const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2] + const int num_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * head_size; + + const int embed_dim = head_size / 2; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int idx = token_idx * n + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int token_head = token_idx * n + head_idx * head_size; + + const bool is_first_half = head_offset < embed_dim; + const int rot_offset = head_offset % embed_dim; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const scalar_t cos = __ldg(cache_ptr + x_index); + const scalar_t sin = __ldg(cache_ptr + y_index); + + const scalar_t q_x = __ldg(query + token_head + x_index); + const scalar_t q_y = __ldg(query + token_head + y_index); + const scalar_t q_cos = is_first_half ? q_x : q_y; + const scalar_t q_sin = is_first_half ? -q_y : q_x; + out_query[idx] = q_cos * cos + q_sin * sin; + + const scalar_t k_x = __ldg(key + token_head + x_index); + const scalar_t k_y = __ldg(key + token_head + y_index); + const scalar_t k_cos = is_first_half ? k_x : k_y; + const scalar_t k_sin = is_first_half ? -k_y : k_x; + out_key[idx] = k_cos * cos + k_sin * sin; + } +} + +} // namespace cacheflow + +void rotary_embedding_neox( + torch::Tensor& out_query, // [num_tokens, num_heads * head_size] + torch::Tensor& out_key, // [num_tokens, num_heads * head_size] + torch::Tensor& positions, // [num_tokens] + torch::Tensor& query, // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [num_tokens, num_heads * head_size] + torch::Tensor& cos_sin_cache) // [max_position, head_size] +{ + int num_tokens = query.size(0); + int head_size = cos_sin_cache.size(1); + int num_heads = query.size(1) / head_size; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + query.scalar_type(), + "rotary_embedding_neox", + [&] { + cacheflow::rotary_embedding_neox_kernel<<>>( + out_query.data_ptr(), + out_key.data_ptr(), + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + num_heads, + head_size); + }); +} diff --git a/setup.py b/setup.py index 428088a8682ad..9889918fb1837 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,14 @@ ) ext_modules.append(attention_extension) +# Positional encodings. +positional_encoding_extension = cpp_extension.CUDAExtension( + name='cacheflow.pos_encoding_ops', + sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'], + extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS}, +) +ext_modules.append(positional_encoding_extension) + setuptools.setup( name='cacheflow', ext_modules=ext_modules, From b29eb16100a583436001de57fa82f543dcb2b9fe Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 10:23:29 +0000 Subject: [PATCH 6/9] Add test code for rotary embedding kernel --- tests/kernels/pos_encoding.py | 72 ++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/tests/kernels/pos_encoding.py b/tests/kernels/pos_encoding.py index d60115f6c237d..87153d9ea098d 100644 --- a/tests/kernels/pos_encoding.py +++ b/tests/kernels/pos_encoding.py @@ -4,6 +4,8 @@ import torch.nn as nn import torch.nn.functional as F +from cacheflow import pos_encoding_ops + def rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] @@ -22,7 +24,7 @@ def apply_rotary_pos_emb( return q_embed, k_embed -class RefRotaryEmbedding(nn.Module): +class RefRotaryEmbeddingNeox(nn.Module): """Reference implementation of the GPT-NeoX style rotary embedding.""" def __init__( @@ -46,7 +48,7 @@ def __init__( def forward( self, - positions: torch.LongTensor, + positions: torch.LongTensor, # [num_tokens] query: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size] ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -59,3 +61,69 @@ def forward( key = key.transpose(0, 1).contiguous() # Output query/key shape: [num_tokens, num_tokens, head_size] return query, key + + +@torch.inference_mode() +def test_rotary_embedding_neox( + num_tokens: int, + num_heads: int, + head_size: int, + max_position: int, + dtype: torch.dtype, + base: int = 10000, +) -> None: + positions = torch.randint(0, max_position, (num_tokens,), device='cuda') + query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') + key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') + + # Create the rotary embedding. + inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size)) + t = torch.arange(max_position).float() + freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + + # Run the kernel. + out_query = torch.empty_like(query) + out_key = torch.empty_like(key) + pos_encoding_ops.rotary_embedding_neox( + out_query, + out_key, + positions, + query, + key, + cos_sin_cache, + ) + + # Run the reference implementation. + ref_rotary_embedding = RefRotaryEmbeddingNeox( + dim=head_size, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device='cuda') + ref_query, ref_key = ref_rotary_embedding( + positions, + query.view(num_tokens, num_heads, head_size), + key.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) + + # Compare the results. + assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) + assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) + + +if __name__ == '__main__': + for dtype in [torch.half, torch.float]: + for head_size in [64, 128, 256]: + print(f'Running tests for head_size={head_size} and dtype={dtype}') + test_rotary_embedding_neox( + num_tokens=2145, + num_heads=5, + head_size=head_size, + max_position=8192, + dtype=dtype, + ) From 7392665d166e11515cf89756ae120b9af7472d95 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 10:25:18 +0000 Subject: [PATCH 7/9] Implement Llama attention layer --- cacheflow/models/attention.py | 73 ++++++++++++++++++++++++++++- cacheflow/models/llama.py | 64 ++----------------------- cacheflow/models/memory_analyzer.py | 3 +- 3 files changed, 77 insertions(+), 63 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 7f24670b7eaa3..e8c49af6c9b39 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -6,13 +6,14 @@ from cacheflow import attention_ops from cacheflow import cache_ops +from cacheflow import pos_encoding_ops from cacheflow.models import InputMetadata -class OPTCacheFlowAttention(nn.Module): +class GPTCacheFlowAttention(nn.Module): def __init__(self, scale: float) -> None: - super(OPTCacheFlowAttention, self).__init__() + super().__init__() self.scale = float(scale) self.flash_attn = FlashAttention(softmax_scale=self.scale) @@ -136,3 +137,71 @@ def forward( # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. return output.view(-1, num_heads * head_size) + + +class OPTCacheFlowAttention(GPTCacheFlowAttention): + """OPT uses the same attention mechanism as GPT.""" + + def __init__(self, scale: float) -> None: + super().__init__(scale) + + +class LlamaCacheFlowAttention(GPTCacheFlowAttention): + """Llama uses GPT-NeoX style rotary embedding.""" + + def __init__( + self, + scale: float, + head_size: int, + max_position: int = 8192, + base: int = 10000, + ) -> None: + super().__init__(scale) + + # Create the rotary embedding. + inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size)) + t = torch.arange(max_position).float() + freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + + # FIXME(woosuk): This assumes that we configure the default dtype when + # initializing the model. Make it more robust. + torch_dtype = torch.get_default_dtype() + cache = cache.to(torch_dtype) + # Embedding size: [max_position, head_size] + self.register_buffer('cos_sin_cache', cache, persistent=False) + + def forward( + self, + positions: torch.LongTensor, # [num_tokens] + query: torch.Tensor, # [num_tokens, num_heads * head_size] + key: torch.Tensor, # [num_tokens, num_heads * head_size] + value: torch.Tensor, # [num_tokens, num_heads * head_size] + key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: # [num_tokens, num_heads * head_size] + # Apply the rotary embedding to the query and key before passing them + # to the attention op. + out_query = torch.empty_like(query) + out_key = torch.empty_like(key) + pos_encoding_ops.rotary_embedding_neox( + out_query, + out_key, + positions, + query, + key, + self.cos_sin_cache, + ) + return super().forward( + out_query, + out_key, + value, + key_cache, + value_cache, + input_metadata, + cache_event, + ) diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 4ddbc698eb789..236feab4b4cb4 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -8,12 +8,10 @@ import numpy as np import torch from torch import nn -import torch.nn.functional as F from transformers import LlamaConfig -from transformers import PreTrainedModel from cacheflow.models import InputMetadata -from cacheflow.models.attention import OPTCacheFlowAttention +from cacheflow.models.attention import LlamaCacheFlowAttention from cacheflow.models.sample import Sampler from cacheflow.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -41,48 +39,8 @@ def forward(self, hidden_states): return self.weight * hidden_states -class LlamaRotaryEmbedding(torch.nn.Module): - - def __init__(self, dim, max_position_embeddings=2048, base=10000): - super().__init__() - self.max_position_embeddings = max_position_embeddings - - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Create cos and sin embeddings. - t = torch.arange(max_position_embeddings).float() - freqs = torch.einsum("i,j->ij", t, self.inv_freq.float()) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=self.inv_freq.dtype) - sin = emb.sin().to(dtype=self.inv_freq.dtype) - self.register_buffer("cos_cached", cos, persistent=False) - self.register_buffer("sin_cached", sin, persistent=False) - - def forward( - self, - positions: torch.LongTensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - cos = F.embedding(positions, self.cos_cached) - sin = F.embedding(positions, self.sin_cached) - return cos, sin - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin): - # TODO: Optimize. - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class LlamaMLP(nn.Module): + def __init__( self, hidden_size: int, @@ -156,9 +114,7 @@ def __init__( input_is_parallel=True, perform_initialization=False, ) - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) - # FIXME(woosuk): Rename this. - self.attn = OPTCacheFlowAttention(scale=self.scaling) + self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim) def forward( self, @@ -171,19 +127,9 @@ def forward( q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - - # Apply rotrary embedding. - # TODO: Optimize. - q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1) - k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1) - cos, sin = self.rotary_emb(positions) - q, k = apply_rotary_pos_emb(q, k, cos, sin) - q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim) - k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim) - - key_cache, value_cache = kv_cache + k_cache, v_cache = kv_cache attn_output = self.attn( - q, k, v, key_cache, value_cache, input_metadata, cache_event) + positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) output, _ = self.o_proj(attn_output) return output diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index d3dc8f44bbf98..3a539cca97633 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -165,8 +165,7 @@ def __init__( self.head_size = config.hidden_size // self.num_heads self.ffn_size = config.intermediate_size self.vocab_size = config.vocab_size - # FIXME - self.max_position = 2048 + self.max_position = 8192 def _get_param_size(self) -> int: word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size From eef11baf2ab9f1defa91383f9d307a722a54b7f8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 10:30:11 +0000 Subject: [PATCH 8/9] Minor fix in comment --- cacheflow/models/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index e8c49af6c9b39..8b132e4423798 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -158,7 +158,7 @@ def __init__( ) -> None: super().__init__(scale) - # Create the rotary embedding. + # Create the cos and sin cache. inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size)) t = torch.arange(max_position).float() freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) @@ -184,7 +184,7 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # [num_tokens, num_heads * head_size] - # Apply the rotary embedding to the query and key before passing them + # Apply rotary embedding to the query and key before passing them # to the attention op. out_query = torch.empty_like(query) out_key = torch.empty_like(key) From 1a261885e9ab5ac7f228ae80a5d0c010870ac347 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 17:39:15 +0000 Subject: [PATCH 9/9] Test more head sizes --- tests/kernels/pos_encoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/pos_encoding.py b/tests/kernels/pos_encoding.py index 87153d9ea098d..2dbce545e3455 100644 --- a/tests/kernels/pos_encoding.py +++ b/tests/kernels/pos_encoding.py @@ -118,7 +118,7 @@ def test_rotary_embedding_neox( if __name__ == '__main__': for dtype in [torch.half, torch.float]: - for head_size in [64, 128, 256]: + for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Running tests for head_size={head_size} and dtype={dtype}') test_rotary_embedding_neox( num_tokens=2145,