-
Notifications
You must be signed in to change notification settings - Fork 969
/
decoderMaskedMultiheadAttention.h
287 lines (239 loc) · 11.2 KB
/
decoderMaskedMultiheadAttention.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
namespace tensorrt_llm
{
namespace kernels
{
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do \
{ \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) \
{ \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Round up to next higher power of 2 (return x if it's already a power
/// of 2).
inline int pow2roundup(int x)
{
if (x < 0)
return 0;
--x;
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
return x + 1;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template <typename T>
struct Multihead_attention_params_base
{
// The output buffer. Dimensions B x D.
void* out = nullptr;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
T const *q = nullptr, *q_bias = nullptr;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
T const *k = nullptr, *k_bias = nullptr;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
T const *v = nullptr, *v_bias = nullptr;
// The indirections to use for cache when beam sampling.
int const* cache_indir = nullptr;
// scales
float const* query_weight_output_scale = nullptr;
float const* attention_qk_scale = nullptr;
float const* attention_output_weight_input_scale_inv = nullptr;
// Stride to handle the case when KQV is a single buffer
int stride = 0;
// The batch size.
int batch_size = 0;
// The beam width
int beam_width = 0;
// By default, max_attention_window_size == cyclic_attention_window_size
// unless each layer has different cyclic kv cache length.
// Max cache capacity (used to allocate KV cache)
int max_attention_window_size = 0;
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
int cyclic_attention_window_size = 0;
// Length of the sink token in KV cache
int sink_token_length = 0;
// The number of heads (H).
int num_heads = 0;
// Controls MHA/MQA/GQA
int num_kv_heads = 0;
// The hidden dimension per head (Dh).
int hidden_size_per_head = 0;
// Rotary position embedding type
PositionEmbeddingType position_embedding_type = PositionEmbeddingType::kLEARNED_ABSOLUTE;
// The per-head latent space reserved for rotary embeddings.
int rotary_embedding_dim = 0;
float rotary_embedding_base = 0.0f;
RotaryScalingType rotary_embedding_scale_type = RotaryScalingType::kNONE;
float rotary_embedding_scale = 1.0f;
// The pre-computed rotary inv freq when building the engines (as constant weights).
float const* rotary_embedding_inv_freq_cache = nullptr;
float rotary_embedding_short_m_scale = 1.0f;
float rotary_embedding_long_m_scale = 1.0f;
int rotary_embedding_max_positions = 0;
int rotary_embedding_original_max_positions = 0;
int rotary_cogvlm_vision_start = -1;
int rotary_cogvlm_vision_length = -1;
// Position shift for streamingllm
bool position_shift_enabled = false;
// The current timestep. TODO Check that do we only this param in cross attention?
int timestep = 0;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float inv_sqrt_dh = 0.0f;
// The tanh scale factor. (only used by Grok).
float qk_tanh_scale = 0.0f;
float qk_tanh_inverse_scale = 0.0f;
// The attention mask [batch_size, attention_mask_stride (i.e. max_kv_seqlen)]
bool const* attention_mask = nullptr;
int attention_mask_stride = 0;
// If relative position embedding is used
T const* relative_attention_bias = nullptr;
int relative_attention_bias_stride = 0;
int max_distance = 0;
// block sparse config
bool block_sparse_attention = false;
BlockSparseParams block_sparse_params{64, false, 16, 8};
// The slope per head of linear position bias to attention score (H).
T const* linear_bias_slopes = nullptr;
T const* ia3_key_weights = nullptr;
T const* ia3_value_weights = nullptr;
int const* ia3_tasks = nullptr;
float const* qkv_scale_quant_orig = nullptr;
float const* attention_out_scale_orig_quant = nullptr;
// 8 bits kv cache scales.
float const* kv_scale_orig_quant = nullptr;
float const* kv_scale_quant_orig = nullptr;
bool int8_kv_cache = false;
bool fp8_kv_cache = false;
// Multi-block setups
mutable bool multi_block_mode = true;
// Number of streaming processors on the device.
// Tune block size to maximum occupancy.
int multi_processor_count = 1;
mutable int timesteps_per_block = 1;
mutable int seq_len_tile = 1;
mutable int min_seq_len_tile = 1;
mutable int max_seq_len_tile = 1;
// The partial output buffer. Dimensions max_seq_len_tile x B x D. (for each timestep only seq_len_tile x B x D is
// needed)
T* partial_out = nullptr;
// ThreadBlock sum. Dimensions max_seq_len_tile x 1. (for each timestep only seq_len_tile x 1 is needed)
float* partial_sum = nullptr;
// ThreadBlock max. Dimensions max_seq_len_tile x 1. (for each timestep only seq_len_tile x 1 is needed)
float* partial_max = nullptr;
// threadblock counter to identify the complete of partial attention computations
int* block_counter = nullptr;
int const* memory_length_per_sample = nullptr;
};
template <typename T, bool USE_CROSS_ATTENTION = false>
struct Multihead_attention_params;
// self-attention params
template <typename T>
struct Multihead_attention_params<T, false> : public Multihead_attention_params_base<T>
{
static constexpr bool DO_CROSS_ATTENTION = false;
int max_decoder_seq_len = 0;
// allows to exit attention early
bool* finished = nullptr;
// required in case of masked attention with different length
int const* length_per_sample = nullptr;
// input lengths to identify the paddings (i.e. input seq < padding < new generated seq).
int const* input_lengths = nullptr;
};
template <class T>
using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
// cross-attention params
template <typename T>
struct Multihead_attention_params<T, true> : public Multihead_attention_params_base<T>
{
static constexpr bool DO_CROSS_ATTENTION = true;
int max_decoder_seq_len = 0;
// allows to exit attention early
bool* finished = nullptr;
// required in case of masked attention with different length
int const* length_per_sample = nullptr;
// input lengths to identify the paddings (i.e. input seq < padding < new generated seq).
int const* input_lengths = nullptr;
};
template <class T>
using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
////////////////////////////////////////////////////////////////////////////////////////////////////
// Pre-check whether head size is supported when building engines.
bool mmha_supported(int head_size);
#define DECLARE_MMHA_NORMAL_AND_PAGED(T) \
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params, \
const KVBlockArray& block_array, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params, \
const KVLinearBuffer& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
void masked_multihead_attention(const Cross_multihead_attention_params<T>& params, \
const KVBlockArray& block_array, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
void masked_multihead_attention(const Cross_multihead_attention_params<T>& params, \
const KVLinearBuffer& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream);
DECLARE_MMHA_NORMAL_AND_PAGED(float);
DECLARE_MMHA_NORMAL_AND_PAGED(uint16_t);
#ifdef ENABLE_BF16
DECLARE_MMHA_NORMAL_AND_PAGED(__nv_bfloat16);
#endif
#undef DECLARE_MMHA_NORMAL_AND_PAGED
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline int estimate_min_multi_block_count(int max_timesteps, int max_dynamic_shmem_per_block)
{
auto const qk_elts = static_cast<int>((max_timesteps + 1 + 4 - 1) / 4);
int size_per_elts = 16;
auto const qk_sz = qk_elts * 16;
size_t logits_sz = 0;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if (sizeof(T) != 4)
{
size_per_elts += 4 * sizeof(T);
}
#endif
int elts_per_block = max_dynamic_shmem_per_block / size_per_elts;
int min_block_count = (qk_elts + elts_per_block - 1) / elts_per_block;
return std::max(1, min_block_count);
}
} // namespace kernels
} // namespace tensorrt_llm