From fbf1397bf862597aad9f09779abfd5d5d3d84975 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Fri, 16 Jul 2021 15:09:15 +0200 Subject: [PATCH] Turn on eval mode when exporting to ONNX (#12758) * Set model in eval mode when exporting to ONNX. * Disable t5 for now. * Disable T5 with past too. * Style. --- src/transformers/onnx/convert.py | 1 + tests/test_onnx_v2.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index e844392febaa..651e52b9a22d 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -87,6 +87,7 @@ def export( logger.info(f"Using framework PyTorch: {torch.__version__}") torch.set_grad_enabled(False) model.config.return_dict = True + model.eval() # Check if we need to override certain configuration item if config.values_override is not None: diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index a4480c5746ed..cea2954c68d8 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -3,14 +3,13 @@ from unittest import TestCase from unittest.mock import patch -from transformers import ( # LongformerConfig, +from transformers import ( # LongformerConfig,; T5Config, AlbertConfig, AutoTokenizer, BartConfig, DistilBertConfig, GPT2Config, RobertaConfig, - T5Config, XLMRobertaConfig, is_torch_available, ) @@ -22,7 +21,8 @@ # from transformers.models.longformer import LongformerOnnxConfig from transformers.models.gpt2 import GPT2OnnxConfig from transformers.models.roberta import RobertaOnnxConfig -from transformers.models.t5 import T5OnnxConfig + +# from transformers.models.t5 import T5OnnxConfig from transformers.models.xlm_roberta import XLMRobertaOnnxConfig from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast @@ -122,7 +122,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase): Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX) """ - SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)} + SUPPORTED_WITH_PAST_CONFIGS = { + ("BART", BartConfig), + ("GPT2", GPT2Config), + # ("T5", T5Config) + } @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) def test_use_past(self): @@ -165,14 +169,13 @@ def test_values_override(self): if is_torch_available(): - from transformers import ( + from transformers import ( # T5Model, AlbertModel, BartModel, BertModel, DistilBertModel, GPT2Model, RobertaModel, - T5Model, XLMRobertaModel, ) @@ -185,7 +188,7 @@ def test_values_override(self): # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), - ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), + # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), } PYTORCH_EXPORT_WITH_PAST_MODELS = {