Skip to content

Commit

Permalink
Fix slow tests & sdpa dropout (#974)
Browse files Browse the repository at this point in the history
* fix slow tests & sdpa dropout

* add comment
  • Loading branch information
fxmarty authored Apr 17, 2023
1 parent f7f1ef1 commit 9829418
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 26 deletions.
45 changes: 28 additions & 17 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ def gpt2_wrapped_scaled_dot_product(
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
Expand All @@ -73,7 +74,7 @@ def gpt2_wrapped_scaled_dot_product(
attention_mask = causal_mask + attention_mask

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
Expand Down Expand Up @@ -103,14 +104,15 @@ def gpt_neo_wrapped_scaled_dot_product(
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, 0, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

dropout_p = self.dropout_prob_attn if self.training else 0.0
if (batch_size == 1 or self.training) and self.attention_type == "global":
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
Expand All @@ -125,7 +127,7 @@ def gpt_neo_wrapped_scaled_dot_product(
attention_mask = causal_mask + attention_mask

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

return sdpa_result, None
Expand Down Expand Up @@ -153,18 +155,19 @@ def codegen_wrapped_scaled_dot_product(
query = query.to(value.dtype)
key = key.to(value.dtype)

dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
# first step of the decoding
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask`` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
Expand All @@ -183,7 +186,7 @@ def codegen_wrapped_scaled_dot_product(
attention_mask = torch.min(causal_mask, attention_mask)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

return sdpa_result, None
Expand Down Expand Up @@ -247,18 +250,20 @@ def opt_forward(
query_states = self._shape(query_states, tgt_len, batch_size)

query_states = query_states * self.scale

dropout_p = self.dropout if self.training else 0.0
if batch_size == 1 or self.training:
if query_states.shape[2] > 1:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True
query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

if attn_output.size() != (batch_size, self.num_heads, tgt_len, self.head_dim):
Expand Down Expand Up @@ -361,15 +366,16 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
past_key_value[1] if past_key_value is not None else None,
)

dropout_p = self.dropout if self.training else 0.0
query_states = self.scale * query_states
if position_bias is None and not self.has_relative_attention_bias:
if mask is None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
elif mask is not None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=mask, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=mask, dropout_p=dropout_p, is_causal=False
)

if position_bias is None:
Expand All @@ -394,11 +400,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):

if self.has_relative_attention_bias:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=0.0, is_causal=False
query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False
)

attn_output = unshape(attn_output) # (batch_size, seq_length, dim)
Expand Down Expand Up @@ -471,7 +477,12 @@ def bart_forward(
value_states = value_states

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)

if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
Expand Down
19 changes: 14 additions & 5 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.supports_training = True
self.downcast_qk = False
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
Expand All @@ -91,19 +92,19 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
"scale_attn",
"masked_bias",
]
# Attribute only for transformers>=4.28
if hasattr(layer, "embed_positions"):
submodules.append("embed_positions")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

self.module_mapping = None
self.original_layers_mapping = {submodule: submodule for submodule in submodules}

# this attributes does not exist in transformers<=4.27.4
if hasattr(self, "embed_positions"):
self.original_layers_mapping["embed_positions"] = "embed_positions"
setattr(self, "embed_positions", getattr(layer, "embed_positions"))

self.downcast_qk = True
self.supports_training = True
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
Expand All @@ -127,6 +128,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.downcast_qk = True
self.supports_training = True
self.dropout_prob_attn = 0.0 # no dropout for gpt-neox

def forward(self, *args, **kwargs):
super().forward_checker()
Expand Down Expand Up @@ -156,6 +158,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.scale = torch.sqrt(torch.tensor(layer.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.supports_training = True
self.dropout_prob_attn = float(config.attention_dropout)

def forward(self, *args, **kwargs):
super().forward_checker()
Expand All @@ -173,12 +176,18 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.module_mapping = None
submodules = ["attn_dropout", "resid_dropout", "qkv_proj", "out_proj", "causal_mask", "scale_attn"]

# Attribute only for transformers>=4.28
if hasattr(layer, "embed_positions"):
submodules.append("embed_positions")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.supports_training = True
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
Expand Down
15 changes: 13 additions & 2 deletions optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
logger = logging.getLogger(__name__)

if is_accelerate_available():
from accelerate import dispatch_model
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import remove_hook_from_module

ERROR_MESSAGE = r"The Better Transformers implementation for the model {model_name} has not been implemented yet. Please open an issue requesting the addition of this model with its `BetterTransformer` implementation."
Expand Down Expand Up @@ -254,7 +254,18 @@ def transform(
setattr(model_fast, "use_bettertransformer", True)

if load_accelerate:
model_fast = dispatch_model(model_fast, hf_device_map)
all_model_tensors = [name for name, _ in model_fast.state_dict().items()]
for module_name in hf_device_map.keys():
all_model_tensors = [name for name in all_model_tensors if not name.startswith(module_name)]

if len(all_model_tensors) > 0:
# This is the case where a transformed submodule is broken into several devices:
# as the submodules map may differ, we need to reinfer the device map
bt_device_map = infer_auto_device_map(model_fast, max_memory=max_memory)
else:
bt_device_map = hf_device_map

model_fast = dispatch_model(model_fast, bt_device_map)

# It is not recommended to have `keep_original_model=True` with a model
# that is loaded with accelerate but just in case
Expand Down
40 changes: 39 additions & 1 deletion tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest

import pytest
Expand All @@ -22,7 +23,7 @@

from optimum.bettertransformer import BetterTransformer
from optimum.utils import DummyPastKeyValuesGenerator, NormalizedConfigManager
from optimum.utils.testing_utils import grid_parameters, require_torch_20, require_torch_gpu
from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_20, require_torch_gpu


class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
Expand Down Expand Up @@ -193,3 +194,40 @@ def test_invert_model_logits(self, test_name: str, model_type: str, keep_origina
self._test_invert_model_logits(
model_id=model_id, model_type=model_type, keep_original_model=keep_original_model
)

@parameterized.expand(
grid_parameters(
{"keep_original_model": [True], "max_memory": [{0: "300MB", "cpu": "3GB"}, {0: "2GB"}]},
add_test_name=False,
)
)
@require_torch_gpu
@require_accelerate
def test_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_memory=None):
hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", max_memory=max_memory).eval()
bt_model = BetterTransformer.transform(
hf_model, keep_original_model=keep_original_model, max_memory=max_memory
)

inputs_ids = torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]])
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]])

# Check that the model has been dispatched on CPU and GPU
self.assertSetEqual(set(hf_model.hf_device_map.values()), set(max_memory))
self.assertSetEqual(set(bt_model.hf_device_map.values()), set(max_memory))

# Check that the model has weights on GPU and CPU
self.assertEqual(bt_model.transformer.h[0].mlp.c_fc.weight.device, torch.device("cuda:0"))

# Weights that are offloaded on the CPU are offloaded on the `meta` device
if "cpu" in set(max_memory):
self.assertEqual(bt_model.transformer.h[-1].mlp.c_fc.weight.device, torch.device("meta"))

with torch.inference_mode():
output_bt = bt_model(inputs_ids, attention_mask=attention_mask)
output_hf = hf_model(inputs_ids, attention_mask=attention_mask)

self.assertEqual(output_bt[0].device, torch.device("cpu"))
self.assertTrue(torch.allclose(output_bt[0], output_hf[0], atol=1e-3))

gc.collect()
2 changes: 1 addition & 1 deletion tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def _onnx_export(
monolith: bool,
device="cpu",
):
model_class = TasksManager.get_model_class_for_task(task)
config = AutoConfig.from_pretrained(model_name)
model_class = TasksManager.get_model_class_for_task(task, model_type=config.model_type.replace("_", "-"))
model = model_class.from_config(config)

# Dynamic axes aren't supported for YOLO-like models. This means they cannot be exported to ONNX on CUDA devices.
Expand Down

0 comments on commit 9829418

Please sign in to comment.