diff --git a/alt_e2eshark/onnx_tests/helper_classes.py b/alt_e2eshark/onnx_tests/helper_classes.py index 57897ef7..f7a1835d 100644 --- a/alt_e2eshark/onnx_tests/helper_classes.py +++ b/alt_e2eshark/onnx_tests/helper_classes.py @@ -14,7 +14,7 @@ from onnx.helper import make_node, make_graph, make_model from pathlib import Path from e2e_testing import azutils -from e2e_testing.framework import OnnxModelInfo, TestTensors +from e2e_testing.framework import ExtraOptions, ImporterOptions, OnnxModelInfo, TestTensors from e2e_testing.onnx_utils import ( modify_model_output, find_node, @@ -113,6 +113,10 @@ def __init__(self, is_validated: bool, model_url: str, name: str, onnx_model_pat os.mkdir(self.cache_dir) super().__init__(name, onnx_model_path, opset_version) + def update_extra_options(self): + # Default to using opset version 21 for all ONNX Model Zoo models. + self.extra_options = ExtraOptions(import_model_options=ImporterOptions(opset_version=21)) + def unzip_model_archive(self, tar_path): model_dir = str(Path(self.model).parent) with tarfile.open(tar_path) as tar: @@ -229,6 +233,10 @@ def __init__(self, name: str, onnx_model_path: str): self.cache_dir = os.path.join(parent_cache_dir, name) super().__init__(name, onnx_model_path, opset_version) + def update_extra_options(self): + # Default to using opset version 21 for all Azure models. + self.extra_options = ExtraOptions(import_model_options=ImporterOptions(opset_version=21)) + def update_sess_options(self): self.sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL