Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightSeq QAT #307

Merged
merged 132 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
13ac633
ls embedding support qat
Apr 7, 2022
ace049a
[WIP]ls transformer qat
May 5, 2022
23041c2
fix fairseq transformer cli shape bug of output projection
godweiyang May 6, 2022
5c6ad24
ln_bw_i8 test passed!
godweiyang May 7, 2022
576de44
test with_mean of ln_i8
godweiyang May 7, 2022
178b774
update to the latest version of master, fix conflict
godweiyang May 7, 2022
2c6c30b
ls encoder attn add qat
May 9, 2022
580775c
dropout_relu_bias_i8 passed!
godweiyang May 10, 2022
dd1229b
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
godweiyang May 10, 2022
75eb4f4
dropout_gelu_bias unit test passed!
godweiyang May 16, 2022
e4b8e9c
dropout_relu_bias_bwd_i8 passed!
godweiyang May 16, 2022
0b0c890
Merge branch 'ls-qat-wy' into ls-qat
godweiyang May 16, 2022
6fab849
dropout_gelu_bias_bwd_i8 unit test passed!
godweiyang May 17, 2022
b5ab256
format
godweiyang May 17, 2022
f2f8401
dropout_gelu_bias_bwd_i8 unit test passed!
godweiyang May 17, 2022
59e5ebf
format
godweiyang May 17, 2022
d39f15c
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
godweiyang May 18, 2022
ffda943
polish unit test
godweiyang May 18, 2022
5370059
[WIP] ls encoder qat test
May 23, 2022
e66708c
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
May 23, 2022
20231b6
quant_bias_add_transform_20314, quant_transform4d_0213 unit test passed!
godweiyang May 23, 2022
991ca74
fix unit test bug
godweiyang May 26, 2022
a8af545
[WIP] ls encoder qat unit test
May 30, 2022
e03da8a
fix bug
godweiyang May 30, 2022
238f717
set default module to disable quant, fix bugs in examples
godweiyang May 30, 2022
1e035a8
fix encoder bug
Jun 1, 2022
b6ec156
encoder qat test pass
Jun 6, 2022
f934ca0
decoder qat forward test pass
Jun 8, 2022
3d15555
fix bug in encoder bw
godweiyang Jun 9, 2022
a98e6b5
fix conflict
godweiyang Jun 9, 2022
7d06ffc
fix bug of cmax grad
godweiyang Jun 10, 2022
0a8fc46
fix bug of act mask
godweiyang Jun 10, 2022
89208b1
fix bug in tensor quantizer
godweiyang Jun 10, 2022
09a3d10
fix cmax grad bug
godweiyang Jun 13, 2022
07e376d
[WIP] decoder support qat
Jun 14, 2022
c57b398
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
Jun 14, 2022
a380b72
ls decoder qat pass
Jun 20, 2022
09f6bf5
ls encoder qat pass
Jun 22, 2022
4e2b257
add unit test for quant bert encoder
godweiyang Jul 5, 2022
956f681
fix memory bug
godweiyang Jul 11, 2022
1c9ddb4
fix cmax grad bug in huggingface
godweiyang Jul 12, 2022
db9d160
quant bert enc fw&bw test passed!
godweiyang Jul 14, 2022
e6338f7
fix hf cmax export bug
godweiyang Jul 14, 2022
d508f62
fix fairseq out_proj bug
godweiyang Jul 14, 2022
c641567
fix fairseq shell bug
godweiyang Jul 14, 2022
60a368e
fix decoder mem bug
godweiyang Jul 15, 2022
b18995f
modify initial lr of fairseq quant training
godweiyang Jul 15, 2022
04d1291
decoupled qat code
Jul 18, 2022
972a54d
modify huggingface training scripts
godweiyang Jul 18, 2022
7ea40bd
decoupled qat code
godweiyang Jul 18, 2022
1b117f0
add cmax grad
Jul 18, 2022
1859341
delete enc_kv output quant
godweiyang Jul 18, 2022
7d4a8cd
modify ffn2gemm quant like inference
godweiyang Jul 20, 2022
b569eb7
fuse dequantize
Jul 20, 2022
94ffe31
fix post ln mem bug
godweiyang Jul 20, 2022
d596b29
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
godweiyang Jul 20, 2022
eb6c59a
add decoder self attn qkv cache quant
godweiyang Jul 21, 2022
9cd64f5
export quant model (stage 1)
godweiyang Jul 21, 2022
b5ab18c
export quant model (stage 2)
godweiyang Jul 22, 2022
42607a1
export quant model (stage 3)
godweiyang Jul 22, 2022
8358e48
support vit quant train
godweiyang Jul 26, 2022
5d40f93
add gradient clip
Jul 26, 2022
725f112
fix hf export bug
godweiyang Jul 26, 2022
28b2b76
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
godweiyang Jul 26, 2022
075af1a
fix quant gpt bug
godweiyang Jul 26, 2022
b622b0d
support quant gpt training
godweiyang Jul 26, 2022
d3d3566
modify huggingface training scripts
godweiyang Jul 27, 2022
baba9d7
support ls bert, gpt export
godweiyang Jul 27, 2022
3bbfac3
support custom quant transformer export
godweiyang Jul 28, 2022
483c252
optimizer ffn fake quant and dcmax
Jul 28, 2022
0983fe0
support quant gpt export
godweiyang Jul 28, 2022
bfa3b76
support quant vit export
godweiyang Jul 29, 2022
0c21296
add quant linear layer
Jul 29, 2022
89a92e7
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
Jul 29, 2022
e683775
fix quant linear layer bug
Jul 29, 2022
c180d2c
support quant vit infer
godweiyang Aug 1, 2022
e8ad835
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
godweiyang Aug 1, 2022
931c0b0
speedup cublass igemm on A100 (by huxingwu)
godweiyang Aug 2, 2022
0fe0862
optimize ls_quant_dropout_act_bias_bwd_kernel
Aug 2, 2022
ee03974
polish training gemm algo code
godweiyang Aug 3, 2022
56211cb
support gemm best algo search on different GPUs and shapes
godweiyang Aug 5, 2022
2e99f37
search in the range (min_bsz, 512, 1) and (512, max_bsz, 32)
godweiyang Aug 5, 2022
e2439e1
add configs_sm75/h512_i2048_b1-10016.json
godweiyang Aug 5, 2022
156bb6e
support col32 igemm
Aug 5, 2022
477d3e7
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
Aug 5, 2022
01dccf8
add configs_sm75/h768_i3072_b1-10016.json
godweiyang Aug 5, 2022
6f13034
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
godweiyang Aug 5, 2022
4fbbb0f
add configs_sm80/h512_i2048_b1-10016.json
godweiyang Aug 5, 2022
8305cd3
add configs_sm75/h1024_i4096_b1-10016.json
godweiyang Aug 5, 2022
3f02450
add configs_sm80/h768_i3072_b1-10016.json
godweiyang Aug 5, 2022
c0c9d81
fix syntax error
godweiyang Aug 5, 2022
1b23508
configs_sm80/h1024_i4096_b1-10016.json
godweiyang Aug 5, 2022
d521929
modify gemm test config format
godweiyang Aug 8, 2022
3868f59
merge all the configs to one
godweiyang Aug 8, 2022
33702dd
support search all shapes which are not in the config
godweiyang Aug 8, 2022
a9104d0
polish the merged config
godweiyang Aug 8, 2022
430e4e4
add cublas_algo_map cpp code
godweiyang Aug 9, 2022
aa7d4a1
move get_sm func to lightseq kernels
godweiyang Aug 9, 2022
cf21734
move gemm_test to lightseq ops
godweiyang Aug 9, 2022
e582959
modify default config dir, fix algo_map bug
godweiyang Aug 9, 2022
d91024a
fix col32 bug
Aug 10, 2022
690801e
col major igemm become default
Aug 11, 2022
2877db5
fix dcax kernel bug
Aug 11, 2022
649a037
loosen cuda 11.6 requirement
Aug 12, 2022
fc6a035
add vit cpp example
godweiyang Aug 12, 2022
0a76d2e
Merge branch 'ls-qat' of https://github.com/bytedance/lightseq into l…
godweiyang Aug 12, 2022
cd9c85c
fix bug from col32 gemm and a100 tuned col gemm
Aug 15, 2022
7eac0ed
support training encoder qkv_linear auto-tune gemm (in comment)
godweiyang Aug 16, 2022
7a1bd0c
add required header file
godweiyang Aug 16, 2022
1904c74
dynamic use col32 or col4 in different GPUs
godweiyang Aug 16, 2022
220c2e8
fix multidefinition bug
godweiyang Aug 16, 2022
964ff11
fix weight transform col32 bug
godweiyang Aug 17, 2022
50f6512
add best algo for inference gemm (in comments)
godweiyang Aug 18, 2022
78857bb
support easy benchmark for gpt and transformer
Aug 19, 2022
36b6607
support benmark huggingface
Aug 22, 2022
efa5c71
fix embedding clip_max bug
Aug 23, 2022
3e63d41
ls quant linear support more shape
Aug 23, 2022
12d58ed
fix quant linear bug
Aug 24, 2022
c8e5b89
fix quant linear bug
Aug 24, 2022
41e5cec
update pad function for older torch
Aug 24, 2022
33cc7b0
fix quant linear bug
Aug 24, 2022
123874a
remove redundant code
Aug 24, 2022
8dac833
Merge branch 'master' into ls-qat
Aug 26, 2022
332e22e
fix export bug
Aug 26, 2022
a8df5d6
Merge remote-tracking branch 'origin/master' into ls-qat
Aug 26, 2022
11af3e7
fix format
Aug 26, 2022
389106d
fix conflicts
godweiyang Aug 30, 2022
a62cad9
fix custom train&infer bug
godweiyang Aug 30, 2022
541ec0f
fix quant infer size overflow
godweiyang Aug 30, 2022
91d7c1a
fix ls gpt export bug (extra_decode_length)
godweiyang Aug 30, 2022
a42a98c
fix hf bart cmax init and state
Aug 30, 2022
2e8fe5a
fix max-batch-tokens bug of bart predict
godweiyang Aug 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(LightSeq LANGUAGES C CXX CUDA)

set(CMAKE_CUDA_ARCHITECTURES
60
61
70
75
80
86
87)
find_package(CUDA 11.6 REQUIRED)
set(CMAKE_CUDA_ARCHITECTURES 60 61 70 75 80 86)
find_package(CUDA 11 REQUIRED)

option(FP16_MODE "inference with fp16" OFF)
option(DEBUG_MODE "debug computation result" OFF)
Expand Down
4 changes: 2 additions & 2 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ The main code is as follows (some parameters are omitted). Complete code is avai
```python
model = Transformer()
encoder_state_dict, decoder_state_dict = _extract_weight(state_dict)
export_ls_embedding(model, encoder_state_dict, is_encoder=True)
export_ls_embedding(model, encoder_state_dict, is_encoder=False)
export_ls_embedding(model, encoder_state_dict, max_length, emb_dim, is_encoder=True)
export_ls_embedding(model, encoder_state_dict, max_length, emb_dim, is_encoder=False)
export_ls_encoder(model, encoder_state_dict)
export_ls_decoder(model, decoder_state_dict)
export_fs_weights(model, state_dict)
Expand Down
40 changes: 40 additions & 0 deletions examples/inference/benchmark_bart.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

SCRIPT=$(realpath "$0")
CUR_DIR=$(dirname "$SCRIPT")

model_full_name=facebook/bart-base
model_name=$(echo $model_full_name | cut -d "/" -f 2)
all_log=$CUR_DIR/${model_name}_bench.log
res_log=$CUR_DIR/${model_name}_bench.txt
if [ -f $res_log ]; then
rm $res_log
fi
if [ -f $all_log ]; then
rm $all_log
fi
echo "batch_size input_seq_len output_seq_len beam_size latency" >>$res_log

for batch_size in 1 8 32; do
for beam_size in 1 4 32; do
for input_seq_len in 8 16 32 64; do
output_seq_len=$input_seq_len
cd $CUR_DIR/python

python3 generate_model.py --model_name $model_full_name --sampling_method beam_search \
--beam_size $beam_size --input_seq_len $input_seq_len --output_seq_len=$output_seq_len
model_path=$(realpath lightseq_${model_name}_bench.hdf5)

cd $CUR_DIR/../../build
./examples/inference/cpp/transformer_example \
$model_path $batch_size $input_seq_len |& tee temp.log

cat temp.log >>$all_log
latency=$(tail -n 5 temp.log | head -n 1 | awk '{print $4}')
echo "$batch_size $input_seq_len $output_seq_len $beam_size $latency" >>$res_log
rm temp.log
done
done
done
pip3 install tabulate
tabulate --header $res_log
40 changes: 40 additions & 0 deletions examples/inference/benchmark_gpt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

SCRIPT=$(realpath "$0")
CUR_DIR=$(dirname "$SCRIPT")

model_full_name=gpt2
model_name=$model_full_name
all_log=$CUR_DIR/${model_name}_bench.log
res_log=$CUR_DIR/${model_name}_bench.txt
if [ -f $res_log ]; then
rm $res_log
fi
if [ -f $all_log ]; then
rm $all_log
fi
echo "batch_size input_seq_len output_seq_len topk latency" >>$res_log

for batch_size in 1 8 32; do
for topk in 1 4 32; do
for input_seq_len in 118 86 22; do
output_seq_len=$((150 - $input_seq_len))
cd $CUR_DIR/python

python3 generate_model.py --model_name $model_full_name --sampling_method topk \
--topk $topk --input_seq_len $input_seq_len --output_seq_len=$output_seq_len
model_path=$(realpath lightseq_${model_name}_bench.hdf5)

cd $CUR_DIR/../../build
./examples/inference/cpp/gpt_example \
$model_path $batch_size $input_seq_len |& tee temp.log

cat temp.log >>$all_log
latency=$(tail -n 3 temp.log | head -n 1 | awk '{print $4}')
echo "$batch_size $input_seq_len $output_seq_len $topk $latency" >>$res_log
rm temp.log
done
done
done
pip3 install tabulate
tabulate --header $res_log
41 changes: 41 additions & 0 deletions examples/inference/benchmark_quant_bart.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash

SCRIPT=$(realpath "$0")
CUR_DIR=$(dirname "$SCRIPT")

model_full_name=facebook/bart-base
model_name=$(echo $model_full_name | cut -d "/" -f 2)
all_log=$CUR_DIR/quant_${model_name}_bench.log
res_log=$CUR_DIR/quant_${model_name}_bench.txt
if [ -f $all_log ]; then
rm $res_log
fi
if [ -f $res_log ]; then
rm $res_log
fi
echo "batch_size input_seq_len output_seq_len beam_size latency" >>$res_log

for batch_size in 1 8 32; do
for beam_size in 1 4 32; do
for input_seq_len in 16 32 64; do
output_seq_len=$input_seq_len
cd $CUR_DIR/python

python3 generate_model.py --model_name $model_full_name --sampling_method beam_search \
--beam_size $beam_size --input_seq_len $input_seq_len --output_seq_len=$output_seq_len
model_path=$(realpath lightseq_${model_name}_bench.hdf5)

cd $CUR_DIR/../../build
./examples/inference/cpp/quant_transformer_example \
$model_path $batch_size $input_seq_len |& tee temp.log

cat temp.log >> $all_log
latency=$(tail -n 5 temp.log | head -n 1 | awk '{print $4}')
echo "$batch_size $input_seq_len $output_seq_len $beam_size $latency" >>$res_log
rm temp.log
done
done
done

pip3 install tabulate
tabulate --header $res_log
40 changes: 40 additions & 0 deletions examples/inference/benchmark_quant_gpt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

SCRIPT=$(realpath "$0")
CUR_DIR=$(dirname "$SCRIPT")

model_full_name=/tmp/quant/test-clm/pytorch_model.bin
model_name=quant_gpt2
all_log=$CUR_DIR/${model_name}_bench.log
res_log=$CUR_DIR/${model_name}_bench.txt
if [ -f $res_log ]; then
rm $res_log
fi
if [ -f $all_log ]; then
rm $all_log
fi
echo "batch_size input_seq_len output_seq_len topk latency" >>$res_log

for batch_size in 1 8 32; do
for topk in 1 4 32; do
for input_seq_len in 118 86 22; do
output_seq_len=$((150 - $input_seq_len))
cd $CUR_DIR/python

python3 generate_model.py --model_name $model_full_name --sampling_method topk \
--topk $topk --input_seq_len $input_seq_len --output_seq_len=$output_seq_len --enable_quant true
model_path=$(realpath lightseq_${model_name}_bench.hdf5)

cd $CUR_DIR/../../build
./examples/inference/cpp/quant_gpt_example \
$model_path $batch_size $input_seq_len |& tee temp.log

cat temp.log >>$all_log
latency=$(tail -n 3 temp.log | head -n 1 | awk '{print $4}')
echo "$batch_size $input_seq_len $output_seq_len $topk $latency" >>$res_log
rm temp.log
done
done
done
pip3 install tabulate
tabulate --header $res_log
6 changes: 6 additions & 0 deletions examples/inference/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ target_link_libraries(quant_gpt_example PUBLIC liblightseq)

add_executable(transformer_decoder_example decoder_example.cc.cu)
target_link_libraries(transformer_decoder_example PUBLIC transformer_model)

add_executable(vit_example vit_example.cc)
target_link_libraries(vit_example PUBLIC liblightseq)

add_executable(quant_vit_example quant_vit_example.cc)
target_link_libraries(quant_vit_example PUBLIC liblightseq)
21 changes: 15 additions & 6 deletions examples/inference/cpp/gpt_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@ int main(int argc, char* argv[]) {
std::string model_weights_path = argv[1];
std::vector<int> example_input = {40, 1842, 345, 11, 475, 345, 910, 326};
int eg_seq_len = example_input.size();
int max_batch_size = 128;

int batch_size = 1;
int batch_seq_len = eg_seq_len;

if (argc == 4) {
batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
}
if (batch_size > max_batch_size) {
throw std::runtime_error("batch_size exceeds the maximum (128)!");
}

int max_batch_size = std::max(8, batch_size);

std::vector<int> host_input;
for (int i = 0; i < batch_size; ++i) {
Expand All @@ -39,6 +38,7 @@ int main(int argc, char* argv[]) {
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len,
cudaMemcpyHostToDevice));

model->benchmark_mode(true);
model->set_input_ptr(0, d_input);
model->set_input_shape(0, {batch_size, batch_seq_len});

Expand All @@ -56,13 +56,22 @@ int main(int argc, char* argv[]) {
lightseq::cuda::CHECK_GPU_ERROR(cudaStreamSynchronize(0));
std::cout << "infer preprocessing finished" << std::endl;

std::chrono::duration<double> elapsed;
int iter = 0;
/* ---step5. infer and log--- */
for (int i = 0; i < 10; i++) {
for (int i = 0; i < 20; i++) {
auto start = std::chrono::high_resolution_clock::now();
model->Infer();
lightseq::cuda::print_time_duration(start, "one infer time", 0);
auto finish = std::chrono::high_resolution_clock::now();
if (i >= 5) {
iter++;
elapsed += finish - start;
}
}

std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
<< " ms" << std::endl;

for (int i = 0; i < model->get_output_size(); i++) {
const int* d_output;
d_output = static_cast<const int*>(model->get_output_ptr(i));
Expand Down
18 changes: 15 additions & 3 deletions examples/inference/cpp/quant_gpt_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ int main(int argc, char* argv[]) {
std::string model_weights_path = argv[1];
std::vector<int> example_input = {40, 1842, 345, 11, 475, 345, 910, 326};
int eg_seq_len = example_input.size();
int max_batch_size = 128;
int batch_size = 1;
int batch_seq_len = eg_seq_len;

if (argc == 4) {
batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
}

int max_batch_size = std::max(4, batch_size);

if (batch_size > max_batch_size) {
throw std::runtime_error("batch_size exceeds the maximum (128)!");
}
Expand All @@ -39,6 +41,7 @@ int main(int argc, char* argv[]) {
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len,
cudaMemcpyHostToDevice));

model->benchmark_mode(true);
model->set_input_ptr(0, d_input);
model->set_input_shape(0, {batch_size, batch_seq_len});

Expand All @@ -56,13 +59,22 @@ int main(int argc, char* argv[]) {
lightseq::cuda::CHECK_GPU_ERROR(cudaStreamSynchronize(0));
std::cout << "infer preprocessing finished" << std::endl;

std::chrono::duration<double> elapsed;
int iter = 0;
/* ---step5. infer and log--- */
for (int i = 0; i < 10; i++) {
for (int i = 0; i < 20; i++) {
auto start = std::chrono::high_resolution_clock::now();
model->Infer();
lightseq::cuda::print_time_duration(start, "one infer time", 0);
auto finish = std::chrono::high_resolution_clock::now();
if (i >= 5) {
iter++;
elapsed += finish - start;
}
}

std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
<< " ms" << std::endl;

for (int i = 0; i < model->get_output_size(); i++) {
const int* d_output;
d_output = static_cast<const int*>(model->get_output_ptr(i));
Expand Down
21 changes: 14 additions & 7 deletions examples/inference/cpp/quant_transformer_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,14 @@ int main(int argc, char* argv[]) {
std::vector<int> example_input = {63, 47, 65, 1507, 88, 74,
10, 2057, 362, 9, 284, 6};
int eg_seq_len = example_input.size();
int max_batch_size = 128;
int batch_size = 1;
int batch_seq_len = eg_seq_len;

if (argc == 4) {
batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
}
if (batch_size > max_batch_size) {
throw std::runtime_error("batch_size exceeds the maximum (128)!");
}
int max_batch_size = std::max(4, batch_size);

std::vector<int> host_input;
for (int i = 0; i < batch_size; ++i) {
Expand All @@ -41,6 +38,7 @@ int main(int argc, char* argv[]) {
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len,
cudaMemcpyHostToDevice));

model->benchmark_mode(true);
model->set_input_ptr(0, d_input);
model->set_input_shape(0, {batch_size, batch_seq_len});

Expand All @@ -58,13 +56,22 @@ int main(int argc, char* argv[]) {
lightseq::cuda::CHECK_GPU_ERROR(cudaStreamSynchronize(0));
std::cout << "infer preprocessing finished" << std::endl;

std::chrono::duration<double> elapsed;
int iter = 0;
/* ---step5. infer and log--- */
for (int i = 0; i < 20; i++) {
auto start = std::chrono::high_resolution_clock::now();
model->Infer();
lightseq::cuda::print_time_duration(start, "one infer time", 0);
auto finish = std::chrono::high_resolution_clock::now();
if (i >= 5) {
iter++;
elapsed += finish - start;
}
}

std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
<< " ms" << std::endl;

for (int i = 0; i < model->get_output_size(); i++) {
const void* d_output;
d_output = static_cast<const float*>(model->get_output_ptr(i));
Expand All @@ -76,9 +83,9 @@ int main(int argc, char* argv[]) {
std::cout << std::endl;

if (!i)
lightseq::cuda::print_vec((int*)d_output, "output", 15);
lightseq::cuda::print_vec((int*)d_output, "output", batch_size);
else
lightseq::cuda::print_vec((float*)d_output, "output", 5);
lightseq::cuda::print_vec((float*)d_output, "output", batch_size);
}

// const int* res = model.get_result_ptr();
Expand Down
Loading