From 3c15c43769b46e62954a7225e360f5cb38d85d4d Mon Sep 17 00:00:00 2001 From: yuikosakuma1 Date: Wed, 21 Sep 2022 12:34:21 +0900 Subject: [PATCH 1/2] Unify valid and test geotypes code --- nnabla_nas/runner/searcher/ofa.py | 63 +++++++++++++------------------ 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/nnabla_nas/runner/searcher/ofa.py b/nnabla_nas/runner/searcher/ofa.py index 98bac736..d9f0d51d 100644 --- a/nnabla_nas/runner/searcher/ofa.py +++ b/nnabla_nas/runner/searcher/ofa.py @@ -70,23 +70,7 @@ def run(self): # Test for init parameters if self.args['task'] != 'fullnet': - OFAResize.IS_TRAINING = False - for genotype in self.args['valid_genotypes']: - for img_size in self.args['valid_image_size_list']: - self.monitor.reset() - OFAResize.ACTIVE_SIZE = img_size - self.model.set_valid_arch(genotype) - self.reset_running_statistics() - for i in tqdm(range(self.one_epoch_test), desc='Test for init parameters'): - self.update_graph('test') - self.valid_on_batch(is_test=True) - clear_memory_cache() - self.monitor.info(f'img_size={img_size}, genotype={genotype} \n') - self.callback_on_epoch_end(is_test=True) - - self.loss.zero() - for k in self.metrics: - self.metrics[k].zero() + self.valid_genotypes(mode='test') # training for self.cur_epoch in range(self.cur_epoch, self.args['epoch']): @@ -105,25 +89,8 @@ def run(self): self.monitor.display(i, key=train_keys) clear_memory_cache() if self.cur_epoch % self.args["validation_frequency"] == 0: - OFAResize.IS_TRAINING = False - for genotype in self.args['valid_genotypes']: - for img_size in self.args['valid_image_size_list']: - self.monitor.reset() - OFAResize.ACTIVE_SIZE = img_size - self.model.set_valid_arch(genotype) - self.reset_running_statistics() - for i in tqdm(range(self.one_epoch_valid), - desc=f'Valid [{self.cur_epoch}/{self.args["epoch"]}]'): - self.update_graph('valid') - self.valid_on_batch(is_test=False) - clear_memory_cache() - self.monitor.info(f'img_size={img_size}, genotype={genotype} \n') - self.callback_on_epoch_end(is_test=False) - self.monitor.write(self.cur_epoch) - - self.loss.zero() - for k in self.metrics: - self.metrics[k].zero() + self.valid_genotypes(mode='valid') + return self def callback_on_start(self): @@ -204,6 +171,30 @@ def valid_on_batch(self, is_test=False): [self.loss] + list(self.metrics.values()), division=True, inplace=False) self.event.add_default_stream_event() + def valid_genotypes(self, mode='valid'): + assert mode in ['valid', 'test'] + one_epoch = self.one_epoch_valid if mode == 'valid' else self.one_epoch_test + is_test = True if mode == 'test' else False + + OFAResize.IS_TRAINING = False + for genotype in self.args['valid_genotypes']: + for img_size in self.args['valid_image_size_list']: + self.monitor.reset() + OFAResize.ACTIVE_SIZE = img_size + self.model.set_valid_arch(genotype) + self.reset_running_statistics() + for _ in tqdm(range(one_epoch), desc=f'{mode} [{self.cur_epoch}/{self.args["epoch"]}]'): + self.update_graph(mode) + self.valid_on_batch(is_test=is_test) + clear_memory_cache() + self.monitor.info(f'img_size={img_size}, genotype={genotype} \n') + self.callback_on_epoch_end(is_test=is_test) + self.monitor.write(self.cur_epoch) + + self.loss.zero() + for k in self.metrics: + self.metrics[k].zero() + def callback_on_epoch_end(self, epoch=None, is_test=False, info=None): if is_test: num_of_samples = self.one_epoch_test * self.accum_test * self.mbs_test From edde90b659991d71e873dca5712ad725243d2afa Mon Sep 17 00:00:00 2001 From: yuikosakuma1 Date: Thu, 22 Sep 2022 11:03:48 +0900 Subject: [PATCH 2/2] Remove one_epoch variable --- nnabla_nas/runner/searcher/ofa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nnabla_nas/runner/searcher/ofa.py b/nnabla_nas/runner/searcher/ofa.py index d9f0d51d..5d176ee1 100644 --- a/nnabla_nas/runner/searcher/ofa.py +++ b/nnabla_nas/runner/searcher/ofa.py @@ -173,7 +173,6 @@ def valid_on_batch(self, is_test=False): def valid_genotypes(self, mode='valid'): assert mode in ['valid', 'test'] - one_epoch = self.one_epoch_valid if mode == 'valid' else self.one_epoch_test is_test = True if mode == 'test' else False OFAResize.IS_TRAINING = False @@ -183,7 +182,8 @@ def valid_genotypes(self, mode='valid'): OFAResize.ACTIVE_SIZE = img_size self.model.set_valid_arch(genotype) self.reset_running_statistics() - for _ in tqdm(range(one_epoch), desc=f'{mode} [{self.cur_epoch}/{self.args["epoch"]}]'): + for _ in tqdm(range(self.one_epoch_valid if mode == 'valid' else self.one_epoch_test), + desc=f'{mode} [{self.cur_epoch}/{self.args["epoch"]}]'): self.update_graph(mode) self.valid_on_batch(is_test=is_test) clear_memory_cache()