Skip to content

Commit

Permalink
feedback from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Swanson authored and Jesse Swanson committed Apr 23, 2021
1 parent 046e4bb commit d2d4894
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
28 changes: 27 additions & 1 deletion guides/models/adding_models.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Adding a model

`jiant` supports or can easily be exteneded to support Hugging Face Hugging Face's [Transformer models](https://huggingface.co/transformers/viewer/) since `jiant` utilizes [Auto Classes](https://huggingface.co/transformers/model_doc/auto.html) to determine the architecture of the model used based on the name of the [pretrained model](https://huggingface.co/models). To add a model not currently supported in `jiant`, follow the following steps:
`jiant` supports or can easily be extended to support Hugging Face's [Transformer models](https://huggingface.co/transformers/viewer/) since `jiant` utilizes [Auto Classes](https://huggingface.co/transformers/model_doc/auto.html) to determine the architecture of the model used based on the name of the [pretrained model](https://huggingface.co/models). Although `jiant` uses AutoModels to reolve model classes, the `jiant` pipeline requires additional information (such as matching the correct tokenizer for the models). Furthermore, there are subtle differences in the models that `jiant` must abstract and additional steps are required to add a Hugging Face model to `jiant`. To add a model not currently supported in `jiant`, follow the following steps:

## 1. Add to ModelArchitectures enum
Add the model to the ModelArchitectures enum in [`model_resolution.py`](../../jiant/tasks/model_resolution.py) as a member-string mapping. For example, adding the field DEBERTAV2 = "deberta-v2" would add Deberta V2 to the ModelArchitectures enum.
Expand Down Expand Up @@ -68,3 +68,29 @@ class DebertaV2MLMHead(BaseMLMHead):
...
````

## 5. Fine-tune the model
You should now be able to use the new model with the following simple fine-tuning example (Deberta-V2 used as an example below):

```python
from jiant.proj.simple import runscript as run
import jiant.scripts.download_data.runscript as downloader

EXP_DIR = "/path/to/exp"

# Download the Data
downloader.download_data(["mrpc"], f"{EXP_DIR}/tasks")

# Set up the arguments for the Simple API
args = run.RunConfiguration(
run_name="simple",
exp_dir=EXP_DIR,
data_dir=f"{EXP_DIR}/tasks",
hf_pretrained_model_name_or_path="microsoft/deberta-v2-xlarge",
tasks="mrpc",
train_batch_size=16,
num_train_epochs=3
)

# Run!
run.run_simple(args)
```
2 changes: 1 addition & 1 deletion jiant/proj/main/modeling/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def load_encoder_from_transformers_weights(
if k.startswith(encoder_prefix):
load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v
elif k.startswith(encoder_prefix.split("-")[0]):
# workaround for deberta-v2
# workaround for deberta-v2 -> remove the "-v2" suffix since the weight names are prefixed with "deberta" and not "deberta-v2"
load_weights_dict[strings.remove_prefix(k, encoder_prefix.split("-")[0] + ".")] = v
else:
remainder_weights_dict[k] = v
Expand Down

0 comments on commit d2d4894

Please sign in to comment.