Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

NAS benchmark integration (stage 2) - download #4205

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nni/nas/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import load_benchmark, download_benchmark
20 changes: 18 additions & 2 deletions nni/nas/benchmarks/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
import os


# TODO: need to be refactored to support automatic download
ENV_NNI_HOME = 'NNI_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'

DATABASE_DIR = os.environ.get("NASBENCHMARK_DIR", os.path.expanduser("~/.nni/nasbenchmark"))

def _get_nasbenchmark_dir():
nni_home = os.path.expanduser(
os.getenv(ENV_NNI_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'nni')))
return os.path.join(nni_home, 'nasbenchmark')


DATABASE_DIR = _get_nasbenchmark_dir()

DB_URLS = {
'nasbench101': 'https://nni.blob.core.windows.net/nasbenchmark/nasbench101-209f5694.db',
'nasbench201': 'https://nni.blob.core.windows.net/nasbenchmark/nasbench201-b2b60732.db',
'nds': 'https://nni.blob.core.windows.net/nasbenchmark/nds-5745c235.db'
}
5 changes: 4 additions & 1 deletion nni/nas/benchmarks/nasbench101/db_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from tqdm import tqdm
from nasbench import api # pylint: disable=import-error

from .model import db, Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
from .graph_util import nasbench_format_to_architecture_repr, hash_module


Expand All @@ -13,6 +14,8 @@ def main():
help='Path to the file to be converted, e.g., nasbench_full.tfrecord')
args = parser.parse_args()
nasbench = api.NASBench(args.input_file)

db = load_benchmark('nasbench101')
with db:
db.create_tables([Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats])
for hashval in tqdm(nasbench.hash_iterator(), desc='Dumping data into database'):
Expand Down
15 changes: 6 additions & 9 deletions nni/nas/benchmarks/nasbench101/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True)
proxy = Proxy()


class Nb101TrialConfig(Model):
Expand Down Expand Up @@ -35,7 +32,7 @@ class Nb101TrialConfig(Model):
num_epochs = IntegerField(index=True)

class Meta:
database = db
database = proxy


class Nb101TrialStats(Model):
Expand Down Expand Up @@ -68,7 +65,7 @@ class Nb101TrialStats(Model):
training_time = FloatField()

class Meta:
database = db
database = proxy


class Nb101IntermediateStats(Model):
Expand Down Expand Up @@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model):
training_time = FloatField()

class Meta:
database = db
database = proxy
8 changes: 7 additions & 1 deletion nni/nas/benchmarks/nasbench101/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from peewee import fn
from playhouse.shortcuts import model_to_dict
from .model import Nb101TrialStats, Nb101TrialConfig

from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb101TrialStats, Nb101TrialConfig, proxy
from .graph_util import hash_module, infer_num_vertices


Expand Down Expand Up @@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
where each of them has been converted into a dict.
"""

if proxy.obj is None:
proxy.initialize(load_benchmark('nasbench101'))

fields = []
if reduction == 'none':
reduction = None
Expand Down
5 changes: 4 additions & 1 deletion nni/nas/benchmarks/nasbench201/db_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import tqdm
import torch

from nni.nas.benchmarks.utils import load_benchmark
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
from .model import db, Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats
from .model import Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats


def parse_arch_str(arch_str):
Expand Down Expand Up @@ -39,6 +40,8 @@ def main():
'imagenet16-120': ['train', 'x-valid', 'x-test', 'ori-test'],
}

db = load_benchmark('nasbench201')

with db:
db.create_tables([Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats])
print('Loading NAS-Bench-201 pickle...')
Expand Down
15 changes: 6 additions & 9 deletions nni/nas/benchmarks/nasbench201/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True)
proxy = Proxy()


class Nb201TrialConfig(Model):
Expand Down Expand Up @@ -48,7 +45,7 @@ class Nb201TrialConfig(Model):
])

class Meta:
database = db
database = proxy


class Nb201TrialStats(Model):
Expand Down Expand Up @@ -113,7 +110,7 @@ class Nb201TrialStats(Model):
ori_test_evaluation_time = FloatField()

class Meta:
database = db
database = proxy


class Nb201IntermediateStats(Model):
Expand Down Expand Up @@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model):
ori_test_loss = FloatField(null=True)

class Meta:
database = db
database = proxy
8 changes: 7 additions & 1 deletion nni/nas/benchmarks/nasbench201/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from peewee import fn
from playhouse.shortcuts import model_to_dict
from .model import Nb201TrialStats, Nb201TrialConfig

from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb201TrialStats, Nb201TrialConfig, proxy


def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
Expand Down Expand Up @@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
where each of them has been converted into a dict.
"""

if proxy.obj is None:
proxy.initialize(load_benchmark('nasbench201'))

fields = []
if reduction == 'none':
reduction = None
Expand Down
5 changes: 4 additions & 1 deletion nni/nas/benchmarks/nds/db_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
import tqdm

from .model import db, NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
from nni.nas.benchmarks.utils import load_benchmark
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats


def inject_item(db, item, proposer, dataset, generator):
Expand Down Expand Up @@ -120,6 +121,8 @@ def main():
'Vanilla_rng3.json'
]

db = load_benchmark('nds')

with db:
db.create_tables([NdsTrialConfig, NdsTrialStats, NdsIntermediateStats])
for json_idx, json_file in enumerate(sweep_list, start=1):
Expand Down
15 changes: 6 additions & 9 deletions nni/nas/benchmarks/nds/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True)
proxy = Proxy()


class NdsTrialConfig(Model):
Expand Down Expand Up @@ -67,7 +64,7 @@ class NdsTrialConfig(Model):
num_epochs = IntegerField()

class Meta:
database = db
database = proxy


class NdsTrialStats(Model):
Expand Down Expand Up @@ -112,7 +109,7 @@ class NdsTrialStats(Model):
iter_time = FloatField()

class Meta:
database = db
database = proxy


class NdsIntermediateStats(Model):
Expand Down Expand Up @@ -140,4 +137,4 @@ class NdsIntermediateStats(Model):
test_acc = FloatField()

class Meta:
database = db
database = proxy
8 changes: 7 additions & 1 deletion nni/nas/benchmarks/nds/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from peewee import fn
from playhouse.shortcuts import model_to_dict
from .model import NdsTrialStats, NdsTrialConfig

from nni.nas.benchmarks.utils import load_benchmark
from .model import NdsTrialStats, NdsTrialConfig, proxy


def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
Expand Down Expand Up @@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
where each of them has been converted into a dict.
"""

if proxy.obj is None:
proxy.initialize(load_benchmark('nds'))

fields = []
if reduction == 'none':
reduction = None
Expand Down
100 changes: 100 additions & 0 deletions nni/nas/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,105 @@
import functools
import hashlib
import json
import logging
import os
import shutil
import tempfile
from pathlib import Path

import requests
import tqdm
from playhouse.sqlite_ext import SqliteExtDatabase

from .constants import DB_URLS, DATABASE_DIR


json_dumps = functools.partial(json.dumps, sort_keys=True)

# to prevent repetitive loading of benchmarks
_loaded_benchmarks = {}


def load_or_download_file(local_path: str, download_url: str, download: bool = False, progress: bool = True):
f = None
hash_prefix = Path(local_path).stem.split('-')[-1]

_logger = logging.getLogger(__name__)

try:
sha256 = hashlib.sha256()

if Path(local_path).exists():
_logger.info('"%s" already exists. Checking hash.', local_path)
with Path(local_path).open('rb') as fr:
while True:
chunk = fr.read(8192)
if len(chunk) == 0:
break
sha256.update(chunk)
elif download:
_logger.info('"%s" does not exist. Downloading "%s"', local_path, download_url)

# Follow download implementation in torchvision:
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
dst_dir = Path(local_path).parent
dst_dir.mkdir(exist_ok=True, parents=True)

f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
r = requests.get(download_url, stream=True)
total_length = int(r.headers.get('content-length'))
with tqdm.tqdm(total=total_length, disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for chunk in r.iter_content(8192):
f.write(chunk)
sha256.update(chunk)
pbar.update(len(chunk))
f.flush()
else:
raise FileNotFoundError('Download is not enabled, but file still does not exist: {}'.format(local_path))

digest = sha256.hexdigest()
if not digest.startswith(hash_prefix):
raise RuntimeError('Invalid hash value (expected "{}", got "{}")'.format(hash_prefix, digest))

if f is not None:
shutil.move(f.name, local_path)
finally:
if f is not None:
f.close()
if os.path.exists(f.name):
os.remove(f.name)


def load_benchmark(benchmark: str) -> SqliteExtDatabase:
"""
Load a benchmark as a database.

Parmaeters
----------
benchmark : str
Benchmark name like nasbench201.
"""
if benchmark in _loaded_benchmarks:
return _loaded_benchmarks[benchmark]
url = DB_URLS[benchmark]
local_path = os.path.join(DATABASE_DIR, os.path.basename(url))
load_or_download_file(local_path, url)
_loaded_benchmarks[benchmark] = SqliteExtDatabase(local_path, autoconnect=True)
return _loaded_benchmarks[benchmark]


def download_benchmark(benchmark: str, progress: bool = True):
"""
Download a converted benchmark.

Parameters
----------
benchmark : str
Benchmark name like nasbench201.
"""
url = DB_URLS[benchmark]
local_path = os.path.join(DATABASE_DIR, os.path.basename(url))
load_or_download_file(local_path, url, True, progress)