Skip to content

Commit

Permalink
Add ability for force bos id for mbart (#22)
Browse files Browse the repository at this point in the history
* Merge with main (#1)

* Update beam_search_topk_kernels.cu

fix: fix bug of beam search

* fix: change int of some kernels to int64_t to prevent overflow

* fix: gpt tensor shapes inconsistency (NVIDIA#505)

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* Update gpt_guide.md (NVIDIA#529)

* fix: fix bug of gpt buffer and gpt gemm overflow

* Update T5DecodingWeight.cc

fix: fix loading bug of t5

* [Enhancement]add pytorch backend support for gptneox (NVIDIA#550)

* add pytorch backend support for gptneox

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* fix early stopping invalid

* 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B.

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* Change the names of classes, removing 'parallel' from their names

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* Format the code.

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* Only print results when rank is 0.

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* Add dist.init_process_group().

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* update docs

Signed-off-by: AkiyamaYummy <842720660@qq.com>

---------

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* Update cublasMMWrapper.cc

Fix the CUBLAS_VERSION checking of cublasMMWrapper

* Update cublasMMWrapper.cc

* fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524)

* Update unfused_attention_kernels.cu

fix bug of softmax kernel

* [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569)

* create huggingface_gptneox_convert.py

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* adjust HF's multi bin files

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* update gptneox_guide.md

Signed-off-by: AkiyamaYummy <842720660@qq.com>

---------

Signed-off-by: AkiyamaYummy <842720660@qq.com>

* perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568)

Co-authored-by: r.yang <r.yang@tianrang-inc.com>

* Fix/gpt early stop (NVIDIA#584)

* fix: fix bug of early stopping of gpt

* [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672)

FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results.

* fix: swap tensor bug (NVIDIA#683)

* Support size_per_head=112 (NVIDIA#660)

* fix multi-gpu build

* add support for size_per_head=112 for gpt decoder

* remove mpi_cxx from multi-gpu build for now (NVIDIA#705)

---------

Signed-off-by: AkiyamaYummy <842720660@qq.com>
Co-authored-by: byshiue <bhsueh@nvidia.com>
Co-authored-by: _yummy_ <842720660@qq.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: zhangxin81 <115389973+zhangxin81@users.noreply.github.com>
Co-authored-by: 杨睿 <595403043@qq.com>
Co-authored-by: r.yang <r.yang@tianrang-inc.com>
Co-authored-by: Rahul Kindi <rkindi@users.noreply.github.com>
Co-authored-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
Co-authored-by: Daya Khudia <37562707+dskhudia@users.noreply.github.com>
Co-authored-by: Dean Wyatte <2512762+dwyatte@users.noreply.github.com>

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

---------

Signed-off-by: AkiyamaYummy <842720660@qq.com>
Co-authored-by: Asim Shankar <asim.shankar@snowflake.com>
Co-authored-by: byshiue <bhsueh@nvidia.com>
Co-authored-by: _yummy_ <842720660@qq.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: zhangxin81 <115389973+zhangxin81@users.noreply.github.com>
Co-authored-by: 杨睿 <595403043@qq.com>
Co-authored-by: r.yang <r.yang@tianrang-inc.com>
Co-authored-by: Rahul Kindi <rkindi@users.noreply.github.com>
Co-authored-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
Co-authored-by: Daya Khudia <37562707+dskhudia@users.noreply.github.com>
Co-authored-by: Dean Wyatte <2512762+dwyatte@users.noreply.github.com>
  • Loading branch information
12 people authored Oct 5, 2023
1 parent 3336e68 commit e0b124a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 2 deletions.
28 changes: 28 additions & 0 deletions src/fastertransformer/kernels/decoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,34 @@ void invokeDecodingInitialize(bool* finished,
finished, sequence_length, word_ids, cum_log_probs, sentence_ids, batch_size, beam_width, max_input_length);
}

__global__ void forceId(int* word_ids,
const int* force_bos_ids,
const int batch_size,
const int beam_width,
const int step)
{
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width;
index += blockDim.x * gridDim.x) {
if (word_ids != nullptr) {
word_ids[index+step*batch_size*beam_width] = force_bos_ids[index / beam_width];
}
}
}

void invokeForceId(int* word_ids,
const int* force_bos_ids,
const int batch_size,
const int beam_width,
const int step,
cudaStream_t stream)
{
dim3 grid((int)ceil(batch_size * beam_width * 1.0 / 256));
dim3 block(256);

forceId<<<grid, block, 0, stream>>>(
word_ids, force_bos_ids, batch_size, beam_width, step);
}

template void invokeDecodingInitialize(bool* finished,
int* sequence_length,
int* word_ids,
Expand Down
7 changes: 7 additions & 0 deletions src/fastertransformer/kernels/decoding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ void invokeDecodingInitialize(bool* finished,
const int max_input_length,
cudaStream_t stream);

void invokeForceId(int* word_ids,
const int* force_bos_ids,
const int batch_size,
const int beam_width,
const int step,
cudaStream_t stream);

// get token from all_ids at step, then lookup from the embedding table
// by the token
template<typename T>
Expand Down
30 changes: 30 additions & 0 deletions src/fastertransformer/models/bart/BartDecoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void BartDecoding<T>::allocateBuffer(

start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false));
end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false));
forced_bos_ids_buf_ = (int*)(allocator_->reMalloc(forced_bos_ids_buf_, sizeof(int) * batch_size, false));

output_ids_buf_ =
(int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false));
Expand Down Expand Up @@ -182,6 +183,7 @@ void BartDecoding<T>::freeBuffer()
allocator_->free((void**)(&tiled_encoder_sequence_length_));

allocator_->free((void**)(&start_ids_buf_));
allocator_->free((void**)(&forced_bos_ids_buf_));
allocator_->free((void**)(&end_ids_buf_));

allocator_->free((void**)(&output_ids_buf_));
Expand Down Expand Up @@ -343,6 +345,7 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
// stop_words_list [batch_size, 2, stop_words_length], optional
// bad_words_list [batch_size, 2, stop_words_length], optional
// start_id [batch_size] on cpu, optional
// forced_bos_id [batch_size] on cpu, optional
// end_id [batch_size] on cpu, optional
// runtime_top_k [1] or [batch_size] on cpu, optional, uint.
// runtime_top_p [1] or [batch_size] on cpu, optional, float.
Expand Down Expand Up @@ -382,6 +385,7 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
dynamic_decode_layer_->setup(batch_size, beam_width, &input_map);
handleOptArg(&input_map, "start_id", start_ids_buf_, start_id_, batch_size);
handleOptArg(&input_map, "end_id", end_ids_buf_, end_id_, batch_size);
handleOptArg(&input_map, "forced_bos_id", forced_bos_ids_buf_, -1, batch_size);
}

FT_CHECK_WITH_INFO(input_tensors->at("encoder_output").shape[2] == d_model_,
Expand Down Expand Up @@ -792,6 +796,32 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
dynamic_decode_output_tensors.insert(*t);
}
dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
if (step == 1 && input_tensors->isExist("forced_bos_id")) {
invokeForceId(output_ids_buf_,
forced_bos_ids_buf_,
batch_size,
beam_width,
step,
stream_);
sync_check_cuda_error();
}
// {
// for (auto t = dynamic_decode_output_tensors.begin(); t != dynamic_decode_output_tensors.end(); ++t) {
// printf("step: %d, t->first: %s\n", step, t->first.c_str());
// // printf("%s\n", t->second.toString().c_str());
// {
// int* buf;
// int st = t->second.size();
// buf = new int[st];
// cudaMemcpy(buf, t->second.data, sizeof(int) * t->second.size(), cudaMemcpyDeviceToHost);
// for (int i=0; i<st; i++) {
// printf("%d ", buf[i]);
// }
// printf("\n");
// }
// }
// printf("\n\n");
// }
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/fastertransformer/models/bart/BartDecoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ class BartDecoding: public BaseLayer {
bool* finished_buf_ = nullptr;
bool* h_finished_buf_ = nullptr;

int* start_ids_buf_ = nullptr;
int* end_ids_buf_ = nullptr;
int* start_ids_buf_ = nullptr;
int* forced_bos_ids_buf_ = nullptr;
int* end_ids_buf_ = nullptr;

T* key_cache_ = nullptr;
T* value_cache_ = nullptr;
Expand Down

0 comments on commit e0b124a

Please sign in to comment.