Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sdpa with batch size = 1, better benchmark #915

Merged
merged 9 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def wrapped_scaled_dot_product(
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
query = query.to(value.dtype)
key = key.to(value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

Expand Down Expand Up @@ -86,16 +92,15 @@ def wrapped_scaled_dot_product(
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
attention_mask = causal_mask + attention_mask

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
query = query.to(value.dtype)
key = key.to(value.dtype)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
sdpa_result = sdpa_result.to(value.dtype)

return sdpa_result, None

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -201,6 +206,11 @@ def wrapped_scaled_dot_product(
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)

if batch_size == 1:
if query.shape[2] > 1:
# first step of the decoding
Expand Down Expand Up @@ -232,11 +242,6 @@ def wrapped_scaled_dot_product(
# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)

# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
Expand Down Expand Up @@ -381,6 +386,8 @@ def forward(
super().forward_checker()
raise_on_head_mask(layer_head_mask)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
if len(self.orig_layer.pruned_heads) > 0:
raise ValueError(
f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.orig_layer.pruned_heads}."
Expand Down Expand Up @@ -451,58 +458,52 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
past_key_value[1] if past_key_value is not None else None,
)

if (position_bias is None and not self.orig_layer.has_relative_attention_bias) or (
position_bias is not None and position_bias[0, 0, 0, 0] == 0
):
if position_bias is None and not self.orig_layer.has_relative_attention_bias:
query_states = self.scale * query_states
if position_bias is None and not self.orig_layer.has_relative_attention_bias:
if mask is None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
)
elif mask is not None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=mask, dropout_p=0.0, is_causal=False
)

if position_bias is None:
if not self.orig_layer.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.orig_layer.n_heads, real_seq_length, key_length),
device=query_states.device,
dtype=query_states.dtype,
device=value_states.device,
dtype=value_states.dtype,
)
if self.orig_layer.gradient_checkpointing and self.orig_layer.training:
position_bias.requires_grad = True
else:
position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=value_states.device)

query_states = self.scale * query_states
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
)

else:
# compute scores
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

if position_bias is None:
position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=scores.device)

# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

scores += position_bias
attn_weights = torch.nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = torch.nn.functional.dropout(
attn_weights, p=self.orig_layer.dropout, training=self.orig_layer.training
) # (batch_size, n_heads, seq_length, key_length)
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)

attn_output = torch.matmul(attn_weights, value_states)
if self.orig_layer.has_relative_attention_bias:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False
)

attn_output = unshape(attn_output) # (batch_size, seq_length, dim)
attn_output = self.orig_layer.o(attn_output)

present_key_value_state = (key_states, value_states) if (self.orig_layer.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

if output_attentions:
outputs = outputs + (attn_weights,)
return outputs


Expand Down
126 changes: 80 additions & 46 deletions tests/benchmark/benchmark_bettertransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import torch
from tqdm import tqdm
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

from optimum.bettertransformer import BetterTransformer
from optimum.exporters import TasksManager


def get_parser():
Expand Down Expand Up @@ -104,57 +105,55 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s
def timing_cuda(model, num_batches, input_ids, masks, is_decoder, generation_config=None):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
torch.cuda.synchronize()

start_event.record()
for _ in range(num_batches):
for _ in tqdm(range(num_batches)):
if is_decoder:
_ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config)
else:
_ = model(input_ids, masks)
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches
max_memory = torch.cuda.max_memory_allocated(device)

return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches, max_memory

def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id):

def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id):
# Warmup
if is_decoder:
min_length = max(max_token - 20, 5)

gen_config = GenerationConfig(
do_greedy=True,
max_new_tokens=max_token,
min_length=min_length,
min_new_tokens=max_token,
use_cache=True,
pad_token_id=pad_token_id,
)
_ = hf_model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
torch.cuda.synchronize()
bt_model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
_ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
torch.cuda.synchronize()

else:
_ = hf_model(input_ids, masks)
torch.cuda.synchronize()
_ = bt_model(input_ids, masks)
_ = model(input_ids, masks)
torch.cuda.synchronize()

# benchmark
if is_decoder:
total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks, is_decoder, gen_config)
total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks, is_decoder, gen_config)
total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder, gen_config)
else:
total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks, is_decoder)
total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks, is_decoder)
total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder)

return total_bt_time, total_hf_time
return total_time, max_mem


if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()

BATCH_SIZES = [8, 16, 64]
SEQ_LEN = [64, 128, 256]
BATCH_SIZES = [2]
SEQ_LEN = [64]
if args.is_decoder:
PAD_PERCENTAGES = [0]
else:
Expand All @@ -166,63 +165,98 @@ def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

if args.is_decoder:
hf_model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.float16 if args.use_half else None, use_cache=True
).eval()
task = TasksManager.infer_task_from_model(args.model_name)

if task == "causal-lm":
autoclass = AutoModelForCausalLM
elif task == "seq2seq-lm":
autoclass = AutoModelForSeq2SeqLM
else:
hf_model = AutoModel.from_pretrained(
args.model_name, torch_dtype=torch.float16 if args.use_half else None
).eval()
autoclass = AutoModel

if args.use_cuda:
hf_model = hf_model.to(0)
with torch.device("cuda:0"):
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None)
# in PyTorch we trust :)
hf_model = hf_model.to("cuda:0")
hf_model = hf_model.to(torch.float16)
else:
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None)

bt_model = BetterTransformer.transform(hf_model, keep_original_model=True)

output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w")
output_file.write(
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup\n"
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup, Mem eager (MB), Mem BT (MB), Mem saved\n"
)
for bs in tqdm(BATCH_SIZES):
for seq_len in tqdm(SEQ_LEN):
for pad_perc in tqdm(PAD_PERCENTAGES):
print(f"-- Running: bs={bs}, seq_len={seq_len}")
# current_std = int(seq_len*pad_perc)
# max_seqlen = seq_len + current_std
max_seqlen = seq_len
mean_seqlen = int((1 - pad_perc) * max_seqlen)
input_ids, _, masks = get_batch(bs, mean_seqlen, max_seqlen, args.seqlen_stdev)
input_ids, _, masks = get_batch(
bs, mean_seqlen, max_seqlen, args.seqlen_stdev, vocab_size=hf_model.config.vocab_size
)

if args.use_cuda:
input_ids = input_ids.to(device)
masks = masks.to(device)
if not args.use_mask:

if args.use_mask is False and bs == 1:
masks = None

total_bt_time, total_hf_time = benchmark(
hf_model,
bt_model,
input_ids,
masks,
args.num_batches,
args.is_decoder,
args.max_token,
tokenizer.pad_token_id,
)
with torch.inference_mode():
total_hf_time, max_mem_eager = benchmark(
hf_model,
input_ids,
masks,
args.num_batches,
args.is_decoder,
args.max_token,
tokenizer.pad_token_id,
)

# raise error if no optimized kernel is available
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
total_bt_time, max_mem_bt = benchmark(
bt_model,
input_ids,
masks,
args.num_batches,
args.is_decoder,
args.max_token,
tokenizer.pad_token_id,
)

speedup = total_hf_time / total_bt_time
mem_saved = max_mem_eager / max_mem_bt

max_mem_eager = max_mem_eager * 1e-6
max_mem_bt = max_mem_bt * 1e-6

print(f"PT eager: {total_hf_time:.3f} s, peak {max_mem_eager:.2f} MB")
print(f"PT native: {total_bt_time:.3f} s, peak {max_mem_bt:.2f} MB")

output_file.write(
"{},{},{},{},{},{},{},{},{},{}\n".format(
"{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
args.num_batches,
args.use_cuda,
bs,
seq_len,
args.use_half,
args.use_mask,
pad_perc,
total_hf_time,
total_bt_time,
speedup,
f"{total_hf_time:.3f}",
f"{total_bt_time:.3f}",
f"{speedup:.3f}",
f"{max_mem_eager:.3f}",
f"{max_mem_bt:.3f}",
f"{mem_saved:.3f}",
)
)
output_file.close()
Loading