diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 2a6a045fe4..46551f4bd2 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -49,7 +49,7 @@ def gpt2_wrapped_scaled_dot_product( if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") - if batch_size == 1: + if batch_size == 1 or self.training: if query.shape[2] > 1: sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True @@ -103,7 +103,7 @@ def gpt_neo_wrapped_scaled_dot_product( if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") - if batch_size == 1 and self.attention_type == "global": + if (batch_size == 1 or self.training) and self.attention_type == "global": if query.shape[2] > 1: sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True @@ -153,7 +153,7 @@ def codegen_wrapped_scaled_dot_product( query = query.to(value.dtype) key = key.to(value.dtype) - if batch_size == 1: + if batch_size == 1 or self.training: if query.shape[2] > 1: # first step of the decoding sdpa_result = torch.nn.functional.scaled_dot_product_attention( @@ -247,7 +247,7 @@ def opt_forward( query_states = self._shape(query_states, tgt_len, batch_size) query_states = query_states * self.scale - if batch_size == 1: + if batch_size == 1 or self.training: if query_states.shape[2] > 1: attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index 8a3936d1de..c7d694ab4c 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) if is_accelerate_available(): - from accelerate import dispatch_model, infer_auto_device_map + from accelerate import dispatch_model from accelerate.hooks import remove_hook_from_module ERROR_MESSAGE = r"The Better Transformers implementation for the model {model_name} has not been implemented yet. Please open an issue requesting the addition of this model with its `BetterTransformer` implementation." @@ -192,6 +192,7 @@ def transform( # Check if we have to load the model using `accelerate` if hasattr(model, "hf_device_map"): load_accelerate = True + hf_device_map = model.hf_device_map else: load_accelerate = False @@ -224,8 +225,7 @@ def transform( hf_config = model.config if load_accelerate: - # remove the hooks from the original model to - # avoid weights being on `meta` device. + # Remove the hooks from the original model to avoid weights being on `meta` device. remove_hook_from_module(model, recurse=True) if keep_original_model: @@ -249,24 +249,27 @@ def transform( if BetterTransformerManager.requires_nested_tensor(model_fast.config.model_type): set_last_layer(model_fast) - # Step 6: Add a class arguments, we might need to identify whether the model + # Add a class arguments, we might need to identify whether the model # has been correctly converted to its `BetterTransformer` version. setattr(model_fast, "use_bettertransformer", True) - # Step 7: dispatch model if `accelerate` is enabled if load_accelerate: - device_map_bt = infer_auto_device_map(model_fast, max_memory=max_memory) - - remove_hook_from_module(model_fast, recurse=True) - - model_fast = dispatch_model(model_fast, device_map_bt) + model_fast = dispatch_model(model_fast, hf_device_map) + # It is not recommended to have `keep_original_model=True` with a model + # that is loaded with accelerate but just in case if keep_original_model: - # It is not recommended to have `keep_original_model=True` with a model - # that is loaded with accelerate but just in case .. - model = dispatch_model(model, model.hf_device_map) + model = dispatch_model(model, hf_device_map) + + # See: https://github.com/pytorch/pytorch/issues/96099 + if BetterTransformerManager.requires_torch_20(model_fast.config.model_type): + logging.warning( + f"For training, the BetterTransformer implementation for {model_fast.config.model_type} " + " architecture currently does not support padding as fused kernels do not support custom" + " attention masks. Beware that passing padded batched training data may result in unexpected outputs." + ) - # Step 8: overwrite the `save_pretrained` method + # Overwrite the `save_pretrained` method # by raising an error if the user tries to save the model # or push it to the hub. model_fast._old_save_pretrained = model_fast.save_pretrained diff --git a/tests/benchmark/benchmark_bettertransformer.py b/tests/benchmark/benchmark_bettertransformer.py index 3dfb375b30..38517d66d2 100644 --- a/tests/benchmark/benchmark_bettertransformer.py +++ b/tests/benchmark/benchmark_bettertransformer.py @@ -183,12 +183,14 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t else: hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None) - bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) - output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w") output_file.write( - "num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup, Mem eager (MB), Mem BT (MB), Mem saved\n" + "num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, Latency eager (s), Latency BT (s), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)\n" ) + + all_total_hf_time = {} + all_max_mem_eager = {} + for bs in tqdm(BATCH_SIZES): for seq_len in tqdm(SEQ_LEN): for pad_perc in tqdm(PAD_PERCENTAGES): @@ -219,6 +221,30 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t tokenizer.pad_token_id, ) + all_total_hf_time[(bs, seq_len)] = total_hf_time + all_max_mem_eager[(bs, seq_len)] = max_mem_eager + + bt_model = BetterTransformer.transform(hf_model) + for bs in tqdm(BATCH_SIZES): + for seq_len in tqdm(SEQ_LEN): + for pad_perc in tqdm(PAD_PERCENTAGES): + print(f"-- Running: bs={bs}, seq_len={seq_len}") + # current_std = int(seq_len*pad_perc) + # max_seqlen = seq_len + current_std + max_seqlen = seq_len + mean_seqlen = int((1 - pad_perc) * max_seqlen) + input_ids, _, masks = get_batch( + bs, mean_seqlen, max_seqlen, args.seqlen_stdev, vocab_size=hf_model.config.vocab_size + ) + + if args.use_cuda: + input_ids = input_ids.to(device) + masks = masks.to(device) + + if args.use_mask is False and bs == 1: + masks = None + + with torch.inference_mode(): # raise error if no optimized kernel is available with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=True, enable_mem_efficient=True @@ -233,8 +259,11 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t tokenizer.pad_token_id, ) - speedup = total_hf_time / total_bt_time - mem_saved = max_mem_eager / max_mem_bt + total_hf_time = all_total_hf_time[(bs, seq_len)] + max_mem_eager = all_max_mem_eager[(bs, seq_len)] + + speedup = (total_hf_time / total_bt_time - 1) * 100 + mem_saved = (max_mem_eager / max_mem_bt - 1) * 100 max_mem_eager = max_mem_eager * 1e-6 max_mem_bt = max_mem_bt * 1e-6 @@ -259,4 +288,5 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t f"{mem_saved:.3f}", ) ) + output_file.close() diff --git a/tests/benchmark/benchmark_bettertransformer_training_minimal.py b/tests/benchmark/benchmark_bettertransformer_training_minimal.py index 6f356ec0ea..747b72184a 100644 --- a/tests/benchmark/benchmark_bettertransformer_training_minimal.py +++ b/tests/benchmark/benchmark_bettertransformer_training_minimal.py @@ -17,9 +17,9 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "--num-epochs", + "--num_training_steps", type=int, - default=5, + default=100, help="", ) @@ -51,8 +51,7 @@ def seed_init_fn(x): return -def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): - num_training_steps = 1024 // batch_size +def benchmark_training(model, inputs: Dict, num_training_steps: int): progress_bar = tqdm(range(num_training_steps)) model.train() @@ -70,31 +69,31 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): torch.cuda.synchronize() start_event.record() - for _ in range(num_epochs): - for _ in range(num_training_steps): - model.zero_grad() - outputs = model(**inputs) - loss = outputs.logits.sum() - loss.backward() - - progress_bar.update(1) + for _ in range(num_training_steps): + model.zero_grad() + outputs = model(**inputs) + loss = outputs.logits.sum() + loss.backward() + + progress_bar.update(1) end_event.record() torch.cuda.synchronize() max_memory = torch.cuda.max_memory_allocated(device) - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_epochs, max_memory + return (start_event.elapsed_time(end_event) * 1.0e-3) / num_training_steps, max_memory if __name__ == "__main__": parser = get_parser() args = parser.parse_args() - hf_model = AutoModelForCausalLM.from_pretrained(args.model_name) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - dtype = torch.float32 if args.use_half is False else torch.float16 - hf_model = hf_model.to(device=device, dtype=dtype) + with torch.device(device): + hf_model = AutoModelForCausalLM.from_pretrained( + args.model_name, torch_dtype=torch.float16 if args.use_half else None + ) + hf_model = hf_model.to(device) BATCH_SIZES = [8] SEQ_LEN = [1024] @@ -102,13 +101,14 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): output_file = open("log_{}_train.csv".format(args.model_name.replace("/", "-")), "w") output_file.write( - "num_epochs, batch_size, seq_len, is cuda, HF time / epoch (s), BT time / epoch (s), Speedup, Eager peak mem (MB), BT peak mem (MB), Mem saving\n" + "num_training_steps, batch_size, seq_len, is cuda, Time per batch (eager - s), Time per batch (BT - s), Speedup (%), Eager peak mem (MB), BT peak mem (MB), Mem saving (%)\n" ) - num_epochs = args.num_epochs + all_hf_time_per_batch = {} + all_eager_max_mem = {} for batch_size in BATCH_SIZES: for sequence_length in SEQ_LEN: - print(f"Benchmark on: bs={batch_size}, seq_len={sequence_length}") + print(f"Benchmark PT on: bs={batch_size}, seq_len={sequence_length}") vocab_size = hf_model.config.vocab_size inputs = { @@ -118,38 +118,55 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int): "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device), } - hf_time_per_epoch, eager_max_mem = benchmark_training( - hf_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size + hf_time_per_batch, eager_max_mem = benchmark_training( + hf_model, inputs=inputs, num_training_steps=args.num_training_steps ) - bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) + all_hf_time_per_batch[(batch_size, sequence_length)] = hf_time_per_batch + all_eager_max_mem[(batch_size, sequence_length)] = eager_max_mem + + bt_model = BetterTransformer.transform(hf_model) + for batch_size in BATCH_SIZES: + for sequence_length in SEQ_LEN: + print(f"Benchmark BT on: bs={batch_size}, seq_len={sequence_length}") + + vocab_size = hf_model.config.vocab_size + inputs = { + "input_ids": torch.randint(vocab_size - 1, (batch_size, sequence_length), dtype=torch.int64).to( + device + ), + "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device), + } # raise error if no optimized kernel is available with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): - bt_time_per_epoch, bt_max_mem = benchmark_training( - bt_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size + bt_time_per_batch, bt_max_mem = benchmark_training( + bt_model, inputs=inputs, num_training_steps=args.num_training_steps ) - eager_max_mem = eager_max_mem * 1e-6 + eager_max_mem = all_eager_max_mem[(batch_size, sequence_length)] * 1e-6 bt_max_mem = bt_max_mem * 1e-6 - print(f"PT eager: {hf_time_per_epoch:.3f} s, peak {eager_max_mem:.2f} MB") - print(f"PT native: {bt_time_per_epoch:.3f} s, peak {bt_max_mem:.2f} MB") - speedup = hf_time_per_epoch / bt_time_per_epoch - mem_saved = eager_max_mem / bt_max_mem + hf_time_per_batch = all_hf_time_per_batch[(batch_size, sequence_length)] + + print(f"PT eager: {hf_time_per_batch:.3f} s, peak {eager_max_mem:.2f} MB") + print(f"PT native: {bt_time_per_batch:.3f} s, peak {bt_max_mem:.2f} MB") + speedup = (hf_time_per_batch / bt_time_per_batch - 1) * 100 + mem_saved = (eager_max_mem / bt_max_mem - 1) * 100 output_file.write( "{},{},{},{},{},{},{},{},{},{}\n".format( - num_epochs, + args.num_training_steps, batch_size, sequence_length, args.use_cuda, - f"{hf_time_per_epoch:.3f}", - f"{bt_time_per_epoch:.3f}", + f"{hf_time_per_batch:.3f}", + f"{bt_time_per_batch:.3f}", f"{speedup:.3f}", f"{eager_max_mem:.3f}", f"{bt_max_mem:.3f}", f"{mem_saved:.3f}", ) ) + output_file.close()