Skip to content

Commit

Permalink
[BT] Fix bt benchmark (#858)
Browse files Browse the repository at this point in the history
fix script
  • Loading branch information
younesbelkada authored Mar 6, 2023
1 parent c7b384a commit f297f2a
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/benchmark/benchmark_bettertransformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

import torch
from tqdm import tqdm
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig

from optimum.bettertransformer import BetterTransformer
Expand Down Expand Up @@ -154,7 +155,11 @@ def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max

BATCH_SIZES = [8, 16, 64]
SEQ_LEN = [64, 128, 256]
PAD_PERCENTAGES = [0, 0.1, 0.2, 0.5, 0.75]
if args.is_decoder:
PAD_PERCENTAGES = [0]
else:
PAD_PERCENTAGES = [0, 0.1, 0.2, 0.5, 0.75]

device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)

Expand All @@ -178,9 +183,9 @@ def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max
output_file.write(
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup\n"
)
for bs in BATCH_SIZES:
for seq_len in SEQ_LEN:
for pad_perc in PAD_PERCENTAGES:
for bs in tqdm(BATCH_SIZES):
for seq_len in tqdm(SEQ_LEN):
for pad_perc in tqdm(PAD_PERCENTAGES):
# current_std = int(seq_len*pad_perc)
# max_seqlen = seq_len + current_std
max_seqlen = seq_len
Expand Down

0 comments on commit f297f2a

Please sign in to comment.