Skip to content

Commit

Permalink
Fix sdpa with batch size = 1, better benchmark (#915)
Browse files Browse the repository at this point in the history
* fix bug in spda

* simplify t5

* simplify

* fix

* typo

* update bench

* update bench

* style

---------

Co-authored-by: Your Name <you@example.com>
  • Loading branch information
fxmarty and Your Name committed Mar 23, 2023
1 parent 1cef3b0 commit edf535e
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 114 deletions.
97 changes: 49 additions & 48 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def wrapped_scaled_dot_product(
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
query = query.to(value.dtype)
key = key.to(value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

Expand Down Expand Up @@ -86,16 +92,15 @@ def wrapped_scaled_dot_product(
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
attention_mask = causal_mask + attention_mask

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
query = query.to(value.dtype)
key = key.to(value.dtype)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
sdpa_result = sdpa_result.to(value.dtype)

return sdpa_result, None

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -201,6 +206,11 @@ def wrapped_scaled_dot_product(
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)

if batch_size == 1:
if query.shape[2] > 1:
# first step of the decoding
Expand Down Expand Up @@ -232,11 +242,6 @@ def wrapped_scaled_dot_product(
# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)

# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
Expand Down Expand Up @@ -381,6 +386,8 @@ def forward(
super().forward_checker()
raise_on_head_mask(layer_head_mask)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
if len(self.orig_layer.pruned_heads) > 0:
raise ValueError(
f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.orig_layer.pruned_heads}."
Expand Down Expand Up @@ -451,58 +458,52 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
past_key_value[1] if past_key_value is not None else None,
)

if (position_bias is None and not self.orig_layer.has_relative_attention_bias) or (
position_bias is not None and position_bias[0, 0, 0, 0] == 0
):
if position_bias is None and not self.orig_layer.has_relative_attention_bias:
query_states = self.scale * query_states
if position_bias is None and not self.orig_layer.has_relative_attention_bias:
if mask is None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
)
elif mask is not None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=mask, dropout_p=0.0, is_causal=False
)

if position_bias is None:
if not self.orig_layer.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.orig_layer.n_heads, real_seq_length, key_length),
device=query_states.device,
dtype=query_states.dtype,
device=value_states.device,
dtype=value_states.dtype,
)
if self.orig_layer.gradient_checkpointing and self.orig_layer.training:
position_bias.requires_grad = True
else:
position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=value_states.device)

query_states = self.scale * query_states
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
)

else:
# compute scores
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

if position_bias is None:
position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=scores.device)

# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

scores += position_bias
attn_weights = torch.nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = torch.nn.functional.dropout(
attn_weights, p=self.orig_layer.dropout, training=self.orig_layer.training
) # (batch_size, n_heads, seq_length, key_length)
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)

attn_output = torch.matmul(attn_weights, value_states)
if self.orig_layer.has_relative_attention_bias:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False
)

attn_output = unshape(attn_output) # (batch_size, seq_length, dim)
attn_output = self.orig_layer.o(attn_output)

present_key_value_state = (key_states, value_states) if (self.orig_layer.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

if output_attentions:
outputs = outputs + (attn_weights,)
return outputs


Expand Down
126 changes: 80 additions & 46 deletions tests/benchmark/benchmark_bettertransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import torch
from tqdm import tqdm
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

from optimum.bettertransformer import BetterTransformer
from optimum.exporters import TasksManager


def get_parser():
Expand Down Expand Up @@ -104,57 +105,55 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s
def timing_cuda(model, num_batches, input_ids, masks, is_decoder, generation_config=None):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
torch.cuda.synchronize()

start_event.record()
for _ in range(num_batches):
for _ in tqdm(range(num_batches)):
if is_decoder:
_ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config)
else:
_ = model(input_ids, masks)
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches
max_memory = torch.cuda.max_memory_allocated(device)

return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches, max_memory

def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id):

def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id):
# Warmup
if is_decoder:
min_length = max(max_token - 20, 5)

gen_config = GenerationConfig(
do_greedy=True,
max_new_tokens=max_token,
min_length=min_length,
min_new_tokens=max_token,
use_cache=True,
pad_token_id=pad_token_id,
)
_ = hf_model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
torch.cuda.synchronize()
bt_model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
_ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
torch.cuda.synchronize()

else:
_ = hf_model(input_ids, masks)
torch.cuda.synchronize()
_ = bt_model(input_ids, masks)
_ = model(input_ids, masks)
torch.cuda.synchronize()

# benchmark
if is_decoder:
total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks, is_decoder, gen_config)
total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks, is_decoder, gen_config)
total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder, gen_config)
else:
total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks, is_decoder)
total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks, is_decoder)
total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder)

return total_bt_time, total_hf_time
return total_time, max_mem


if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()

BATCH_SIZES = [8, 16, 64]
SEQ_LEN = [64, 128, 256]
BATCH_SIZES = [2]
SEQ_LEN = [64]
if args.is_decoder:
PAD_PERCENTAGES = [0]
else:
Expand All @@ -166,63 +165,98 @@ def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

if args.is_decoder:
hf_model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.float16 if args.use_half else None, use_cache=True
).eval()
task = TasksManager.infer_task_from_model(args.model_name)

if task == "causal-lm":
autoclass = AutoModelForCausalLM
elif task == "seq2seq-lm":
autoclass = AutoModelForSeq2SeqLM
else:
hf_model = AutoModel.from_pretrained(
args.model_name, torch_dtype=torch.float16 if args.use_half else None
).eval()
autoclass = AutoModel

if args.use_cuda:
hf_model = hf_model.to(0)
with torch.device("cuda:0"):
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None)
# in PyTorch we trust :)
hf_model = hf_model.to("cuda:0")
hf_model = hf_model.to(torch.float16)
else:
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None)

bt_model = BetterTransformer.transform(hf_model, keep_original_model=True)

output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w")
output_file.write(
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup\n"
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup, Mem eager (MB), Mem BT (MB), Mem saved\n"
)
for bs in tqdm(BATCH_SIZES):
for seq_len in tqdm(SEQ_LEN):
for pad_perc in tqdm(PAD_PERCENTAGES):
print(f"-- Running: bs={bs}, seq_len={seq_len}")
# current_std = int(seq_len*pad_perc)
# max_seqlen = seq_len + current_std
max_seqlen = seq_len
mean_seqlen = int((1 - pad_perc) * max_seqlen)
input_ids, _, masks = get_batch(bs, mean_seqlen, max_seqlen, args.seqlen_stdev)
input_ids, _, masks = get_batch(
bs, mean_seqlen, max_seqlen, args.seqlen_stdev, vocab_size=hf_model.config.vocab_size
)

if args.use_cuda:
input_ids = input_ids.to(device)
masks = masks.to(device)
if not args.use_mask:

if args.use_mask is False and bs == 1:
masks = None

total_bt_time, total_hf_time = benchmark(
hf_model,
bt_model,
input_ids,
masks,
args.num_batches,
args.is_decoder,
args.max_token,
tokenizer.pad_token_id,
)
with torch.inference_mode():
total_hf_time, max_mem_eager = benchmark(
hf_model,
input_ids,
masks,
args.num_batches,
args.is_decoder,
args.max_token,
tokenizer.pad_token_id,
)

# raise error if no optimized kernel is available
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
total_bt_time, max_mem_bt = benchmark(
bt_model,
input_ids,
masks,
args.num_batches,
args.is_decoder,
args.max_token,
tokenizer.pad_token_id,
)

speedup = total_hf_time / total_bt_time
mem_saved = max_mem_eager / max_mem_bt

max_mem_eager = max_mem_eager * 1e-6
max_mem_bt = max_mem_bt * 1e-6

print(f"PT eager: {total_hf_time:.3f} s, peak {max_mem_eager:.2f} MB")
print(f"PT native: {total_bt_time:.3f} s, peak {max_mem_bt:.2f} MB")

output_file.write(
"{},{},{},{},{},{},{},{},{},{}\n".format(
"{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
args.num_batches,
args.use_cuda,
bs,
seq_len,
args.use_half,
args.use_mask,
pad_perc,
total_hf_time,
total_bt_time,
speedup,
f"{total_hf_time:.3f}",
f"{total_bt_time:.3f}",
f"{speedup:.3f}",
f"{max_mem_eager:.3f}",
f"{max_mem_bt:.3f}",
f"{mem_saved:.3f}",
)
)
output_file.close()
Loading

0 comments on commit edf535e

Please sign in to comment.