Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.add_embeddings prevent model from training adapters #384

Closed
4 tasks
eugene-yang opened this issue Jul 5, 2022 · 0 comments · Fixed by #386
Closed
4 tasks

.add_embeddings prevent model from training adapters #384

eugene-yang opened this issue Jul 5, 2022 · 0 comments · Fixed by #386
Labels
bug Something isn't working

Comments

@eugene-yang
Copy link

Environment info

  • adapter-transformers version: v3.0.1+ (commit 11bd9d2)
  • Platform: Arch Linux
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?):
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Information

Model I am using (Bert, XLNet ...): XLMR

Language I am using the model on (English, Chinese ...):

Adapter setup I am using (if any):

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

from transformers import AdapterSetup, AutoAdapterModel, AutoTokenizer, AdapterConfig
model = AutoAdapterModel.from_pretrained('xlm-roberta-base')
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
tokenizer_new = AutoTokenizer.from_pretrained('xlm-roberta-base')
model.add_embeddings('new', tokenizer_new, reference_tokenizer=tokenizer, reference_embedding='default')
model.add_adapter('test')
model.train_adapter(['test'])

This gives me the following error message.

File /expscratch/eyang/workspace/adapter/adapter-transformers/src/transformers/adapters/model_mixin.py:949, in ModelWithHeadsAdaptersMixin.train_adapter(self, adapter_setup, train_embeddings)
    947     super().train_adapter(adapter_setup, train_embeddings)
    948 else:
--> 949     self.base_model.train_adapter(adapter_setup, train_embeddings)

File /expscratch/eyang/workspace/adapter/adapter-transformers/src/transformers/adapters/model_mixin.py:287, in ModelAdaptersMixin.train_adapter(self, adapter_setup, train_embeddings)
    285 """Sets the model into mode for training the given adapters."""
    286 self.train()
--> 287 self.freeze_model(True)
    288 adapter_setup = parse_composition(adapter_setup)
    289 self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, True, False))

File /expscratch/eyang/workspace/adapter/adapter-transformers/src/transformers/adapters/model_mixin.py:726, in ModelAdaptersMixin.freeze_model(self, freeze)
    724 # first freeze/ unfreeze all model weights
    725 for param in self.base_model.parameters():
--> 726     param.requires_grad = not freeze
    727 self.model_frozen = freeze

RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

I used the following code to see which specific parameters is causing the issue

for name, p in model.base_model.named_parameters():
    try:
        p.requires_grad = False
    except Exception as e:
        print(name)
        print(e)

and got

embeddings.word_embeddings.weight
you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

Interestingly, the following code is fine.

from transformers import AdapterSetup, AutoAdapterModel, AutoTokenizer, AdapterConfig
model = AutoAdapterModel.from_pretrained('xlm-roberta-base')
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
model.add_embeddings('new', tokenizer)
model.add_adapter('test')
model.train_adapter(['test'])

So the referencing seems to be the part that breaks the computational graph. (which I'm not sure why as the parameters were cloned when adding a new embeddings...)

The following code is also fine

from transformers import AdapterSetup, AutoAdapterModel, AutoTokenizer, AdapterConfig
model = AutoAdapterModel.from_pretrained('xlm-roberta-base')
model.add_adapter('test')
model.train_adapter(['test'])
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
tokenizer_new = AutoTokenizer.from_pretrained('xlm-roberta-base')
model.add_embeddings('new', tokenizer_new, reference_tokenizer=tokenizer, reference_embedding='default')

Expected behavior

Should be able to train stuff...

@eugene-yang eugene-yang added the bug Something isn't working label Jul 5, 2022
calpt added a commit that referenced this issue Jul 11, 2022
- Introduces a new `EmbeddingAdaptersWrapperMixin` to make embedding methods available to heads model classes. This is implemented in new per-model heads mixins. Closes #382.
- Fixes size issues with embeddings. Closes #383.
- Detach embedding weights before cloning. Closes #384.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant