Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BT] Add fp16 support #859

Merged
merged 15 commits into from
Mar 7, 2023
23 changes: 19 additions & 4 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, gpt_layer: "nn.Module", config: "PretrainedConfig"):
self.gpt_layer = gpt_layer
self.gpt_layer._attn = self.wrapped_scaled_dot_product

self.downcast_qk = config.model_type in ["gptj", "gpt_neox"]

mask_value = torch.finfo(torch.float32).min
self._mask_value = torch.full([], mask_value, dtype=torch.float32)

Expand Down Expand Up @@ -74,12 +76,18 @@ def wrapped_scaled_dot_product(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, self._mask_value)
causal_mask = torch.where(causal_mask, 0, self._mask_value.to(value.dtype))
Copy link
Contributor

@fxmarty fxmarty Mar 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not equivalent:

import torch

mask_value = torch.finfo(torch.float32).min
mask_value = torch.full([], mask_value, dtype=torch.float32)
casted = mask_value.to(torch.float16)

mask_value = torch.finfo(torch.float16).min
mask_value = torch.full([], mask_value, dtype=torch.float16)
assert torch.equal(casted, mask_value)

not sure if it has any influence or not though. I would just put the definition of mask_value in the forward directly, as in transformers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, adapted as suggested!


# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
attention_mask = causal_mask + attention_mask

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
query = query.to(value.dtype)
key = key.to(value.dtype)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
Expand Down Expand Up @@ -121,6 +129,7 @@ def wrapped_scaled_dot_product(
raise_on_head_mask(head_mask)
query = query * self.scale
batch_size = query.shape[0]

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

Expand All @@ -136,12 +145,13 @@ def wrapped_scaled_dot_product(
else:
query_length, key_length = query.size(-2), key.size(-2)

causal_mask = self.gpt_layer.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.gpt_layer.bias[:, :, key_length - query_length : key_length, :key_length]

causal_mask = torch.where(causal_mask, 0, self._mask_value)
causal_mask = torch.where(causal_mask, 0, self._mask_value.to(value.dtype))
if batch_size > 1:
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

attention_mask = causal_mask + attention_mask

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
Expand Down Expand Up @@ -201,14 +211,19 @@ def wrapped_scaled_dot_product(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, self._mask_value)
causal_mask = torch.where(causal_mask, 0, self._mask_value.to(value.dtype))

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)

# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
Expand Down
16 changes: 16 additions & 0 deletions tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import unittest

import pytest
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
import torch
from packaging.version import parse
from parameterized import parameterized
Expand Down Expand Up @@ -62,6 +63,21 @@ def test_logits_without_cache(self, test_name: str, model_type: str, padding, ba
model_id = MODELS_DICT[model_type]
super()._test_logits(model_id, padding=padding, batch_size=batch_size)

@parameterized.expand(
grid_parameters(
{
"model_type": SUPPORTED_ARCH,
"use_to_operator": [True, False],
}
)
)
@pytest.mark.fp16
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool):
self._skip_on_torch_version(model_type)

model_id = MODELS_DICT[model_type]
super()._test_fp16_inference(model_id, use_to_operator=use_to_operator)

@parameterized.expand(
grid_parameters(
{
Expand Down
42 changes: 40 additions & 2 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.

import torch
from transformers import AutoModel
from transformers import AutoModel, AutoModelForCausalLM

from optimum.bettertransformer import BetterTransformer
from optimum.utils.testing_utils import flatten_dict
from optimum.utils.testing_utils import flatten_dict, require_torch_gpu


MODELS_DICT = {
Expand Down Expand Up @@ -79,6 +79,44 @@ class BetterTransformersTestMixin:
def prepare_inputs_for_class(self, model_id=None):
raise NotImplementedError

@require_torch_gpu
def _test_fp16_inference(self, model_id: str, use_to_operator=False, **preprocessor_kwargs):
r"""
This tests if the converted model runs fine under fp16.
"""
# The first row of the attention mask needs to be all ones -> check: https://github.com/pytorch/pytorch/blob/19171a21ee8a9cc1a811ac46d3abd975f0b6fc3b/test/test_nn.py#L5283
inputs = self.prepare_inputs_for_class(model_id=model_id, **preprocessor_kwargs).to(0)

torch.manual_seed(0)
if not use_to_operator:
hf_random_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(0)
else:
hf_random_model = AutoModelForCausalLM.from_pretrained(model_id).to(0, torch.float16)

torch.manual_seed(0)
converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=True)

self.assertFalse(
hasattr(hf_random_model, "use_bettertransformer"),
f"The model {hf_random_model.__class__.__name__} has been converted to a `fast` model by mistake.",
)

with torch.no_grad():
r"""
Make sure the models are in eval mode! Make also sure that the original model
has not been converted to a fast model. The check is done above.
"""
torch.manual_seed(0)
output_hf = hf_random_model.generate(**inputs)

torch.manual_seed(0)
output_bt = converted_model.generate(**inputs)

self.assertTrue(
torch.allclose(output_hf, output_bt, atol=1e-4),
f"The logits of the converted model {converted_model.__class__.__name__} are not equal to the logits of the original model {hf_random_model.__class__.__name__}.",
)

def _test_logits(self, model_id: str, **preprocessor_kwargs):
r"""
This tests if the converted model produces the same logits
Expand Down