Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement custom kernel for LLaMA rotary embedding #14

Merged
merged 9 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

from cacheflow import attention_ops
from cacheflow import cache_ops
from cacheflow import pos_encoding_ops
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):
class GPTCacheFlowAttention(nn.Module):

def __init__(self, scale: float) -> None:
super(OPTCacheFlowAttention, self).__init__()
super().__init__()
self.scale = float(scale)

self.flash_attn = FlashAttention(softmax_scale=self.scale)
Expand Down Expand Up @@ -136,3 +137,71 @@ def forward(
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, num_heads * head_size)


class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""

def __init__(self, scale: float) -> None:
super().__init__(scale)


class LlamaCacheFlowAttention(GPTCacheFlowAttention):
"""Llama uses GPT-NeoX style rotary embedding."""

def __init__(
self,
scale: float,
head_size: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(scale)

# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)

# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, head_size]
self.register_buffer('cos_sin_cache', cache, persistent=False)

def forward(
self,
positions: torch.LongTensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions,
query,
key,
self.cos_sin_cache,
)
return super().forward(
out_query,
out_key,
value,
key_cache,
value_cache,
input_metadata,
cache_event,
)
64 changes: 5 additions & 59 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from transformers import LlamaConfig
from transformers import PreTrainedModel

from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -41,48 +39,8 @@ def forward(self, hidden_states):
return self.weight * hidden_states


class LlamaRotaryEmbedding(torch.nn.Module):

def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.max_position_embeddings = max_position_embeddings

inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
self.register_buffer("inv_freq", inv_freq)

# Create cos and sin embeddings.
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=self.inv_freq.dtype)
sin = emb.sin().to(dtype=self.inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)

def forward(
self,
positions: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
return cos, sin


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
# TODO: Optimize.
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class LlamaMLP(nn.Module):

def __init__(
self,
hidden_size: int,
Expand Down Expand Up @@ -156,9 +114,7 @@ def __init__(
input_is_parallel=True,
perform_initialization=False,
)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
# FIXME(woosuk): Rename this.
self.attn = OPTCacheFlowAttention(scale=self.scaling)
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)

def forward(
self,
Expand All @@ -171,19 +127,9 @@ def forward(
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)

# Apply rotrary embedding.
# TODO: Optimize.
q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
cos, sin = self.rotary_emb(positions)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)

key_cache, value_cache = kv_cache
k_cache, v_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output

Expand Down
3 changes: 1 addition & 2 deletions cacheflow/models/memory_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def __init__(
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size
# FIXME
self.max_position = 2048
self.max_position = 8192

def _get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
Expand Down
3 changes: 1 addition & 2 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.scaling = self.head_dim ** -0.5

# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
Expand All @@ -66,7 +66,6 @@ def __init__(
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)

self.attn = OPTCacheFlowAttention(scale=self.scaling)

def forward(
Expand Down
2 changes: 1 addition & 1 deletion cacheflow/models/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class Sampler(nn.Module):

def __init__(self) -> None:
super(Sampler, self).__init__()
super().__init__()

def forward(
self,
Expand Down
6 changes: 3 additions & 3 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ void reshape_and_cache(
torch::Tensor& value_cache,
torch::Tensor& slot_mapping) {
int num_tokens = key.size(0);
int head_num = key.size(1);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);

dim3 grid(num_tokens);
dim3 block(std::min(head_num * head_size, 512));
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
key.scalar_type(),
Expand All @@ -140,7 +140,7 @@ void reshape_and_cache(
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
head_num,
num_heads,
head_size,
block_size,
x);
Expand Down
16 changes: 16 additions & 0 deletions csrc/pos_encoding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <torch/extension.h>

void rotary_embedding_neox(
torch::Tensor& out_query,
torch::Tensor& out_key,
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
torch::Tensor& cos_sin_cache);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
}
83 changes: 83 additions & 0 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

namespace cacheflow {

template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size]
const int64_t* __restrict__ positions, // [num_tokens]
const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
const int num_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;

const int embed_dim = head_size / 2;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int idx = token_idx * n + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int token_head = token_idx * n + head_idx * head_size;

const bool is_first_half = head_offset < embed_dim;
const int rot_offset = head_offset % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;

const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);

const scalar_t q_x = __ldg(query + token_head + x_index);
const scalar_t q_y = __ldg(query + token_head + y_index);
const scalar_t q_cos = is_first_half ? q_x : q_y;
const scalar_t q_sin = is_first_half ? -q_y : q_x;
out_query[idx] = q_cos * cos + q_sin * sin;

const scalar_t k_x = __ldg(key + token_head + x_index);
const scalar_t k_y = __ldg(key + token_head + y_index);
const scalar_t k_cos = is_first_half ? k_x : k_y;
const scalar_t k_sin = is_first_half ? -k_y : k_x;
out_key[idx] = k_cos * cos + k_sin * sin;
}
}

} // namespace cacheflow

void rotary_embedding_neox(
torch::Tensor& out_query, // [num_tokens, num_heads * head_size]
torch::Tensor& out_key, // [num_tokens, num_heads * head_size]
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size]
torch::Tensor& cos_sin_cache) // [max_position, head_size]
{
int num_tokens = query.size(0);
int head_size = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
query.scalar_type(),
"rotary_embedding_neox",
[&] {
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
out_query.data_ptr<scalar_t>(),
out_key.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
num_heads,
head_size);
});
}
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
)
ext_modules.append(attention_extension)

# Positional encodings.
positional_encoding_extension = cpp_extension.CUDAExtension(
name='cacheflow.pos_encoding_ops',
sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(positional_encoding_extension)

setuptools.setup(
name='cacheflow',
ext_modules=ext_modules,
Expand Down
Loading