-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathpnas.py
128 lines (108 loc) · 5 KB
/
pnas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) 2020 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nnabla as nn
import numpy as np
from .search import Searcher
from nnabla_nas.utils.estimator.latency import LatencyEstimator
from nnabla_nas.utils.estimator.latency import LatencyGraphEstimator
class ProxylessNasSearcher(Searcher):
r""" ProxylessNAS: Direct Neural Architecture Search on Target Task and
Hardware.
"""
def callback_on_start(self):
r"""Gets the architecture parameters."""
self._reward = nn.NdArray.from_numpy_array(np.zeros((1,)))
# load checkpoint if available
self.load_checkpoint()
def train_on_batch(self, key='train'):
r"""Update the model parameters."""
self.update_graph(key)
params = self.model.get_net_parameters(grad_only=True)
self.optimizer[key].set_parameters(params)
bz, p = self.mbs_train, self.placeholder['train']
self.optimizer[key].zero_grad()
if self.comm.n_procs > 1:
grads = [x.grad for x in params.values()]
self.event.default_stream_synchronize()
for _ in range(self.accum_train):
self._load_data(p, self.dataloader['train'].next())
p['loss'].forward(clear_no_need_grad=True)
for k, m in p['metrics'].items():
m.forward(clear_buffer=True)
self.monitor.update(f'{k}/train', m.d.copy(), bz)
p['loss'].backward(clear_buffer=True)
loss = p['loss'].d.copy()
self.monitor.update('loss/train', loss * self.accum_train, bz)
if self.comm.n_procs > 1:
self.comm.all_reduce(grads, division=True, inplace=False)
self.event.add_default_stream_event()
self.optimizer[key].update()
def valid_on_batch(self):
r"""Update the arch parameters."""
beta, n_iter = 0.9, 10
bz, p = self.mbs_valid, self.placeholder['valid']
valid_data = [self.dataloader['valid'].next()
for i in range(self.accum_valid)]
rewards, grads = [], []
if self.comm.n_procs > 1:
self.event.default_stream_synchronize()
for _ in range(n_iter):
reward = 0
self.update_graph('valid')
arch_params = self.model.get_arch_parameters(grad_only=True)
self.optimizer['valid'].set_parameters(arch_params)
for minibatch in valid_data:
self._load_data(p, minibatch)
p['loss'].forward(clear_buffer=True)
for k, m in p['metrics'].items():
m.forward(clear_buffer=True)
self.monitor.update(f'{k}/valid', m.d.copy(), bz)
loss = p['loss'].d.copy()
reward += (1 - p['metrics']['error'].d) / self.accum_valid
self.monitor.update('loss/valid', loss * self.accum_valid, bz)
# adding constraints
for k, v in self.regularizer.items():
if isinstance(v, LatencyGraphEstimator):
# when using LatencyGraphEstimator (by graph)
inp = [nn.Variable((1,)+si[1:]) for si in
self.model.input_shapes]
out = self.model.call(*inp)
value = v.get_estimation(out)
elif isinstance(v, LatencyEstimator):
# when using LatencyEstimator (by module)
value = v.get_estimation(self.model)
else:
raise NotImplementedError
reward *= (min(1.0, v._bound / value))**v._weight
self.monitor.update(k, value, 1)
rewards.append(reward)
grads.append([m.g.copy() for m in arch_params.values()])
# compute gradients
for j, m in enumerate(arch_params.values()):
m.grad.zero()
for i, r in enumerate(rewards):
m.g += (r - self._reward.data)*grads[i][j]/n_iter
# update global reward
self._reward.data = beta*sum(rewards)/n_iter + \
(1 - beta)*self._reward.data
if self.comm.n_procs > 1:
self.comm.all_reduce(
[x.grad for x in arch_params.values()],
division=True,
inplace=False
)
self.comm.all_reduce(self._reward, division=True, inplace=False)
self.event.add_default_stream_event()
self.monitor.update('reward', self._reward.data[0], self.bs_valid)
self.optimizer['valid'].update()