diff --git a/naslib/optimizers/discrete/bananas/optimizer.py b/naslib/optimizers/discrete/bananas/optimizer.py index 7c2658702..507af71f2 100644 --- a/naslib/optimizers/discrete/bananas/optimizer.py +++ b/naslib/optimizers/discrete/bananas/optimizer.py @@ -60,6 +60,9 @@ def __init__(self, config, zc_api=None): config.search, 'zc_names') else None self.zc_only = config.search.zc_only if hasattr( config.search, 'zc_only') else False + + self.load_labeled = config.search.load_labeled if hasattr( + config.search, 'load_labeled') else False def adapt_search_space(self, search_space, scope=None, dataset_api=None): assert ( @@ -119,7 +122,7 @@ def _sample_new_model(self): model = torch.nn.Module() model.arch = self.search_space.clone() model.arch.sample_random_architecture( - dataset_api=self.dataset_api, load_labeled=self.use_zc_api) + dataset_api=self.dataset_api, load_labeled=self.load_labeled) model.arch_hash = model.arch.get_hash() if self.search_space.instantiate_model == True: diff --git a/naslib/optimizers/discrete/npenas/optimizer.py b/naslib/optimizers/discrete/npenas/optimizer.py index a6e067633..52765f32b 100644 --- a/naslib/optimizers/discrete/npenas/optimizer.py +++ b/naslib/optimizers/discrete/npenas/optimizer.py @@ -59,6 +59,8 @@ def __init__(self, config, zc_api=None): config.search, 'zc_names') else None self.zc_only = config.search.zc_only if hasattr( config.search, 'zc_only') else False + self.load_labeled = config.search.load_labeled if hasattr( + config.search, 'load_labeled') else False def adapt_search_space(self, search_space, scope=None, dataset_api=None): assert ( @@ -119,7 +121,7 @@ def _sample_new_model(self): model = torch.nn.Module() model.arch = self.search_space.clone() model.arch.sample_random_architecture( - dataset_api=self.dataset_api, load_labeled=self.use_zc_api) + dataset_api=self.dataset_api, load_labeled=self.load_labeled) model.arch_hash = model.arch.get_hash() if self.search_space.instantiate_model == True: diff --git a/naslib/runners/zc/zc_config.yaml b/naslib/runners/zc/zc_config.yaml index 70968d7a8..b0577ff83 100644 --- a/naslib/runners/zc/zc_config.yaml +++ b/naslib/runners/zc/zc_config.yaml @@ -6,23 +6,66 @@ cutout_prob: 1.0 dataset: cifar10 out_dir: run predictor: fisher -search_space: nasbench101 #nasbench201 #nasbench301 -seed: 0 +search_space: nasbench201 #nasbench101 #nasbench301 test_size: 200 train_size: 400 -zc_ensemble: true -zc_only: true -zc_names: - - params - - flops - - jacov - - plain - - grasp - - snip - - fisher - - grad_norm - - epe_nas - - synflow - - l2_norm optimizer: npenas -train_portion: 0.7 \ No newline at end of file +train_portion: 0.7 +seed: 0 + +search: + # for bohb + seed: 0 + budgets: 50000000 + checkpoint_freq: 1000 + fidelity: 108 + + # for all optimizers + epochs: 10 + + # for bananas and npenas, choose one predictor + # out of the 16 model-based predictors + predictor_type: var_sparse_gp + + # number of initial architectures + num_init: 10 + + # NPENAS + k: 10 + num_ensemble: 3 + acq_fn_type: its + acq_fn_optimization: mutation + encoding_type: adjacency_one_hot + num_arches_to_mutate: 1 + max_mutations: 1 + num_candidates: 50 + + # jacov data loader + batch_size: 256 + data_size: 25000 + cutout: False + cutout_length: 16 + cutout_prob: 1.0 + train_portion: 0.7 + + # other params + debug_predictor: False + sample_size: 10 + population_size: 30 + + # zc parameters + use_zc_api: False + zc_ensemble: true + zc_names: + - params + - flops + - jacov + - plain + - grasp + - snip + - fisher + - grad_norm + - epe_nas + - synflow + - l2_norm + zc_only: true \ No newline at end of file