Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shape mismatch on the masked_tokens param in decoder masked multi-head attention kernel. #773

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ struct Multihead_attention_params_base {
int batch_size = 0;
// The beam width
int beam_width = 0;
// The sequence length.
// The cache length.
int memory_max_len = 0;
// The whole sequence length, which includes context and output.
int session_len = 0;
// The number of heads (H).
int num_heads = 0;
// The hidden dimension per head (Dh).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;

const size_t bi_seq_len_offset = bi * params.memory_max_len;
const size_t bi_session_len_offset = bi * params.session_len;

// int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
Expand Down Expand Up @@ -1515,7 +1516,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,

for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
const int ti_circ = ti % params.memory_max_len;
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_session_len_offset + ti];

// The keys loaded from the key cache.
K_vec_k k[K_VECS_PER_THREAD];
Expand Down Expand Up @@ -1627,7 +1628,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
float sum = 0.f;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_session_len_offset + ti];
#ifdef FP8_MHA
float logit = 0.f;
if (FP8_MHA_KERNEL) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const int rotary_embedding_dim,
const bool neox_rotary_style,
const int memory_max_len,
const int session_len,
const int* prefix_prompt_lengths,
const int max_prefix_prompt_length,
const int max_input_len,
Expand Down Expand Up @@ -105,6 +106,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.batch_size = inference_batch_size;
params.beam_width = beam_width;
params.memory_max_len = memory_max_len;
params.session_len = session_len;
params.prefix_prompt_lengths = prefix_prompt_lengths;
params.max_prefix_prompt_length = max_prefix_prompt_length;
params.length_per_sample = sequence_lengths; // max_input_length + current output length
Expand Down Expand Up @@ -163,6 +165,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const int rotary_embedding_dim, \
const bool neox_rotary_style, \
const int memory_max_len, \
const int session_len, \
const int* prefix_prompt_lengths, \
const int max_prefix_prompt_length, \
const int max_input_len, \
Expand Down Expand Up @@ -467,7 +470,7 @@ void DecoderSelfAttentionLayer<T>::forward(TensorMap* output_tens
// finished [batch_size] (optional)
// total_padding_tokens [batch_size] (optional)
// max_input_length [1] on cpu (optional)
// masked_tokens [batch_size, memory_len], (optional)
// masked_tokens [batch_size, session_len], (optional)
// cache_indirection [batch_size / beam_width, beam_width, memory_max_len] (optional)
// d_prefix_prompt_lengths [batch_size] (optional)
// max_prefix_prompt_length [1] on cpu (optional)
Expand Down Expand Up @@ -504,6 +507,7 @@ void DecoderSelfAttentionLayer<T>::forward(TensorMap* output_tens
const int batch_size = input_tensors->at("input_query").shape[0];
const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1;
const int memory_max_len = output_tensors->at("key_cache").shape[3];
const int session_len = masked_tokens != nullptr ? input_tensors->at("masked_tokens").shape[1] : 0;

const int* d_prefix_prompt_lengths = input_tensors->getPtr<int>("d_prefix_prompt_lengths", nullptr);
const int max_prefix_prompt_length = input_tensors->getVal<int>("max_prefix_prompt_length", 0);
Expand Down Expand Up @@ -596,6 +600,7 @@ void DecoderSelfAttentionLayer<T>::forward(TensorMap* output_tens
rotary_embedding_dim_,
neox_rotary_style_,
memory_max_len,
session_len,
d_prefix_prompt_lengths,
max_prefix_prompt_length,
input_tensors->getVal<int>("max_input_length", 0),
Expand Down
10 changes: 5 additions & 5 deletions src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ void ParallelGpt<T>::allocateBuffer(size_t batch_size,
parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true));
seq_limit_len_ = (uint32_t*)(allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false));
tiled_masked_tokens_ =
(bool*)(allocator_->reMalloc(tiled_masked_tokens_, sizeof(bool) * batchxbeam * memory_len, true));
(bool*)(allocator_->reMalloc(tiled_masked_tokens_, sizeof(bool) * batchxbeam * max_session_len, true));

context_decoder_input_buf_ = (T*)(allocator_->reMalloc(
context_decoder_input_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false));
Expand Down Expand Up @@ -865,7 +865,7 @@ void ParallelGpt<T>::forward(std::unordered_map<std::string, Tensor>* outp
PUSH_RANGE("initialize output and parent ids");
cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_);
cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_);
cudaMemsetAsync(tiled_masked_tokens_, false, sizeof(bool) * batch_size * beam_width * memory_len, stream_);
cudaMemsetAsync(tiled_masked_tokens_, false, sizeof(bool) * batch_size * beam_width * session_len, stream_);
cudaMemsetAsync(tiled_total_padding_count_, 0, sizeof(int) * batch_size * beam_width, stream_);
if (beam_width > 1) {
cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * memory_len, stream_);
Expand Down Expand Up @@ -1180,7 +1180,7 @@ void ParallelGpt<T>::forward(std::unordered_map<std::string, Tensor>* outp
PUSH_RANGE("mask padding tokens");
invokeMaskPaddingTokens(tiled_masked_tokens_,
input_tensors->at("input_lengths").getPtr<int>(),
memory_len,
session_len,
max_input_length,
initial_step,
batch_size,
Expand Down Expand Up @@ -1316,8 +1316,8 @@ void ParallelGpt<T>::forward(std::unordered_map<std::string, Tensor>* outp
{"masked_tokens",
Tensor(MEMORY_GPU,
TYPE_BOOL,
{local_batch_size * beam_width, memory_len},
tiled_masked_tokens_ + id_offset * memory_len)}});
{local_batch_size * beam_width, session_len},
tiled_masked_tokens_ + id_offset * session_len)}});
if (beam_width > 1) {
decoder_input_tensors.insert({"cache_indirection",
Tensor(MEMORY_GPU,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void ParallelGptDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
// cache_indirection [local_batch_size / beam_width, beam_width, memory_len]
// Here, local_batch_size contains the beam_width, so local_batch_size / beam_width
// is real local_batch_size. (optional.)
// masked_tokens [local_batch_size, memory_len]
// masked_tokens [local_batch_size, session_len]
// linear_bias_slopes [head_num], optional

// output tensors:
Expand Down