This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
NAS benchmark integration (stage 2) - download #4205
Merged
QuanluZhang
merged 6 commits into
microsoft:master
from
ultmaster:retiarii/benchmark-integration-2
Oct 12, 2021
Merged
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
bdc8a11
benchmark integration download
ultmaster be31de0
fix import path
ultmaster acc0d39
fix doc and lint
ultmaster b856235
update loading logic
ultmaster 4de4504
Update db in db_gen
ultmaster e9c340d
startswith
ultmaster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .utils import load_benchmark, download_benchmark |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,22 @@ | ||
import os | ||
|
||
|
||
# TODO: need to be refactored to support automatic download | ||
ENV_NNI_HOME = 'NNI_HOME' | ||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' | ||
DEFAULT_CACHE_DIR = '~/.cache' | ||
|
||
DATABASE_DIR = os.environ.get("NASBENCHMARK_DIR", os.path.expanduser("~/.nni/nasbenchmark")) | ||
|
||
def _get_nasbenchmark_dir(): | ||
nni_home = os.path.expanduser( | ||
os.getenv(ENV_NNI_HOME, | ||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'nni'))) | ||
return os.path.join(nni_home, 'nasbenchmark') | ||
|
||
|
||
DATABASE_DIR = _get_nasbenchmark_dir() | ||
|
||
DB_URLS = { | ||
'nasbench101': 'https://nni.blob.core.windows.net/nasbenchmark/nasbench101-209f5694.db', | ||
'nasbench201': 'https://nni.blob.core.windows.net/nasbenchmark/nasbench201-b2b60732.db', | ||
'nds': 'https://nni.blob.core.windows.net/nasbenchmark/nds-5745c235.db' | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,105 @@ | ||
import functools | ||
import hashlib | ||
import json | ||
import logging | ||
import os | ||
import shutil | ||
import tempfile | ||
from pathlib import Path | ||
|
||
import requests | ||
import tqdm | ||
from playhouse.sqlite_ext import SqliteExtDatabase | ||
|
||
from .constants import DB_URLS, DATABASE_DIR | ||
|
||
|
||
json_dumps = functools.partial(json.dumps, sort_keys=True) | ||
|
||
# to prevent repetitive loading of benchmarks | ||
_loaded_benchmarks = {} | ||
|
||
|
||
def load_or_download_file(local_path: str, download_url: str, download: bool = False, progress: bool = True): | ||
f = None | ||
hash_prefix = Path(local_path).stem.split('-')[-1] | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
try: | ||
sha256 = hashlib.sha256() | ||
|
||
if Path(local_path).exists(): | ||
_logger.info('"%s" already exists. Checking hash.', local_path) | ||
with Path(local_path).open('rb') as fr: | ||
while True: | ||
chunk = fr.read(8192) | ||
if len(chunk) == 0: | ||
break | ||
sha256.update(chunk) | ||
elif download: | ||
_logger.info('"%s" does not exist. Downloading "%s"', local_path, download_url) | ||
|
||
# Follow download implementation in torchvision: | ||
# We deliberately save it in a temp file and move it after | ||
# download is complete. This prevents a local working checkpoint | ||
# being overridden by a broken download. | ||
dst_dir = Path(local_path).parent | ||
dst_dir.mkdir(exist_ok=True, parents=True) | ||
|
||
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | ||
r = requests.get(download_url, stream=True) | ||
total_length = int(r.headers.get('content-length')) | ||
with tqdm.tqdm(total=total_length, disable=not progress, | ||
unit='B', unit_scale=True, unit_divisor=1024) as pbar: | ||
for chunk in r.iter_content(8192): | ||
f.write(chunk) | ||
sha256.update(chunk) | ||
pbar.update(len(chunk)) | ||
f.flush() | ||
else: | ||
raise FileNotFoundError('Download is not enabled, but file still does not exist: {}'.format(local_path)) | ||
|
||
digest = sha256.hexdigest() | ||
if digest[:len(hash_prefix)] != hash_prefix: | ||
raise RuntimeError('invalid hash value (expected "{}", got "{}")'.format(hash_prefix, digest)) | ||
|
||
if f is not None: | ||
shutil.move(f.name, local_path) | ||
finally: | ||
if f is not None: | ||
f.close() | ||
if os.path.exists(f.name): | ||
os.remove(f.name) | ||
|
||
|
||
def load_benchmark(benchmark: str) -> SqliteExtDatabase: | ||
""" | ||
Load a benchmark as a database. | ||
|
||
Parmaeters | ||
---------- | ||
benchmark : str | ||
Benchmark name like nasbench201. | ||
""" | ||
if benchmark in _loaded_benchmarks: | ||
return _loaded_benchmarks[benchmark] | ||
url = DB_URLS[benchmark] | ||
local_path = os.path.join(DATABASE_DIR, os.path.basename(url)) | ||
load_or_download_file(local_path, url) | ||
_loaded_benchmarks[benchmark] = SqliteExtDatabase(local_path, autoconnect=True) | ||
return _loaded_benchmarks[benchmark] | ||
|
||
|
||
def download_benchmark(benchmark: str, progress: bool = True): | ||
""" | ||
Download a converted benchmark. | ||
|
||
Parameters | ||
---------- | ||
benchmark : str | ||
Benchmark name like nasbench201. | ||
""" | ||
url = DB_URLS[benchmark] | ||
local_path = os.path.join(DATABASE_DIR, os.path.basename(url)) | ||
load_or_download_file(local_path, url, True, progress) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
startswith
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is directly copied from pytorch. But I think you are right.