Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save arch.json when save-nnp option is specified. #16

Merged
merged 1 commit into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion nnabla_nas/contrib/classification/darts/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import json
import os

from graphviz import Digraph
import imageio
from nnabla.logger import logger
import numpy as np
Expand All @@ -26,6 +25,7 @@


def plot(choice, prob, filename):
from graphviz import Digraph
g = Digraph(format='png',
edge_attr=dict(fontsize='14', fontname="times"),
node_attr=dict(style='filled', shape='rect', align='center',
Expand Down Expand Up @@ -109,4 +109,13 @@ def save_dart_arch(model, output_path):
arch_file = os.path.join(output_path, 'arch.json')
logger.info('Saving arch to {}'.format(arch_file))
write_to_json_file(memo, arch_file)


def visualize_dart_arch(output_path):
r"""Saves visualized DARTS architecture.

Args:
output_path (str): Where to save the architecture.
"""
arch_file = os.path.join(output_path, 'arch.json')
visualize(arch_file, output_path)
20 changes: 16 additions & 4 deletions nnabla_nas/contrib/classification/darts/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from collections import Counter
from collections import OrderedDict
import json
import os

import nnabla.functions as F
from nnabla.initializer import ConstantInitializer
Expand All @@ -25,7 +24,7 @@
from .... import module as Mo
from ..base import ClassificationModel as Model
from ..misc import AuxiliaryHeadCIFAR
from .helper import save_dart_arch
from .helper import save_dart_arch, visualize_dart_arch


class SearchNet(Model):
Expand Down Expand Up @@ -176,8 +175,21 @@ def save_parameters(self, path=None, params=None, grad_only=False):
super().save_parameters(path, params=params, grad_only=grad_only)
if self._shared:
# save the architectures
output_path = os.path.dirname(path)
save_dart_arch(self, output_path)
save_dart_arch(self, path)

def save_net_nnp(self, path, inp, out, calc_latency=False,
func_real_latency=None, func_accum_latency=None):
super().save_net_nnp(path, inp, out, calc_latency=False,
func_real_latency=func_real_latency,
func_accum_latency=func_accum_latency)
if self._shared:
# save the architectures
save_dart_arch(self, path)

def visualize(self, path):
if self._shared:
# save the architectures
visualize_dart_arch(path)

def loss(self, outputs, targets, loss_weights=None):
loss = F.mean(F.softmax_cross_entropy(outputs[0], targets[0]))
Expand Down
8 changes: 3 additions & 5 deletions nnabla_nas/contrib/classification/fairnas/network.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..base import ClassificationModel as Model
from .modules import ChoiceBlock
from ..mobilenet.modules import ConvBNReLU, CANDIDATES
from ..mobilenet.helper import plot_mobilenet
from ..mobilenet.helper import visualize_mobilenet_arch
from ..mobilenet.network import _make_divisible, label_smoothing_loss


Expand Down Expand Up @@ -146,12 +146,10 @@ def extra_repr(self):
f'candidates={self._candidates}, '
f'skip_connect={self._skip_connect}')

def save_parameters(self, path=None, params=None, grad_only=False):
super().save_parameters(path, params=params, grad_only=grad_only)
def visualize(self, path):
# save the architectures
if isinstance(self._features[2]._mixed, Mo.MixedOp):
output_path = os.path.dirname(path)
plot_mobilenet(self, os.path.join(output_path, 'arch'))
visualize_mobilenet_arch(self, os.path.join(path, 'arch'))

def loss(self, outputs, targets, loss_weights=None):
assert len(outputs) == 1 and len(targets) == 1
Expand Down
4 changes: 2 additions & 2 deletions nnabla_nas/contrib/classification/mobilenet/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import graphviz
import imageio

default_style = {
Expand Down Expand Up @@ -40,7 +39,8 @@ def get_width(label):
return '2'


def plot_mobilenet(model, filename):
def visualize_mobilenet_arch(model, filename):
import graphviz
r"""Plot the architecture of MobileNet V2

Args:
Expand Down
8 changes: 3 additions & 5 deletions nnabla_nas/contrib/classification/mobilenet/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .... import module as Mo
from ..base import ClassificationModel as Model
from .helper import plot_mobilenet
from .helper import visualize_mobilenet_arch
from .modules import CANDIDATES
from .modules import ChoiceBlock
from .modules import ConvBNReLU
Expand Down Expand Up @@ -234,12 +234,10 @@ def print_arch(arch_idx, op_names):
self._arch_idx = arch_idx
return txt + ''.join(stats)

def save_parameters(self, path=None, params=None, grad_only=False):
super().save_parameters(path, params=params, grad_only=grad_only)
def visualize(self, path):
# save the architectures
if isinstance(self._features[2]._mixed, Mo.MixedOp):
output_path = os.path.dirname(path)
plot_mobilenet(self, os.path.join(output_path, 'arch'))
visualize_mobilenet_arch(self, os.path.join(path, 'arch'))

def loss(self, outputs, targets, loss_weights=None):
assert len(outputs) == 1 and len(targets) == 1
Expand Down
8 changes: 8 additions & 0 deletions nnabla_nas/contrib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,11 @@ def metrics(self, outputs, targets):
NotImplementedError: [description]
"""
raise NotImplementedError

def visualize(self, path):
r"""Save visualized graph to a file.

Args:
path (str): Path to directory to save.
"""
return
4 changes: 2 additions & 2 deletions nnabla_nas/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,12 @@ def save_net_nnp(self, path, inp, out, calc_latency=False,
hasattr(self, 'name') and self.name) else 'empty'
contents = {'networks': [{'name': name_for_nnp,
'batch_size': batch_size,
'outputs': {'y': out},
'outputs': {"y'": out},
'names': {'x': inp}}],
'executors': [{'name': 'runtime',
'network': name_for_nnp,
'data': ['x'],
'output': ['y']}]}
'output': ["y'"]}]}

save(filename, contents, variable_batch_size=False)

Expand Down
3 changes: 1 addition & 2 deletions nnabla_nas/module/static/static_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
# import nnabla_nas.module as mo
from ... import module as mo

from graphviz import Digraph


def _get_abs_string_index(obj, idx):
"""Get the absolute index for the list of modules"""
Expand Down Expand Up @@ -904,6 +902,7 @@ def get_gv_graph(self, active_only=True,
color_map (dict): the mapping of class instance to
vertice color used to visualize the graph.
"""
from graphviz import Digraph
graph = Digraph(name=self.name)
# 1. get all the static modules in the graph
if active_only:
Expand Down
2 changes: 2 additions & 0 deletions nnabla_nas/runner/searcher/fairnas.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def callback_on_epoch_end(self):
)
# checkpoint
self.save_checkpoint()
if self.args['no_visualize']: # action:store_false
self.model.visualize(self.args['output_path'])

# reset loss and metric
self.loss.zero()
Expand Down
5 changes: 5 additions & 0 deletions nnabla_nas/runner/searcher/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def callback_on_epoch_end(self):
)
# checkpoint
self.save_checkpoint()
if self.args['no_visualize']: # action:store_false
self.model.visualize(self.args['output_path'])

self.monitor.info(self.model.summary() + '\n')

def callback_on_finish(self):
Expand All @@ -88,6 +91,8 @@ def callback_on_finish(self):
path=os.path.join(self.args['output_path'], 'weights.h5'),
params=self.model.get_net_parameters()
)
if self.args['no_visualize']: # action:store_false
self.model.visualize(self.args['output_path'])

def callback_on_start(self):
r"""Calls this on starting the training."""
Expand Down
2 changes: 2 additions & 0 deletions nnabla_nas/runner/trainer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def callback_on_epoch_end(self):
self.model.save_parameters(path)
# checkpoint
self.save_checkpoint({'best_metric': self._best_metric})
if self.args['no_visualize']: # action:store_false
self.model.visualize(self.args['output_path'])

# reset loss and metric
self.loss.zero()
Expand Down
2 changes: 2 additions & 0 deletions nnabla_nas/utils/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def main():
help='Path to save the monitoring log files.')
parser.add_argument('--save-nnp', action='store_true',
help='Store network and parameter with nnp format.')
parser.add_argument('--no-visualize', action='store_false',
help='Disable visualization with graphviz.')

options = parser.parse_args()

Expand Down