Skip to content

Commit

Permalink
[Relay][TopHub] Add switch to disable TopHub download (apache#4015)
Browse files Browse the repository at this point in the history
  • Loading branch information
soiferj authored and wweic committed Oct 18, 2019
1 parent 4b28d66 commit 5cc891f
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
TopHub: Tensor Operator Hub
To get the best performance, we typically need auto-tuning for the specific devices.
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you call nnvm.compiler.build_module .
TVM will download these parameters for you when you call
nnvm.compiler.build_module or relay.build.
"""
# pylint: disable=invalid-name

Expand All @@ -30,6 +31,16 @@
from .. import target as _target
from ..contrib.download import download
from .record import load_from_file
from .util import EmptyContext

# environment variable to read TopHub location
AUTOTVM_TOPHUB_LOC_VAR = "TOPHUB_LOCATION"

# default location of TopHub
AUTOTVM_TOPHUB_DEFAULT_LOC = "https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub"

# value of AUTOTVM_TOPHUB_LOC_VAR to specify to not read from TopHub
AUTOTVM_TOPHUB_NONE_LOC = "NONE"

# root path to store TopHub files
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
Expand Down Expand Up @@ -61,6 +72,9 @@ def _alias(name):
}
return table.get(name, name)

def _get_tophub_location():
location = os.getenv(AUTOTVM_TOPHUB_LOC_VAR, None)
return AUTOTVM_TOPHUB_DEFAULT_LOC if location is None else location

def context(target, extra_files=None):
"""Return the dispatch context with pre-tuned parameters.
Expand All @@ -75,6 +89,10 @@ def context(target, extra_files=None):
extra_files: list of str, optional
Extra log files to load
"""
tophub_location = _get_tophub_location()
if tophub_location == AUTOTVM_TOPHUB_NONE_LOC:
return EmptyContext()

best_context = ApplyHistoryBest([])

targets = target if isinstance(target, (list, tuple)) else [target]
Expand All @@ -94,7 +112,7 @@ def context(target, extra_files=None):
for name in possible_names:
name = _alias(name)
if name in all_packages:
if not check_backend(name):
if not check_backend(tophub_location, name):
continue

filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
Expand All @@ -108,7 +126,7 @@ def context(target, extra_files=None):
return best_context


def check_backend(backend):
def check_backend(tophub_location, backend):
"""Check whether have pre-tuned parameters of the certain target.
If not, will download it.
Expand All @@ -135,18 +153,21 @@ def check_backend(backend):
else:
import urllib2
try:
download_package(package_name)
download_package(tophub_location, package_name)
return True
except urllib2.URLError as e:
logging.warning("Failed to download tophub package for %s: %s", backend, e)
return False


def download_package(package_name):
def download_package(tophub_location, package_name):
"""Download pre-tuned parameters of operators for a backend
Parameters
----------
tophub_location: str
The location to download TopHub parameters from
package_name: str
The name of package
"""
Expand All @@ -160,9 +181,9 @@ def download_package(package_name):
if not os.path.isdir(path):
os.mkdir(path)

logger.info("Download pre-tuned parameters package %s", package_name)
download("https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub/%s"
% package_name, os.path.join(rootpath, package_name), True, verbose=0)
download_url = "{0}/{1}".format(tophub_location, package_name)
logger.info("Download pre-tuned parameters package from %s", download_url)
download(download_url, os.path.join(rootpath, package_name), True, verbose=0)


# global cache for load_reference_log
Expand Down

0 comments on commit 5cc891f

Please sign in to comment.