Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability for force bos id for mbart #22

Merged
merged 190 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
190 commits
Select commit Hold shift + click to select a range
743369a
Merge with main (#1)
sfc-gh-ashankar Jul 11, 2023
e095d10
commit
sfc-gh-zhwang Oct 3, 2023
0141b94
commit
sfc-gh-zhwang Oct 3, 2023
4553c67
commit
sfc-gh-zhwang Oct 3, 2023
dec08a8
commit
sfc-gh-zhwang Oct 3, 2023
a9d7564
commit
sfc-gh-zhwang Oct 3, 2023
dddb699
commit
sfc-gh-zhwang Oct 3, 2023
62a99c2
commit
sfc-gh-zhwang Oct 3, 2023
a469c03
commit
sfc-gh-zhwang Oct 3, 2023
4e115cb
commit
sfc-gh-zhwang Oct 3, 2023
fb0cb6c
commit
sfc-gh-zhwang Oct 3, 2023
fa580c3
commit
sfc-gh-zhwang Oct 3, 2023
adebee7
commit
sfc-gh-zhwang Oct 3, 2023
933199a
commit
sfc-gh-zhwang Oct 3, 2023
18a666d
commit
sfc-gh-zhwang Oct 3, 2023
5b3df49
commit
sfc-gh-zhwang Oct 3, 2023
cac44d6
commit
sfc-gh-zhwang Oct 3, 2023
8356c50
commit
sfc-gh-zhwang Oct 3, 2023
b7b1c67
commit
sfc-gh-zhwang Oct 3, 2023
ec6d344
commit
sfc-gh-zhwang Oct 3, 2023
359227c
commit
sfc-gh-zhwang Oct 3, 2023
8bdd5d3
commit
sfc-gh-zhwang Oct 3, 2023
4a283f3
commit
sfc-gh-zhwang Oct 3, 2023
ad70082
commit
sfc-gh-zhwang Oct 3, 2023
810c4a6
commit
sfc-gh-zhwang Oct 3, 2023
f2f3292
commit
sfc-gh-zhwang Oct 3, 2023
b13b755
commit
sfc-gh-zhwang Oct 3, 2023
3345d4b
commit
sfc-gh-zhwang Oct 3, 2023
adc510c
commit
sfc-gh-zhwang Oct 3, 2023
d3c0325
commit
sfc-gh-zhwang Oct 3, 2023
dc32a83
commit
sfc-gh-zhwang Oct 3, 2023
16acf40
commit
sfc-gh-zhwang Oct 3, 2023
e4c4eb8
commit
sfc-gh-zhwang Oct 3, 2023
8b46d8c
commit
sfc-gh-zhwang Oct 3, 2023
ecbb7c7
commit
sfc-gh-zhwang Oct 3, 2023
5f6c84f
commit
sfc-gh-zhwang Oct 3, 2023
73a2b8a
commit
sfc-gh-zhwang Oct 3, 2023
db9404a
commit
sfc-gh-zhwang Oct 3, 2023
cf66777
commit
sfc-gh-zhwang Oct 3, 2023
1b2982f
commit
sfc-gh-zhwang Oct 3, 2023
62d1080
commit
sfc-gh-zhwang Oct 3, 2023
167fc06
commit
sfc-gh-zhwang Oct 3, 2023
31e8baa
commit
sfc-gh-zhwang Oct 3, 2023
d2a4e07
commit
sfc-gh-zhwang Oct 3, 2023
44d1a82
commit
sfc-gh-zhwang Oct 3, 2023
6939cc0
commit
sfc-gh-zhwang Oct 3, 2023
593e323
commit
sfc-gh-zhwang Oct 3, 2023
c5f5ff1
commit
sfc-gh-zhwang Oct 3, 2023
22e3fe1
commit
sfc-gh-zhwang Oct 3, 2023
e88eed3
commit
sfc-gh-zhwang Oct 3, 2023
d9b6600
commit
sfc-gh-zhwang Oct 3, 2023
513c896
commit
sfc-gh-zhwang Oct 3, 2023
3171ba5
commit
sfc-gh-zhwang Oct 3, 2023
6644645
commit
sfc-gh-zhwang Oct 3, 2023
1a38e3b
commit
sfc-gh-zhwang Oct 3, 2023
bfe5a99
commit
sfc-gh-zhwang Oct 3, 2023
bc142d9
commit
sfc-gh-zhwang Oct 3, 2023
576dcae
commit
sfc-gh-zhwang Oct 3, 2023
22e3a95
commit
sfc-gh-zhwang Oct 3, 2023
1102ac4
commit
sfc-gh-zhwang Oct 3, 2023
5fe1c26
commit
sfc-gh-zhwang Oct 3, 2023
30963a3
commit
sfc-gh-zhwang Oct 3, 2023
8430816
commit
sfc-gh-zhwang Oct 3, 2023
05bb001
commit
sfc-gh-zhwang Oct 3, 2023
c846e6b
commit
sfc-gh-zhwang Oct 3, 2023
85ff2db
commit
sfc-gh-zhwang Oct 3, 2023
7ece020
commit
sfc-gh-zhwang Oct 3, 2023
e75f695
commit
sfc-gh-zhwang Oct 3, 2023
aff9753
commit
sfc-gh-zhwang Oct 3, 2023
eda4b15
commit
sfc-gh-zhwang Oct 3, 2023
e780880
commit
sfc-gh-zhwang Oct 3, 2023
1dfe406
commit
sfc-gh-zhwang Oct 3, 2023
0bbaf68
commit
sfc-gh-zhwang Oct 3, 2023
2f57ae2
commit
sfc-gh-zhwang Oct 3, 2023
5a9c22e
commit
sfc-gh-zhwang Oct 3, 2023
043d70a
commit
sfc-gh-zhwang Oct 3, 2023
37f299c
commit
sfc-gh-zhwang Oct 3, 2023
69cda55
commit
sfc-gh-zhwang Oct 3, 2023
4559ea5
commit
sfc-gh-zhwang Oct 3, 2023
fa58a40
commit
sfc-gh-zhwang Oct 3, 2023
a556d14
commit
sfc-gh-zhwang Oct 3, 2023
607b91b
commit
sfc-gh-zhwang Oct 3, 2023
5f9cb8c
commit
sfc-gh-zhwang Oct 3, 2023
2855a1d
commit
sfc-gh-zhwang Oct 3, 2023
0179a4d
commit
sfc-gh-zhwang Oct 3, 2023
30c7e0b
commit
sfc-gh-zhwang Oct 3, 2023
71e9d92
commit
sfc-gh-zhwang Oct 3, 2023
3f6c3f1
commit
sfc-gh-zhwang Oct 3, 2023
df4f09a
commit
sfc-gh-zhwang Oct 3, 2023
0bbbf28
commit
sfc-gh-zhwang Oct 3, 2023
1be605c
commit
sfc-gh-zhwang Oct 3, 2023
6ea635d
commit
sfc-gh-zhwang Oct 3, 2023
8a1dcb4
commit
sfc-gh-zhwang Oct 3, 2023
6f093ba
commit
sfc-gh-zhwang Oct 3, 2023
96733cc
commit
sfc-gh-zhwang Oct 3, 2023
1878627
commit
sfc-gh-zhwang Oct 4, 2023
c3168bb
commit
sfc-gh-zhwang Oct 4, 2023
ad7994d
commit
sfc-gh-zhwang Oct 4, 2023
82c70ec
commit
sfc-gh-zhwang Oct 4, 2023
308634b
commit
sfc-gh-zhwang Oct 4, 2023
d57b90a
commit
sfc-gh-zhwang Oct 4, 2023
144feed
commit
sfc-gh-zhwang Oct 4, 2023
5dca6c7
commit
sfc-gh-zhwang Oct 4, 2023
ddf451d
commit
sfc-gh-zhwang Oct 4, 2023
51968e4
commit
sfc-gh-zhwang Oct 4, 2023
ba1518c
commit
sfc-gh-zhwang Oct 4, 2023
06dc654
commit
sfc-gh-zhwang Oct 4, 2023
6d1f22b
commit
sfc-gh-zhwang Oct 4, 2023
7e3e8f5
commit
sfc-gh-zhwang Oct 4, 2023
e559a27
commit
sfc-gh-zhwang Oct 4, 2023
14fcd33
commit
sfc-gh-zhwang Oct 4, 2023
b3c6a3c
commit
sfc-gh-zhwang Oct 4, 2023
d35e6d7
commit
sfc-gh-zhwang Oct 4, 2023
26ca197
commit
sfc-gh-zhwang Oct 4, 2023
cf15d8c
commit
sfc-gh-zhwang Oct 4, 2023
b8bf161
commit
sfc-gh-zhwang Oct 4, 2023
402003e
commit
sfc-gh-zhwang Oct 4, 2023
d1b6ee4
commit
sfc-gh-zhwang Oct 4, 2023
5756704
commit
sfc-gh-zhwang Oct 4, 2023
690b601
commit
sfc-gh-zhwang Oct 4, 2023
66786ff
commit
sfc-gh-zhwang Oct 4, 2023
1386421
commit
sfc-gh-zhwang Oct 4, 2023
9130b59
commit
sfc-gh-zhwang Oct 4, 2023
be449be
commit
sfc-gh-zhwang Oct 4, 2023
0a11917
commit
sfc-gh-zhwang Oct 4, 2023
6475254
commit
sfc-gh-zhwang Oct 4, 2023
c6b429f
commit
sfc-gh-zhwang Oct 4, 2023
5f13792
commit
sfc-gh-zhwang Oct 4, 2023
2b76d11
commit
sfc-gh-zhwang Oct 4, 2023
6def8b7
commit
sfc-gh-zhwang Oct 4, 2023
9d6f5fa
commit
sfc-gh-zhwang Oct 4, 2023
afa44e0
commit
sfc-gh-zhwang Oct 4, 2023
644a16d
commit
sfc-gh-zhwang Oct 4, 2023
e15eb73
commit
sfc-gh-zhwang Oct 4, 2023
60c5491
commit
sfc-gh-zhwang Oct 4, 2023
f18b9bb
commit
sfc-gh-zhwang Oct 4, 2023
6c9b8af
commit
sfc-gh-zhwang Oct 4, 2023
bdbd3c1
commit
sfc-gh-zhwang Oct 4, 2023
5aa1cc7
commit
sfc-gh-zhwang Oct 4, 2023
402d3b4
commit
sfc-gh-zhwang Oct 4, 2023
ffc99b4
commit
sfc-gh-zhwang Oct 4, 2023
c134751
commit
sfc-gh-zhwang Oct 4, 2023
ce59bd3
commit
sfc-gh-zhwang Oct 4, 2023
dd844b1
commit
sfc-gh-zhwang Oct 4, 2023
5e5db2d
commit
sfc-gh-zhwang Oct 4, 2023
14f4c5c
commit
sfc-gh-zhwang Oct 4, 2023
3bf8f0b
commit
sfc-gh-zhwang Oct 4, 2023
0e395df
commit
sfc-gh-zhwang Oct 4, 2023
42decf9
commit
sfc-gh-zhwang Oct 4, 2023
9ca2847
commit
sfc-gh-zhwang Oct 4, 2023
c85348b
commit
sfc-gh-zhwang Oct 4, 2023
19eb92e
commit
sfc-gh-zhwang Oct 4, 2023
42a4959
commit
sfc-gh-zhwang Oct 4, 2023
93c09b8
commit
sfc-gh-zhwang Oct 4, 2023
ffbcbd7
commit
sfc-gh-zhwang Oct 4, 2023
bb48a5e
commit
sfc-gh-zhwang Oct 4, 2023
9fcff86
commit
sfc-gh-zhwang Oct 5, 2023
1cd8b24
commit
sfc-gh-zhwang Oct 5, 2023
45a879f
commit
sfc-gh-zhwang Oct 5, 2023
916588f
commit
sfc-gh-zhwang Oct 5, 2023
ed59157
commit
sfc-gh-zhwang Oct 5, 2023
dffe9c0
commit
sfc-gh-zhwang Oct 5, 2023
7396c9b
commit
sfc-gh-zhwang Oct 5, 2023
59cef67
commit
sfc-gh-zhwang Oct 5, 2023
1e36c9b
commit
sfc-gh-zhwang Oct 5, 2023
b97107f
commit
sfc-gh-zhwang Oct 5, 2023
5671a23
commit
sfc-gh-zhwang Oct 5, 2023
ea8c5b8
commit
sfc-gh-zhwang Oct 5, 2023
5b1b3ea
commit
sfc-gh-zhwang Oct 5, 2023
1323a79
commit
sfc-gh-zhwang Oct 5, 2023
b3c8f26
commit
sfc-gh-zhwang Oct 5, 2023
4eee2a5
commit
sfc-gh-zhwang Oct 5, 2023
63e1586
commit
sfc-gh-zhwang Oct 5, 2023
0794d49
commit
sfc-gh-zhwang Oct 5, 2023
4333b24
commit
sfc-gh-zhwang Oct 5, 2023
c9cd870
commit
sfc-gh-zhwang Oct 5, 2023
043661a
commit
sfc-gh-zhwang Oct 5, 2023
c1384a0
commit
sfc-gh-zhwang Oct 5, 2023
dfba6e5
commit
sfc-gh-zhwang Oct 5, 2023
6ac30f1
commit
sfc-gh-zhwang Oct 5, 2023
37ccba5
commit
sfc-gh-zhwang Oct 5, 2023
ff1966a
commit
sfc-gh-zhwang Oct 5, 2023
0353f25
commit
sfc-gh-zhwang Oct 5, 2023
1150022
commit
sfc-gh-zhwang Oct 5, 2023
d88b1dc
commit
sfc-gh-zhwang Oct 5, 2023
cebd483
commit
sfc-gh-zhwang Oct 5, 2023
79f1ca5
/opt/tritonserver/bin/tritonserver --model-repository=/models
sfc-gh-zhwang Oct 5, 2023
07bba5d
commit
sfc-gh-zhwang Oct 5, 2023
b474c78
commit
sfc-gh-zhwang Oct 5, 2023
e177bd4
commit
sfc-gh-zhwang Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cpp/bart/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repetition_penalty=1.0 ; Use for sampling
presence_penalty=0.0 ; Only one of repetition_penalty and presence_penalty are allowed.
len_penalty=0.0
beam_search_diversity_rate=0.0
request_batch_size=8 # determine by the request
request_batch_size=1 # determine by the request
request_output_len=32 # determine by the request

[encoder]
Expand Down
329 changes: 329 additions & 0 deletions examples/pytorch/bart/utils/huggingface_bart_convert.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions src/fastertransformer/kernels/decoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,34 @@ void invokeDecodingInitialize(bool* finished,
finished, sequence_length, word_ids, cum_log_probs, sentence_ids, batch_size, beam_width, max_input_length);
}

__global__ void forceId(int* word_ids,
const int* force_bos_ids,
const int batch_size,
const int beam_width,
const int step)
{
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width;
index += blockDim.x * gridDim.x) {
if (word_ids != nullptr) {
word_ids[index+step*batch_size*beam_width] = force_bos_ids[index / beam_width];
}
}
}

void invokeForceId(int* word_ids,
const int* force_bos_ids,
const int batch_size,
const int beam_width,
const int step,
cudaStream_t stream)
{
dim3 grid((int)ceil(batch_size * beam_width * 1.0 / 256));
dim3 block(256);

forceId<<<grid, block, 0, stream>>>(
word_ids, force_bos_ids, batch_size, beam_width, step);
}

template void invokeDecodingInitialize(bool* finished,
int* sequence_length,
int* word_ids,
Expand Down
7 changes: 7 additions & 0 deletions src/fastertransformer/kernels/decoding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ void invokeDecodingInitialize(bool* finished,
const int max_input_length,
cudaStream_t stream);

void invokeForceId(int* word_ids,
const int* force_bos_ids,
const int batch_size,
const int beam_width,
const int step,
cudaStream_t stream);

// get token from all_ids at step, then lookup from the embedding table
// by the token
template<typename T>
Expand Down
36 changes: 17 additions & 19 deletions src/fastertransformer/models/bart/BartDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,35 +546,33 @@ void BartDecoder<T>::forward(std::vector<Tensor>* outp
stream_);
}
sync_check_cuda_error();
// {
// {
// T* buf;
// int st = local_batch_size * d_model_;
// buf = new T[st];
// cudaMemcpy(buf, decoder_output, sizeof(T) * st, cudaMemcpyDeviceToHost);
// auto step_ptr = input_tensors->at(4).data;
// int step = ((int*)step_ptr)[0];
// if (step == 1) {
// printf("decoder_output at layer %d step %d\n", l, step);
// for (int i=0; i<50; i++) {
// printf("%f ", double(buf[i]));
// }
// printf("buf last: %f\n", double(buf[st-1]));
// printf("\n");
// }
// }
// }

if (isLastLayerParallelId(l) == true && pipeline_para_.rank_ != pipeline_para_.world_size_ - 1
&& pipeline_para_.world_size_ > 1) {
// ftNcclSend(decoder_output, local_batch_size * d_model_, pipeline_para_.rank_ + 1,
// pipeline_para_, stream_);

ftNcclSend(decoder_output + local_batch_size * d_model_ / tensor_para_.world_size_ * tensor_para_.rank_,
local_batch_size * d_model_ / tensor_para_.world_size_,
pipeline_para_.rank_ + 1,
pipeline_para_,
stream_);
}
// {
// T* buf;
// int st = local_batch_size * d_model_;
// buf = new T[st];
// cudaMemcpy(buf, decoder_output, sizeof(T) * st, cudaMemcpyDeviceToHost);
// auto step_ptr = input_tensors->at(4).data;
// int step = ((int*)step_ptr)[0];
// if (step == 1) {
// printf("decoder_output at layer %d step %d\n", l, step);
// for (int i=0; i<50; i++) {
// printf("%f ", double(buf[i]));
// }
// printf("buf last: %f\n", double(buf[st-1]));
// printf("\n");
// }
// }
}

if (is_free_buffer_after_forward_ == true) {
Expand Down
78 changes: 30 additions & 48 deletions src/fastertransformer/models/bart/BartDecoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void BartDecoding<T>::allocateBuffer(

start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false));
end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false));
forced_bos_ids_buf_ = (int*)(allocator_->reMalloc(forced_bos_ids_buf_, sizeof(int) * batch_size, false));

output_ids_buf_ =
(int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false));
Expand Down Expand Up @@ -182,6 +183,7 @@ void BartDecoding<T>::freeBuffer()
allocator_->free((void**)(&tiled_encoder_sequence_length_));

allocator_->free((void**)(&start_ids_buf_));
allocator_->free((void**)(&forced_bos_ids_buf_));
allocator_->free((void**)(&end_ids_buf_));

allocator_->free((void**)(&output_ids_buf_));
Expand Down Expand Up @@ -343,6 +345,7 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
// stop_words_list [batch_size, 2, stop_words_length], optional
// bad_words_list [batch_size, 2, stop_words_length], optional
// start_id [batch_size] on cpu, optional
// forced_bos_id [batch_size] on cpu, optional
// end_id [batch_size] on cpu, optional
// runtime_top_k [1] or [batch_size] on cpu, optional, uint.
// runtime_top_p [1] or [batch_size] on cpu, optional, float.
Expand Down Expand Up @@ -382,6 +385,7 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
dynamic_decode_layer_->setup(batch_size, beam_width, &input_map);
handleOptArg(&input_map, "start_id", start_ids_buf_, start_id_, batch_size);
handleOptArg(&input_map, "end_id", end_ids_buf_, end_id_, batch_size);
handleOptArg(&input_map, "forced_bos_id", forced_bos_ids_buf_, -1, batch_size);
}

FT_CHECK_WITH_INFO(input_tensors->at("encoder_output").shape[2] == d_model_,
Expand Down Expand Up @@ -734,18 +738,6 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
{"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_local_batch_size}},
{"is_initialize_random_table", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_initialize_random_table}}});

// {
// T* buf;
// int st = batch_size * beam_width * vocab_size_padded_;
// buf = new T[st];
// cudaMemcpy(buf, logits_buf_, sizeof(T) * st, cudaMemcpyDeviceToHost);
// printf("logits_buf_\n");
// for (int i=0; i<50; i++) {
// printf("%f ", double(buf[i]));
// }
// printf("buf last: %f\n", double(buf[st-1]));
// printf("\n");
// }
if (cache_indirections_[src_indir_idx] != nullptr) {
dynamic_decode_input_tensors.insert(
"src_cache_indirection",
Expand Down Expand Up @@ -803,19 +795,33 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
}
dynamic_decode_output_tensors.insert(*t);
}
// {
// int* buf;
// int st = batch_size * (max_seq_len+1);
// buf = new int[st];
// cudaMemcpy(buf, output_ids_buf_, sizeof(int) * st, cudaMemcpyDeviceToHost);
// printf("start_ids_buf_ before forward: %d\n", batch_size);
// for (int i=0; i<st; i++) {
// printf("%d ", buf[i]);
// }
// printf("\n");
// }

dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
if (step == 1 && input_tensors->isExist("forced_bos_id")) {
invokeForceId(output_ids_buf_,
forced_bos_ids_buf_,
batch_size,
beam_width,
step,
stream_);
sync_check_cuda_error();
}
// {
// for (auto t = dynamic_decode_output_tensors.begin(); t != dynamic_decode_output_tensors.end(); ++t) {
// printf("step: %d, t->first: %s\n", step, t->first.c_str());
// // printf("%s\n", t->second.toString().c_str());
// {
// int* buf;
// int st = t->second.size();
// buf = new int[st];
// cudaMemcpy(buf, t->second.data, sizeof(int) * t->second.size(), cudaMemcpyDeviceToHost);
// for (int i=0; i<st; i++) {
// printf("%d ", buf[i]);
// }
// printf("\n");
// }
// }
// printf("\n\n");
// }
}
}

Expand Down Expand Up @@ -944,19 +950,6 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
}
}

// {
// int* buf;
// int st = batch_size * (max_seq_len+1);
// buf = new int[st];
// cudaMemcpy(buf, output_ids_buf_, sizeof(int) * st, cudaMemcpyDeviceToHost);
// printf("output_ids_buf_ after finalize: %d\n", batch_size);
// for (int i=0; i<st; i++) {
// printf("%d ", buf[i]);
// }
// printf("\n");

// }

if (pipeline_para_.world_size_ > 1) {
ftNcclGroupStart();
if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) {
Expand Down Expand Up @@ -1023,17 +1016,6 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
// throw errors when detected
ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_);

// {
// int* buf;
// int st = 32;
// buf = new int[st];
// cudaMemcpy(buf, output_tensors->at("output_ids").data, sizeof(int) * st, cudaMemcpyDeviceToHost);
// printf("output_ids after finalize: %s %d\n", output_tensors->at("output_ids").toString().c_str(), batch_size);
// for (int i=0; i<st; i++) {
// printf("%d ", buf[i]);
// }
// printf("\n");
// }
if (is_free_buffer_after_forward_) {
freeBuffer();
}
Expand Down
5 changes: 3 additions & 2 deletions src/fastertransformer/models/bart/BartDecoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ class BartDecoding: public BaseLayer {
bool* finished_buf_ = nullptr;
bool* h_finished_buf_ = nullptr;

int* start_ids_buf_ = nullptr;
int* end_ids_buf_ = nullptr;
int* start_ids_buf_ = nullptr;
int* forced_bos_ids_buf_ = nullptr;
int* end_ids_buf_ = nullptr;

T* key_cache_ = nullptr;
T* value_cache_ = nullptr;
Expand Down
27 changes: 22 additions & 5 deletions src/fastertransformer/models/bart/BartDecodingWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,28 @@ void BartDecodingWeight<T>::loadModel(std::string dir_path)
loadWeightFromBin<T>(
weights_ptr[3], {(size_t)weights_size[3]}, dir_path + "/decoder.final_layer_norm.weight.bin", model_file_type);
if (bart_with_bias) {
loadWeightFromBin<T>(weights_ptr[4],
{(size_t)weights_size[4]},
dir_path + "/decoder.final_layer_norm.bias.bin",
model_file_type);
loadWeightFromBin<T>(weights_ptr[5], {(size_t)weights_size[5]}, dir_path + "/decoder.final_logits_bias.bin", model_file_type);
if (mbart) {
loadWeightFromBin<T>(weights_ptr[4],
{(size_t)weights_size[4]},
dir_path + "/decoder.layer_norm.weight.bin",
model_file_type);
loadWeightFromBin<T>(weights_ptr[5],
{(size_t)weights_size[5]},
dir_path + "/decoder.final_layer_norm.bias.bin",
model_file_type);
loadWeightFromBin<T>(weights_ptr[6],
{(size_t)weights_size[6]},
dir_path + "/decoder.layer_norm.bias.bin",
model_file_type);
loadWeightFromBin<T>(weights_ptr[7], {(size_t)weights_size[7]}, dir_path + "/decoder.final_logits_bias.bin", model_file_type);

} else {
loadWeightFromBin<T>(weights_ptr[4],
{(size_t)weights_size[4]},
dir_path + "/decoder.final_layer_norm.bias.bin",
model_file_type);
loadWeightFromBin<T>(weights_ptr[5], {(size_t)weights_size[5]}, dir_path + "/decoder.final_logits_bias.bin", model_file_type);
}
}

for (int l = 0; l < num_layer_; l++) {
Expand Down
4 changes: 2 additions & 2 deletions src/fastertransformer/models/bart/BartEncoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,15 +455,15 @@ void BartEncoder<T>::forward(TensorMap* output_tensors,
// {
// T* buf;
// int batch_size = 1;
// int seq_len = 11;
// int seq_len = 13;
// int st = batch_size * seq_len * d_model_;
// printf("st: %d %d %d %d\n",batch_size, seq_len, d_model_, st);
// buf = new T[st];
// cudaMemcpy(buf, bart_encoder_emb_buf_, sizeof(T) * st, cudaMemcpyDeviceToHost);
// printf("bart_encoder_emb_buf_\n");
// for (int i=0; i < seq_len; i++) {
// for (int j=0; j<d_model_; j++) {
// printf("%f ", double(buf[i+j*seq_len]));
// printf("%f ", double(buf[i * d_model_+ j]));
// if (j > 10) {
// break;
// }
Expand Down
23 changes: 19 additions & 4 deletions src/fastertransformer/models/bart/BartEncoderWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,25 @@ void BartEncoderWeight<T>::loadModel(std::string dir_path)
loadWeightFromBin<T>(
weights_ptr[2], {(size_t)weights_size[2]}, dir_path + "/encoder.final_layer_norm.weight.bin", model_file_type);
if (bart_with_bias) {
loadWeightFromBin<T>(weights_ptr[3],
{(size_t)weights_size[3]},
dir_path + "/encoder.final_layer_norm.bias.bin",
model_file_type);
if (mbart) {
loadWeightFromBin<T>(weights_ptr[3],
{(size_t)weights_size[3]},
dir_path + "/encoder.final_layer_norm.bias.bin",
model_file_type);
loadWeightFromBin<T>(weights_ptr[4],
{(size_t)weights_size[4]},
dir_path + "/encoder.layer_norm.weight.bin",
model_file_type);
loadWeightFromBin<T>(weights_ptr[5],
{(size_t)weights_size[5]},
dir_path + "/encoder.layer_norm.bias.bin",
model_file_type);
} else {
loadWeightFromBin<T>(weights_ptr[3],
{(size_t)weights_size[3]},
dir_path + "/encoder.final_layer_norm.bias.bin",
model_file_type);
}
}

for (int l = 0; l < num_layer_; l++) {
Expand Down
17 changes: 15 additions & 2 deletions src/fastertransformer/triton_backend/bart/BartTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ namespace ft = fastertransformer;

std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createBartModel(std::string model_dir)
{
printf("createBartModel\n");
INIReader reader = INIReader(model_dir + "/config.ini");
if (reader.ParseError() < 0) {
std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini"
<< "'\n";
return nullptr;
}

const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type");
const std::string data_type = "fp32"; // reader.Get("ft_instance_hyperparameter", "data_type");
if (data_type == "fp16") {
// return std::make_shared<BartTritonModel<half>>(reader, model_dir);
return std::make_shared<BartTritonModel<half>>(1, 1, 0, model_dir, 0);
Expand Down Expand Up @@ -60,6 +61,12 @@ BartTritonModel<T>::BartTritonModel(INIReader reader, std::string model_dir): mo
encoder_num_layer_ = reader.GetInteger("encoder", "num_layers");
encoder_vocab_size_ = reader.GetInteger("encoder", "vocab_size");
encoder_max_pos_seq_len_ = reader.GetInteger("encoder", "max_pos_seq_len");
mbart_para_ = reader.GetBoolean("encoder", "mbart", false);
if (mbart_para_) {
layernorm_type_ = ft::LayerNormType::pre_layernorm;
} else {
layernorm_type_ = ft::LayerNormType::post_layernorm;
}

// decoding
decoding_head_num_ = reader.GetInteger("decoder", "num_heads");
Expand Down Expand Up @@ -111,7 +118,13 @@ BartTritonModel<T>::BartTritonModel(size_t tensor_para_size,
encoder_vocab_size_ = reader.GetInteger("encoder", "vocab_size");
encoder_max_pos_seq_len_ =
reader.GetInteger("encoder", "max_pos_seq_len");

mbart_para_ = reader.GetBoolean("encoder", "mbart", false);
if (mbart_para_) {
layernorm_type_ = ft::LayerNormType::pre_layernorm;
} else {
layernorm_type_ = ft::LayerNormType::post_layernorm;
}

// decoding
decoding_head_num_ = reader.GetInteger("decoder", "num_heads");
decoding_size_per_head_ = reader.GetInteger("decoder", "d_kv");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ struct BartTritonModel: public AbstractTransformerModel {

// bart structure difference
bool bart_with_bias_ = true;
// TODO(zhwang): support mbart.
bool mbart_para_ = false;
bool use_gated_activation_ = false;
ft::PositionEmbeddingType position_embedding_type_ = ft::PositionEmbeddingType::absolute;
Expand Down
Loading