Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReFT does not support BFloat16 #772

Open
2 of 4 tasks
lyy1994 opened this issue Dec 31, 2024 · 1 comment
Open
2 of 4 tasks

ReFT does not support BFloat16 #772

lyy1994 opened this issue Dec 31, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@lyy1994
Copy link

lyy1994 commented Dec 31, 2024

Environment info

  • adapters version: 1.0.1
  • transformers version: 4.45.2
  • Platform: Linux-5.4.0-202-generic-x86_64-with-glibc2.31
  • Python version: 3.12.5
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA A100-SXM4-80GB

Information

Model I am using (Bert, XLNet ...): Llama-3.1-8B-Instruct

Language I am using the model on (English, Chinese ...): English

Adapter setup I am using (if any): LoReFT

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Code:

from transformers import AutoModelForCausalLM, AutoTokenizer
from adapters import ReftConfig
import adapters
import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

config = ReftConfig(
    layers="all", prefix_positions=3, suffix_positions=0, r=1, orthogonality=True
)
adapters.init(model)
model.add_adapter("loreft", config=config)
model.train_adapter("loreft")
model.bfloat16()
model.cuda()

tokens = tokenizer(["<|start_header_id|>user<|end_header_id|>\n\nhello, my name is Tom. Nice to meet you!<|start_header_id|>assistant<|end_header_id|>\n\nHi! I am Llama."], return_tensors="pt")
model(input_ids=tokens["input_ids"].cuda())

Traceback:

Traceback (most recent call last):
  File "/data/liyanyang/tests/test_adapters.py", line 22, in <module>
    model(input_ids=tokens["input_ids"].cuda())                                                                                                              
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                  
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                     
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(     
              ^^^^^^^^^^^                                                                                                                                    
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                  
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                     
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/context.py", line 116, in wrapper_func
    results = f(self, *args, **kwargs)          
              ^^^^^^^^^^^^^^^^^^^^^^^^                   
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/model_mixin.py", line 1470, in forward
    return super().forward(*args, **kwargs)                  
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1000, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1616, in _call_impl
    hook_result = hook(self, args, result)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/methods/reft.py", line 209, in hook_fn
    return (module.reft_layer(output[0]),) + output[1:]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/methods/reft.py", line 191, in forward
    hidden_states = self.refts[first_adapter](hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/methods/reft.py", line 150, in forward
    adapted_states[i] = unit(adapted_states[i])
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/methods/reft.py", line 39, in forward
    projected_states = self.projection(x)
                       ^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
                           ^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/utils/parametrize.py", line 379, in get_parametrized
    return parametrization()
         ^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/utils/parametrize.py", line 276, in forward
    x = self[0](self.original)
        ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/utils/parametrizations.py", line 100, in forward
    Q = torch.linalg.householder_product(A, tau)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: "orgqr_cuda" not implemented for 'BFloat16'

Expected behavior

The BFloat16 should work with ReFT.

@lyy1994 lyy1994 added the bug Something isn't working label Dec 31, 2024
@lyy1994
Copy link
Author

lyy1994 commented Dec 31, 2024

Looks like we should cast x to float32 before calling nn.utils.parametrizations.orthogonal if orthogonal is true.

projected_states = self.projection(x)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant