From 27a6ed5619cff2da02be213602c90c58b93e89e4 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 31 May 2023 15:42:33 -0700 Subject: [PATCH] Enable ONNX test in CI (#2363) * Enable ONNX test in CI --- requirements/developer.txt | 2 ++ test/pytest/test_onnx.py | 18 +----------------- ts/torch_handler/base_handler.py | 6 +++--- 3 files changed, 6 insertions(+), 20 deletions(-) diff --git a/requirements/developer.txt b/requirements/developer.txt index 1b3447c3d8..49d60df970 100644 --- a/requirements/developer.txt +++ b/requirements/developer.txt @@ -15,3 +15,5 @@ twine==4.0.2 mypy==1.3.0 torchpippy==0.1.1 intel_extension_for_pytorch==2.0.100; sys_platform != 'win32' and sys_platform != 'darwin' +onnxruntime==1.15.0 +onnx==1.14.0 diff --git a/test/pytest/test_onnx.py b/test/pytest/test_onnx.py index 477e57e4fa..dd466544ee 100644 --- a/test/pytest/test_onnx.py +++ b/test/pytest/test_onnx.py @@ -1,18 +1,7 @@ import subprocess -import pytest import torch - -try: - import onnx - import torch.onnx - - print( - onnx.__version__ - ) # Adding this so onnx import doesn't get removed by pre-commit - ONNX_ENABLED = True -except: - ONNX_ENABLED = False +import torch.onnx class ToyModel(torch.nn.Module): @@ -28,7 +17,6 @@ def forward(self, x): # For a custom model you still need to manually author your converter, as far as I can tell there isn't a nice out of the box that exists -@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed") def test_convert_to_onnx(): model = ToyModel() dummy_input = torch.randn(1, 1) @@ -55,7 +43,6 @@ def test_convert_to_onnx(): ) -@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed") def test_model_packaging_and_start(): subprocess.run("mkdir model_store", shell=True) subprocess.run( @@ -65,7 +52,6 @@ def test_model_packaging_and_start(): ) -@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed") def test_model_start(): subprocess.run( "torchserve --start --ncs --model-store model_store --models onnx.mar", @@ -74,7 +60,6 @@ def test_model_start(): ) -@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed") def test_inference(): subprocess.run( "curl -X POST http://127.0.0.1:8080/predictions/onnx --data-binary '1'", @@ -82,6 +67,5 @@ def test_inference(): ) -@pytest.mark.skipif(ONNX_ENABLED == False, reason="ONNX is not installed") def test_stop(): subprocess.run("torchserve --stop", shell=True, check=True) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 8a4f77b054..08405e79fd 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -77,10 +77,10 @@ ONNX_AVAILABLE = False -def setup_ort_session(model_pt_path): +def setup_ort_session(model_pt_path, map_location): providers = ( ["CUDAExecutionProvider", "CPUExecutionProvider"] - if self.map_location == "cuda" + if map_location == "cuda" else ["CPUExecutionProvider"] ) @@ -168,7 +168,7 @@ def initialize(self, context): # Convert your model by following instructions: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html elif self.model_pt_path.endswith(".onnx") and ONNX_AVAILABLE: - self.model = setup_ort_session(self.model_pt_path) + self.model = setup_ort_session(self.model_pt_path, self.map_location) logger.info("Succesfully setup ort session") else: