Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TEST] Cache test data #2921

Merged
merged 3 commits into from
Mar 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions nnvm/tests/python/frontend/coreml/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,24 @@
from six.moves import urllib
import os
from PIL import Image
import numpy as np

def download(url, path, overwrite=False):
if os.path.exists(path) and not overwrite:
return
print('Downloading {} to {}.'.format(url, path))
urllib.request.urlretrieve(url, path)
from tvm.contrib.download import download_testdata

def get_mobilenet():
url = 'https://docs-assets.developer.apple.com/coreml/models/MobileNet.mlmodel'
dst = 'mobilenet.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
return os.path.abspath(real_dst)
real_dst = download_testdata(url, dst, module='coreml')
return real_dst

def get_resnet50():
url = 'https://docs-assets.developer.apple.com/coreml/models/Resnet50.mlmodel'
dst = 'resnet50.mlmodel'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
return os.path.abspath(real_dst)
real_dst = download_testdata(url, dst, module='coreml')
return real_dst

def get_cat_image():
url = 'https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png'
dst = 'cat.png'
real_dst = os.path.abspath(os.path.join(os.path.dirname(__file__), dst))
download(url, real_dst)
real_dst = download_testdata(url, dst, module='coreml')
img = Image.open(real_dst).resize((224, 224))
img = np.transpose(img, (2, 0, 1))[np.newaxis, :]
return np.asarray(img)
62 changes: 15 additions & 47 deletions nnvm/tests/python/frontend/darknet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,16 @@
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm.contrib.download import download_testdata
from nnvm import frontend
from nnvm.testing.darknet import LAYERTYPE
from nnvm.testing.darknet import __darknetffi__
import nnvm.compiler
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2


def _download(url, path, overwrite=False, sizecompare=False):
''' Download from internet'''
if os.path.isfile(path) and not overwrite:
if sizecompare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
urlfile_size = int(res_get.headers['Content-Length'])
if urlfile_size != file_size:
print("exist file got corrupted, downloading", path, " file freshly")
_download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path)
print('')
except:
urllib.urlretrieve(url, path)

DARKNET_LIB = 'libdarknet2.0.so'
DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
+ DARKNET_LIB + '?raw=true'
_download(DARKNETLIB_URL, DARKNET_LIB)
LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
LIB = __darknetffi__.dlopen(download_testdata(DARKNETLIB_URL, DARKNET_LIB, module='darknet'))

def _read_memory_buffer(shape, data, dtype='float32'):
length = 1
Expand Down Expand Up @@ -82,6 +54,12 @@ def _get_tvm_output(net, data, build_dtype='float32'):
tvm_out.append(m.get_output(i).asnumpy())
return tvm_out

def _load_net(cfg_url, cfg_name, weights_url, weights_name):
cfg_path = download_testdata(cfg_url, cfg_name, module='darknet')
weights_path = download_testdata(weights_url, weights_name, module='darknet')
net = LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
return net

def test_forward(net, build_dtype='float32'):
'''Test network with given input image on both darknet and tvm'''
def get_darknet_output(net, img):
Expand Down Expand Up @@ -125,8 +103,8 @@ def get_darknet_output(net, img):

test_image = 'dog.jpg'
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true'
_download(img_url, test_image)
img = LIB.letterbox_image(LIB.load_image_color(test_image.encode('utf-8'), 0, 0), net.w, net.h)
img_path = download_testdata(img_url, test_image, module='darknet')
img = LIB.letterbox_image(LIB.load_image_color(img_path.encode('utf-8'), 0, 0), net.w, net.h)
darknet_output = get_darknet_output(net, img)
batch_size = 1
data = np.empty([batch_size, img.c, img.h, img.w], dtype)
Expand Down Expand Up @@ -167,9 +145,7 @@ def test_forward_extraction():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)

Expand All @@ -180,9 +156,7 @@ def test_forward_alexnet():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)

Expand All @@ -193,9 +167,7 @@ def test_forward_resnet50():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
test_forward(net)
LIB.free_network(net)

Expand All @@ -206,9 +178,7 @@ def test_forward_yolov2():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
build_dtype = {}
test_forward(net, build_dtype)
LIB.free_network(net)
Expand All @@ -220,9 +190,7 @@ def test_forward_yolov3():
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
build_dtype = {}
test_forward(net, build_dtype)
LIB.free_network(net)
Expand Down
23 changes: 4 additions & 19 deletions nnvm/tests/python/frontend/onnx/model_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,7 @@
import os
import logging
from .super_resolution import get_super_resolution

def _download(url, filename, overwrite=False):
if os.path.isfile(filename) and not overwrite:
logging.debug('File %s existed, skip.', filename)
return
logging.debug('Downloading from url %s to %s', url, filename)
try:
import urllib.request
urllib.request.urlretrieve(url, filename)
except:
import urllib
urllib.urlretrieve(url, filename)

def _as_abs_path(fname):
cur_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(cur_dir, fname)
from tvm.contrib.download import download_testdata


URLS = {
Expand All @@ -30,9 +15,9 @@ def _as_abs_path(fname):
# download and add paths
for k, v in URLS.items():
name = k.split('.')[0]
path = _as_abs_path(k)
_download(v, path, False)
locals()[name] = path
relpath = os.path.join('onnx', k)
abspath = download_testdata(v, relpath, module='onnx')
locals()[name] = abspath

# symbol for graph comparison
super_resolution_sym = get_super_resolution()
74 changes: 66 additions & 8 deletions python/tvm/contrib/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import os
import sys
import time
import uuid
import shutil

def download(url, path, overwrite=False, size_compare=False, verbose=1):
def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison

Expand Down Expand Up @@ -53,6 +55,11 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):

# Stateful start time
start_time = time.time()
dirpath = os.path.dirname(path)
if not os.path.isdir(dirpath):
os.makedirs(dirpath)
random_uuid = str(uuid.uuid4())
tempfile = os.path.join(dirpath, random_uuid)

def _download_progress(count, block_size, total_size):
#pylint: disable=unused-argument
Expand All @@ -68,11 +75,62 @@ def _download_progress(count, block_size, total_size):
(percent, progress_size / (1024.0 * 1024), speed, duration))
sys.stdout.flush()

if sys.version_info >= (3,):
urllib2.urlretrieve(url, path, reporthook=_download_progress)
print("")
while retries >= 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
if sys.version_info >= (3,):
urllib2.urlretrieve(url, tempfile, reporthook=_download_progress)
print("")
else:
f = urllib2.urlopen(url)
data = f.read()
with open(tempfile, "wb") as code:
code.write(data)
shutil.move(tempfile, path)
break
except Exception as err:
retries -= 1
if retries == 0:
os.remove(tempfile)
raise err
else:
print("download failed due to {}, retrying, {} attempt{} left"
.format(repr(err), retries, 's' if retries > 1 else ''))


TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser('~'), '.tvm_test_data')
if not os.path.exists(TEST_DATA_ROOT_PATH):
os.mkdir(TEST_DATA_ROOT_PATH)

def download_testdata(url, relpath, module=None):
"""Downloads the test data from the internet.

Parameters
----------
url : str
Download url.

relpath : str
Relative file path.

module : Union[str, list, tuple], optional
Subdirectory paths under test data folder.

Returns
-------
abspath : str
Absolute file path of downloaded file
"""
global TEST_DATA_ROOT_PATH
if module is None:
module_path = ''
elif isinstance(module, str):
module_path = module
elif isinstance(module, (list, tuple)):
module_path = os.path.join(*module)
else:
f = urllib2.urlopen(url)
data = f.read()
with open(path, "wb") as code:
code.write(data)
raise ValueError("Unsupported module: " + module)
abspath = os.path.join(TEST_DATA_ROOT_PATH, module_path, relpath)
download(url, abspath, overwrite=False, size_compare=True)
return abspath
49 changes: 18 additions & 31 deletions python/tvm/relay/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tensorflow as tf
from tensorflow.core.framework import graph_pb2

from tvm.contrib import util
from tvm.contrib.download import download_testdata

######################################################################
# Some helper functions
Expand Down Expand Up @@ -136,7 +136,7 @@ def id_to_string(self, node_id):
return ''
return self.node_lookup[node_id]

def get_workload_official(model_url, model_sub_path, temp_dir):
def get_workload_official(model_url, model_sub_path):
""" Import workload from tensorflow official

Parameters
Expand All @@ -158,21 +158,17 @@ def get_workload_official(model_url, model_sub_path, temp_dir):
"""

model_tar_name = os.path.basename(model_url)

from mxnet.gluon.utils import download
temp_path = temp_dir.relpath("./")
path_model = temp_path + model_tar_name

download(model_url, path_model)
model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official'])
dir_path = os.path.dirname(model_path)

import tarfile
if path_model.endswith("tgz") or path_model.endswith("gz"):
tar = tarfile.open(path_model)
tar.extractall(path=temp_path)
if model_path.endswith("tgz") or model_path.endswith("gz"):
tar = tarfile.open(model_path)
tar.extractall(path=dir_path)
tar.close()
else:
raise RuntimeError('Could not decompress the file: ' + path_model)
return temp_path + model_sub_path
raise RuntimeError('Could not decompress the file: ' + model_path)
return os.path.join(dir_path, model_sub_path)

def get_workload(model_path, model_sub_path=None):
""" Import workload from frozen protobuf
Expand All @@ -192,24 +188,18 @@ def get_workload(model_path, model_sub_path=None):

"""

temp = util.tempdir()
if model_sub_path:
path_model = get_workload_official(model_path, model_sub_path, temp)
path_model = get_workload_official(model_path, model_sub_path)
else:
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name = os.path.basename(model_path)
model_url = os.path.join(repo_base, model_path)

from mxnet.gluon.utils import download
path_model = temp.relpath(model_name)
download(model_url, path_model)
path_model = download_testdata(model_url, model_path, module='tf')

# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
temp.remove()
return graph_def

#######################################################################
Expand Down Expand Up @@ -292,7 +282,7 @@ def _get_feed_dict(input_name, input_data):

def _create_ptb_vocabulary(data_dir):
"""Read the PTB sample data input to create vocabulary"""
data_path = data_dir+'simple-examples/data/'
data_path = os.path.join(data_dir, 'simple-examples/data/')
file_name = 'ptb.train.txt'
def _read_words(filename):
"""Read the data for creating vocabulary"""
Expand Down Expand Up @@ -341,13 +331,10 @@ def get_workload_ptb():
ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'

import tarfile
from tvm.contrib.download import download
DATA_DIR = './ptb_data/'
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
download(sample_url, DATA_DIR+sample_data_file)
t = tarfile.open(DATA_DIR+sample_data_file, 'r')
t.extractall(DATA_DIR)

word_to_id, id_to_word = _create_ptb_vocabulary(DATA_DIR)
file_path = download_testdata(sample_url, sample_data_file, module=['tf', 'ptb_data'])
dir_path = os.path.dirname(file_path)
t = tarfile.open(file_path, 'r')
t.extractall(dir_path)

word_to_id, id_to_word = _create_ptb_vocabulary(dir_path)
return word_to_id, id_to_word, get_workload(ptb_model_file)
Loading