From d9195cf4a742133cc68f4bfe2fd577ffdf53f568 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Mar 2023 13:17:44 +0100 Subject: [PATCH 1/8] fix bug in spda --- .../models/decoder_models.py | 27 +++++++++++-------- tests/bettertransformer/test_decoder.py | 9 +++++-- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 5d4713244c..1e359cc0ad 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -58,6 +58,12 @@ def wrapped_scaled_dot_product( mask_value = torch.finfo(value.dtype).min mask_value = torch.full([], mask_value, dtype=value.dtype) + # in gpt-neo-x and gpt-j the query and keys are always in fp32 + # thus we need to cast them to the value dtype + if self.downcast_qk: + query = query.to(value.dtype) + key = key.to(value.dtype) + if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") @@ -86,16 +92,15 @@ def wrapped_scaled_dot_product( causal_mask = causal_mask.expand(batch_size, -1, -1, -1) attention_mask = causal_mask + attention_mask - # in gpt-neo-x and gpt-j the query and keys are always in fp32 - # thus we need to cast them to the value dtype - if self.downcast_qk: - query = query.to(value.dtype) - key = key.to(value.dtype) - sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + # in gpt-neo-x and gpt-j the query and keys are always in fp32 + # thus we need to cast them to the value dtype + if self.downcast_qk: + sdpa_result = sdpa_result.to(value.dtype) + return sdpa_result, None def forward(self, *args, **kwargs): @@ -201,6 +206,11 @@ def wrapped_scaled_dot_product( if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") + # in codegen the query and key are always in fp32 regardless of the dtype of the model + # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 + query = query.to(value.dtype) + key = key.to(value.dtype) + if batch_size == 1: if query.shape[2] > 1: # first step of the decoding @@ -232,11 +242,6 @@ def wrapped_scaled_dot_product( # we use torch.min to avoid having tensor(-inf) attention_mask = torch.min(causal_mask, attention_mask) - # in codegen the query and key are always in fp32 regardless of the dtype of the model - # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 - query = query.to(value.dtype) - key = key.to(value.dtype) - sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index efe98862eb..1877e33bf4 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -68,18 +68,23 @@ def test_logits_without_cache(self, test_name: str, model_type: str, padding, ba { "model_type": SUPPORTED_ARCH, "use_to_operator": [True, False], + "batch_size": [1, 2], } ) ) @pytest.mark.fp16 @require_torch_gpu @pytest.mark.gpu_test - def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool): + def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool, batch_size: int): self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_fp16_inference( - model_id, model_type=model_type, use_to_operator=use_to_operator, automodel_class=AutoModelForCausalLM + model_id, + model_type=model_type, + use_to_operator=use_to_operator, + automodel_class=AutoModelForCausalLM, + batch_size=batch_size, ) @parameterized.expand( From 72d7f13d9f7c64b7bb04e2128d53aad25ae994d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Mar 2023 14:10:39 +0100 Subject: [PATCH 2/8] simplify t5 --- .../models/decoder_models.py | 54 +++++++------------ 1 file changed, 19 insertions(+), 35 deletions(-) diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 1e359cc0ad..8cfdb7fe42 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -386,6 +386,8 @@ def forward( super().forward_checker() raise_on_head_mask(layer_head_mask) + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") if len(self.orig_layer.pruned_heads) > 0: raise ValueError( f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.orig_layer.pruned_heads}." @@ -456,31 +458,20 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): past_key_value[1] if past_key_value is not None else None, ) - if (position_bias is None and not self.orig_layer.has_relative_attention_bias) or ( - position_bias is not None and position_bias[0, 0, 0, 0] == 0 - ): - if position_bias is None and not self.orig_layer.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.orig_layer.n_heads, real_seq_length, key_length), - device=query_states.device, - dtype=query_states.dtype, - ) - if self.orig_layer.gradient_checkpointing and self.orig_layer.training: - position_bias.requires_grad = True - - query_states = self.scale * query_states - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - else: - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - + if True: if position_bias is None: - position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=scores.device) + if not self.orig_layer.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.orig_layer.n_heads, real_seq_length, key_length), + device=value_states.device, + dtype=value_states.dtype, + ) + if self.orig_layer.gradient_checkpointing and self.orig_layer.training: + position_bias.requires_grad = True + else: + position_bias = self.orig_layer.compute_bias( + real_seq_length, key_length, device=value_states.device + ) # if key and values are already calculated # we want only the last query position bias @@ -490,15 +481,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) - scores += position_bias - attn_weights = torch.nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = torch.nn.functional.dropout( - attn_weights, p=self.orig_layer.dropout, training=self.orig_layer.training - ) # (batch_size, n_heads, seq_length, key_length) - - attn_output = torch.matmul(attn_weights, value_states) + # query_states = self.scale * query_states + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + ) attn_output = unshape(attn_output) # (batch_size, seq_length, dim) attn_output = self.orig_layer.o(attn_output) @@ -506,8 +492,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): present_key_value_state = (key_states, value_states) if (self.orig_layer.is_decoder and use_cache) else None outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - if output_attentions: - outputs = outputs + (attn_weights,) return outputs From 280b485e8c65c2a9c4cf24891ae1c7277a41b3f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Mar 2023 14:36:31 +0100 Subject: [PATCH 3/8] simplify --- .../models/decoder_models.py | 61 ++++++++++++------- tests/bettertransformer/testing_utils.py | 3 +- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 8cfdb7fe42..7e2dd0e375 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -458,33 +458,50 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): past_key_value[1] if past_key_value is not None else None, ) - if True: - if position_bias is None: - if not self.orig_layer.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.orig_layer.n_heads, real_seq_length, key_length), - device=value_states.device, - dtype=value_states.dtype, - ) - if self.orig_layer.gradient_checkpointing and self.orig_layer.training: - position_bias.requires_grad = True - else: - position_bias = self.orig_layer.compute_bias( - real_seq_length, key_length, device=value_states.device - ) + if position_bias is None: + if not self.orig_layer.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.orig_layer.n_heads, real_seq_length, key_length), + device=value_states.device, + dtype=value_states.dtype, + ) + if self.orig_layer.gradient_checkpointing and self.orig_layer.training: + position_bias.requires_grad = True + else: + position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=value_states.device) - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) - # query_states = self.scale * query_states + query_states = self.scale * query_states + if not self.orig_layer.has_relative_attention_bias: attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + query_states, + key_states, + value_states, + attn_mask=mask if mask is not None else None, + dropout_p=0.0, + is_causal=False, ) + else: + if mask is not None: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias + mask, + dropout_p=0.0, + is_causal=False, + ) attn_output = unshape(attn_output) # (batch_size, seq_length, dim) attn_output = self.orig_layer.o(attn_output) diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 34772b2a4b..87d48a470d 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -60,7 +60,8 @@ "roformer": "hf-internal-testing/tiny-random-RoFormerModel", "splinter": "hf-internal-testing/tiny-random-SplinterModel", "tapas": "hf-internal-testing/tiny-random-TapasModel", - "t5": "hf-internal-testing/tiny-random-t5", + # "t5": "hf-internal-testing/tiny-random-t5", + "t5": "t5-base", "vilt": "hf-internal-testing/tiny-vilt-random-vqa", "vit": "hf-internal-testing/tiny-random-ViTModel", "vit_mae": "hf-internal-testing/tiny-random-ViTMAEModel", From 3db897e86f45d82f1b623841a4cca4e7e57cde55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Mar 2023 15:14:37 +0100 Subject: [PATCH 4/8] fix --- .../models/decoder_models.py | 71 +++++++++---------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 5d4713244c..b6b0420324 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -381,6 +381,8 @@ def forward( super().forward_checker() raise_on_head_mask(layer_head_mask) + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") if len(self.orig_layer.pruned_heads) > 0: raise ValueError( f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.orig_layer.pruned_heads}." @@ -451,58 +453,53 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): past_key_value[1] if past_key_value is not None else None, ) - if (position_bias is None and not self.orig_layer.has_relative_attention_bias) or ( - position_bias is not None and position_bias[0, 0, 0, 0] == 0 - ): - if position_bias is None and not self.orig_layer.has_relative_attention_bias: + query_states = self.scale * query_states + if position_bias is None and not self.orig_layer.has_relative_attention_bias: + if mask is None: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False + ) + elif mask is not None: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + if position_bias is None: + if not self.orig_layer.has_relative_attention_bias: position_bias = torch.zeros( (1, self.orig_layer.n_heads, real_seq_length, key_length), - device=query_states.device, - dtype=query_states.dtype, + device=value_states.device, + dtype=value_states.dtype, ) if self.orig_layer.gradient_checkpointing and self.orig_layer.training: position_bias.requires_grad = True + else: + position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=value_states.device) - query_states = self.scale * query_states - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - else: - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - position_bias = self.orig_layer.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - scores += position_bias - attn_weights = torch.nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = torch.nn.functional.dropout( - attn_weights, p=self.orig_layer.dropout, training=self.orig_layer.training - ) # (batch_size, n_heads, seq_length, key_length) + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) - attn_output = torch.matmul(attn_weights, value_states) + if self.orig_layer.has_relative_attention_bias: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False + ) attn_output = unshape(attn_output) # (batch_size, seq_length, dim) + attn_output = self.orig_layer.o(attn_output) present_key_value_state = (key_states, value_states) if (self.orig_layer.is_decoder and use_cache) else None outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - if output_attentions: - outputs = outputs + (attn_weights,) return outputs From 4fa0dc88d04f22451adb1ae9f1a35ad85ca5cc62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Mar 2023 16:21:44 +0100 Subject: [PATCH 5/8] typo --- tests/bettertransformer/testing_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 87d48a470d..34772b2a4b 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -60,8 +60,7 @@ "roformer": "hf-internal-testing/tiny-random-RoFormerModel", "splinter": "hf-internal-testing/tiny-random-SplinterModel", "tapas": "hf-internal-testing/tiny-random-TapasModel", - # "t5": "hf-internal-testing/tiny-random-t5", - "t5": "t5-base", + "t5": "hf-internal-testing/tiny-random-t5", "vilt": "hf-internal-testing/tiny-vilt-random-vqa", "vit": "hf-internal-testing/tiny-random-ViTModel", "vit_mae": "hf-internal-testing/tiny-random-ViTMAEModel", From c2748247ace4fe93f6da0269fa9313a46f3dbca1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 23 Mar 2023 16:31:08 +0100 Subject: [PATCH 6/8] update bench --- .../benchmark/benchmark_bettertransformer.py | 119 +++++++++++------- ...mark_bettertransformer_training_minimal.py | 50 +++++--- 2 files changed, 107 insertions(+), 62 deletions(-) diff --git a/tests/benchmark/benchmark_bettertransformer.py b/tests/benchmark/benchmark_bettertransformer.py index 9a31fc58fa..bd48bae55c 100644 --- a/tests/benchmark/benchmark_bettertransformer.py +++ b/tests/benchmark/benchmark_bettertransformer.py @@ -2,7 +2,7 @@ import torch from tqdm import tqdm -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoModelForSeq2SeqLM from optimum.bettertransformer import BetterTransformer @@ -104,57 +104,54 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s def timing_cuda(model, num_batches, input_ids, masks, is_decoder, generation_config=None): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.empty_cache() + torch.cuda.synchronize() + start_event.record() - for _ in range(num_batches): + for _ in tqdm(range(num_batches)): if is_decoder: _ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config) else: _ = model(input_ids, masks) end_event.record() torch.cuda.synchronize() - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches + max_memory = torch.cuda.max_memory_allocated(device) + return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches, max_memory -def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id): +def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id): # Warmup if is_decoder: - min_length = max(max_token - 20, 5) - gen_config = GenerationConfig( - do_greedy=True, max_new_tokens=max_token, - min_length=min_length, + min_new_tokens=max_token, use_cache=True, pad_token_id=pad_token_id, ) - _ = hf_model.generate(input_ids, attention_mask=masks, generation_config=gen_config) - torch.cuda.synchronize() - bt_model.generate(input_ids, attention_mask=masks, generation_config=gen_config) + _ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config) torch.cuda.synchronize() else: - _ = hf_model(input_ids, masks) - torch.cuda.synchronize() - _ = bt_model(input_ids, masks) + _ = model(input_ids, masks) torch.cuda.synchronize() # benchmark if is_decoder: - total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks, is_decoder, gen_config) - total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks, is_decoder, gen_config) + total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder, gen_config) else: - total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks, is_decoder) - total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks, is_decoder) + total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks, is_decoder) - return total_bt_time, total_hf_time + return total_time, max_mem if __name__ == "__main__": parser = get_parser() args = parser.parse_args() - BATCH_SIZES = [8, 16, 64] - SEQ_LEN = [64, 128, 256] + BATCH_SIZES = [8] + SEQ_LEN = [64] if args.is_decoder: PAD_PERCENTAGES = [0] else: @@ -167,52 +164,81 @@ def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max tokenizer.pad_token = tokenizer.eos_token if args.is_decoder: - hf_model = AutoModelForCausalLM.from_pretrained( - args.model_name, torch_dtype=torch.float16 if args.use_half else None, use_cache=True - ).eval() + # autoclass = AutoModelForCausalLM + autoclass = AutoModelForSeq2SeqLM else: - hf_model = AutoModel.from_pretrained( - args.model_name, torch_dtype=torch.float16 if args.use_half else None - ).eval() + autoclass = AutoModel if args.use_cuda: - hf_model = hf_model.to(0) + with torch.device("cuda:0"): + hf_model = autoclass.from_pretrained( + args.model_name, torch_dtype=torch.float16 if args.use_half else None + ) + # in PyTorch we trust :) + hf_model = hf_model.to("cuda:0") + hf_model = hf_model.to(torch.float16) + else: + hf_model = autoclass.from_pretrained( + args.model_name, torch_dtype=torch.float16 if args.use_half else None + ) + bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w") output_file.write( - "num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup\n" + "num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup, Mem eager (MB), Mem BT (MB), Mem saved\n" ) for bs in tqdm(BATCH_SIZES): for seq_len in tqdm(SEQ_LEN): for pad_perc in tqdm(PAD_PERCENTAGES): + print(f"-- Running: bs={bs}, seq_len={seq_len}") # current_std = int(seq_len*pad_perc) # max_seqlen = seq_len + current_std max_seqlen = seq_len mean_seqlen = int((1 - pad_perc) * max_seqlen) - input_ids, _, masks = get_batch(bs, mean_seqlen, max_seqlen, args.seqlen_stdev) + input_ids, _, masks = get_batch(bs, mean_seqlen, max_seqlen, args.seqlen_stdev, vocab_size=hf_model.config.vocab_size) if args.use_cuda: input_ids = input_ids.to(device) masks = masks.to(device) - if not args.use_mask: + + if args.use_mask is False and bs == 1: masks = None - total_bt_time, total_hf_time = benchmark( - hf_model, - bt_model, - input_ids, - masks, - args.num_batches, - args.is_decoder, - args.max_token, - tokenizer.pad_token_id, - ) + with torch.inference_mode(): + total_hf_time, max_mem_eager = benchmark( + hf_model, + input_ids, + masks, + args.num_batches, + args.is_decoder, + args.max_token, + tokenizer.pad_token_id, + ) + + # raise error if no optimized kernel is available + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True): + total_bt_time, max_mem_bt = benchmark( + bt_model, + input_ids, + masks, + args.num_batches, + args.is_decoder, + args.max_token, + tokenizer.pad_token_id, + ) speedup = total_hf_time / total_bt_time + mem_saved = max_mem_eager / max_mem_bt + + max_mem_eager = max_mem_eager * 1e-6 + max_mem_bt = max_mem_bt * 1e-6 + + print(f"PT eager: {total_hf_time:.3f} s, peak {max_mem_eager:.2f} MB") + print(f"PT native: {total_bt_time:.3f} s, peak {max_mem_bt:.2f} MB") output_file.write( - "{},{},{},{},{},{},{},{},{},{}\n".format( + "{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format( args.num_batches, args.use_cuda, bs, @@ -220,9 +246,12 @@ def benchmark(hf_model, bt_model, input_ids, masks, num_batches, is_decoder, max args.use_half, args.use_mask, pad_perc, - total_hf_time, - total_bt_time, - speedup, + f"{total_hf_time:.3f}", + f"{total_bt_time:.3f}", + f"{speedup:.3f}", + f"{max_mem_eager:.3f}", + f"{max_mem_bt:.3f}", + f"{mem_saved:.3f}" ) ) output_file.close() diff --git a/tests/benchmark/benchmark_bettertransformer_training_minimal.py b/tests/benchmark/benchmark_bettertransformer_training_minimal.py index 768d713ba3..d64ef340fa 100644 --- a/tests/benchmark/benchmark_bettertransformer_training_minimal.py +++ b/tests/benchmark/benchmark_bettertransformer_training_minimal.py @@ -51,8 +51,8 @@ def seed_init_fn(x): return -def benchmark_training(model, inputs: Dict, num_epochs: int): - num_training_steps = num_epochs * 1000 +def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): + num_training_steps = 1024 // batch_size progress_bar = tqdm(range(num_training_steps)) model.train() @@ -64,6 +64,11 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.empty_cache() + + torch.cuda.synchronize() start_event.record() for _ in range(num_epochs): for _ in range(num_training_steps): @@ -76,7 +81,9 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): end_event.record() torch.cuda.synchronize() - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_epochs + max_memory = torch.cuda.max_memory_allocated(device) + + return (start_event.elapsed_time(end_event) * 1.0e-3) / num_epochs, max_memory if __name__ == "__main__": @@ -89,12 +96,12 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): dtype = torch.float32 if args.use_half is False else torch.float16 hf_model = hf_model.to(device=device, dtype=dtype) - BATCH_SIZES = [8, 16, 32, 64] - SEQ_LEN = [32, 64, 128, 256] + BATCH_SIZES = [64] + SEQ_LEN = [1024] device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu") - output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w") - output_file.write("num_epochs, batch_size, seq_len, is cuda, HF time / epoch (s), BT time / epoch (s), Speedup\n") + output_file = open("log_{}_train.csv".format(args.model_name.replace("/", "-")), "w") + output_file.write("num_epochs, batch_size, seq_len, is cuda, HF time / epoch (s), BT time / epoch (s), Speedup, Eager peak mem (MB), BT peak mem (MB), Mem saving\n") num_epochs = args.num_epochs for batch_size in BATCH_SIZES: @@ -109,22 +116,28 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device), } - hf_time_per_epoch = benchmark_training(hf_model, inputs=inputs, num_epochs=num_epochs) - - print(f"Vanilla time / epoch : {hf_time_per_epoch:.3f} s") + hf_time_per_epoch, eager_max_mem = benchmark_training(hf_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size) bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) bt_model = bt_model.to(device=device, dtype=dtype) - bt_time_per_epoch = benchmark_training( - bt_model, - inputs=inputs, - num_epochs=num_epochs, - ) - print(f"BT time / epoch : {bt_time_per_epoch:.3f} s") + # raise error if no optimized kernel is available + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True): + bt_time_per_epoch, bt_max_mem = benchmark_training( + bt_model, + inputs=inputs, + num_epochs=num_epochs, + batch_size=batch_size + ) + + eager_max_mem = eager_max_mem * 1e-6 + bt_max_mem = bt_max_mem * 1e-6 + + print(f"PT eager: {hf_time_per_epoch:.3f} s, peak {eager_max_mem:.2f} MB") + print(f"PT native: {bt_time_per_epoch:.3f} s, peak {bt_max_mem:.2f} MB") speedup = hf_time_per_epoch / bt_time_per_epoch - print(f"Speedup: {speedup:.3f}x") + mem_saved = eager_max_mem / bt_max_mem output_file.write( "{},{},{},{},{},{},{}\n".format( @@ -135,6 +148,9 @@ def benchmark_training(model, inputs: Dict, num_epochs: int): f"{hf_time_per_epoch:.3f}", f"{bt_time_per_epoch:.3f}", f"{speedup:.3f}", + f"{eager_max_mem:.3f}", + f"{bt_max_mem:.3f}", + f"{mem_saved:.3f}" ) ) output_file.close() From 453f18c3d9dbdd3a785aae41cc80c4d7963d11f0 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 23 Mar 2023 17:25:09 +0100 Subject: [PATCH 7/8] update bench --- tests/benchmark/benchmark_bettertransformer.py | 12 ++++++++---- .../benchmark_bettertransformer_training_minimal.py | 8 +++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/benchmark/benchmark_bettertransformer.py b/tests/benchmark/benchmark_bettertransformer.py index bd48bae55c..6d29afbe6b 100644 --- a/tests/benchmark/benchmark_bettertransformer.py +++ b/tests/benchmark/benchmark_bettertransformer.py @@ -6,6 +6,7 @@ from optimum.bettertransformer import BetterTransformer +from optimum.exporters import TasksManager def get_parser(): parser = argparse.ArgumentParser() @@ -150,7 +151,7 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t parser = get_parser() args = parser.parse_args() - BATCH_SIZES = [8] + BATCH_SIZES = [2] SEQ_LEN = [64] if args.is_decoder: PAD_PERCENTAGES = [0] @@ -163,8 +164,11 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if args.is_decoder: - # autoclass = AutoModelForCausalLM + task = TasksManager.infer_task_from_model(args.model_name) + + if task == "causal-lm": + autoclass = AutoModelForCausalLM + elif task == "seq2seq-lm": autoclass = AutoModelForSeq2SeqLM else: autoclass = AutoModel @@ -217,7 +221,7 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t ) # raise error if no optimized kernel is available - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): total_bt_time, max_mem_bt = benchmark( bt_model, input_ids, diff --git a/tests/benchmark/benchmark_bettertransformer_training_minimal.py b/tests/benchmark/benchmark_bettertransformer_training_minimal.py index d64ef340fa..46bafe7a0a 100644 --- a/tests/benchmark/benchmark_bettertransformer_training_minimal.py +++ b/tests/benchmark/benchmark_bettertransformer_training_minimal.py @@ -5,7 +5,7 @@ import numpy as np import torch from tqdm.auto import tqdm -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM from optimum.bettertransformer import BetterTransformer @@ -96,7 +96,7 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): dtype = torch.float32 if args.use_half is False else torch.float16 hf_model = hf_model.to(device=device, dtype=dtype) - BATCH_SIZES = [64] + BATCH_SIZES = [8] SEQ_LEN = [1024] device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu") @@ -119,11 +119,9 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): hf_time_per_epoch, eager_max_mem = benchmark_training(hf_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size) bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) - bt_model = bt_model.to(device=device, dtype=dtype) - # raise error if no optimized kernel is available - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): bt_time_per_epoch, bt_max_mem = benchmark_training( bt_model, inputs=inputs, From 3acb3da7e2ab7fb1b43ee939e219f86546c95a0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 23 Mar 2023 17:25:50 +0100 Subject: [PATCH 8/8] style --- .../benchmark/benchmark_bettertransformer.py | 23 ++++++++++--------- ...mark_bettertransformer_training_minimal.py | 17 +++++++------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/benchmark/benchmark_bettertransformer.py b/tests/benchmark/benchmark_bettertransformer.py index 6d29afbe6b..3dfb375b30 100644 --- a/tests/benchmark/benchmark_bettertransformer.py +++ b/tests/benchmark/benchmark_bettertransformer.py @@ -2,12 +2,12 @@ import torch from tqdm import tqdm -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoModelForSeq2SeqLM +from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig from optimum.bettertransformer import BetterTransformer - from optimum.exporters import TasksManager + def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( @@ -122,6 +122,7 @@ def timing_cuda(model, num_batches, input_ids, masks, is_decoder, generation_con return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches, max_memory + def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id): # Warmup if is_decoder: @@ -175,16 +176,12 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t if args.use_cuda: with torch.device("cuda:0"): - hf_model = autoclass.from_pretrained( - args.model_name, torch_dtype=torch.float16 if args.use_half else None - ) + hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None) # in PyTorch we trust :) hf_model = hf_model.to("cuda:0") hf_model = hf_model.to(torch.float16) else: - hf_model = autoclass.from_pretrained( - args.model_name, torch_dtype=torch.float16 if args.use_half else None - ) + hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None) bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) @@ -200,7 +197,9 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t # max_seqlen = seq_len + current_std max_seqlen = seq_len mean_seqlen = int((1 - pad_perc) * max_seqlen) - input_ids, _, masks = get_batch(bs, mean_seqlen, max_seqlen, args.seqlen_stdev, vocab_size=hf_model.config.vocab_size) + input_ids, _, masks = get_batch( + bs, mean_seqlen, max_seqlen, args.seqlen_stdev, vocab_size=hf_model.config.vocab_size + ) if args.use_cuda: input_ids = input_ids.to(device) @@ -221,7 +220,9 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t ) # raise error if no optimized kernel is available - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): total_bt_time, max_mem_bt = benchmark( bt_model, input_ids, @@ -255,7 +256,7 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t f"{speedup:.3f}", f"{max_mem_eager:.3f}", f"{max_mem_bt:.3f}", - f"{mem_saved:.3f}" + f"{mem_saved:.3f}", ) ) output_file.close() diff --git a/tests/benchmark/benchmark_bettertransformer_training_minimal.py b/tests/benchmark/benchmark_bettertransformer_training_minimal.py index 46bafe7a0a..c8abdca165 100644 --- a/tests/benchmark/benchmark_bettertransformer_training_minimal.py +++ b/tests/benchmark/benchmark_bettertransformer_training_minimal.py @@ -5,7 +5,7 @@ import numpy as np import torch from tqdm.auto import tqdm -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM +from transformers import AutoModelForCausalLM from optimum.bettertransformer import BetterTransformer @@ -101,7 +101,9 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu") output_file = open("log_{}_train.csv".format(args.model_name.replace("/", "-")), "w") - output_file.write("num_epochs, batch_size, seq_len, is cuda, HF time / epoch (s), BT time / epoch (s), Speedup, Eager peak mem (MB), BT peak mem (MB), Mem saving\n") + output_file.write( + "num_epochs, batch_size, seq_len, is cuda, HF time / epoch (s), BT time / epoch (s), Speedup, Eager peak mem (MB), BT peak mem (MB), Mem saving\n" + ) num_epochs = args.num_epochs for batch_size in BATCH_SIZES: @@ -116,17 +118,16 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device), } - hf_time_per_epoch, eager_max_mem = benchmark_training(hf_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size) + hf_time_per_epoch, eager_max_mem = benchmark_training( + hf_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size + ) bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) # raise error if no optimized kernel is available with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): bt_time_per_epoch, bt_max_mem = benchmark_training( - bt_model, - inputs=inputs, - num_epochs=num_epochs, - batch_size=batch_size + bt_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size ) eager_max_mem = eager_max_mem * 1e-6 @@ -148,7 +149,7 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): f"{speedup:.3f}", f"{eager_max_mem:.3f}", f"{bt_max_mem:.3f}", - f"{mem_saved:.3f}" + f"{mem_saved:.3f}", ) ) output_file.close()