Skip to content

Commit

Permalink
Fix non-nesting bug in BetterTransformer integration (#637)
Browse files Browse the repository at this point in the history
* small patch

* fix all models

* add test

* fix import

* import bis

* more tests

* fix bug

* informative log

* fix test

* add reason

* remove print

* better test

* fix log

Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
  • Loading branch information
younesbelkada and fxmarty authored Dec 23, 2022
1 parent 5f00fee commit 4600452
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 28 deletions.
30 changes: 6 additions & 24 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ def forward(self, hidden_states, attention_mask, *_):
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
Expand Down Expand Up @@ -207,10 +204,7 @@ def forward(self, hidden_states, attention_mask, *_):
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
Expand Down Expand Up @@ -323,10 +317,7 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
attention_mask = attention_mask.squeeze(1)[:, 0]
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
Expand Down Expand Up @@ -441,10 +432,7 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
attention_mask = attention_mask.squeeze(1)[:, 0]
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
Expand Down Expand Up @@ -948,10 +936,7 @@ def forward(self, hidden_states, attention_mask, **__):
if len(attention_mask.shape) == 4:
attention_mask = attention_mask.squeeze(1)[:, 0]
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
Expand Down Expand Up @@ -1061,17 +1046,14 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)

# FSMT swaps the first two axis before calling the encoder stack
# Reference: https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/fsmt/modeling_fsmt.py#L508
if hidden_states.shape[0] != attention_mask.shape[0]:
hidden_states = hidden_states.transpose(1, 0)
original_shape = hidden_states.shape

if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmark/benchmark_bettertransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def benchmark(model_name, num_batches, batch_size, avg_seqlen, max_seqlen, seqle
masks = None

# Warmup
_ = hf_model(input_ids[0].unsqueeze(0), masks[0].unsqueeze(0))
_ = hf_model(input_ids, masks)
torch.cuda.synchronize()
_ = bt_model(input_ids[0].unsqueeze(0), masks[0].unsqueeze(0))
_ = bt_model(input_ids, masks)
torch.cuda.synchronize()

total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks)
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmark/benchmark_bettertransformer_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def benchmark(model_name, num_batches, batch_size, is_cuda, is_half):
input_features = input_features.to(0)

# Warmup
_ = hf_model(input_features[0].unsqueeze(0))
_ = hf_model(input_features)
torch.cuda.synchronize()
_ = bt_model(input_features[0].unsqueeze(0))
_ = bt_model(input_features)
torch.cuda.synchronize()

total_hf_time = timing_cuda(hf_model, num_batches, input_features)
Expand Down
108 changes: 108 additions & 0 deletions tests/bettertransformer/test_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import unittest

import torch
from transformers import AutoModel

from optimum.bettertransformer import BetterTransformer
from optimum.utils import logging
from optimum.utils.testing_utils import grid_parameters
from parameterized import parameterized


logger = logging.get_logger()
logging.set_verbosity_info()


def timing_cuda(model, num_batches, input_ids, masks):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_batches):
_ = model(input_ids, masks)
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches


def benchmark(model_name: str, num_batches: int, batch_size: int, max_seqlen: int, is_half: bool):
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16 if is_half else None).eval()
hf_model = hf_model.to("cuda:0")
bt_model = BetterTransformer.transform(hf_model, keep_original_model=True)

vocab_size = hf_model.config.vocab_size
input_ids = torch.randint(vocab_size - 1, (batch_size, max_seqlen), dtype=torch.int64) + 1
masks = torch.ones(batch_size, max_seqlen, dtype=torch.int64)

input_ids = input_ids.to("cuda:0")
masks = masks.to("cuda:0")

# Warmup
_ = hf_model(input_ids[0].unsqueeze(0), masks[0].unsqueeze(0))
torch.cuda.synchronize()
_ = bt_model(input_ids[0].unsqueeze(0), masks[0].unsqueeze(0))
torch.cuda.synchronize()

total_hf_time = timing_cuda(hf_model, num_batches, input_ids, masks)
total_bt_time = timing_cuda(bt_model, num_batches, input_ids, masks)

return total_bt_time, total_hf_time


class TestSpeedup(unittest.TestCase):
"""
TODO: test missing for:
- WhisperEncoderLayerBetterTransformer
- ViTLayerBetterTransformer
- ViltLayerBetterTransformer
- Wav2Vec2EncoderLayerBetterTransformer
- FSMTEncoderLayerBetterTransformer
- CLIPLayerBetterTransformer
"""

REPRESENTATIVE_MODELS = [
"bert-base-uncased",
# "albert-base-v2", # TODO: AlbertLayerBetterTransformer seem to nest/unnest tensors all the time
"facebook/bart-base",
"facebook/mbart-large-50",
"distilbert-base-uncased",
]

@parameterized.expand(
grid_parameters(
{
"model_name": REPRESENTATIVE_MODELS,
"batch_size": [32, 64],
"sequence_length": [64, 128],
"use_half": [True, False],
}
)
)
@unittest.skipIf(int(os.environ.get("TEST_LEVEL", 0)) < 1, reason="disabled by default")
def test_base_speedup(
self, test_name: str, model_name: str, batch_size: int, sequence_length: int, use_half: bool
):
"""
Test to validate the BetterTransformer base speedup on GPU.
The speedup check is low because we still hit https://github.com/pytorch/pytorch/issues/91305
"""
num_batches = 50

total_bt_time, total_hf_time = benchmark(
model_name,
num_batches,
batch_size,
sequence_length,
use_half,
)

speedup = total_hf_time / total_bt_time

self.assertTrue(speedup > 0.85, msg=f"The BetterTransformer base speedup for {test_name} is {speedup}")

if speedup >= 0.85 and speedup < 1:
logger.warning(f"The BetterTransformer base speedup for {test_name} is {speedup}")
if speedup >= 1:
logger.info(f"The BetterTransformer base speedup for {test_name} is {speedup}")

0 comments on commit 4600452

Please sign in to comment.