Skip to content

Commit 897cb2a

Browse files
authored
Optimize data movement (#20)
1 parent 1f01a18 commit 897cb2a

File tree

17 files changed

+275
-135
lines changed

17 files changed

+275
-135
lines changed

cacheflow/models/activation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from cacheflow import activation_ops
5+
6+
7+
class SiluAndMul(nn.Module):
8+
9+
def __init__(self):
10+
super().__init__()
11+
12+
def forward(
13+
self,
14+
x: torch.Tensor, # (num_tokens, 2 * d)
15+
) -> torch.Tensor: # (num_tokens, d)
16+
num_tokens = x.shape[0]
17+
d = x.shape[1] // 2
18+
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
19+
activation_ops.silu_and_mul(out, x)
20+
return out

cacheflow/models/attention.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import List, Optional
1+
from typing import Optional
22

3-
from flash_attn.flash_attention import FlashAttention
3+
from flash_attn.flash_attn_interface import _flash_attn_forward
44
import torch
55
import torch.nn as nn
66

@@ -16,40 +16,38 @@ def __init__(self, scale: float) -> None:
1616
super().__init__()
1717
self.scale = float(scale)
1818

19-
self.flash_attn = FlashAttention(softmax_scale=self.scale)
20-
2119
def multi_query_kv_attention(
2220
self,
23-
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
24-
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
25-
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
26-
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
27-
prompt_lens: List[int],
21+
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
22+
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
23+
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
24+
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
25+
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
26+
max_prompt_len: int,
2827
) -> None:
2928
if query.dtype == torch.float:
3029
raise ValueError('The float data type is not supported by '
3130
'FlashAttention. Use the half data type instead.')
32-
head_size = query.shape[2]
31+
head_size = query.shape[-1]
3332
if head_size > 128:
3433
raise ValueError('FlashAttention does not support head_size > 128.')
3534

36-
device = query.device
37-
prefix_sum = [0]
38-
for prompt_len in prompt_lens:
39-
prefix_sum.append(prefix_sum[-1] + prompt_len)
40-
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
41-
max_prompt_len = max(prompt_lens)
42-
43-
# FIXME(woosuk): Unnecessary copy. Optimize this.
44-
qkv = torch.stack([query, key, value], dim=1)
45-
out = self.flash_attn(
46-
qkv,
47-
cu_seqlens=prefix_sum,
48-
max_s=max_prompt_len,
35+
# Directly call FlashAttention's internal function to avoid allocating
36+
# a new tensor for the output.
37+
_flash_attn_forward(
38+
query,
39+
key,
40+
value,
41+
output,
42+
cumulative_prompt_lens,
43+
cumulative_prompt_lens,
44+
max_prompt_len,
45+
max_prompt_len,
46+
dropout_p=0.0,
47+
softmax_scale=self.scale,
4948
causal=True,
50-
)[0]
51-
# FIXME(woosuk): Unnecessary copy. Optimize this.
52-
output.copy_(out, non_blocking=True)
49+
return_softmax=False,
50+
)
5351

5452
def single_query_cached_kv_attention(
5553
self,
@@ -90,21 +88,18 @@ def forward(
9088
input_metadata: InputMetadata,
9189
cache_event: Optional[torch.cuda.Event],
9290
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
93-
# Pre-allocate the output tensor.
94-
output = torch.empty_like(query)
95-
96-
# Prune out paddings if any.
97-
query = query[:input_metadata.num_valid_tokens]
98-
key = key[:input_metadata.num_valid_tokens]
99-
value = value[:input_metadata.num_valid_tokens]
91+
# NOTE: The query, key, and value tensors must be sliced from a qkv
92+
# tensor of shape [num_tokens, 3 * num_heads * head_size].
10093

101-
# Reshape the input tensors.
94+
# Reshape the query, key, and value tensors.
10295
num_heads = value_cache.shape[1]
10396
head_size = value_cache.shape[2]
10497
query = query.view(-1, num_heads, head_size)
10598
key = key.view(-1, num_heads, head_size)
10699
value = value.view(-1, num_heads, head_size)
107-
output = output.view(-1, num_heads, head_size)
100+
101+
# Pre-allocate the output tensor.
102+
output = torch.empty_like(query)
108103

109104
# Compute the attention op for prompts.
110105
num_prompt_tokens = input_metadata.num_prompt_tokens
@@ -114,22 +109,31 @@ def forward(
114109
query[:num_prompt_tokens],
115110
key[:num_prompt_tokens],
116111
value[:num_prompt_tokens],
117-
input_metadata.prompt_lens,
112+
input_metadata.cumulative_prompt_lens,
113+
input_metadata.max_prompt_len,
118114
)
119115

120116
# Wait until the cache op is done.
121117
if cache_event is not None:
122118
cache_event.wait()
123119

124120
# Reshape the keys and values and store them in the cache.
125-
cache_ops.reshape_and_cache(
126-
key, value, key_cache, value_cache, input_metadata.slot_mapping)
121+
num_valid_tokens = input_metadata.num_valid_tokens
122+
if num_valid_tokens > 0:
123+
# The stride is 3 because the key and value are sliced from qkv.
124+
cache_ops.reshape_and_cache(
125+
key[:num_valid_tokens],
126+
value[:num_valid_tokens],
127+
key_cache,
128+
value_cache,
129+
input_metadata.slot_mapping,
130+
)
127131

128132
if input_metadata.num_generation_tokens > 0:
129133
# Compute the attention op for generation tokens.
130134
self.single_query_cached_kv_attention(
131-
output[num_prompt_tokens:],
132-
query[num_prompt_tokens:],
135+
output[num_prompt_tokens:num_valid_tokens],
136+
query[num_prompt_tokens:num_valid_tokens],
133137
key_cache,
134138
value_cache,
135139
input_metadata)
@@ -186,19 +190,15 @@ def forward(
186190
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
187191
# Apply rotary embedding to the query and key before passing them
188192
# to the attention op.
189-
out_query = torch.empty_like(query)
190-
out_key = torch.empty_like(key)
191193
pos_encoding_ops.rotary_embedding_neox(
192-
out_query,
193-
out_key,
194194
positions,
195195
query,
196196
key,
197197
self.cos_sin_cache,
198198
)
199199
return super().forward(
200-
out_query,
201-
out_key,
200+
query,
201+
key,
202202
value,
203203
key_cache,
204204
value_cache,

cacheflow/models/input_metadata.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(
1212
seq_groups: List[Tuple[List[int], SamplingParams]],
1313
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
1414
prompt_lens: List[int],
15+
cumulative_prompt_lens: torch.Tensor,
1516
slot_mapping: torch.Tensor,
1617
context_lens: torch.Tensor,
1718
max_context_len: int,
@@ -20,13 +21,15 @@ def __init__(
2021
self.seq_groups = seq_groups
2122
self.seq_logprobs = seq_logprobs
2223
self.prompt_lens = prompt_lens
24+
self.cumulative_prompt_lens = cumulative_prompt_lens
2325
self.slot_mapping = slot_mapping
2426
self.context_lens = context_lens
2527
self.max_context_len = max_context_len
2628
self.block_tables = block_tables
2729

2830
self.num_prompts = len(prompt_lens)
2931
self.num_prompt_tokens = sum(prompt_lens)
32+
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
3033
self.num_generation_tokens = context_lens.shape[0]
3134
self.num_valid_tokens = slot_mapping.shape[0]
3235
if block_tables.numel() > 0:
@@ -40,11 +43,13 @@ def __repr__(self) -> str:
4043
return (f'InputMetadata('
4144
f'num_prompts={self.num_prompts}, '
4245
f'num_prompt_tokens={self.num_prompt_tokens}, '
46+
f'max_prompt_len={self.max_prompt_len}, '
4347
f'num_generation_tokens={self.num_generation_tokens}, '
4448
f'num_valid_tokens={self.num_valid_tokens}, '
4549
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
4650
f'max_context_len={self.max_context_len}), '
4751
f'prompt_lens={self.prompt_lens}, '
52+
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
4853
f'slot_mapping={self.slot_mapping}, '
4954
f'context_lens={self.context_lens}, '
5055
f'block_tables={self.block_tables})')

cacheflow/models/llama.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from transformers import LlamaConfig
1212

1313
from cacheflow.models import InputMetadata
14+
from cacheflow.models.activation import SiluAndMul
1415
from cacheflow.models.attention import LlamaCacheFlowAttention
1516
from cacheflow.models.layernorm import RMSNorm
1617
from cacheflow.models.sample import Sampler
@@ -39,16 +40,14 @@ def __init__(
3940
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
4041
bias=False, input_is_parallel=True,
4142
perform_initialization=False)
42-
assert hidden_act == 'silu'
43-
self.act_fn = nn.SiLU()
43+
if hidden_act != 'silu':
44+
raise ValueError(f'Unsupported activation: {hidden_act}. '
45+
'Only silu is supported for now.')
46+
self.act_fn = SiluAndMul()
4447

4548
def forward(self, x):
4649
gate_up, _ = self.gate_up_proj(x)
47-
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
48-
gate, up = torch.split(gate_up, 1, dim=-2)
49-
gate = gate.squeeze(dim=-2).contiguous()
50-
up = up.squeeze(dim=-2).contiguous()
51-
x = self.act_fn(gate) * up
50+
x = self.act_fn(gate_up)
5251
x, _ = self.down_proj(x)
5352
return x
5453

@@ -94,11 +93,7 @@ def forward(
9493
cache_event: Optional[torch.cuda.Event],
9594
) -> torch.Tensor:
9695
qkv, _ = self.qkv_proj(hidden_states)
97-
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
98-
q, k, v = torch.split(qkv, 1, dim=-2)
99-
q = q.squeeze(dim=-2).contiguous()
100-
k = k.squeeze(dim=-2).contiguous()
101-
v = v.squeeze(dim=-2).contiguous()
96+
q, k, v = qkv.chunk(chunks=3, dim=-1)
10297
k_cache, v_cache = kv_cache
10398
attn_output = self.attn(
10499
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)

cacheflow/models/opt.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,14 @@ def forward(
6969
cache_event: Optional[torch.cuda.Event],
7070
) -> torch.Tensor:
7171
qkv, _ = self.qkv_proj(hidden_states)
72-
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
73-
q, k, v = torch.split(qkv, 1, dim=-2)
74-
q = q.squeeze(dim=-2).contiguous()
75-
k = k.squeeze(dim=-2).contiguous()
76-
v = v.squeeze(dim=-2).contiguous()
72+
q, k, v = qkv.chunk(chunks=3, dim=-1)
7773
key_cache, value_cache = kv_cache
7874
attn_output = self.attn(
7975
q, k, v, key_cache, value_cache, input_metadata, cache_event)
8076
output, _ = self.out_proj(attn_output)
8177
return output
8278

79+
8380
class OPTDecoderLayer(nn.Module):
8481

8582
def __init__(self, config: OPTConfig):

cacheflow/worker/worker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def prepare_inputs(
128128
slot = block_number * self.block_size + block_offset
129129
slot_mapping.append(slot)
130130

131+
cumulative_prompt_lens: List[int] = [0]
132+
for prompt_len in prompt_lens:
133+
cumulative_prompt_lens.append(
134+
cumulative_prompt_lens[-1] + prompt_len)
135+
131136
# Add generation tokens.
132137
max_context_len = 0
133138
max_num_blocks_per_seq = 0
@@ -183,11 +188,14 @@ def prepare_inputs(
183188
for block_table in generation_block_tables]
184189
block_tables_tensor = torch.tensor(
185190
padded_block_tables, dtype=torch.int, device='cuda')
191+
cumulative_prompt_lens_tensor = torch.tensor(
192+
cumulative_prompt_lens, dtype=torch.int, device='cuda')
186193

187194
input_metadata = InputMetadata(
188195
seq_groups=seq_groups,
189196
seq_logprobs=seq_logprobs,
190197
prompt_lens=prompt_lens,
198+
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
191199
slot_mapping=slot_mapping_tensor,
192200
context_lens=context_lens_tensor,
193201
max_context_len=max_context_len,

csrc/activation.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include <torch/extension.h>
2+
3+
void silu_and_mul(
4+
torch::Tensor& out,
5+
torch::Tensor& input);
6+
7+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8+
m.def(
9+
"silu_and_mul",
10+
&silu_and_mul,
11+
"Activation function used in SwiGLU.");
12+
}

csrc/activation_kernels.cu

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <torch/extension.h>
2+
#include <ATen/cuda/CUDAContext.h>
3+
4+
namespace cacheflow {
5+
6+
template<typename T>
7+
__device__ __forceinline__ T silu(const T& x) {
8+
// x * sigmoid(x)
9+
return (T) (((float) x) / (1.0f + expf((float) -x)));
10+
}
11+
12+
template<typename scalar_t>
13+
__global__ void silu_and_mul_kernel(
14+
scalar_t* __restrict__ out, // [num_tokens, d]
15+
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
16+
const int d) {
17+
const int token_idx = blockIdx.x;
18+
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
19+
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
20+
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
21+
out[token_idx * d + idx] = silu(x) * y;
22+
}
23+
}
24+
25+
} // namespace cacheflow
26+
27+
void silu_and_mul(
28+
torch::Tensor& out, // [num_tokens, d]
29+
torch::Tensor& input) // [num_tokens, 2 * d]
30+
{
31+
int num_tokens = input.size(0);
32+
int d = input.size(1) / 2;
33+
34+
dim3 grid(num_tokens);
35+
dim3 block(std::min(d, 1024));
36+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
37+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
38+
input.scalar_type(),
39+
"silu_and_mul_kernel",
40+
[&] {
41+
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
42+
out.data_ptr<scalar_t>(),
43+
input.data_ptr<scalar_t>(),
44+
d);
45+
});
46+
}

0 commit comments

Comments
 (0)