Skip to content

Commit

Permalink
Merge pull request #16 from sony/feature/20220323-include-json
Browse files Browse the repository at this point in the history
save arch.json when save-nnp option is specified.
  • Loading branch information
TE-LukasMauch authored Apr 14, 2022
2 parents 3bbfb9b + bf5e437 commit 7b27524
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 21 deletions.
11 changes: 10 additions & 1 deletion nnabla_nas/contrib/classification/darts/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import json
import os

from graphviz import Digraph
import imageio
from nnabla.logger import logger
import numpy as np
Expand All @@ -26,6 +25,7 @@


def plot(choice, prob, filename):
from graphviz import Digraph
g = Digraph(format='png',
edge_attr=dict(fontsize='14', fontname="times"),
node_attr=dict(style='filled', shape='rect', align='center',
Expand Down Expand Up @@ -109,4 +109,13 @@ def save_dart_arch(model, output_path):
arch_file = os.path.join(output_path, 'arch.json')
logger.info('Saving arch to {}'.format(arch_file))
write_to_json_file(memo, arch_file)


def visualize_dart_arch(output_path):
r"""Saves visualized DARTS architecture.
Args:
output_path (str): Where to save the architecture.
"""
arch_file = os.path.join(output_path, 'arch.json')
visualize(arch_file, output_path)
20 changes: 16 additions & 4 deletions nnabla_nas/contrib/classification/darts/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from collections import Counter
from collections import OrderedDict
import json
import os

import nnabla.functions as F
from nnabla.initializer import ConstantInitializer
Expand All @@ -25,7 +24,7 @@
from .... import module as Mo
from ..base import ClassificationModel as Model
from ..misc import AuxiliaryHeadCIFAR
from .helper import save_dart_arch
from .helper import save_dart_arch, visualize_dart_arch


class SearchNet(Model):
Expand Down Expand Up @@ -176,8 +175,21 @@ def save_parameters(self, path=None, params=None, grad_only=False):
super().save_parameters(path, params=params, grad_only=grad_only)
if self._shared:
# save the architectures
output_path = os.path.dirname(path)
save_dart_arch(self, output_path)
save_dart_arch(self, path)

def save_net_nnp(self, path, inp, out, calc_latency=False,
func_real_latency=None, func_accum_latency=None):
super().save_net_nnp(path, inp, out, calc_latency=False,
func_real_latency=func_real_latency,
func_accum_latency=func_accum_latency)
if self._shared:
# save the architectures
save_dart_arch(self, path)

def visualize(self, path):
if self._shared:
# save the architectures
visualize_dart_arch(path)

def loss(self, outputs, targets, loss_weights=None):
loss = F.mean(F.softmax_cross_entropy(outputs[0], targets[0]))
Expand Down
8 changes: 3 additions & 5 deletions nnabla_nas/contrib/classification/fairnas/network.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..base import ClassificationModel as Model
from .modules import ChoiceBlock
from ..mobilenet.modules import ConvBNReLU, CANDIDATES
from ..mobilenet.helper import plot_mobilenet
from ..mobilenet.helper import visualize_mobilenet_arch
from ..mobilenet.network import _make_divisible, label_smoothing_loss


Expand Down Expand Up @@ -146,12 +146,10 @@ def extra_repr(self):
f'candidates={self._candidates}, '
f'skip_connect={self._skip_connect}')

def save_parameters(self, path=None, params=None, grad_only=False):
super().save_parameters(path, params=params, grad_only=grad_only)
def visualize(self, path):
# save the architectures
if isinstance(self._features[2]._mixed, Mo.MixedOp):
output_path = os.path.dirname(path)
plot_mobilenet(self, os.path.join(output_path, 'arch'))
visualize_mobilenet_arch(self, os.path.join(path, 'arch'))

def loss(self, outputs, targets, loss_weights=None):
assert len(outputs) == 1 and len(targets) == 1
Expand Down
4 changes: 2 additions & 2 deletions nnabla_nas/contrib/classification/mobilenet/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import graphviz
import imageio

default_style = {
Expand Down Expand Up @@ -40,7 +39,8 @@ def get_width(label):
return '2'


def plot_mobilenet(model, filename):
def visualize_mobilenet_arch(model, filename):
import graphviz
r"""Plot the architecture of MobileNet V2
Args:
Expand Down
8 changes: 3 additions & 5 deletions nnabla_nas/contrib/classification/mobilenet/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .... import module as Mo
from ..base import ClassificationModel as Model
from .helper import plot_mobilenet
from .helper import visualize_mobilenet_arch
from .modules import CANDIDATES
from .modules import ChoiceBlock
from .modules import ConvBNReLU
Expand Down Expand Up @@ -234,12 +234,10 @@ def print_arch(arch_idx, op_names):
self._arch_idx = arch_idx
return txt + ''.join(stats)

def save_parameters(self, path=None, params=None, grad_only=False):
super().save_parameters(path, params=params, grad_only=grad_only)
def visualize(self, path):
# save the architectures
if isinstance(self._features[2]._mixed, Mo.MixedOp):
output_path = os.path.dirname(path)
plot_mobilenet(self, os.path.join(output_path, 'arch'))
visualize_mobilenet_arch(self, os.path.join(path, 'arch'))

def loss(self, outputs, targets, loss_weights=None):
assert len(outputs) == 1 and len(targets) == 1
Expand Down
8 changes: 8 additions & 0 deletions nnabla_nas/contrib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,11 @@ def metrics(self, outputs, targets):
NotImplementedError: [description]
"""
raise NotImplementedError

def visualize(self, path):
r"""Save visualized graph to a file.
Args:
path (str): Path to directory to save.
"""
return
4 changes: 2 additions & 2 deletions nnabla_nas/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,12 @@ def save_net_nnp(self, path, inp, out, calc_latency=False,
hasattr(self, 'name') and self.name) else 'empty'
contents = {'networks': [{'name': name_for_nnp,
'batch_size': batch_size,
'outputs': {'y': out},
'outputs': {"y'": out},
'names': {'x': inp}}],
'executors': [{'name': 'runtime',
'network': name_for_nnp,
'data': ['x'],
'output': ['y']}]}
'output': ["y'"]}]}

save(filename, contents, variable_batch_size=False)

Expand Down
3 changes: 1 addition & 2 deletions nnabla_nas/module/static/static_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
# import nnabla_nas.module as mo
from ... import module as mo

from graphviz import Digraph


def _get_abs_string_index(obj, idx):
"""Get the absolute index for the list of modules"""
Expand Down Expand Up @@ -904,6 +902,7 @@ def get_gv_graph(self, active_only=True,
color_map (dict): the mapping of class instance to
vertice color used to visualize the graph.
"""
from graphviz import Digraph
graph = Digraph(name=self.name)
# 1. get all the static modules in the graph
if active_only:
Expand Down
2 changes: 2 additions & 0 deletions nnabla_nas/runner/searcher/fairnas.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def callback_on_epoch_end(self):
)
# checkpoint
self.save_checkpoint()
if self.args['no_visualize']: # action:store_false
self.model.visualize(self.args['output_path'])

# reset loss and metric
self.loss.zero()
Expand Down
5 changes: 5 additions & 0 deletions nnabla_nas/runner/searcher/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def callback_on_epoch_end(self):
)
# checkpoint
self.save_checkpoint()
if self.args['no_visualize']: # action:store_false
self.model.visualize(self.args['output_path'])

self.monitor.info(self.model.summary() + '\n')

def callback_on_finish(self):
Expand All @@ -88,6 +91,8 @@ def callback_on_finish(self):
path=os.path.join(self.args['output_path'], 'weights.h5'),
params=self.model.get_net_parameters()
)
if self.args['no_visualize']: # action:store_false
self.model.visualize(self.args['output_path'])

def callback_on_start(self):
r"""Calls this on starting the training."""
Expand Down
2 changes: 2 additions & 0 deletions nnabla_nas/runner/trainer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def callback_on_epoch_end(self):
self.model.save_parameters(path)
# checkpoint
self.save_checkpoint({'best_metric': self._best_metric})
if self.args['no_visualize']: # action:store_false
self.model.visualize(self.args['output_path'])

# reset loss and metric
self.loss.zero()
Expand Down
2 changes: 2 additions & 0 deletions nnabla_nas/utils/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def main():
help='Path to save the monitoring log files.')
parser.add_argument('--save-nnp', action='store_true',
help='Store network and parameter with nnp format.')
parser.add_argument('--no-visualize', action='store_false',
help='Disable visualization with graphviz.')

options = parser.parse_args()

Expand Down

0 comments on commit 7b27524

Please sign in to comment.