Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPTDecoderLayer does not return attentions when gradient_checkpointing and training is enabled. #23366

Closed
4 tasks
gmlwns2000 opened this issue May 15, 2023 · 0 comments · Fixed by #23367
Closed
4 tasks

Comments

@gmlwns2000
Copy link
Contributor

gmlwns2000 commented May 15, 2023

Bug Description

In modeling_opt.py#704:710 code, OPTDecoder calls OPTDecoderLayer.forward with following argument order.

if self.gradient_checkpointing and self.training:
    def create_custom_forward(module):
        def custom_forward(*inputs):
            # None for past_key_value
            return module(*inputs, output_attentions, None)

        return custom_forward

    layer_outputs = torch.utils.checkpoint.checkpoint(
        create_custom_forward(decoder_layer),
        hidden_states,
        causal_attention_mask,
        head_mask[idx] if head_mask is not None else None,
        None,
    )
else:
    layer_outputs = decoder_layer(
        hidden_states,
        attention_mask=causal_attention_mask,
        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
    )

However, in OPTDecoderLayer.forward code, the order of argument is different with the previously showed function call argument order .

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    layer_head_mask: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = False, # **need to be reorder**
    use_cache: Optional[bool] = False, # **need to be reorder**
    past_key_value: Optional[Tuple[torch.Tensor]] = None, # **need to be reorder**
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

Therefore, output_attentions of OPTDecoderLayer.forward always being None, because 4th argument in function call is always None code

Solution

Just change the order of declaration of OPTDecoderLayer.forward as following

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    layer_head_mask: Optional[torch.Tensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

System Information

  • transformers version: 4.29.1
  • Platform: Linux-5.15.0-58-generic-x86_64-with-glibc2.35
  • Python version: 3.9.16
  • Huggingface_hub version: 0.14.1
  • Safetensors version: 0.2.7
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes and No. Bug happens in both places.
  • Using distributed or parallel set-up in script?: None

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import transformers
from transformers.models.opt.modeling_opt import OPTDecoder
import torch

model = transformers.OPTForCausalLM.from_pretrained('facebook/opt-125m')
model.train()
for m in model.modules():
    if isinstance(m, OPTDecoder):
        m.gradient_checkpointing = True
        m.config.use_cache = False
output = model(torch.zeros((1, 4), dtype=torch.int64), output_attentions=True)
assert type(output.attentions) == tuple
assert type(output.attentions[0]) == torch.Tensor, type(output.attentions[0])

The above test code should finish without error. However, the result is the following.

(torch) ainl@ainl-main-ubuntu:~/library/bug$ python -m opt_bug
Traceback (most recent call last):
  File "/home/ainl/anaconda3/envs/torch/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ainl/anaconda3/envs/torch/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/ainl/library/bug/opt_bug.py", line 13, in <module>
    assert type(output.attentions[0]) == torch.Tensor, type(output.attentions[0])
AssertionError: <class 'tuple'>

Following is my environment setting.

(torch) ainl@ainl-main-ubuntu:~/library/bug$ pip show torch transformers
Name: torch
Version: 2.0.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /home/ainl/anaconda3/envs/torch/lib/python3.9/site-packages
Requires: filelock, jinja2, networkx, nvidia-cublas-cu11, nvidia-cuda-cupti-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-runtime-cu11, nvidia-cudnn-cu11, nvidia-cufft-cu11, nvidia-curand-cu11, nvidia-cusolver-cu11, nvidia-cusparse-cu11, nvidia-nccl-cu11, nvidia-nvtx-cu11, sympy, triton, typing-extensions
Required-by: axial-positional-embedding, basicsr, deepspeed, facexlib, gfpgan, invisible-watermark, local-attention, onnx2torch, open-clip-torch, performer-pytorch, product-key-memory, pytorch-tabnet, realesrgan, sinkhorn-transformer, thop, timm, torch-tensorrt, torchaudio, torchdata, torchtext, torchvision, triton
---
Name: transformers
Version: 4.29.1
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /home/ainl/anaconda3/envs/torch/lib/python3.9/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, tokenizers, tqdm
Required-by:

Expected behavior

Finish the above test code without any errors.

Call for Moderator (Text-models)

@ArthurZucker and @younesbelkada

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant