Skip to content

Commit

Permalink
[TEST] Cache test data (apache#2921)
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon authored and wweic committed Mar 29, 2019
1 parent 3b917cf commit 52f8b4e
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 139 deletions.
21 changes: 6 additions & 15 deletions nnvm/tests/python/frontend/coreml/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,24 @@
from six.moves import urllib
import os
from PIL import Image
import numpy as np

def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.request.urlretrieve(url, path)
from tvm.contrib.download import download_testdata

def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst = 'mobilenet.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
return os.path.abspath(real_dst)
real_dst = download_testdata(url, dst, module='coreml')
return real_dst

def get_resnet50():
url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst = 'resnet50.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
return os.path.abspath(real_dst)
real_dst = download_testdata(url, dst, module='coreml')
return real_dst

def get_cat_image():
url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst = 'cat.png'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
real_dst = download_testdata(url, dst, module='coreml')
img = Image.open(real_dst).resize((224, 224))
img = np.transpose(img, (2, 0, 1))[np.newaxis, :]
return np.asarray(img)
62 changes: 15 additions & 47 deletions nnvm/tests/python/frontend/darknet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,16 @@
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm.contrib.download import download_testdata
from nnvm import frontend
from nnvm.testing.darknet import LAYERTYPE
from nnvm.testing.darknet import __darknetffi__
import nnvm.compiler
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2


def _download(url, path, overwrite=False, sizecompare=False):
''' Download from internet'''
if os.path.isfile(path) and not overwrite:
if sizecompare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
urlfile_size = int(res_get.headers['Content-Length'])
if urlfile_size != file_size:
print("exist file got corrupted, downloading", path, " file freshly")
_download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path)
print('')
except:
urllib.urlretrieve(url, path)

DARKNET_LIB = 'libdarknet2.0.so'
DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
+ DARKNET_LIB + '?raw=true'
_download(DARKNETLIB_URL, DARKNET_LIB)
LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
LIB = __darknetffi__.dlopen(download_testdata(DARKNETLIB_URL, DARKNET_LIB, module='darknet'))

def _read_memory_buffer(shape, data, dtype='float32'):
length = 1
Expand Down Expand Up @@ -82,6 +54,12 @@ def _get_tvm_output(net, data, build_dtype='float32'):
tvm_out.append(m.get_output(i).asnumpy())
return tvm_out

def _load_net(cfg_url, cfg_name, weights_url, weights_name):
cfg_path = download_testdata(cfg_url, cfg_name, module='darknet')
weights_path = download_testdata(weights_url, weights_name, module='darknet')
net = LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
return net

def test_forward(net, build_dtype='float32'):
'''Test network with given input image on both darknet and tvm'''
def get_darknet_output(net, img):
Expand Down Expand Up @@ -125,8 +103,8 @@ def get_darknet_output(net, img):

test_image = 'dog.jpg'
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true'
_download(img_url, test_image)
img = LIB.letterbox_image(LIB.load_image_color(test_image.encode('utf-8'), 0, 0), net.w, net.h)
img_path = download_testdata(img_url, test_image, module='darknet')
img = LIB.letterbox_image(LIB.load_image_color(img_path.encode('utf-8'), 0, 0), net.w, net.h)
darknet_output = get_darknet_output(net, img)
batch_size = 1
data = np.empty([batch_size, img.c, img.h, img.w], dtype)
Expand Down Expand Up @@ -167,9 +145,7 @@ def test_forward_extraction():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)

Expand All @@ -180,9 +156,7 @@ def test_forward_alexnet():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)

Expand All @@ -193,9 +167,7 @@ def test_forward_resnet50():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)

Expand All @@ -206,9 +178,7 @@ def test_forward_yolov2():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
build_dtype = {}
test_forward(net, build_dtype)
LIB.free_network(net)
Expand All @@ -220,9 +190,7 @@ def test_forward_yolov3():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
build_dtype = {}
test_forward(net, build_dtype)
LIB.free_network(net)
Expand Down
23 changes: 4 additions & 19 deletions nnvm/tests/python/frontend/onnx/model_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,7 @@
import os
import logging
from .super_resolution import get_super_resolution

def _download(url, filename, overwrite=False):
if os.path.isfile(filename) and not overwrite:
logging.debug('File %s existed, skip.', filename)
return
logging.debug('Downloading from url %s to %s', url, filename)
try:
import urllib.request
urllib.request.urlretrieve(url, filename)
except:
import urllib
urllib.urlretrieve(url, filename)

def _as_abs_path(fname):
cur_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(cur_dir, fname)
from tvm.contrib.download import download_testdata


URLS = {
Expand All @@ -30,9 +15,9 @@ def _as_abs_path(fname):
# download and add paths
for k, v in URLS.items():
name = k.split('.')[0]
path = _as_abs_path(k)
_download(v, path, False)
locals()[name] = path
relpath = os.path.join('onnx', k)
abspath = download_testdata(v, relpath, module='onnx')
locals()[name] = abspath

# symbol for graph comparison
super_resolution_sym = get_super_resolution()
74 changes: 66 additions & 8 deletions python/tvm/contrib/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import os
import sys
import time
import uuid
import shutil

def download(url, path, overwrite=False, size_compare=False, verbose=1):
def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Expand Down Expand Up @@ -53,6 +55,11 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):

# Stateful start time
start_time = time.time()
dirpath = os.path.dirname(path)
if not os.path.isdir(dirpath):
os.makedirs(dirpath)
random_uuid = str(uuid.uuid4())
tempfile = os.path.join(dirpath, random_uuid)

def _download_progress(count, block_size, total_size):
#pylint: disable=unused-argument
Expand All @@ -68,11 +75,62 @@ def _download_progress(count, block_size, total_size):
(percent, progress_size / (1024.0 * 1024), speed, duration))
sys.stdout.flush()

if sys.version_info >= (3,):
urllib2.urlretrieve(url, path, reporthook=_download_progress)
print("")
while retries >= 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
if sys.version_info >= (3,):
urllib2.urlretrieve(url, tempfile, reporthook=_download_progress)
print("")
else:
f = urllib2.urlopen(url)
data = f.read()
with open(tempfile, "wb") as code:
code.write(data)
shutil.move(tempfile, path)
break
except Exception as err:
retries -= 1
if retries == 0:
os.remove(tempfile)
raise err
else:
print("download failed due to {}, retrying, {} attempt{} left"
.format(repr(err), retries, 's' if retries > 1 else ''))


TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser('~'), '.tvm_test_data')
if not os.path.exists(TEST_DATA_ROOT_PATH):
os.mkdir(TEST_DATA_ROOT_PATH)

def download_testdata(url, relpath, module=None):
"""Downloads the test data from the internet.
Parameters
----------
url : str
Download url.
relpath : str
Relative file path.
module : Union[str, list, tuple], optional
Subdirectory paths under test data folder.
Returns
-------
abspath : str
Absolute file path of downloaded file
"""
global TEST_DATA_ROOT_PATH
if module is None:
module_path = ''
elif isinstance(module, str):
module_path = module
elif isinstance(module, (list, tuple)):
module_path = os.path.join(*module)
else:
f = urllib2.urlopen(url)
data = f.read()
with open(path, "wb") as code:
code.write(data)
raise ValueError("Unsupported module: " + module)
abspath = os.path.join(TEST_DATA_ROOT_PATH, module_path, relpath)
download(url, abspath, overwrite=False, size_compare=True)
return abspath
49 changes: 18 additions & 31 deletions python/tvm/relay/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tensorflow as tf
from tensorflow.core.framework import graph_pb2

from tvm.contrib import util
from tvm.contrib.download import download_testdata

######################################################################
# Some helper functions
Expand Down Expand Up @@ -136,7 +136,7 @@ def id_to_string(self, node_id):
return ''
return self.node_lookup[node_id]

def get_workload_official(model_url, model_sub_path, temp_dir):
def get_workload_official(model_url, model_sub_path):
""" Import workload from tensorflow official
Parameters
Expand All @@ -158,21 +158,17 @@ def get_workload_official(model_url, model_sub_path, temp_dir):
"""

model_tar_name = os.path.basename(model_url)

from mxnet.gluon.utils import download
temp_path = temp_dir.relpath("./")
path_model = temp_path + model_tar_name

download(model_url, path_model)
model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official'])
dir_path = os.path.dirname(model_path)

import tarfile
if path_model.endswith("tgz") or path_model.endswith("gz"):
tar = tarfile.open(path_model)
tar.extractall(path=temp_path)
if model_path.endswith("tgz") or model_path.endswith("gz"):
tar = tarfile.open(model_path)
tar.extractall(path=dir_path)
tar.close()
else:
raise RuntimeError('Could not decompress the file: ' + path_model)
return temp_path + model_sub_path
raise RuntimeError('Could not decompress the file: ' + model_path)
return os.path.join(dir_path, model_sub_path)

def get_workload(model_path, model_sub_path=None):
""" Import workload from frozen protobuf
Expand All @@ -192,24 +188,18 @@ def get_workload(model_path, model_sub_path=None):
"""

temp = util.tempdir()
if model_sub_path:
path_model = get_workload_official(model_path, model_sub_path, temp)
path_model = get_workload_official(model_path, model_sub_path)
else:
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name = os.path.basename(model_path)
model_url = os.path.join(repo_base, model_path)

from mxnet.gluon.utils import download
path_model = temp.relpath(model_name)
download(model_url, path_model)
path_model = download_testdata(model_url, model_path, module='tf')

# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
temp.remove()
return graph_def

#######################################################################
Expand Down Expand Up @@ -292,7 +282,7 @@ def _get_feed_dict(input_name, input_data):

def _create_ptb_vocabulary(data_dir):
"""Read the PTB sample data input to create vocabulary"""
data_path = data_dir+'simple-examples/data/'
data_path = os.path.join(data_dir, 'simple-examples/data/')
file_name = 'ptb.train.txt'
def _read_words(filename):
"""Read the data for creating vocabulary"""
Expand Down Expand Up @@ -341,13 +331,10 @@ def get_workload_ptb():
ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'

import tarfile
from tvm.contrib.download import download
DATA_DIR = './ptb_data/'
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
download(sample_url, DATA_DIR+sample_data_file)
t = tarfile.open(DATA_DIR+sample_data_file, 'r')
t.extractall(DATA_DIR)

word_to_id, id_to_word = _create_ptb_vocabulary(DATA_DIR)
file_path = download_testdata(sample_url, sample_data_file, module=['tf', 'ptb_data'])
dir_path = os.path.dirname(file_path)
t = tarfile.open(file_path, 'r')
t.extractall(dir_path)

word_to_id, id_to_word = _create_ptb_vocabulary(dir_path)
return word_to_id, id_to_word, get_workload(ptb_model_file)
Loading

0 comments on commit 52f8b4e

Please sign in to comment.