Skip to content

Commit

Permalink
Fix (graph_eq): correct handling of already equalized BN
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 31, 2023
1 parent 191206b commit 04a194f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 44 deletions.
110 changes: 67 additions & 43 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,48 +65,67 @@


def _equalize_bn(bn_module: nn.Module, scaling_factors: torch.Tensor):
class_name = bn_module.__class__.__name__ + 'Equalized'
bn_module.register_parameter('orig_bias', nn.Parameter(bn_module.bias.clone().detach()))
bn_module.register_parameter('orig_weight', nn.Parameter(bn_module.weight.clone().detach()))
bn_module.register_buffer('scaling_factors', scaling_factors.clone().detach())
bn_module.register_buffer('inverse_scaling_factors', torch.ones_like(bn_module.orig_bias))

del bn_module.bias

def new_bias(self):
return self.inverse_scaling_factors * \
(self.running_mean.data * self.orig_weight / torch.sqrt(self.running_var + self.eps) \
* (self.scaling_factors - 1) + self.orig_bias)

def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(self.__class__, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
output_dict[prefix + 'bias'] = self.bias

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):

equalized_bn_key = prefix + 'orig_weight'
is_equalized_bn = equalized_bn_key in state_dict
if not is_equalized_bn:
self.scaling_factors.fill_(1.)
self.inverse_scaling_factors.fill_(1.)
state_dict[prefix + 'orig_bias'] = state_dict[prefix + 'bias']
state_dict[prefix + 'orig_weight'] = state_dict[prefix + 'weight']
super(self.__class__, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
if not is_equalized_bn:
missing_keys.remove(prefix + 'scaling_factors')
missing_keys.remove(prefix + 'inverse_scaling_factors')
unexpected_keys.remove(prefix + 'bias')

var = {
'bias': property(new_bias),
'state_dict': state_dict,
'_load_from_state_dict': _load_from_state_dict}
child_class = type(class_name, (bn_module.__class__,), var)
bn_module.__class__ = child_class
class_name = bn_module.__class__.__name__
if 'Equalized' in class_name:
bn_module.orig_bias.data = bn_module.bias.clone()
bn_module.orig_weight.data = bn_module.weight.data.clone()
bn_module.scaling_factors = scaling_factors.clone().detach()
bn_module.inverse_scaling_factors.fill_(1.)
else:
class_name = bn_module.__class__.__name__ + 'Equalized'
bn_module.register_parameter('orig_bias', nn.Parameter(bn_module.bias.clone().detach()))
bn_module.register_parameter('orig_weight', nn.Parameter(bn_module.weight.clone().detach()))
bn_module.register_buffer('scaling_factors', scaling_factors.clone().detach())
bn_module.register_buffer('inverse_scaling_factors', torch.ones_like(bn_module.orig_bias))

del bn_module.bias

def new_bias(self):
return self.inverse_scaling_factors * \
(self.running_mean.data * self.orig_weight / torch.sqrt(self.running_var + self.eps) \
* (self.scaling_factors - 1) + self.orig_bias)

def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(self.__class__, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
output_dict[prefix + 'bias'] = self.bias

def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs):

equalized_bn_key = prefix + 'orig_weight'
is_equalized_bn = equalized_bn_key in state_dict
if not is_equalized_bn:
self.scaling_factors.fill_(1.)
self.inverse_scaling_factors.fill_(1.)
state_dict[prefix + 'orig_bias'] = state_dict[prefix + 'bias']
state_dict[prefix + 'orig_weight'] = state_dict[prefix + 'weight']
super(self.__class__, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs)
if not is_equalized_bn:
missing_keys.remove(prefix + 'scaling_factors')
missing_keys.remove(prefix + 'inverse_scaling_factors')
unexpected_keys.remove(prefix + 'bias')

var = {
'bias': property(new_bias),
'state_dict': state_dict,
'_load_from_state_dict': _load_from_state_dict}
child_class = type(class_name, (bn_module.__class__,), var)
bn_module.__class__ = child_class


def _select_scale_computation_fn(
Expand Down Expand Up @@ -329,12 +348,17 @@ def _cross_layer_equalization(
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
if isinstance(module, _batch_norm):
# If it is BatchNorm, we need to define a new metaclass where the bias is a function of

# If it is BatchNorm that has not been yet equalized,
# we need to define a new metaclass where the bias is a function of
# running_mean and running_var, as well as function of the pre-equalized weight and bias.
# If the stats are updated, the bias value will change accordingly.
# If a new training is performed, the weight and the pre-equalized bias and weight will
# be learned.
# If the BN has already been equalized, we simply perform an update of the parameters and
# buffers.
_equalize_bn(module, scaling_factors)

module.weight.data = module.weight.data * torch.reshape(scaling_factors, src_broadcast_size)

return scaling_factors
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_bn_stats_torchvision_models(model_coverage: tuple, merge_bias: bool):
post_bn_stats.append(module.bias.data.clone())

for pre_val, post_val in zip(pre_bn_stats, post_bn_stats):
assert not torch.allclose(pre_val, post_val, atol=ATOL)
assert not torch.allclose(pre_val, post_val)


@pytest_cases.parametrize("merge_bias", [True, False])
Expand Down

0 comments on commit 04a194f

Please sign in to comment.