Skip to content

Commit

Permalink
Merge pull request #8 from microsoft/raviskolli/ort
Browse files Browse the repository at this point in the history
Raviskolli/ort
  • Loading branch information
raviskolli authored Apr 27, 2021
2 parents 0aaf93a + f9156dc commit 60e32e3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions examples/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def main():
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"ort": True if training_args.ort else None,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
Expand Down
1 change: 1 addition & 0 deletions examples/seq2seq/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort=True if training_args.ort else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def __init__(self, config):
self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.ort = config.ort

self.init_weights()

Expand Down Expand Up @@ -554,7 +555,10 @@ def forward(

mlm_loss = None
if labels is not None:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
if self.ort:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)).to(torch.float32), labels.view(-1))
else:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))

if not return_dict:
output = (prediction_logits,) + dlbrt_output[1:]
Expand Down

0 comments on commit 60e32e3

Please sign in to comment.