You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using distributed or parallel set-up in script?: no
Using GPU in script?: yes
GPU type: NVIDIA A100-SXM4-80GB
Information
Model I am using (Bert, XLNet ...): Llama-3.1-8B-Instruct
Language I am using the model on (English, Chinese ...): English
Adapter setup I am using (if any): LoReFT
The problem arises when using:
the official example scripts: (give details below)
my own modified scripts: (give details below)
The tasks I am working on is:
an official GLUE/SQUaD task: (give the name)
my own task or dataset: (give details below)
To reproduce
Code:
fromtransformersimportAutoModelForCausalLM, AutoTokenizerfromadaptersimportReftConfigimportadaptersimporttorchmodel=AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype=torch.bfloat16,
)
tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
config=ReftConfig(
layers="all", prefix_positions=3, suffix_positions=0, r=1, orthogonality=True
)
adapters.init(model)
model.add_adapter("loreft", config=config)
model.train_adapter("loreft")
model.bfloat16()
model.cuda()
tokens=tokenizer(["<|start_header_id|>user<|end_header_id|>\n\nhello, my name is Tom. Nice to meet you!<|start_header_id|>assistant<|end_header_id|>\n\nHi! I am Llama."], return_tensors="pt")
model(input_ids=tokens["input_ids"].cuda())
Traceback:
Traceback (most recent call last):
File "/data/liyanyang/tests/test_adapters.py", line 22, in <module>
model(input_ids=tokens["input_ids"].cuda())
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/context.py", line 116, in wrapper_func
results = f(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/model_mixin.py", line 1470, in forward
return super().forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1000, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1616, in _call_impl
hook_result = hook(self, args, result)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/methods/reft.py", line 209, in hook_fn
return (module.reft_layer(output[0]),) + output[1:]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/methods/reft.py", line 191, in forward
hidden_states = self.refts[first_adapter](hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/methods/reft.py", line 150, in forward
adapted_states[i] = unit(adapted_states[i])
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/adapters/methods/reft.py", line 39, in forward
projected_states = self.projection(x)
^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 117, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/utils/parametrize.py", line 379, in get_parametrized
return parametrization()
^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/utils/parametrize.py", line 276, in forward
x = self[0](self.original)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/liyanyang/miniconda3/envs/reason/lib/python3.12/site-packages/torch/nn/utils/parametrizations.py", line 100, in forward
Q = torch.linalg.householder_product(A, tau)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: "orgqr_cuda" not implemented for 'BFloat16'
Expected behavior
The BFloat16 should work with ReFT.
The text was updated successfully, but these errors were encountered:
Environment info
adapters
version: 1.0.1transformers
version: 4.45.2Information
Model I am using (Bert, XLNet ...): Llama-3.1-8B-Instruct
Language I am using the model on (English, Chinese ...): English
Adapter setup I am using (if any): LoReFT
The problem arises when using:
The tasks I am working on is:
To reproduce
Code:
Traceback:
Expected behavior
The BFloat16 should work with ReFT.
The text was updated successfully, but these errors were encountered: