Skip to content

Commit

Permalink
Default to opset 21 in extra_options for ONNX models (#421)
Browse files Browse the repository at this point in the history
Adds overrides for `update_extra_options()` to `AzureDownloadableModel`
and `OnnxModelZooDownloadableModel` to set opset version to 21 during
import.

Fixes CI failures introduced by the removal of the `--opset-version`
flag in commit `9215c13b696aa8ca767207c15e63a2dff6eccccf`.
  • Loading branch information
vinayakdsci authored Dec 26, 2024
1 parent 138fe87 commit ad1aaa7
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion alt_e2eshark/onnx_tests/helper_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from onnx.helper import make_node, make_graph, make_model
from pathlib import Path
from e2e_testing import azutils
from e2e_testing.framework import OnnxModelInfo, TestTensors
from e2e_testing.framework import ExtraOptions, ImporterOptions, OnnxModelInfo, TestTensors
from e2e_testing.onnx_utils import (
modify_model_output,
find_node,
Expand Down Expand Up @@ -113,6 +113,10 @@ def __init__(self, is_validated: bool, model_url: str, name: str, onnx_model_pat
os.mkdir(self.cache_dir)
super().__init__(name, onnx_model_path, opset_version)

def update_extra_options(self):
# Default to using opset version 21 for all ONNX Model Zoo models.
self.extra_options = ExtraOptions(import_model_options=ImporterOptions(opset_version=21))

def unzip_model_archive(self, tar_path):
model_dir = str(Path(self.model).parent)
with tarfile.open(tar_path) as tar:
Expand Down Expand Up @@ -229,6 +233,10 @@ def __init__(self, name: str, onnx_model_path: str):
self.cache_dir = os.path.join(parent_cache_dir, name)
super().__init__(name, onnx_model_path, opset_version)

def update_extra_options(self):
# Default to using opset version 21 for all Azure models.
self.extra_options = ExtraOptions(import_model_options=ImporterOptions(opset_version=21))

def update_sess_options(self):
self.sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL

Expand Down

0 comments on commit ad1aaa7

Please sign in to comment.