From bf5e437906b0d3168b949c33f2a6514c3746ef25 Mon Sep 17 00:00:00 2001 From: Tomonobu Tsujikawa Date: Wed, 23 Mar 2022 18:31:23 +0900 Subject: [PATCH] save arch.json when save-nnp option is specified. --- .../contrib/classification/darts/helper.py | 11 +++++++++- .../contrib/classification/darts/network.py | 20 +++++++++++++++---- .../contrib/classification/fairnas/network.py | 8 +++----- .../classification/mobilenet/helper.py | 4 ++-- .../classification/mobilenet/network.py | 8 +++----- nnabla_nas/contrib/model.py | 8 ++++++++ nnabla_nas/module/module.py | 4 ++-- nnabla_nas/module/static/static_module.py | 3 +-- nnabla_nas/runner/searcher/fairnas.py | 2 ++ nnabla_nas/runner/searcher/search.py | 5 +++++ nnabla_nas/runner/trainer/train.py | 2 ++ nnabla_nas/utils/cli/cli.py | 2 ++ 12 files changed, 56 insertions(+), 21 deletions(-) mode change 100755 => 100644 nnabla_nas/contrib/classification/fairnas/network.py diff --git a/nnabla_nas/contrib/classification/darts/helper.py b/nnabla_nas/contrib/classification/darts/helper.py index 5fb33278..4fb3ecc3 100644 --- a/nnabla_nas/contrib/classification/darts/helper.py +++ b/nnabla_nas/contrib/classification/darts/helper.py @@ -15,7 +15,6 @@ import json import os -from graphviz import Digraph import imageio from nnabla.logger import logger import numpy as np @@ -26,6 +25,7 @@ def plot(choice, prob, filename): + from graphviz import Digraph g = Digraph(format='png', edge_attr=dict(fontsize='14', fontname="times"), node_attr=dict(style='filled', shape='rect', align='center', @@ -109,4 +109,13 @@ def save_dart_arch(model, output_path): arch_file = os.path.join(output_path, 'arch.json') logger.info('Saving arch to {}'.format(arch_file)) write_to_json_file(memo, arch_file) + + +def visualize_dart_arch(output_path): + r"""Saves visualized DARTS architecture. + + Args: + output_path (str): Where to save the architecture. + """ + arch_file = os.path.join(output_path, 'arch.json') visualize(arch_file, output_path) diff --git a/nnabla_nas/contrib/classification/darts/network.py b/nnabla_nas/contrib/classification/darts/network.py index 11461ff7..d8cbe01b 100644 --- a/nnabla_nas/contrib/classification/darts/network.py +++ b/nnabla_nas/contrib/classification/darts/network.py @@ -15,7 +15,6 @@ from collections import Counter from collections import OrderedDict import json -import os import nnabla.functions as F from nnabla.initializer import ConstantInitializer @@ -25,7 +24,7 @@ from .... import module as Mo from ..base import ClassificationModel as Model from ..misc import AuxiliaryHeadCIFAR -from .helper import save_dart_arch +from .helper import save_dart_arch, visualize_dart_arch class SearchNet(Model): @@ -176,8 +175,21 @@ def save_parameters(self, path=None, params=None, grad_only=False): super().save_parameters(path, params=params, grad_only=grad_only) if self._shared: # save the architectures - output_path = os.path.dirname(path) - save_dart_arch(self, output_path) + save_dart_arch(self, path) + + def save_net_nnp(self, path, inp, out, calc_latency=False, + func_real_latency=None, func_accum_latency=None): + super().save_net_nnp(path, inp, out, calc_latency=False, + func_real_latency=func_real_latency, + func_accum_latency=func_accum_latency) + if self._shared: + # save the architectures + save_dart_arch(self, path) + + def visualize(self, path): + if self._shared: + # save the architectures + visualize_dart_arch(path) def loss(self, outputs, targets, loss_weights=None): loss = F.mean(F.softmax_cross_entropy(outputs[0], targets[0])) diff --git a/nnabla_nas/contrib/classification/fairnas/network.py b/nnabla_nas/contrib/classification/fairnas/network.py old mode 100755 new mode 100644 index 8b8039e4..8f26a30e --- a/nnabla_nas/contrib/classification/fairnas/network.py +++ b/nnabla_nas/contrib/classification/fairnas/network.py @@ -8,7 +8,7 @@ from ..base import ClassificationModel as Model from .modules import ChoiceBlock from ..mobilenet.modules import ConvBNReLU, CANDIDATES -from ..mobilenet.helper import plot_mobilenet +from ..mobilenet.helper import visualize_mobilenet_arch from ..mobilenet.network import _make_divisible, label_smoothing_loss @@ -146,12 +146,10 @@ def extra_repr(self): f'candidates={self._candidates}, ' f'skip_connect={self._skip_connect}') - def save_parameters(self, path=None, params=None, grad_only=False): - super().save_parameters(path, params=params, grad_only=grad_only) + def visualize(self, path): # save the architectures if isinstance(self._features[2]._mixed, Mo.MixedOp): - output_path = os.path.dirname(path) - plot_mobilenet(self, os.path.join(output_path, 'arch')) + visualize_mobilenet_arch(self, os.path.join(path, 'arch')) def loss(self, outputs, targets, loss_weights=None): assert len(outputs) == 1 and len(targets) == 1 diff --git a/nnabla_nas/contrib/classification/mobilenet/helper.py b/nnabla_nas/contrib/classification/mobilenet/helper.py index 05459cb7..43808b1d 100644 --- a/nnabla_nas/contrib/classification/mobilenet/helper.py +++ b/nnabla_nas/contrib/classification/mobilenet/helper.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import graphviz import imageio default_style = { @@ -40,7 +39,8 @@ def get_width(label): return '2' -def plot_mobilenet(model, filename): +def visualize_mobilenet_arch(model, filename): + import graphviz r"""Plot the architecture of MobileNet V2 Args: diff --git a/nnabla_nas/contrib/classification/mobilenet/network.py b/nnabla_nas/contrib/classification/mobilenet/network.py index d3b8d611..e5200cbb 100644 --- a/nnabla_nas/contrib/classification/mobilenet/network.py +++ b/nnabla_nas/contrib/classification/mobilenet/network.py @@ -21,7 +21,7 @@ from .... import module as Mo from ..base import ClassificationModel as Model -from .helper import plot_mobilenet +from .helper import visualize_mobilenet_arch from .modules import CANDIDATES from .modules import ChoiceBlock from .modules import ConvBNReLU @@ -234,12 +234,10 @@ def print_arch(arch_idx, op_names): self._arch_idx = arch_idx return txt + ''.join(stats) - def save_parameters(self, path=None, params=None, grad_only=False): - super().save_parameters(path, params=params, grad_only=grad_only) + def visualize(self, path): # save the architectures if isinstance(self._features[2]._mixed, Mo.MixedOp): - output_path = os.path.dirname(path) - plot_mobilenet(self, os.path.join(output_path, 'arch')) + visualize_mobilenet_arch(self, os.path.join(path, 'arch')) def loss(self, outputs, targets, loss_weights=None): assert len(outputs) == 1 and len(targets) == 1 diff --git a/nnabla_nas/contrib/model.py b/nnabla_nas/contrib/model.py index a357b020..61ed2d74 100644 --- a/nnabla_nas/contrib/model.py +++ b/nnabla_nas/contrib/model.py @@ -95,3 +95,11 @@ def metrics(self, outputs, targets): NotImplementedError: [description] """ raise NotImplementedError + + def visualize(self, path): + r"""Save visualized graph to a file. + + Args: + path (str): Path to directory to save. + """ + return diff --git a/nnabla_nas/module/module.py b/nnabla_nas/module/module.py index b0d070c8..3023f45b 100644 --- a/nnabla_nas/module/module.py +++ b/nnabla_nas/module/module.py @@ -312,12 +312,12 @@ def save_net_nnp(self, path, inp, out, calc_latency=False, hasattr(self, 'name') and self.name) else 'empty' contents = {'networks': [{'name': name_for_nnp, 'batch_size': batch_size, - 'outputs': {'y': out}, + 'outputs': {"y'": out}, 'names': {'x': inp}}], 'executors': [{'name': 'runtime', 'network': name_for_nnp, 'data': ['x'], - 'output': ['y']}]} + 'output': ["y'"]}]} save(filename, contents, variable_batch_size=False) diff --git a/nnabla_nas/module/static/static_module.py b/nnabla_nas/module/static/static_module.py index de98d27d..34f2c562 100644 --- a/nnabla_nas/module/static/static_module.py +++ b/nnabla_nas/module/static/static_module.py @@ -26,8 +26,6 @@ # import nnabla_nas.module as mo from ... import module as mo -from graphviz import Digraph - def _get_abs_string_index(obj, idx): """Get the absolute index for the list of modules""" @@ -904,6 +902,7 @@ def get_gv_graph(self, active_only=True, color_map (dict): the mapping of class instance to vertice color used to visualize the graph. """ + from graphviz import Digraph graph = Digraph(name=self.name) # 1. get all the static modules in the graph if active_only: diff --git a/nnabla_nas/runner/searcher/fairnas.py b/nnabla_nas/runner/searcher/fairnas.py index 1f3b58bb..5eb4c4b6 100755 --- a/nnabla_nas/runner/searcher/fairnas.py +++ b/nnabla_nas/runner/searcher/fairnas.py @@ -166,6 +166,8 @@ def callback_on_epoch_end(self): ) # checkpoint self.save_checkpoint() + if self.args['no_visualize']: # action:store_false + self.model.visualize(self.args['output_path']) # reset loss and metric self.loss.zero() diff --git a/nnabla_nas/runner/searcher/search.py b/nnabla_nas/runner/searcher/search.py index eac3c790..4d420471 100644 --- a/nnabla_nas/runner/searcher/search.py +++ b/nnabla_nas/runner/searcher/search.py @@ -73,6 +73,9 @@ def callback_on_epoch_end(self): ) # checkpoint self.save_checkpoint() + if self.args['no_visualize']: # action:store_false + self.model.visualize(self.args['output_path']) + self.monitor.info(self.model.summary() + '\n') def callback_on_finish(self): @@ -88,6 +91,8 @@ def callback_on_finish(self): path=os.path.join(self.args['output_path'], 'weights.h5'), params=self.model.get_net_parameters() ) + if self.args['no_visualize']: # action:store_false + self.model.visualize(self.args['output_path']) def callback_on_start(self): r"""Calls this on starting the training.""" diff --git a/nnabla_nas/runner/trainer/train.py b/nnabla_nas/runner/trainer/train.py index 93a61bd2..deefe189 100644 --- a/nnabla_nas/runner/trainer/train.py +++ b/nnabla_nas/runner/trainer/train.py @@ -151,6 +151,8 @@ def callback_on_epoch_end(self): self.model.save_parameters(path) # checkpoint self.save_checkpoint({'best_metric': self._best_metric}) + if self.args['no_visualize']: # action:store_false + self.model.visualize(self.args['output_path']) # reset loss and metric self.loss.zero() diff --git a/nnabla_nas/utils/cli/cli.py b/nnabla_nas/utils/cli/cli.py index eecb3aaa..10ec9584 100755 --- a/nnabla_nas/utils/cli/cli.py +++ b/nnabla_nas/utils/cli/cli.py @@ -62,6 +62,8 @@ def main(): help='Path to save the monitoring log files.') parser.add_argument('--save-nnp', action='store_true', help='Store network and parameter with nnp format.') + parser.add_argument('--no-visualize', action='store_false', + help='Disable visualization with graphviz.') options = parser.parse_args()