Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BT] Add fp16 support #859

Merged
merged 15 commits into from
Mar 7, 2023
51 changes: 33 additions & 18 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def __init__(self, gpt_layer: "nn.Module", config: "PretrainedConfig"):
self.gpt_layer = gpt_layer
self.gpt_layer._attn = self.wrapped_scaled_dot_product

mask_value = torch.finfo(torch.float32).min
self._mask_value = torch.full([], mask_value, dtype=torch.float32)
self.downcast_qk = config.model_type in ["gptj", "gpt_neox"]

def wrapped_scaled_dot_product(
self,
Expand All @@ -52,6 +51,9 @@ def wrapped_scaled_dot_product(
raise_on_head_mask(head_mask)
batch_size = query.shape[0]

mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

Expand All @@ -74,12 +76,18 @@ def wrapped_scaled_dot_product(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, self._mask_value)
causal_mask = torch.where(causal_mask, 0, mask_value)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
attention_mask = causal_mask + attention_mask

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
query = query.to(value.dtype)
key = key.to(value.dtype)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
Expand All @@ -98,9 +106,6 @@ def __init__(self, gpt_layer: "nn.Module", config: "PretrainedConfig"):
self.gpt_layer = gpt_layer
self.gpt_layer._attn = self.wrapped_scaled_dot_product

mask_value = torch.finfo(torch.float32).min
self._mask_value = torch.full([], mask_value, dtype=torch.float32)

if self.gpt_layer.bias[0][0][-1][0] == 1:
self.attention_type = "global"
else:
Expand All @@ -121,6 +126,10 @@ def wrapped_scaled_dot_product(
raise_on_head_mask(head_mask)
query = query * self.scale
batch_size = query.shape[0]

mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

Expand All @@ -136,12 +145,13 @@ def wrapped_scaled_dot_product(
else:
query_length, key_length = query.size(-2), key.size(-2)

causal_mask = self.gpt_layer.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.gpt_layer.bias[:, :, key_length - query_length : key_length, :key_length]

causal_mask = torch.where(causal_mask, 0, self._mask_value)
causal_mask = torch.where(causal_mask, 0, mask_value)
if batch_size > 1:
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

attention_mask = causal_mask + attention_mask

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -162,9 +172,6 @@ def __init__(self, gpt_layer: "nn.Module", config: "PretrainedConfig"):
self.gpt_layer = gpt_layer
self.gpt_layer._attn = self.wrapped_scaled_dot_product

mask_value = torch.finfo(torch.float32).min
self._mask_value = torch.full([], mask_value, dtype=torch.float32)

def wrapped_scaled_dot_product(
self,
query: torch.Tensor,
Expand All @@ -175,6 +182,9 @@ def wrapped_scaled_dot_product(
):
raise_on_head_mask(head_mask)
batch_size = query.shape[0]
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

Expand All @@ -201,14 +211,19 @@ def wrapped_scaled_dot_product(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, self._mask_value)
causal_mask = torch.where(causal_mask, 0, mask_value)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)

# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
Expand All @@ -226,9 +241,6 @@ def __init__(self, opt_layer: "nn.Module", config: "PretrainedConfig"):

self.opt_layer = opt_layer

mask_value = torch.finfo(torch.float32).min
self._mask_value = torch.full([], mask_value, dtype=torch.float32)

self.scale = torch.sqrt(torch.tensor(self.opt_layer.head_dim, dtype=torch.float32)).to(
torch.get_default_dtype()
)
Expand All @@ -245,6 +257,9 @@ def forward(
super().forward_checker()
raise_on_head_mask(layer_head_mask)

mask_value = torch.finfo(torch.float32).min
self._mask_value = torch.full([], mask_value, dtype=torch.float32)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

Expand Down Expand Up @@ -329,9 +344,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.layer = layer

mask_value = torch.finfo(torch.float32).min
self._mask_value = torch.full([], mask_value, dtype=torch.float32)

head_dim = self.layer.d_model // self.layer.n_heads # hidden size / num attention heads
self.scale = torch.sqrt(torch.tensor(head_dim, dtype=torch.float32)).to(torch.get_default_dtype())

Expand All @@ -349,6 +361,9 @@ def forward(
):
super().forward_checker()
raise_on_head_mask(layer_head_mask)
mask_value = torch.finfo(torch.float32).min
self._mask_value = torch.full([], mask_value, dtype=torch.float32)

if len(self.layer.pruned_heads) > 0:
raise ValueError(
f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.layer.pruned_heads}."
Expand Down
19 changes: 18 additions & 1 deletion tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import unittest

import pytest
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
import torch
from packaging.version import parse
from parameterized import parameterized
Expand All @@ -22,7 +23,7 @@

from optimum.bettertransformer import BetterTransformer, BetterTransformerManager
from optimum.utils import DummyPastKeyValuesGenerator, NormalizedConfigManager
from optimum.utils.testing_utils import grid_parameters
from optimum.utils.testing_utils import grid_parameters, require_torch_gpu


class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
Expand Down Expand Up @@ -62,6 +63,22 @@ def test_logits_without_cache(self, test_name: str, model_type: str, padding, ba
model_id = MODELS_DICT[model_type]
super()._test_logits(model_id, padding=padding, batch_size=batch_size)

@parameterized.expand(
grid_parameters(
{
"model_type": SUPPORTED_ARCH,
"use_to_operator": [True, False],
}
)
)
@pytest.mark.fp16
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
@require_torch_gpu
def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool):
self._skip_on_torch_version(model_type)

model_id = MODELS_DICT[model_type]
super()._test_fp16_inference(model_id, use_to_operator=use_to_operator)

@parameterized.expand(
grid_parameters(
{
Expand Down
42 changes: 40 additions & 2 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.

import torch
from transformers import AutoModel
from transformers import AutoModel, AutoModelForCausalLM

from optimum.bettertransformer import BetterTransformer
from optimum.utils.testing_utils import flatten_dict
from optimum.utils.testing_utils import flatten_dict, require_torch_gpu


MODELS_DICT = {
Expand Down Expand Up @@ -79,6 +79,44 @@ class BetterTransformersTestMixin:
def prepare_inputs_for_class(self, model_id=None):
raise NotImplementedError

@require_torch_gpu
def _test_fp16_inference(self, model_id: str, use_to_operator=False, **preprocessor_kwargs):
r"""
This tests if the converted model runs fine under fp16.
"""
# The first row of the attention mask needs to be all ones -> check: https://github.com/pytorch/pytorch/blob/19171a21ee8a9cc1a811ac46d3abd975f0b6fc3b/test/test_nn.py#L5283
inputs = self.prepare_inputs_for_class(model_id=model_id, **preprocessor_kwargs).to(0)

torch.manual_seed(0)
if not use_to_operator:
hf_random_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(0)
else:
hf_random_model = AutoModelForCausalLM.from_pretrained(model_id).to(0, torch.float16)

torch.manual_seed(0)
converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=True)

self.assertFalse(
hasattr(hf_random_model, "use_bettertransformer"),
f"The model {hf_random_model.__class__.__name__} has been converted to a `fast` model by mistake.",
)

with torch.no_grad():
r"""
Make sure the models are in eval mode! Make also sure that the original model
has not been converted to a fast model. The check is done above.
"""
torch.manual_seed(0)
output_hf = hf_random_model.generate(**inputs)

torch.manual_seed(0)
output_bt = converted_model.generate(**inputs)

self.assertTrue(
torch.allclose(output_hf, output_bt, atol=1e-4),
f"The logits of the converted model {converted_model.__class__.__name__} are not equal to the logits of the original model {hf_random_model.__class__.__name__}.",
)

def _test_logits(self, model_id: str, **preprocessor_kwargs):
r"""
This tests if the converted model produces the same logits
Expand Down