Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Add onnx end-to-end tests for pose estimation and action recognition models. #19834

Merged
merged 8 commits into from
Feb 4, 2021
3 changes: 1 addition & 2 deletions ci/docker/install/ubuntu_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@ echo "Installing libprotobuf-dev and protobuf-compiler ..."
apt-get update || true
apt-get install -y libprotobuf-dev protobuf-compiler

echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX, tabulate and onnxruntime..."
pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.6.0 gluonnlp gluoncv
pip3 install pytest pytest-cov pytest-xdist protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.6.0 gluonnlp gluoncv
3 changes: 1 addition & 2 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1278,13 +1278,12 @@ integrationtest_ubuntu_cpu_onnx() {
export PYTHONPATH=./python/
export MXNET_SUBGRAPH_VERBOSE=0
export DMLC_LOG_STACK_TRACE_DEPTH=10
#tests/python-pytest/onnx/backend_test.py
COV_ARG="--cov=./ --cov-report=xml --cov-append"
pytest $COV_ARG --verbose tests/python-pytest/onnx/mxnet_export_test.py
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_models.py
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_node.py
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_operators.py
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py
pytest -n 12 $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py
}

integrationtest_ubuntu_gpu_python() {
Expand Down
161 changes: 137 additions & 24 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def download_test_images(image_urls, tmpdir):
])
def test_obj_class_model_inference_onnxruntime(tmp_path, model):
def normalize_image(imgfile):
img_data = mx.image.imread(imgfile).transpose([2, 0, 1]).astype('float32')
img_data = mx.image.imread(imgfile)
img_data = mx.image.imresize(img_data, 224, 224)
img_data = img_data.transpose([2, 0, 1]).astype('float32')
mean_vec = mx.nd.array([0.485, 0.456, 0.406])
stddev_vec = mx.nd.array([0.229, 0.224, 0.225])
norm_img_data = mx.nd.zeros(img_data.shape).astype('float32')
Expand All @@ -160,11 +162,16 @@ def normalize_image(imgfile):
input_name = session.get_inputs()[0].name

test_image_urls = [
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/apron.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dolphin.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/hammerheadshark.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/lotus.jpg'
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/bikers.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/car.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/dancer.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/duck.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/flower.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/shark.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/soccer2.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/tree.jpg',
]

for img in download_test_images(test_image_urls, tmp_path):
Expand Down Expand Up @@ -205,18 +212,18 @@ def normalize_image(imgfile):
input_name = session.get_inputs()[0].name

test_image_urls = [
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog2.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog3.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/car6.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/apron.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dolphin.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/hammerheadshark.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/lotus.jpg'
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/car.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/duck.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/flower.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/shark.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/soccer2.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/tree.jpg',
]

for img in download_test_images(test_image_urls, tmp_path):
img_data = normalize_image(os.path.join(tmp_path, img))
img_data = normalize_image(img)
mx_class_ids, mx_scores, mx_boxes = M.predict(img_data)
onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
assert_almost_equal(mx_class_ids, onnx_class_ids)
Expand Down Expand Up @@ -265,18 +272,20 @@ def normalize_image(imgfile):
input_name = session.get_inputs()[0].name

test_image_urls = [
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog2.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog3.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/car6.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/apron.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dolphin.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/hammerheadshark.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/lotus.jpg'
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/bikers.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/car.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/dancer.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/duck.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/flower.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/shark.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/soccer2.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/tree.jpg',
]

for img in download_test_images(test_image_urls, tmp_path):
img_data = normalize_image(os.path.join(tmp_path, img))
img_data = normalize_image(img)
mx_result = M.predict(img_data)
onnx_result = session.run([], {input_name: img_data.asnumpy()})
assert(len(mx_result) == len(onnx_result))
Expand All @@ -286,6 +295,110 @@ def normalize_image(imgfile):
finally:
shutil.rmtree(tmp_path)

@pytest.mark.parametrize('model', [
'simple_pose_resnet18_v1b',
'simple_pose_resnet50_v1b',
'simple_pose_resnet50_v1d',
'simple_pose_resnet101_v1b',
'simple_pose_resnet101_v1d',
'simple_pose_resnet152_v1b',
'simple_pose_resnet152_v1d',
#'mobile_pose_resnet18_v1b',
#'mobile_pose_resnet50_v1b',
#'mobile_pose_mobilenet1.0',
#'mobile_pose_mobilenetv2_1.0',
#'mobile_pose_mobilenetv3_large',
#'mobile_pose_mobilenetv3_small',
#'alpha_pose_resnet101_v1b_coco',
])
def test_pose_estimation_model_inference_onnxruntime(tmp_path, model):
def normalize_image(imgfile):
img = mx.image.imread(imgfile).astype('float32')
img, _ = mx.image.center_crop(img, size=(512, 512))
img = gluoncv.data.transforms.presets.segmentation.test_transform(img, mx.cpu(0))
return img

try:
tmp_path = str(tmp_path)
M = GluonModel(model, (1,3,512,512), 'float32', tmp_path)
onnx_file = M.export_onnx()
# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_name = session.get_inputs()[0].name

test_image_urls = [
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/bikers.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/dancer.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/soccer2.jpg',
]

for img in download_test_images(test_image_urls, tmp_path):
img_data = normalize_image(img)
mx_result = M.predict(img_data)
onnx_result = session.run([], {input_name: img_data.asnumpy()})
assert(len(mx_result) == len(onnx_result))
for i in range(len(mx_result)):
assert_almost_equal(mx_result[i], onnx_result[i])

finally:
shutil.rmtree(tmp_path)


@pytest.mark.parametrize('model', [
'inceptionv1_kinetics400',
'resnet18_v1b_kinetics400',
'resnet34_v1b_kinetics400',
'resnet50_v1b_hmdb51',
'resnet50_v1b_sthsthv2',
'vgg16_ucf101',
# the following models are failing due to an accuracy issue
#'resnet50_v1b_kinetics400',
#'resnet101_v1b_kinetics400',
#'resnet152_v1b_kinetics400',
#'inceptionv3_kinetics400',
#'inceptionv3_ucf101',
])
def test_action_recognition_model_inference_onnxruntime(tmp_path, model):
batch_size = 64
input_len = 224
if 'inceptionv3' in model:
input_len = 340

def load_video(filepath):
iterator = mx.image.ImageIter(batch_size=batch_size, data_shape=(3,input_len,input_len), path_imgrec=filepath)
for batch in iterator:
return batch.data[0]

try:
tmp_path = str(tmp_path)
M = GluonModel(model, (batch_size,3,input_len,input_len), 'float32', tmp_path)
onnx_file = M.export_onnx()
# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_name = session.get_inputs()[0].name

test_video_urls = [
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/actions/biking.rec',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/actions/diving.rec',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/actions/golfing.rec',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/actions/sledding.rec',
]

for video in download_test_images(test_video_urls, tmp_path):
data = load_video(video)
mx_result = M.predict(data)
onnx_result = session.run([], {input_name: data.asnumpy()})[0]
assert_almost_equal(mx_result, onnx_result)

finally:
shutil.rmtree(tmp_path)



@with_seed()
Expand Down