Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix slow tests & sdpa dropout #974

Merged
merged 2 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ def gpt2_wrapped_scaled_dot_product(
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
Expand All @@ -73,7 +74,7 @@ def gpt2_wrapped_scaled_dot_product(
attention_mask = causal_mask + attention_mask

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
Expand Down Expand Up @@ -103,14 +104,15 @@ def gpt_neo_wrapped_scaled_dot_product(
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

dropout_p = self.dropout_prob_attn if self.training else 0.0
if (batch_size == 1 or self.training) and self.attention_type == "global":
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
Expand All @@ -125,7 +127,7 @@ def gpt_neo_wrapped_scaled_dot_product(
attention_mask = causal_mask + attention_mask

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

return sdpa_result, None
Expand Down Expand Up @@ -153,18 +155,19 @@ def codegen_wrapped_scaled_dot_product(
query = query.to(value.dtype)
key = key.to(value.dtype)

dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
# first step of the decoding
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask`` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
Expand All @@ -183,7 +186,7 @@ def codegen_wrapped_scaled_dot_product(
attention_mask = torch.min(causal_mask, attention_mask)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

return sdpa_result, None
Expand Down Expand Up @@ -247,18 +250,20 @@ def opt_forward(
query_states = self._shape(query_states, tgt_len, batch_size)

query_states = query_states * self.scale

dropout_p = self.dropout if self.training else 0.0
if batch_size == 1 or self.training:
if query_states.shape[2] > 1:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True
query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

if attn_output.size() != (batch_size, self.num_heads, tgt_len, self.head_dim):
Expand Down Expand Up @@ -361,15 +366,16 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
past_key_value[1] if past_key_value is not None else None,
)

dropout_p = self.dropout if self.training else 0.0
query_states = self.scale * query_states
if position_bias is None and not self.has_relative_attention_bias:
if mask is None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
elif mask is not None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=mask, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=mask, dropout_p=dropout_p, is_causal=False
)

if position_bias is None:
Expand All @@ -394,11 +400,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):

if self.has_relative_attention_bias:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False
)

attn_output = unshape(attn_output) # (batch_size, seq_length, dim)
Expand Down Expand Up @@ -471,7 +477,12 @@ def bart_forward(
value_states = value_states

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)

if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
Expand Down
19 changes: 14 additions & 5 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.supports_training = True
self.downcast_qk = False
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
Expand All @@ -91,19 +92,19 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
"scale_attn",
"masked_bias",
]
# Attribute only for transformers>=4.28
if hasattr(layer, "embed_positions"):
submodules.append("embed_positions")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

self.module_mapping = None
self.original_layers_mapping = {submodule: submodule for submodule in submodules}

# this attributes does not exist in transformers<=4.27.4
if hasattr(self, "embed_positions"):
self.original_layers_mapping["embed_positions"] = "embed_positions"
setattr(self, "embed_positions", getattr(layer, "embed_positions"))

self.downcast_qk = True
self.supports_training = True
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
Expand All @@ -127,6 +128,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.downcast_qk = True
self.supports_training = True
self.dropout_prob_attn = 0.0 # no dropout for gpt-neox

def forward(self, *args, **kwargs):
super().forward_checker()
Expand Down Expand Up @@ -156,6 +158,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.scale = torch.sqrt(torch.tensor(layer.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.supports_training = True
self.dropout_prob_attn = float(config.attention_dropout)

def forward(self, *args, **kwargs):
super().forward_checker()
Expand All @@ -173,12 +176,18 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.module_mapping = None
submodules = ["attn_dropout", "resid_dropout", "qkv_proj", "out_proj", "causal_mask", "scale_attn"]

# Attribute only for transformers>=4.28
if hasattr(layer, "embed_positions"):
submodules.append("embed_positions")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.supports_training = True
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
Expand Down
15 changes: 13 additions & 2 deletions optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
logger = logging.getLogger(__name__)

if is_accelerate_available():
from accelerate import dispatch_model
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import remove_hook_from_module

ERROR_MESSAGE = r"The Better Transformers implementation for the model {model_name} has not been implemented yet. Please open an issue requesting the addition of this model with its `BetterTransformer` implementation."
Expand Down Expand Up @@ -254,7 +254,18 @@ def transform(
setattr(model_fast, "use_bettertransformer", True)

if load_accelerate:
model_fast = dispatch_model(model_fast, hf_device_map)
all_model_tensors = [name for name, _ in model_fast.state_dict().items()]
for module_name in hf_device_map.keys():
all_model_tensors = [name for name in all_model_tensors if not name.startswith(module_name)]

if len(all_model_tensors) > 0:
# This is the case where a transformed submodule is broken into several devices:
# as the submodules map may differ, we need to reinfer the device map
bt_device_map = infer_auto_device_map(model_fast, max_memory=max_memory)
else:
bt_device_map = hf_device_map

model_fast = dispatch_model(model_fast, bt_device_map)

# It is not recommended to have `keep_original_model=True` with a model
# that is loaded with accelerate but just in case
Expand Down
40 changes: 39 additions & 1 deletion tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest

import pytest
Expand All @@ -22,7 +23,7 @@

from optimum.bettertransformer import BetterTransformer
from optimum.utils import DummyPastKeyValuesGenerator, NormalizedConfigManager
from optimum.utils.testing_utils import grid_parameters, require_torch_20, require_torch_gpu
from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_20, require_torch_gpu


class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
Expand Down Expand Up @@ -193,3 +194,40 @@ def test_invert_model_logits(self, test_name: str, model_type: str, keep_origina
self._test_invert_model_logits(
model_id=model_id, model_type=model_type, keep_original_model=keep_original_model
)

@parameterized.expand(
grid_parameters(
{"keep_original_model": [True], "max_memory": [{0: "300MB", "cpu": "3GB"}, {0: "2GB"}]},
add_test_name=False,
)
)
@require_torch_gpu
@require_accelerate
def test_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_memory=None):
hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", max_memory=max_memory).eval()
bt_model = BetterTransformer.transform(
hf_model, keep_original_model=keep_original_model, max_memory=max_memory
)

inputs_ids = torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]])
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]])

# Check that the model has been dispatched on CPU and GPU
self.assertSetEqual(set(hf_model.hf_device_map.values()), set(max_memory))
self.assertSetEqual(set(bt_model.hf_device_map.values()), set(max_memory))

# Check that the model has weights on GPU and CPU
self.assertEqual(bt_model.transformer.h[0].mlp.c_fc.weight.device, torch.device("cuda:0"))

# Weights that are offloaded on the CPU are offloaded on the `meta` device
if "cpu" in set(max_memory):
self.assertEqual(bt_model.transformer.h[-1].mlp.c_fc.weight.device, torch.device("meta"))

with torch.inference_mode():
output_bt = bt_model(inputs_ids, attention_mask=attention_mask)
output_hf = hf_model(inputs_ids, attention_mask=attention_mask)

self.assertEqual(output_bt[0].device, torch.device("cpu"))
self.assertTrue(torch.allclose(output_bt[0], output_hf[0], atol=1e-3))

gc.collect()
2 changes: 1 addition & 1 deletion tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def _onnx_export(
monolith: bool,
device="cpu",
):
model_class = TasksManager.get_model_class_for_task(task)
config = AutoConfig.from_pretrained(model_name)
model_class = TasksManager.get_model_class_for_task(task, model_type=config.model_type.replace("_", "-"))
model = model_class.from_config(config)

# Dynamic axes aren't supported for YOLO-like models. This means they cannot be exported to ONNX on CUDA devices.
Expand Down