Skip to content

Commit

Permalink
Enable ONNX export test for supported model.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Jun 28, 2021
1 parent 749d514 commit aedaecc
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions tests/test_onnx_v2.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import TestCase
from unittest.mock import patch

from transformers.models.bert.configuration_bert import BertOnnxConfig
from transformers import AutoTokenizer, is_torch_available, AlbertConfig, DistilBertConfig, LongformerConfig, \
RobertaConfig, XLMRobertaConfig
from transformers.models.albert import AlbertOnnxConfig

from transformers.models.bert.configuration_bert import BertOnnxConfig, BertConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat

# from transformers.onnx.convert import convert_pytorch
from transformers.onnx.config import DEFAULT_ONNX_OPSET
from transformers.onnx.utils import (
compute_effective_axis_dimension,
compute_serialized_parameters_size,
flatten_output_collection_property,
)
from transformers.testing_utils import require_onnx
from transformers.testing_utils import require_onnx, slow, require_torch


@require_onnx
Expand Down Expand Up @@ -69,13 +79,35 @@ def test_use_external_data_format(self):
self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))


class OnnxExportTestCaseV2(TestCase):
EXPORT_DEFAULT_MODELS = {
("BERT", "bert-base-cased", BertOnnxConfig),
if is_torch_available():
from transformers import AlbertModel, BertModel, DistilBertModel, LongformerModel, RobertaModel, XLMRobertaModel
PYTORCH_EXPORT_DEFAULT_MODELS = {
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
}

def export_default(self):
pass

def export_with_past(self):
class OnnxExportTestCaseV2(TestCase):
@slow
@require_torch
def test_pytorch_export_default(self):
from transformers.onnx.convert import convert_pytorch

for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:

with self.subTest(name):
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class())
onnx_config = onnx_config_class.default(model.config)

with NamedTemporaryFile("w") as output:
convert_pytorch(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name))

@slow
@require_torch
def test_pytorch_export_with_past(self):
pass

0 comments on commit aedaecc

Please sign in to comment.