diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 1cb9fcc6cdd482..8b95f5ff0192c8 100755 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -275,6 +275,7 @@ def main(): "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, "use_auth_token": True if model_args.use_auth_token else None, + "ort": True if training_args.ort else None, } if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) diff --git a/examples/seq2seq/run_translation.py b/examples/seq2seq/run_translation.py index ff9a84bf68a840..599fa3f8b3d5d2 100755 --- a/examples/seq2seq/run_translation.py +++ b/examples/seq2seq/run_translation.py @@ -320,6 +320,7 @@ def main(): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + ort=True if training_args.ort else None, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 911fba8088481b..aec2d03203d3e9 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -500,6 +500,7 @@ def __init__(self, config): self.vocab_transform = nn.Linear(config.dim, config.dim) self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) self.vocab_projector = nn.Linear(config.dim, config.vocab_size) + self.ort = config.ort self.init_weights() @@ -554,7 +555,10 @@ def forward( mlm_loss = None if labels is not None: - mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1)) + if self.ort: + mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)).to(torch.float32), labels.view(-1)) + else: + mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1)) if not return_dict: output = (prediction_logits,) + dlbrt_output[1:]