Skip to content

Commit

Permalink
Merge pull request #62 from sony/feature/20220818-fix-dynamicbn
Browse files Browse the repository at this point in the history
Fix Dynamic BN for inefficient memory use
  • Loading branch information
hyingho authored Dec 14, 2022
2 parents 34ecf2b + efc7f77 commit 2268b35
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 22 deletions.
66 changes: 49 additions & 17 deletions nnabla_nas/contrib/common/ofa/elastic_nn/modules/dynamic_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,34 +153,66 @@ def __init__(self, max_feature_dim, n_dims):
self._scope_name = f'<dynamicbatchnorm2d at {hex(id(self))}>'

self._max_feature_dim = max_feature_dim
self.bn = Mo.BatchNormalization(self._max_feature_dim, n_dims)
self._n_dims = n_dims

self.bn = Mo.BatchNormalization(max_feature_dim, n_dims)
self.use_static_bn = True
self.set_running_statistics = False
self._prev_running_stats = None

def _update_running_stats(self):
"""
Note: This implementation is a workaround to avoid undesireble network
graph construction in static mode that causes inefficient memory use.
This method is called before the forward graph construction of each DynamicBN.
However, this leads to a missing update in some cases of using the model
before the next forward graph construction (e.g., saving parameters) unless
this method is callsed after the last iteration.
We decided to leave this issue remained since ignoring the last update
shouldn't affect the performance much.
Probably this implementation can be improved by replacing this part by
F.assign and replace after F.batch_normalization(...)
"""
if self._prev_running_stats is None:
return
bn = self.bn
smean, svar, feature_dim = self._prev_running_stats
self._prev_running_stats = None
channel_axis = 1
if feature_dim < bn._mean.shape[channel_axis]:
bn._mean.data.copy_from(F.concatenate(smean.data, bn._mean.data[:, feature_dim:, :, :], axis=1))
bn._var.data.copy_from(F.concatenate(svar.data, bn._var.data[:, feature_dim:, :, :], axis=1))
else:
bn._mean.data.copy_from(smean.data)
bn._var.data.copy_from(svar.data)

@staticmethod
def bn_forward(x, bn: Mo.BatchNormalization, max_feature_dim, feature_dim, training,
use_static_bn, set_running_statistics):
if use_static_bn or set_running_statistics:
return bn(x)
def call(self, input):
if self.use_static_bn or self.set_running_statistics:
return self.bn(input)
else:
assert not nn.get_auto_forward(), "This code block is verified with static mode only so far."
if self.training:
"""
Note: We decided to call self._update_running_stats() only for the training mode.
For OFA, running this part at the validation mode induces larger loss because
reset_running_statistics() runs before this; running this part overwrites the
re-calculated BN mean/var statistics.
"""
self._update_running_stats()
feature_dim = input.shape[1]
bn = self.bn
sbeta, sgamma = bn._beta[:, :feature_dim, :, :], bn._gamma[:, :feature_dim, :, :]
smean = nn.Variable(sbeta.shape)
svar = nn.Variable(sbeta.shape)
smean.data = bn._mean.data[:, :feature_dim, :, :]
svar.data = bn._var.data[:, :feature_dim, :, :]
y = F.batch_normalization(x, sbeta, sgamma, smean, svar, batch_stat=training,)
if training:
bn._mean = F.concatenate(smean, bn._mean[:, feature_dim:, :, :], axis=1)
bn._var = F.concatenate(svar, bn._var[:, feature_dim:, :, :], axis=1)
y = F.batch_normalization(input, sbeta, sgamma, smean, svar, batch_stat=self.training)
if self.training:
self._prev_running_stats = (smean, svar, feature_dim)
return y

def call(self, input):
feature_dim = input.shape[1]
y = self.bn_forward(
input, self.bn, self._max_feature_dim, feature_dim, self.training,
self.use_static_bn, self.set_running_statistics)
return y


class DynamicDepthwiseConv(Mo.Module):

Expand Down
7 changes: 5 additions & 2 deletions nnabla_nas/module/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def __init__(self, n_features, n_dims, axes=[1], decay_rate=0.9, eps=1e-5,

def call(self, input):
if self.set_running_statistics:
"""
Note: this code block is verified with only
once-for-all algorithm so far.
"""
batch_mean = F.mean(input, axis=(0, 2, 3), keepdims=True)
batch_var = F.mean(input ** 2, axis=(0, 2, 3),
keepdims=True) - batch_mean ** 2
Expand All @@ -122,8 +126,7 @@ def call(self, input):

_feature_dim = batch_mean.shape[1]
return F.batch_normalization(
input, self._beta[:, :_feature_dim, :,
:], self._gamma[:, :_feature_dim, :, :],
input, self._beta[:, :_feature_dim, :, :], self._gamma[:, :_feature_dim, :, :],
batch_mean, batch_var, decay_rate=self._decay_rate, eps=self._eps, batch_stat=False
)
else:
Expand Down
3 changes: 0 additions & 3 deletions nnabla_nas/runner/searcher/ofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import nnabla as nn
import nnabla.functions as F
from nnabla_ext.cuda import clear_memory_cache

from ... import contrib
from .search import Searcher
Expand Down Expand Up @@ -87,7 +86,6 @@ def run(self):
train_keys = [m.name for m in self.monitor.meters.values()
if 'train' in m.name]
self.monitor.display(i, key=train_keys)
clear_memory_cache()
if self.cur_epoch % self.hparams["validation_frequency"] == 0:
self.valid_genotypes(mode='valid')

Expand Down Expand Up @@ -186,7 +184,6 @@ def valid_genotypes(self, mode='valid'):
desc=f'{mode} [{self.cur_epoch}/{self.hparams["epoch"]}]'):
self.update_graph(mode)
self.valid_on_batch(is_test=is_test)
clear_memory_cache()
self.monitor.info(f'img_size={img_size}, genotype={genotype} \n')
self.callback_on_epoch_end(is_test=is_test)
self.monitor.write(self.cur_epoch)
Expand Down

0 comments on commit 2268b35

Please sign in to comment.