forked from modular/max
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkv_attention.mojo
199 lines (170 loc) · 6.73 KB
/
kv_attention.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from math import isqrt
import mo
from max.tensor import Tensor, TensorShape
from max.graph import ops, Dim, Symbol, TensorType, _OpaqueType as OpaqueType
from max.graph.quantization import Float32Encoding, QuantizationEncoding
from kv_cache.types import (
ContiguousKVCache,
KVCacheLayout,
KVCacheStaticParams,
KVCacheKernelNames,
)
from max.serve.kv_cache.kernel_names import _kv_cache_kernel_names
from pipelines.nn import Linear
from pipelines.nn.attention import rope
@value
struct KVCacheOptimizedAttention[kv_params: KVCacheStaticParams]:
"""Attention block that supports the specialized ContiguousKVCache type."""
alias _kernel_names = _kv_cache_kernel_names[DType.float32, kv_params]()
# hyperparams
var n_heads: Int
var dim: Int
# weights/projections
var wqkv: Symbol
var wo: Linear
def __call__(
self,
input: Symbol,
start_pos: Symbol,
freqs_cis: Symbol,
k_cache: Symbol,
v_cache: Symbol,
mask: Symbol,
) -> Tuple[Symbol, Symbol, Symbol]:
"""Constructs the forward pass for this attention block
input: Activations with shape (batch_size, seq_len, num_heads * head_dim)
start_pos: Scalar with index of starting token, effectively tracks
the number of entries in the cache.
freqs_cis: Positional frequencies tensor with shape
(seq_len, head_dim // 2, 2).
k_cache: Previously computed keys. This is a mo.opaque ContiguousKVCache object
with logical shape (batch, prev_seq_len, n_kv_heads, head_dim).
v_cache: Previously computed values. This is a mo.opaque ContiguousKVCache object
with logical shape (batch, prev_seq_len, n_kv_heads, head_dim).
"""
g = input.graph()
# extract shape characteristics of the input
batch_size, seq_len = input.shape()[0], input.shape()[1]
head_dim = g.scalar[DType.float32](kv_params.head_size)
# define opaque types for custom op outputs
# TODO give these guys actual values for num_kv_head and head_size
# We only use these types to get `id()`, and the actual value of this
# string is not used.
var k_cache_type = OpaqueType(
ContiguousKVCache[DType.float32, kv_params].id()
)
var v_cache_type = OpaqueType(
ContiguousKVCache[DType.float32, kv_params].id()
)
# reshape our rope positional frequencies
f_shape = ops.shape_of(freqs_cis)
new_f_shape = ops.stack(List[Symbol](f_shape[0], g.scalar(Int64(-1))))
freqs_cis_2d = ops.reshape(freqs_cis, new_f_shape)
xq_type = input.type()
xq = ops.custom[self._kernel_names.fused_qkv_matmul_kernel](
List[Symbol](input, self.wqkv, k_cache, v_cache),
xq_type,
)
xq = xq.reshape(batch_size, seq_len, self.n_heads, kv_params.head_size)
xq = ops.custom[self._kernel_names.fused_qk_rope_kernel](
List[Symbol](xq, k_cache, freqs_cis_2d), xq.type()
)
@parameter
if kv_params.layout == KVCacheLayout.BHSD:
# Flash Attention shapes differ on CPU and GPU, we need to
# transpose on cpu. This'll will be fixed by KERN-626
xq = xq.swapaxes(1, 2)
# do flash attention
seq_len_sym = ops.shape_of(input)[1]
var attn_mask = attention_mask(
mask, start_pos, seq_len_sym, DType.float32
)
var output_type = xq.type()
attn_out = ops.custom[self._kernel_names.flash_attention_kernel](
List[Symbol](xq, k_cache, v_cache, attn_mask, ops.rsqrt(head_dim)),
output_type,
)
# transpose hidden state to (batch_size, seq_len, num_heads * head_dim)
@parameter
if kv_params.layout == KVCacheLayout.BHSD:
# Flash Attention shapes differ on CPU and GPU, we need to
# transpose on cpu. This'll will be fixed by KERN-626
attn_out = attn_out.swapaxes(1, 2)
attn_out = attn_out.reshape(batch_size, seq_len, -1)
# final projection and return
return attn_out @ self.wo, k_cache, v_cache
def attention_mask(
mask: Symbol, start_pos: Symbol, seq_len: Symbol, activation_dtype: DType
) -> Symbol:
g = start_pos.graph()
seq_len = seq_len.reshape()
start_pos = start_pos.reshape()
# Mask out current sequence elements [i, j] where j > i with an
# upper-triangular matrix filled with -inf.
mask_val = ops.cast(
g.op(
"rmo.mo.broadcast_to",
List(
g.scalar(-10000, DType.float32),
ops.stack(List[Symbol](seq_len, seq_len)),
),
TensorType(
DType.float32,
"seq_len",
"seq_len",
),
),
activation_dtype,
)
new_mask = ops.band_part(
mask_val,
g.scalar[DType.int64](-1),
num_upper=g.scalar[DType.int64](0),
# Invert the mask from lower to upper.
exclude=True,
)
zeros = g.op(
"rmo.mo.broadcast_to",
List(
g.scalar(0, activation_dtype),
ops.stack(List[Symbol](seq_len, start_pos)),
),
TensorType(
activation_dtype,
"seq_len",
"start_pos",
),
)
full_seq_len = Dim("full_seq_len")
x = ops.concat(
List[Symbol](
zeros,
new_mask,
),
axis=1,
out_dim=full_seq_len,
)
# In the above, x, results in a seq_len/start_pos + seq_len tensor
# x, has 0s with the upper-triangular mapped to -inf
# to accomodate for left padding, we should create a new tensor of -inf
# with the same shape of x, and return the values of this new -inf tensor
# when a padded token is present and x when a valid token is present
select_mask = g.op(
"rmo.mo.broadcast_to",
List(mask, ops.stack(List[Symbol](seq_len, ops.shape_of(x)[1]))),
TensorType(DType.bool, "seq_len", full_seq_len),
)
y = g.full[DType.float32](-10000.0, x.shape())
return ops.select(select_mask, x, y)