Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Add image segmentation end-to-end tests and expand object classification tests #19815

Merged
merged 4 commits into from
Feb 2, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 151 additions & 24 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, model_name, input_shape, input_dtype, tmpdir):
self.export()

def get_model(self):
self.model = mx.gluon.model_zoo.vision.get_model(self.model_name, pretrained=True, ctx=self.ctx, root=self.modelpath)
self.model = gluoncv.model_zoo.get_model(self.model_name, pretrained=True, ctx=self.ctx)
self.model.hybridize()

def export(self):
Expand Down Expand Up @@ -69,13 +69,74 @@ def download_test_images(image_urls, tmpdir):
return paths

@pytest.mark.parametrize('model', [
'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25',
'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25',
'resnet18_v1', 'resnet18_v2', 'resnet34_v1', 'resnet34_v2', 'resnet50_v1', 'resnet50_v2',
'resnet101_v1', 'resnet101_v2', 'resnet152_v1', 'resnet152_v2',
'squeezenet1.0', 'squeezenet1.1',
'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn'
'alexnet',
'cifar_resnet20_v1',
'cifar_resnet56_v1',
'cifar_resnet110_v1',
'cifar_resnet20_v2',
'cifar_resnet56_v2',
'cifar_resnet110_v2',
'cifar_wideresnet16_10',
'cifar_wideresnet28_10',
'cifar_wideresnet40_8',
'cifar_resnext29_16x64d',
'darknet53',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'googlenet',
'mobilenet1.0',
'mobilenet0.75',
'mobilenet0.5',
'mobilenet0.25',
'mobilenetv2_1.0',
'mobilenetv2_0.75',
'mobilenetv2_0.5',
'mobilenetv2_0.25',
'mobilenetv3_large',
'mobilenetv3_small',
# failing due to accuracy
#'resnest14',
#'resnest26',
#'resnest50',
#'resnest101',
#'resnest200',
#'resnest269',
'resnet18_v1',
'resnet18_v1b_0.89',
'resnet18_v2',
'resnet34_v1',
'resnet34_v2',
'resnet50_v1',
'resnet50_v1d_0.86',
'resnet50_v1d_0.48',
'resnet50_v1d_0.37',
'resnet50_v1d_0.11',
'resnet50_v2',
'resnet101_v1',
'resnet101_v1d_0.76',
'resnet101_v1d_0.73',
'resnet101_v2',
'resnet152_v1',
'resnet152_v2',
'resnext50_32x4d',
'resnext101_32x4d',
'resnext101_64x4d',
'senet_154',
'se_resnext101_32x4d',
'se_resnext101_64x4d',
'se_resnext50_32x4d',
'squeezenet1.0',
'squeezenet1.1',
'vgg11',
'vgg11_bn',
'vgg13',
'vgg13_bn',
'vgg16',
'vgg16_bn',
'vgg19',
'vgg19_bn'
])
def test_obj_class_model_inference_onnxruntime(tmp_path, model):
def normalize_image(imgfile):
Expand All @@ -99,11 +160,11 @@ def normalize_image(imgfile):
input_name = session.get_inputs()[0].name

test_image_urls = [
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/dog.jpg',
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/apron.jpg',
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/dolphin.jpg',
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/hammerheadshark.jpg',
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/lotus.jpg'
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/apron.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dolphin.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/hammerheadshark.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/lotus.jpg'
]

for img in download_test_images(test_image_urls, tmp_path):
Expand All @@ -117,37 +178,42 @@ def normalize_image(imgfile):



class GluonCVModel(GluonModel):
def __init__(self, *args, **kwargs):
super(GluonCVModel, self).__init__(*args, **kwargs)
def get_model(self):
self.model = gluoncv.model_zoo.get_model(self.model_name, pretrained=True, ctx=self.ctx)
self.model.hybridize()

@pytest.mark.parametrize('model', [
'center_net_resnet18_v1b_voc',
'center_net_resnet50_v1b_voc',
'center_net_resnet101_v1b_voc',
'center_net_resnet18_v1b_coco',
'center_net_resnet50_v1b_coco',
#'center_net_resnet50_v1b_coco',
'center_net_resnet101_v1b_coco'
])
def test_obj_detection_model_inference_onnxruntime(tmp_path, model):
def normalize_image(imgfile):
x, _ = gluoncv.data.transforms.presets.center_net.load_test(imgfile, short=512)
return x
img = mx.image.imread(imgfile)
img, _ = mx.image.center_crop(img, size=(512, 512))
img, _ = gluoncv.data.transforms.presets.center_net.transform_test(img, short=512)
return img

try:
tmp_path = str(tmp_path)
M = GluonCVModel(model, (1,3,512,683), 'float32', tmp_path)
M = GluonModel(model, (1,3,512,512), 'float32', tmp_path)
onnx_file = M.export_onnx()
# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_name = session.get_inputs()[0].name

test_image_urls = ['https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/dog.jpg']
test_image_urls = [
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog2.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog3.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/car6.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/apron.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dolphin.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/hammerheadshark.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/lotus.jpg'
]

for img in download_test_images(test_image_urls, tmp_path):
img_data = normalize_image(os.path.join(tmp_path, img))
Expand All @@ -161,6 +227,67 @@ def normalize_image(imgfile):
shutil.rmtree(tmp_path)


@pytest.mark.parametrize('model', [
'fcn_resnet50_ade',
'fcn_resnet101_ade',
'deeplab_resnet50_ade',
'deeplab_resnet101_ade',
# the 4 models below are failing due to an accuracy issue
#'deeplab_resnest50_ade',
#'deeplab_resnest101_ade',
#'deeplab_resnest200_ade',
#'deeplab_resnest269_ade',
'fcn_resnet101_coco',
'deeplab_resnet101_coco',
'fcn_resnet101_voc',
'deeplab_resnet101_voc',
'deeplab_resnet152_voc',
'deeplab_resnet50_citys',
'deeplab_resnet101_citys',
'deeplab_v3b_plus_wideresnet_citys'
])
def test_img_segmentation_model_inference_onnxruntime(tmp_path, model):
def normalize_image(imgfile):
img = mx.image.imread(imgfile).astype('float32')
img, _ = mx.image.center_crop(img, size=(480, 480))
img = gluoncv.data.transforms.presets.segmentation.test_transform(img, mx.cpu(0))
return img


try:
tmp_path = str(tmp_path)
M = GluonModel(model, (1,3,480,480), 'float32', tmp_path)
onnx_file = M.export_onnx()
# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_name = session.get_inputs()[0].name

test_image_urls = [
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog2.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog3.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/car6.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dog.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/apron.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/dolphin.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/hammerheadshark.jpg',
'https://raw.githubusercontent.com/apache/incubator-mxnet-ci/master/test-data/images/lotus.jpg'
]

for img in download_test_images(test_image_urls, tmp_path):
img_data = normalize_image(os.path.join(tmp_path, img))
mx_result = M.predict(img_data)
onnx_result = session.run([], {input_name: img_data.asnumpy()})
assert(len(mx_result) == len(onnx_result))
for i in range(len(mx_result)):
assert_almost_equal(mx_result[i], onnx_result[i])

finally:
shutil.rmtree(tmp_path)



@with_seed()
@pytest.mark.parametrize('model', ['bert_12_768_12'])
def test_bert_inference_onnxruntime(tmp_path, model):
Expand Down