Skip to content

Commit

Permalink
Turn on eval mode when exporting to ONNX (#12758)
Browse files Browse the repository at this point in the history
* Set model in eval mode when exporting to ONNX.

* Disable t5 for now.

* Disable T5 with past too.

* Style.
  • Loading branch information
mfuntowicz authored Jul 16, 2021
1 parent 8ef3f36 commit fbf1397
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/transformers/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def export(
logger.info(f"Using framework PyTorch: {torch.__version__}")
torch.set_grad_enabled(False)
model.config.return_dict = True
model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
Expand Down
17 changes: 10 additions & 7 deletions tests/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from unittest import TestCase
from unittest.mock import patch

from transformers import ( # LongformerConfig,
from transformers import ( # LongformerConfig,; T5Config,
AlbertConfig,
AutoTokenizer,
BartConfig,
DistilBertConfig,
GPT2Config,
RobertaConfig,
T5Config,
XLMRobertaConfig,
is_torch_available,
)
Expand All @@ -22,7 +21,8 @@
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig

# from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
Expand Down Expand Up @@ -122,7 +122,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
"""

SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)}
SUPPORTED_WITH_PAST_CONFIGS = {
("BART", BartConfig),
("GPT2", GPT2Config),
# ("T5", T5Config)
}

@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
def test_use_past(self):
Expand Down Expand Up @@ -165,14 +169,13 @@ def test_values_override(self):


if is_torch_available():
from transformers import (
from transformers import ( # T5Model,
AlbertModel,
BartModel,
BertModel,
DistilBertModel,
GPT2Model,
RobertaModel,
T5Model,
XLMRobertaModel,
)

Expand All @@ -185,7 +188,7 @@ def test_values_override(self):
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}

PYTORCH_EXPORT_WITH_PAST_MODELS = {
Expand Down

0 comments on commit fbf1397

Please sign in to comment.