Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P1] Loading REFT fro RoBERTa Models #86

Open
hSterz opened this issue May 13, 2024 · 4 comments
Open

[P1] Loading REFT fro RoBERTa Models #86

hSterz opened this issue May 13, 2024 · 4 comments
Assignees
Labels
question Further information is requested

Comments

@hSterz
Copy link

hSterz commented May 13, 2024

I was training and saving REFT modules for the RoBERTa model. But loading them seems to be not possible with the current implementation. I get the following Error:

Traceback (most recent call last):
  File "/mnt/home/pyreft/examples/loreft/eval_nusaX.py", line 159, in <module>
    main()
  File "/mnt/home/pyreft/examples/loreft/eval_nusaX.py", line 57, in main
    reft_model = pyreft.ReftModel.load(
  File "/mnt/home/pyreft/pyreft/reft_model.py", line 26, in load
    model = pv.IntervenableModel.load(*args, **kwargs)
  File "/mnt/home/miniconda3/envs/reft/lib/python3.10/site-packages/pyvene/models/intervenable_base.py", line 550, in load
    intervenable = IntervenableModel(saving_config, model)
  File "/mnt/home/miniconda3/envs/reft/lib/python3.10/site-packages/pyvene/models/intervenable_base.py", line 116, in __init__
    component_dim = get_dimension_by_component(
  File "/mnt/home/miniconda3/envs/reft/lib/python3.10/site-packages/pyvene/models/modeling_utils.py", line 101, in get_dimension_by_component
    if component not in type_to_dimension_mapping[model_type]:
KeyError: <class 'transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification'>

It looks like type_to_dimension_mapping does not include RoBERTa or does RoBERTa fall under one of the existing models?

@frankaging frankaging changed the title Loading REFT fro RoBERTa Models [P1] Loading REFT fro RoBERTa Models May 13, 2024
@frankaging frankaging self-assigned this May 13, 2024
@frankaging frankaging added the question Further information is requested label May 13, 2024
@frankaging
Copy link
Collaborator

hey @hSterz thanks for your question!

the reason is that RoBERTa is not natively supported by pyvene ('pyreft parent library). we thus use RoBERTa to show, we can work with any torch models.

to use RoBERTa, you have to setup your config as (with a string access to the component "model.layers[0].output"; note that this is an example for llama model but the same concept here),

# get reft model
reft_config = pyreft.ReftConfig(representations={
    "layer": 15, "component": "block_output",
    # alternatively, you can specify as string component access,
    # "component": "model.layers[0].output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

"""
trainable intervention params: 32,772 || trainable model params: 0
model params: 6,738,415,616 || trainable%: 0.00048634578018881287
"""

in our actual code, you can see how we did it as well here:
https://github.com/stanfordnlp/pyreft/blob/main/examples/loreft/train.py#L286

@hSterz
Copy link
Author

hSterz commented May 15, 2024

Thank you for the reply @frankaging My question is how can I load a REFT module added and trained trained as described by your example?

@frankaging
Copy link
Collaborator

@hSterz got it! so, if the model is natively supported by pyvene (supported model can be found here), you can load the model as,

reft_model = pyreft.ReftModel.load("<your_directory>", model)

if the model is not supported by pyvene, you have either (1) add the support in pyvene and reinstall pyvene, or (2) reinitialize the pyreft model, and load manually by yourself. All the interventions can be accessed as reft_model.interventions.

let me know if these help.

@m-dev12
Copy link

m-dev12 commented Jul 10, 2024

Hi @hSterz, how did you finally go about this? I am facing the same error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants