Skip to content

Commit

Permalink
Improve bettertransformer benchmark script (#939)
Browse files Browse the repository at this point in the history
* fix mem

* fix

* fix

* just trying

* fix

* just trying

* just trying

* fix bt?

* fix

* add warning
  • Loading branch information
fxmarty authored Apr 4, 2023
1 parent 0bf2c05 commit 283555f
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 56 deletions.
8 changes: 4 additions & 4 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def gpt2_wrapped_scaled_dot_product(
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

if batch_size == 1:
if batch_size == 1 or self.training:
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
Expand Down Expand Up @@ -103,7 +103,7 @@ def gpt_neo_wrapped_scaled_dot_product(
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

if batch_size == 1 and self.attention_type == "global":
if (batch_size == 1 or self.training) and self.attention_type == "global":
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
Expand Down Expand Up @@ -153,7 +153,7 @@ def codegen_wrapped_scaled_dot_product(
query = query.to(value.dtype)
key = key.to(value.dtype)

if batch_size == 1:
if batch_size == 1 or self.training:
if query.shape[2] > 1:
# first step of the decoding
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
Expand Down Expand Up @@ -247,7 +247,7 @@ def opt_forward(
query_states = self._shape(query_states, tgt_len, batch_size)

query_states = query_states * self.scale
if batch_size == 1:
if batch_size == 1 or self.training:
if query_states.shape[2] > 1:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True
Expand Down
31 changes: 17 additions & 14 deletions optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
logger = logging.getLogger(__name__)

if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map
from accelerate import dispatch_model
from accelerate.hooks import remove_hook_from_module

ERROR_MESSAGE = r"The Better Transformers implementation for the model {model_name} has not been implemented yet. Please open an issue requesting the addition of this model with its `BetterTransformer` implementation."
Expand Down Expand Up @@ -192,6 +192,7 @@ def transform(
# Check if we have to load the model using `accelerate`
if hasattr(model, "hf_device_map"):
load_accelerate = True
hf_device_map = model.hf_device_map
else:
load_accelerate = False

Expand Down Expand Up @@ -224,8 +225,7 @@ def transform(
hf_config = model.config

if load_accelerate:
# remove the hooks from the original model to
# avoid weights being on `meta` device.
# Remove the hooks from the original model to avoid weights being on `meta` device.
remove_hook_from_module(model, recurse=True)

if keep_original_model:
Expand All @@ -249,24 +249,27 @@ def transform(
if BetterTransformerManager.requires_nested_tensor(model_fast.config.model_type):
set_last_layer(model_fast)

# Step 6: Add a class arguments, we might need to identify whether the model
# Add a class arguments, we might need to identify whether the model
# has been correctly converted to its `BetterTransformer` version.
setattr(model_fast, "use_bettertransformer", True)

# Step 7: dispatch model if `accelerate` is enabled
if load_accelerate:
device_map_bt = infer_auto_device_map(model_fast, max_memory=max_memory)

remove_hook_from_module(model_fast, recurse=True)

model_fast = dispatch_model(model_fast, device_map_bt)
model_fast = dispatch_model(model_fast, hf_device_map)

# It is not recommended to have `keep_original_model=True` with a model
# that is loaded with accelerate but just in case
if keep_original_model:
# It is not recommended to have `keep_original_model=True` with a model
# that is loaded with accelerate but just in case ..
model = dispatch_model(model, model.hf_device_map)
model = dispatch_model(model, hf_device_map)

# See: https://github.com/pytorch/pytorch/issues/96099
if BetterTransformerManager.requires_torch_20(model_fast.config.model_type):
logging.warning(
f"For training, the BetterTransformer implementation for {model_fast.config.model_type} "
" architecture currently does not support padding as fused kernels do not support custom"
" attention masks. Beware that passing padded batched training data may result in unexpected outputs."
)

# Step 8: overwrite the `save_pretrained` method
# Overwrite the `save_pretrained` method
# by raising an error if the user tries to save the model
# or push it to the hub.
model_fast._old_save_pretrained = model_fast.save_pretrained
Expand Down
40 changes: 35 additions & 5 deletions tests/benchmark/benchmark_bettertransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
else:
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None)

bt_model = BetterTransformer.transform(hf_model, keep_original_model=True)

output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w")
output_file.write(
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, HF time, BT time, Speedup, Mem eager (MB), Mem BT (MB), Mem saved\n"
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, Latency eager (s), Latency BT (s), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)\n"
)

all_total_hf_time = {}
all_max_mem_eager = {}

for bs in tqdm(BATCH_SIZES):
for seq_len in tqdm(SEQ_LEN):
for pad_perc in tqdm(PAD_PERCENTAGES):
Expand Down Expand Up @@ -219,6 +221,30 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
tokenizer.pad_token_id,
)

all_total_hf_time[(bs, seq_len)] = total_hf_time
all_max_mem_eager[(bs, seq_len)] = max_mem_eager

bt_model = BetterTransformer.transform(hf_model)
for bs in tqdm(BATCH_SIZES):
for seq_len in tqdm(SEQ_LEN):
for pad_perc in tqdm(PAD_PERCENTAGES):
print(f"-- Running: bs={bs}, seq_len={seq_len}")
# current_std = int(seq_len*pad_perc)
# max_seqlen = seq_len + current_std
max_seqlen = seq_len
mean_seqlen = int((1 - pad_perc) * max_seqlen)
input_ids, _, masks = get_batch(
bs, mean_seqlen, max_seqlen, args.seqlen_stdev, vocab_size=hf_model.config.vocab_size
)

if args.use_cuda:
input_ids = input_ids.to(device)
masks = masks.to(device)

if args.use_mask is False and bs == 1:
masks = None

with torch.inference_mode():
# raise error if no optimized kernel is available
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
Expand All @@ -233,8 +259,11 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
tokenizer.pad_token_id,
)

speedup = total_hf_time / total_bt_time
mem_saved = max_mem_eager / max_mem_bt
total_hf_time = all_total_hf_time[(bs, seq_len)]
max_mem_eager = all_max_mem_eager[(bs, seq_len)]

speedup = (total_hf_time / total_bt_time - 1) * 100
mem_saved = (max_mem_eager / max_mem_bt - 1) * 100

max_mem_eager = max_mem_eager * 1e-6
max_mem_bt = max_mem_bt * 1e-6
Expand All @@ -259,4 +288,5 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
f"{mem_saved:.3f}",
)
)

output_file.close()
83 changes: 50 additions & 33 deletions tests/benchmark/benchmark_bettertransformer_training_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def get_parser():
parser = argparse.ArgumentParser()

parser.add_argument(
"--num-epochs",
"--num_training_steps",
type=int,
default=5,
default=100,
help="",
)

Expand Down Expand Up @@ -51,8 +51,7 @@ def seed_init_fn(x):
return


def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int):
num_training_steps = 1024 // batch_size
def benchmark_training(model, inputs: Dict, num_training_steps: int):
progress_bar = tqdm(range(num_training_steps))

model.train()
Expand All @@ -70,45 +69,46 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int):

torch.cuda.synchronize()
start_event.record()
for _ in range(num_epochs):
for _ in range(num_training_steps):
model.zero_grad()
outputs = model(**inputs)
loss = outputs.logits.sum()
loss.backward()

progress_bar.update(1)
for _ in range(num_training_steps):
model.zero_grad()
outputs = model(**inputs)
loss = outputs.logits.sum()
loss.backward()

progress_bar.update(1)
end_event.record()
torch.cuda.synchronize()

max_memory = torch.cuda.max_memory_allocated(device)

return (start_event.elapsed_time(end_event) * 1.0e-3) / num_epochs, max_memory
return (start_event.elapsed_time(end_event) * 1.0e-3) / num_training_steps, max_memory


if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()

hf_model = AutoModelForCausalLM.from_pretrained(args.model_name)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float32 if args.use_half is False else torch.float16
hf_model = hf_model.to(device=device, dtype=dtype)
with torch.device(device):
hf_model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.float16 if args.use_half else None
)
hf_model = hf_model.to(device)

BATCH_SIZES = [8]
SEQ_LEN = [1024]
device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")

output_file = open("log_{}_train.csv".format(args.model_name.replace("/", "-")), "w")
output_file.write(
"num_epochs, batch_size, seq_len, is cuda, HF time / epoch (s), BT time / epoch (s), Speedup, Eager peak mem (MB), BT peak mem (MB), Mem saving\n"
"num_training_steps, batch_size, seq_len, is cuda, Time per batch (eager - s), Time per batch (BT - s), Speedup (%), Eager peak mem (MB), BT peak mem (MB), Mem saving (%)\n"
)
num_epochs = args.num_epochs
all_hf_time_per_batch = {}
all_eager_max_mem = {}

for batch_size in BATCH_SIZES:
for sequence_length in SEQ_LEN:
print(f"Benchmark on: bs={batch_size}, seq_len={sequence_length}")
print(f"Benchmark PT on: bs={batch_size}, seq_len={sequence_length}")

vocab_size = hf_model.config.vocab_size
inputs = {
Expand All @@ -118,38 +118,55 @@ def benchmark_training(model, inputs: Dict, num_epochs: int, batch_size: int):
"attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device),
}

hf_time_per_epoch, eager_max_mem = benchmark_training(
hf_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size
hf_time_per_batch, eager_max_mem = benchmark_training(
hf_model, inputs=inputs, num_training_steps=args.num_training_steps
)

bt_model = BetterTransformer.transform(hf_model, keep_original_model=True)
all_hf_time_per_batch[(batch_size, sequence_length)] = hf_time_per_batch
all_eager_max_mem[(batch_size, sequence_length)] = eager_max_mem

bt_model = BetterTransformer.transform(hf_model)
for batch_size in BATCH_SIZES:
for sequence_length in SEQ_LEN:
print(f"Benchmark BT on: bs={batch_size}, seq_len={sequence_length}")

vocab_size = hf_model.config.vocab_size
inputs = {
"input_ids": torch.randint(vocab_size - 1, (batch_size, sequence_length), dtype=torch.int64).to(
device
),
"attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device),
}

# raise error if no optimized kernel is available
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
bt_time_per_epoch, bt_max_mem = benchmark_training(
bt_model, inputs=inputs, num_epochs=num_epochs, batch_size=batch_size
bt_time_per_batch, bt_max_mem = benchmark_training(
bt_model, inputs=inputs, num_training_steps=args.num_training_steps
)

eager_max_mem = eager_max_mem * 1e-6
eager_max_mem = all_eager_max_mem[(batch_size, sequence_length)] * 1e-6
bt_max_mem = bt_max_mem * 1e-6

print(f"PT eager: {hf_time_per_epoch:.3f} s, peak {eager_max_mem:.2f} MB")
print(f"PT native: {bt_time_per_epoch:.3f} s, peak {bt_max_mem:.2f} MB")
speedup = hf_time_per_epoch / bt_time_per_epoch
mem_saved = eager_max_mem / bt_max_mem
hf_time_per_batch = all_hf_time_per_batch[(batch_size, sequence_length)]

print(f"PT eager: {hf_time_per_batch:.3f} s, peak {eager_max_mem:.2f} MB")
print(f"PT native: {bt_time_per_batch:.3f} s, peak {bt_max_mem:.2f} MB")
speedup = (hf_time_per_batch / bt_time_per_batch - 1) * 100
mem_saved = (eager_max_mem / bt_max_mem - 1) * 100

output_file.write(
"{},{},{},{},{},{},{},{},{},{}\n".format(
num_epochs,
args.num_training_steps,
batch_size,
sequence_length,
args.use_cuda,
f"{hf_time_per_epoch:.3f}",
f"{bt_time_per_epoch:.3f}",
f"{hf_time_per_batch:.3f}",
f"{bt_time_per_batch:.3f}",
f"{speedup:.3f}",
f"{eager_max_mem:.3f}",
f"{bt_max_mem:.3f}",
f"{mem_saved:.3f}",
)
)

output_file.close()

0 comments on commit 283555f

Please sign in to comment.