From edf535e9eb12bf13686336f4b7987edfc2d4c42d Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Mar 2023 17:28:08 +0100 Subject: [PATCH] Fix sdpa with batch size = 1, better benchmark (#915) * fix bug in spda * simplify t5 * simplify * fix * typo * update bench * update bench * style --------- Co-authored-by: Your Name --- .../models/decoder_models.py | 97 +++++++------- .../benchmark/benchmark_bettertransformer.py | 126 +++++++++++------- ...mark_bettertransformer_training_minimal.py | 51 ++++--- tests/bettertransformer/test_decoder.py | 9 +- 4 files changed, 169 insertions(+), 114 deletions(-) diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 5d4713244c..499006b174 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -58,6 +58,12 @@ def wrapped_scaled_dot_product( mask_value = torch.finfo(value.dtype).min mask_value = torch.full([], mask_value, dtype=value.dtype) + # in gpt-neo-x and gpt-j the query and keys are always in fp32 + # thus we need to cast them to the value dtype + if self.downcast_qk: + query = query.to(value.dtype) + key = key.to(value.dtype) + if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") @@ -86,16 +92,15 @@ def wrapped_scaled_dot_product( causal_mask = causal_mask.expand(batch_size, -1, -1, -1) attention_mask = causal_mask + attention_mask - # in gpt-neo-x and gpt-j the query and keys are always in fp32 - # thus we need to cast them to the value dtype - if self.downcast_qk: - query = query.to(value.dtype) - key = key.to(value.dtype) - sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + # in gpt-neo-x and gpt-j the query and keys are always in fp32 + # thus we need to cast them to the value dtype + if self.downcast_qk: + sdpa_result = sdpa_result.to(value.dtype) + return sdpa_result, None def forward(self, *args, **kwargs): @@ -201,6 +206,11 @@ def wrapped_scaled_dot_product( if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") + # in codegen the query and key are always in fp32 regardless of the dtype of the model + # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 + query = query.to(value.dtype) + key = key.to(value.dtype) + if batch_size == 1: if query.shape[2] > 1: # first step of the decoding @@ -232,11 +242,6 @@ def wrapped_scaled_dot_product( # we use torch.min to avoid having tensor(-inf) attention_mask = torch.min(causal_mask, attention_mask) - # in codegen the query and key are always in fp32 regardless of the dtype of the model - # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 - query = query.to(value.dtype) - key = key.to(value.dtype) - sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -381,6 +386,8 @@ def forward( super().forward_checker() raise_on_head_mask(layer_head_mask) + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") if len(self.orig_layer.pruned_heads) > 0: raise ValueError( f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.orig_layer.pruned_heads}." @@ -451,49 +458,45 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): past_key_value[1] if past_key_value is not None else None, ) - if (position_bias is None and not self.orig_layer.has_relative_attention_bias) or ( - position_bias is not None and position_bias[0, 0, 0, 0] == 0 - ): - if position_bias is None and not self.orig_layer.has_relative_attention_bias: + query_states = self.scale * query_states + if position_bias is None and not self.orig_layer.has_relative_attention_bias: + if mask is None: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False + ) + elif mask is not None: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + if position_bias is None: + if not self.orig_layer.has_relative_attention_bias: position_bias = torch.zeros( (1, self.orig_layer.n_heads, real_seq_length, key_length), - device=query_states.device, - dtype=query_states.dtype, + device=value_states.device, + dtype=value_states.dtype, ) if self.orig_layer.gradient_checkpointing and self.orig_layer.training: position_bias.requires_grad = True + else: + position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=value_states.device) - query_states = self.scale * query_states - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - else: - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - scores += position_bias - attn_weights = torch.nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = torch.nn.functional.dropout( - attn_weights, p=self.orig_layer.dropout, training=self.orig_layer.training - ) # (batch_size, n_heads, seq_length, key_length) + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) - attn_output = torch.matmul(attn_weights, value_states) + if self.orig_layer.has_relative_attention_bias: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + ) attn_output = unshape(attn_output) # (batch_size, seq_length, dim) attn_output = self.orig_layer.o(attn_output) @@ -501,8 +504,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): present_key_value_state = (key_states, value_states) if (self.orig_layer.is_decoder and use_cache) else None outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - if output_attentions: - outputs = outputs + (attn_weights,) return outputs diff --git a/tests/benchmark/benchmark_bettertransformer.py b/tests/benchmark/benchmark_bettertransformer.py index 9a31fc58fa..3dfb375b30 100644 --- a/tests/benchmark/benchmark_bettertransformer.py +++ b/tests/benchmark/benchmark_bettertransformer.py @@ -2,9 +2,10 @@ import torch from tqdm import tqdm -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig from optimum.bettertransformer import BetterTransformer +from optimum.exporters import TasksManager def get_parser(): @@ -104,57 +105,55 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s def timing_cuda(model, num_batches, input_ids, masks, is_decoder, generation_config=None): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.empty_cache() + torch.cuda.synchronize() + start_event.record() - for _ in range(num_batches): + for _ in tqdm(range(num_batches)): if is_decoder: _ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config) else: _ = model(input_ids, masks) end_event.record() torch.cuda.synchronize() - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches + max_memory = torch.cuda.max_memory_allocated(device) + return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches, max_memory -def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id): + +def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id): # Warmup if is_decoder: - min_length = max(max_token - 20, 5) - gen_config = GenerationConfig( - do_greedy=True, max_new_tokens=max_token, - min_length=min_length, + min_new_tokens=max_token, use_cache=True, pad_token_id=pad_token_id, ) - _ = hf_model.generate(input_ids, attention_mask=masks, generation_config=gen_config) - torch.cuda.synchronize() - bt_model.generate(input_ids, attention_mask=masks, generation_config=gen_config) + _ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config) torch.cuda.synchronize() else: - _ = hf_model(input_ids, masks) - torch.cuda.synchronize() - _ = bt_model(input_ids, masks) + _ = model(input_ids, masks) torch.cuda.synchronize() # benchmark if is_decoder: - total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks, is_decoder, gen_config) - total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks, is_decoder, gen_config) + total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder, gen_config) else: - total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks, is_decoder) - total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks, is_decoder) + total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder) - return total_bt_time, total_hf_time + return total_time, max_mem if __name__ == "__main__": parser = get_parser() args = parser.parse_args() - BATCH_SIZES = [8, 16, 64] - SEQ_LEN = [64, 128, 256] + BATCH_SIZES = [2] + SEQ_LEN = [64] if args.is_decoder: PAD_PERCENTAGES = [0] else: @@ -166,53 +165,85 @@ def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if args.is_decoder: - hf_model = AutoModelForCausalLM.from_pretrained( - args.model_name, torch_dtype=torch.float16 if args.use_half else None, use_cache=True - ).eval() + task = TasksManager.infer_task_from_model(args.model_name) + + if task == "causal-lm": + autoclass = AutoModelForCausalLM + elif task == "seq2seq-lm": + autoclass = AutoModelForSeq2SeqLM else: - hf_model = AutoModel.from_pretrained( - args.model_name, torch_dtype=torch.float16 if args.use_half else None - ).eval() + autoclass = AutoModel if args.use_cuda: - hf_model = hf_model.to(0) + with torch.device("cuda:0"): + hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None) + # in PyTorch we trust :) + hf_model = hf_model.to("cuda:0") + hf_model = hf_model.to(torch.float16) + else: + hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None) + bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w") output_file.write( - "num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup\n" + "num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup, Mem eager (MB), Mem BT (MB), Mem saved\n" ) for bs in tqdm(BATCH_SIZES): for seq_len in tqdm(SEQ_LEN): for pad_perc in tqdm(PAD_PERCENTAGES): + print(f"-- Running: bs={bs}, seq_len={seq_len}") # current_std = int(seq_len*pad_perc) # max_seqlen = seq_len + current_std max_seqlen = seq_len mean_seqlen = int((1 - pad_perc) * max_seqlen) - input_ids, _, masks = get_batch(bs, mean_seqlen, max_seqlen, args.seqlen_stdev) + input_ids, _, masks = get_batch( + bs, mean_seqlen, max_seqlen, args.seqlen_stdev, vocab_size=hf_model.config.vocab_size + ) if args.use_cuda: input_ids = input_ids.to(device) masks = masks.to(device) - if not args.use_mask: + + if args.use_mask is False and bs == 1: masks = None - total_bt_time, total_hf_time = benchmark( - hf_model, - bt_model, - input_ids, - masks, - args.num_batches, - args.is_decoder, - args.max_token, - tokenizer.pad_token_id, - ) + with torch.inference_mode(): + total_hf_time, max_mem_eager = benchmark( + hf_model, + input_ids, + masks, + args.num_batches, + args.is_decoder, + args.max_token, + tokenizer.pad_token_id, + ) + + # raise error if no optimized kernel is available + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): + total_bt_time, max_mem_bt = benchmark( + bt_model, + input_ids, + masks, + args.num_batches, + args.is_decoder, + args.max_token, + tokenizer.pad_token_id, + ) speedup = total_hf_time / total_bt_time + mem_saved = max_mem_eager / max_mem_bt + + max_mem_eager = max_mem_eager * 1e-6 + max_mem_bt = max_mem_bt * 1e-6 + + print(f"PT eager: {total_hf_time:.3f} s, peak {max_mem_eager:.2f} MB") + print(f"PT native: {total_bt_time:.3f} s, peak {max_mem_bt:.2f} MB") output_file.write( - "{},{},{},{},{},{},{},{},{},{}\n".format( + "{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format( args.num_batches, args.use_cuda, bs, @@ -220,9 +251,12 @@ def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max args.use_half, args.use_mask, pad_perc, - total_hf_time, - total_bt_time, - speedup, + f"{total_hf_time:.3f}", + f"{total_bt_time:.3f}", + f"{speedup:.3f}", + f"{max_mem_eager:.3f}", + f"{max_mem_bt:.3f}", + f"{mem_saved:.3f}", ) ) output_file.close() diff --git a/tests/benchmark/benchmark_bettertransformer_training_minimal.py b/tests/benchmark/benchmark_bettertransformer_training_minimal.py index 768d713ba3..c8abdca165 100644 --- a/tests/benchmark/benchmark_bettertransformer_training_minimal.py +++ b/tests/benchmark/benchmark_bettertransformer_training_minimal.py @@ -51,8 +51,8 @@ def seed_init_fn(x): return -def benchmark_training(model, inputs: Dict, num_epochs: int): - num_training_steps = num_epochs * 1000 +def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): + num_training_steps = 1024 // batch_size progress_bar = tqdm(range(num_training_steps)) model.train() @@ -64,6 +64,11 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.empty_cache() + + torch.cuda.synchronize() start_event.record() for _ in range(num_epochs): for _ in range(num_training_steps): @@ -76,7 +81,9 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): end_event.record() torch.cuda.synchronize() - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_epochs + max_memory = torch.cuda.max_memory_allocated(device) + + return (start_event.elapsed_time(end_event) * 1.0e-3) / num_epochs, max_memory if __name__ == "__main__": @@ -89,12 +96,14 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): dtype = torch.float32 if args.use_half is False else torch.float16 hf_model = hf_model.to(device=device, dtype=dtype) - BATCH_SIZES = [8, 16, 32, 64] - SEQ_LEN = [32, 64, 128, 256] + BATCH_SIZES = [8] + SEQ_LEN = [1024] device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu") - output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w") - output_file.write("num_epochs, batch_size, seq_len, is cuda, HF time / epoch (s), BT time / epoch (s), Speedup\n") + output_file = open("log_{}_train.csv".format(args.model_name.replace("/", "-")), "w") + output_file.write( + "num_epochs, batch_size, seq_len, is cuda, HF time / epoch (s), BT time / epoch (s), Speedup, Eager peak mem (MB), BT peak mem (MB), Mem saving\n" + ) num_epochs = args.num_epochs for batch_size in BATCH_SIZES: @@ -109,22 +118,25 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device), } - hf_time_per_epoch = benchmark_training(hf_model, inputs=inputs, num_epochs=num_epochs) - - print(f"Vanilla time / epoch : {hf_time_per_epoch:.3f} s") + hf_time_per_epoch, eager_max_mem = benchmark_training( + hf_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size + ) bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) - bt_model = bt_model.to(device=device, dtype=dtype) - bt_time_per_epoch = benchmark_training( - bt_model, - inputs=inputs, - num_epochs=num_epochs, - ) + # raise error if no optimized kernel is available + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + bt_time_per_epoch, bt_max_mem = benchmark_training( + bt_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size + ) + + eager_max_mem = eager_max_mem * 1e-6 + bt_max_mem = bt_max_mem * 1e-6 - print(f"BT time / epoch : {bt_time_per_epoch:.3f} s") + print(f"PT eager: {hf_time_per_epoch:.3f} s, peak {eager_max_mem:.2f} MB") + print(f"PT native: {bt_time_per_epoch:.3f} s, peak {bt_max_mem:.2f} MB") speedup = hf_time_per_epoch / bt_time_per_epoch - print(f"Speedup: {speedup:.3f}x") + mem_saved = eager_max_mem / bt_max_mem output_file.write( "{},{},{},{},{},{},{}\n".format( @@ -135,6 +147,9 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): f"{hf_time_per_epoch:.3f}", f"{bt_time_per_epoch:.3f}", f"{speedup:.3f}", + f"{eager_max_mem:.3f}", + f"{bt_max_mem:.3f}", + f"{mem_saved:.3f}", ) ) output_file.close() diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index efe98862eb..1877e33bf4 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -68,18 +68,23 @@ def test_logits_without_cache(self, test_name: str, model_type: str, padding, ba { "model_type": SUPPORTED_ARCH, "use_to_operator": [True, False], + "batch_size": [1, 2], } ) ) @pytest.mark.fp16 @require_torch_gpu @pytest.mark.gpu_test - def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool): + def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool, batch_size: int): self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_fp16_inference( - model_id, model_type=model_type, use_to_operator=use_to_operator, automodel_class=AutoModelForCausalLM + model_id, + model_type=model_type, + use_to_operator=use_to_operator, + automodel_class=AutoModelForCausalLM, + batch_size=batch_size, ) @parameterized.expand(