Skip to content

Commit

Permalink
Move onnxruntime pin to torch_common
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle authored and quic-akhobare committed Mar 8, 2023
1 parent c1ea125 commit 0824ecb
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion TrainingExtensions/onnx/test/python/test_qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_one_shot_quantize_dequantize_cpu_vs_gpu(self):

quant_info_gpu = libquant_info.QcQuantizeInfo()
quant_node_gpu = helper.make_node(op_name, inputs=['input'], outputs=['output'],
domain="aimet.customop.cuda", quant_info=libpymo.PtrToInt64(quant_info_gpu))
domain=op_domain, quant_info=libpymo.PtrToInt64(quant_info_gpu))
model_gpu = create_model_from_node(quant_node_gpu, input_arr.shape)
session_gpu = build_session(model_gpu, available_providers)
qc_op_gpu = QcQuantizeOp(quant_info=quant_info_gpu,
Expand Down
5 changes: 3 additions & 2 deletions TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import torch
import numpy as np
from onnx import load_model
import onnxruntime as ort
import pytest
from aimet_common.defs import QuantScheme
from aimet_onnx.quantsim import QuantizationSimModel
Expand Down Expand Up @@ -310,8 +311,8 @@ def onnx_callback(session, inputs):

for node in onnx_sim_gpu.model.graph().node:
if node.op_type == "QcQuantizeOp":
# Note: this check will fail if onnxruntime-gpu is not correctly installed
assert node.domain == "aimet.customop.cuda"
if 'CUDAExecutionProvider' in ort.get_available_providers():
assert node.domain == "aimet.customop.cuda"
for node in onnx_sim_cpu.model.graph().node:
if node.op_type == "QcQuantizeOp":
assert node.domain == "aimet.customop.cpu"
Expand Down
1 change: 1 addition & 0 deletions packaging/dependencies/reqs_pip_torch_common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ holoviews==1.12.7
matplotlib>=3
numpy<1.24,>=1.16.6
onnx==1.10.0
onnxruntime==1.10.0
onnxruntime-extensions
onnxsim
scikit-learn==1.1.3
Expand Down
1 change: 0 additions & 1 deletion packaging/dependencies/tf-torch-cpu/reqs_pip_torch_cpu.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
cumm==0.2.8
onnxruntime==1.10.0
spconv==2.1.20
torch==1.9.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
torchvision==0.10.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
1 change: 0 additions & 1 deletion packaging/dependencies/torch-cpu/reqs_pip_torch_cpu.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
cumm==0.2.8
onnxruntime==1.10.0
spconv==2.1.20
torch==1.9.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
torchvision==0.10.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
1 change: 0 additions & 1 deletion packaging/dependencies/torch-gpu/reqs_pip_torch_gpu.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
cumm-cu111==0.2.8
onnxruntime-gpu==1.10.0
spconv-cu111==2.1.20
torch==1.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html

0 comments on commit 0824ecb

Please sign in to comment.