Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

NAS benchmark integration (stage 2) - download #4205

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nni/nas/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import load_benchmark, download_benchmark
20 changes: 18 additions & 2 deletions nni/nas/benchmarks/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
import os


# TODO: need to be refactored to support automatic download
ENV_NNI_HOME = 'NNI_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'

DATABASE_DIR = os.environ.get("NASBENCHMARK_DIR", os.path.expanduser("~/.nni/nasbenchmark"))

def _get_nasbenchmark_dir():
nni_home = os.path.expanduser(
os.getenv(ENV_NNI_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'nni')))
return os.path.join(nni_home, 'nasbenchmark')


DATABASE_DIR = _get_nasbenchmark_dir()

DB_URLS = {
'nasbench101': 'https://nni.blob.core.windows.net/nasbenchmark/nasbench101-209f5694.db',
'nasbench201': 'https://nni.blob.core.windows.net/nasbenchmark/nasbench201-b2b60732.db',
'nds': 'https://nni.blob.core.windows.net/nasbenchmark/nds-5745c235.db'
}
5 changes: 4 additions & 1 deletion nni/nas/benchmarks/nasbench101/db_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from tqdm import tqdm
from nasbench import api # pylint: disable=import-error

from .model import db, Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
from .graph_util import nasbench_format_to_architecture_repr, hash_module


Expand All @@ -13,6 +14,8 @@ def main():
help='Path to the file to be converted, e.g., nasbench_full.tfrecord')
args = parser.parse_args()
nasbench = api.NASBench(args.input_file)

db = load_benchmark('nasbench101')
with db:
db.create_tables([Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats])
for hashval in tqdm(nasbench.hash_iterator(), desc='Dumping data into database'):
Expand Down
15 changes: 6 additions & 9 deletions nni/nas/benchmarks/nasbench101/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True)
proxy = Proxy()


class Nb101TrialConfig(Model):
Expand Down Expand Up @@ -35,7 +32,7 @@ class Nb101TrialConfig(Model):
num_epochs = IntegerField(index=True)

class Meta:
database = db
database = proxy


class Nb101TrialStats(Model):
Expand Down Expand Up @@ -68,7 +65,7 @@ class Nb101TrialStats(Model):
training_time = FloatField()

class Meta:
database = db
database = proxy


class Nb101IntermediateStats(Model):
Expand Down Expand Up @@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model):
training_time = FloatField()

class Meta:
database = db
database = proxy
8 changes: 7 additions & 1 deletion nni/nas/benchmarks/nasbench101/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from peewee import fn
from playhouse.shortcuts import model_to_dict
from .model import Nb101TrialStats, Nb101TrialConfig

from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb101TrialStats, Nb101TrialConfig, proxy
from .graph_util import hash_module, infer_num_vertices


Expand Down Expand Up @@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
where each of them has been converted into a dict.
"""

if proxy.obj is None:
proxy.initialize(load_benchmark('nasbench101'))

fields = []
if reduction == 'none':
reduction = None
Expand Down
5 changes: 4 additions & 1 deletion nni/nas/benchmarks/nasbench201/db_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import tqdm
import torch

from nni.nas.benchmarks.utils import load_benchmark
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
from .model import db, Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats
from .model import Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats


def parse_arch_str(arch_str):
Expand Down Expand Up @@ -39,6 +40,8 @@ def main():
'imagenet16-120': ['train', 'x-valid', 'x-test', 'ori-test'],
}

db = load_benchmark('nasbench201')

with db:
db.create_tables([Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats])
print('Loading NAS-Bench-201 pickle...')
Expand Down
15 changes: 6 additions & 9 deletions nni/nas/benchmarks/nasbench201/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True)
proxy = Proxy()


class Nb201TrialConfig(Model):
Expand Down Expand Up @@ -48,7 +45,7 @@ class Nb201TrialConfig(Model):
])

class Meta:
database = db
database = proxy


class Nb201TrialStats(Model):
Expand Down Expand Up @@ -113,7 +110,7 @@ class Nb201TrialStats(Model):
ori_test_evaluation_time = FloatField()

class Meta:
database = db
database = proxy


class Nb201IntermediateStats(Model):
Expand Down Expand Up @@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model):
ori_test_loss = FloatField(null=True)

class Meta:
database = db
database = proxy
8 changes: 7 additions & 1 deletion nni/nas/benchmarks/nasbench201/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from peewee import fn
from playhouse.shortcuts import model_to_dict
from .model import Nb201TrialStats, Nb201TrialConfig

from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb201TrialStats, Nb201TrialConfig, proxy


def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
Expand Down Expand Up @@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
where each of them has been converted into a dict.
"""

if proxy.obj is None:
proxy.initialize(load_benchmark('nasbench201'))

fields = []
if reduction == 'none':
reduction = None
Expand Down
5 changes: 4 additions & 1 deletion nni/nas/benchmarks/nds/db_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
import tqdm

from .model import db, NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
from nni.nas.benchmarks.utils import load_benchmark
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats


def inject_item(db, item, proposer, dataset, generator):
Expand Down Expand Up @@ -120,6 +121,8 @@ def main():
'Vanilla_rng3.json'
]

db = load_benchmark('nds')

with db:
db.create_tables([NdsTrialConfig, NdsTrialStats, NdsIntermediateStats])
for json_idx, json_file in enumerate(sweep_list, start=1):
Expand Down
15 changes: 6 additions & 9 deletions nni/nas/benchmarks/nds/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True)
proxy = Proxy()


class NdsTrialConfig(Model):
Expand Down Expand Up @@ -67,7 +64,7 @@ class NdsTrialConfig(Model):
num_epochs = IntegerField()

class Meta:
database = db
database = proxy


class NdsTrialStats(Model):
Expand Down Expand Up @@ -112,7 +109,7 @@ class NdsTrialStats(Model):
iter_time = FloatField()

class Meta:
database = db
database = proxy


class NdsIntermediateStats(Model):
Expand Down Expand Up @@ -140,4 +137,4 @@ class NdsIntermediateStats(Model):
test_acc = FloatField()

class Meta:
database = db
database = proxy
8 changes: 7 additions & 1 deletion nni/nas/benchmarks/nds/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from peewee import fn
from playhouse.shortcuts import model_to_dict
from .model import NdsTrialStats, NdsTrialConfig

from nni.nas.benchmarks.utils import load_benchmark
from .model import NdsTrialStats, NdsTrialConfig, proxy


def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
Expand Down Expand Up @@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
where each of them has been converted into a dict.
"""

if proxy.obj is None:
proxy.initialize(load_benchmark('nds'))

fields = []
if reduction == 'none':
reduction = None
Expand Down
100 changes: 100 additions & 0 deletions nni/nas/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,105 @@
import functools
import hashlib
import json
import logging
import os
import shutil
import tempfile
from pathlib import Path

import requests
import tqdm
from playhouse.sqlite_ext import SqliteExtDatabase

from .constants import DB_URLS, DATABASE_DIR


json_dumps = functools.partial(json.dumps, sort_keys=True)

# to prevent repetitive loading of benchmarks
_loaded_benchmarks = {}


def load_or_download_file(local_path: str, download_url: str, download: bool = False, progress: bool = True):
f = None
hash_prefix = Path(local_path).stem.split('-')[-1]

_logger = logging.getLogger(__name__)

try:
sha256 = hashlib.sha256()

if Path(local_path).exists():
_logger.info('"%s" already exists. Checking hash.', local_path)
with Path(local_path).open('rb') as fr:
while True:
chunk = fr.read(8192)
if len(chunk) == 0:
break
sha256.update(chunk)
elif download:
_logger.info('"%s" does not exist. Downloading "%s"', local_path, download_url)

# Follow download implementation in torchvision:
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
dst_dir = Path(local_path).parent
dst_dir.mkdir(exist_ok=True, parents=True)

f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
r = requests.get(download_url, stream=True)
total_length = int(r.headers.get('content-length'))
with tqdm.tqdm(total=total_length, disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for chunk in r.iter_content(8192):
f.write(chunk)
sha256.update(chunk)
pbar.update(len(chunk))
f.flush()
else:
raise FileNotFoundError('Download is not enabled, but file still does not exist: {}'.format(local_path))

digest = sha256.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

startswith

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is directly copied from pytorch. But I think you are right.

raise RuntimeError('invalid hash value (expected "{}", got "{}")'.format(hash_prefix, digest))

if f is not None:
shutil.move(f.name, local_path)
finally:
if f is not None:
f.close()
if os.path.exists(f.name):
os.remove(f.name)


def load_benchmark(benchmark: str) -> SqliteExtDatabase:
"""
Load a benchmark as a database.

Parmaeters
----------
benchmark : str
Benchmark name like nasbench201.
"""
if benchmark in _loaded_benchmarks:
return _loaded_benchmarks[benchmark]
url = DB_URLS[benchmark]
local_path = os.path.join(DATABASE_DIR, os.path.basename(url))
load_or_download_file(local_path, url)
_loaded_benchmarks[benchmark] = SqliteExtDatabase(local_path, autoconnect=True)
return _loaded_benchmarks[benchmark]


def download_benchmark(benchmark: str, progress: bool = True):
"""
Download a converted benchmark.

Parameters
----------
benchmark : str
Benchmark name like nasbench201.
"""
url = DB_URLS[benchmark]
local_path = os.path.join(DATABASE_DIR, os.path.basename(url))
load_or_download_file(local_path, url, True, progress)