Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for qdense_batchnorm in QKeras #74

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions qkeras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#from .qtools.settings import cfg
from .qconv2d_batchnorm import QConv2DBatchnorm
from .qdepthwiseconv2d_batchnorm import QDepthwiseConv2DBatchnorm
from .qdense_batchnorm import QDenseBatchnorm

assert tf.executing_eagerly(), "QKeras requires TF with eager execution mode on"

Expand Down
327 changes: 327 additions & 0 deletions qkeras/qdense_batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# Copyright 2020 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fold batchnormalization with previous QDense layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import warnings

import numpy as np
import six


from qkeras.qlayers import QDense
from qkeras.quantizers import *

import tensorflow.compat.v2 as tf
from tensorflow.keras import layers
from tensorflow.python.framework import smart_cond as tf_utils
from tensorflow.python.ops import math_ops

tf.compat.v2.enable_v2_behavior()


class QDenseBatchnorm(QDense):
"""Implements a quantized Dense layer fused with Batchnorm."""

def __init__(
self,
units,
activation=None,
use_bias=True,
kernel_initializer="he_normal",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
kernel_quantizer=None,
bias_quantizer=None,
kernel_range=None,
bias_range=None,

# batchnorm params
axis=-1,
momentum=0.99,
epsilon=0.001,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
moving_mean_initializer="zeros",
moving_variance_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
fused=None,
trainable=True,
virtual_batch_size=None,
adjustment=None,

# other params
ema_freeze_delay=None,
folding_mode="ema_stats_folding",
**kwargs):

super(QDenseBatchnorm, self).__init__(
units=units,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
kernel_quantizer=kernel_quantizer,
bias_quantizer=bias_quantizer,
kernel_range=kernel_range,
bias_range=bias_range,
**kwargs)

# initialization of batchnorm part of the composite layer
self.batchnorm = layers.BatchNormalization(
axis=axis, momentum=momentum, epsilon=epsilon, center=center,
scale=scale, beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint, gamma_constraint=gamma_constraint,
renorm=renorm, renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum, fused=fused, trainable=trainable,
virtual_batch_size=virtual_batch_size, adjustment=adjustment
)

self.ema_freeze_delay = ema_freeze_delay
assert folding_mode in ["ema_stats_folding", "batch_stats_folding"]
self.folding_mode = folding_mode

def build(self, input_shape):
super(QDenseBatchnorm, self).build(input_shape)

# self._iteration (i.e., training_steps) is initialized with -1. When
# loading ckpt, it can load the number of training steps that have been
# previously trainied. If start training from scratch.
# TODO(lishanok): develop a way to count iterations outside layer
self._iteration = tf.Variable(-1, trainable=False, name="iteration",
dtype=tf.int64)

def call(self, inputs, training=None):

# numpy value, mark the layer is in training
training = self.batchnorm._get_training_value(training) # pylint: disable=protected-access

# checking if to update batchnorm params
if (self.ema_freeze_delay is None) or (self.ema_freeze_delay < 0):
# if ema_freeze_delay is None or a negative value, do not freeze bn stats
bn_training = tf.cast(training, dtype=bool)
else:
bn_training = tf.math.logical_and(training, tf.math.less_equal(
self._iteration, self.ema_freeze_delay))

kernel = self.kernel

#execute qdense output
qdense_outputs = tf.keras.backend.dot(
inputs,
kernel
)

if self.use_bias:
bias = self.bias
qdense_outputs = tf.keras.backend.bias_add(
qdense_outputs,
bias,
data_format="channels_last")
else:
bias = 0

# begin batchnorm
_ = self.batchnorm(qdense_outputs, training=bn_training)

self._iteration.assign_add(tf_utils.smart_cond(
training, lambda: tf.constant(1, tf.int64),
lambda: tf.constant(0, tf.int64)))

# calculate mean and variance from current batch
bn_shape = qdense_outputs.shape
ndims = len(bn_shape)
reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis]
keep_dims = len(self.batchnorm.axis) > 1
mean, variance = self.batchnorm._moments( # pylint: disable=protected-access
math_ops.cast(qdense_outputs, self.batchnorm._param_dtype), # pylint: disable=protected-access
reduction_axes,
keep_dims=keep_dims)

# get batchnorm weights
gamma = self.batchnorm.gamma
beta = self.batchnorm.beta
moving_mean = self.batchnorm.moving_mean
moving_variance = self.batchnorm.moving_variance

if self.folding_mode == "batch_stats_folding":
# using batch mean and variance in the initial training stage
# after sufficient training, switch to moving mean and variance
new_mean = tf_utils.smart_cond(
bn_training, lambda: mean, lambda: moving_mean)
new_variance = tf_utils.smart_cond(
bn_training, lambda: variance, lambda: moving_variance)

# get the inversion factor so that we replace division by multiplication
inv = math_ops.rsqrt(new_variance + self.batchnorm.epsilon)
if gamma is not None:
inv *= gamma

# fold bias with bn stats
folded_bias = inv * (bias - new_mean) + beta

elif self.folding_mode == "ema_stats_folding":
# We always scale the weights with a correction factor to the long term
# statistics prior to quantization. This ensures that there is no jitter
# in the quantized weights due to batch to batch variation. During the
# initial phase of training, we undo the scaling of the weights so that
# outputs are identical to regular batch normalization. We also modify
# the bias terms correspondingly. After sufficient training, switch from
# using batch statistics to long term moving averages for batch
# normalization.

# use batch stats for calcuating bias before bn freeze, and use moving
# stats after bn freeze
mv_inv = math_ops.rsqrt(moving_variance + self.batchnorm.epsilon)
batch_inv = math_ops.rsqrt(variance + self.batchnorm.epsilon)

if gamma is not None:
mv_inv *= gamma
batch_inv *= gamma
folded_bias = tf_utils.smart_cond(
bn_training,
lambda: batch_inv * (bias - mean) + beta,
lambda: mv_inv * (bias - moving_mean) + beta)
# moving stats is always used to fold kernel in tflite; before bn freeze
# an additional correction factor will be applied to the conv2d output
# end batchnorm
inv = mv_inv
else:
assert ValueError

# wrap dense kernel with bn parameters
folded_kernel = inv*kernel
# quantize the folded kernel
if self.kernel_quantizer is not None:
q_folded_kernel = self.kernel_quantizer_internal(folded_kernel)
else:
q_folded_kernel = folded_kernel

#quantize the folded bias
if self.bias_quantizer_internal is not None:
q_folded_bias = self.bias_quantizer_internal(folded_bias)
else:
q_folded_bias = folded_bias

applied_kernel = q_folded_kernel
applied_bias = q_folded_bias

#calculate qdense output using the quantized folded kernel
folded_outputs = tf.keras.backend.dot(inputs, applied_kernel)

if training is True and self.folding_mode == "ema_stats_folding":
batch_inv = math_ops.rsqrt(variance + self.batchnorm.epsilon)
y_corr = tf_utils.smart_cond(
bn_training,
lambda: (math_ops.sqrt(moving_variance + self.batchnorm.epsilon) *
math_ops.rsqrt(variance + self.batchnorm.epsilon)),
lambda: tf.constant(1.0, shape=moving_variance.shape))
folded_outputs = math_ops.mul(folded_outputs, y_corr)

folded_outputs = tf.keras.backend.bias_add(
folded_outputs,
applied_bias,
data_format="channels_last"
)

if self.activation is not None:
return self.activation(folded_outputs)

return folded_outputs

def get_config(self):
base_config = super().get_config()
bn_config = self.batchnorm.get_config()
config = {"ema_freeze_delay": self.ema_freeze_delay,
"folding_mode": self.folding_mode}
name = base_config["name"]
out_config = dict(
list(base_config.items())
+ list(bn_config.items()) + list(config.items()))

# names from different config override each other; use the base layer name
# as the this layer's config name
out_config["name"] = name
return out_config

def get_quantization_config(self):
return {
"kernel_quantizer": str(self.kernel_quantizer_internal),
"bias_quantizer": str(self.bias_quantizer_internal),
}

def get_quantizers(self):
return self.quantizers

# def get_prunable_weights(self):
# return [self.kernel]

def get_folded_weights(self):
"""Function to get the batchnorm folded weights.
This function converts the weights by folding batchnorm parameters into
the weight of QDense. The high-level equation:
W_fold = gamma * W / sqrt(variance + epsilon)
bias_fold = gamma * (bias - moving_mean) / sqrt(variance + epsilon) + beta
"""

kernel = self.kernel
if self.use_bias:
bias = self.bias
else:
bias = 0

# get batchnorm weights and moving stats
gamma = self.batchnorm.gamma
beta = self.batchnorm.beta
moving_mean = self.batchnorm.moving_mean
moving_variance = self.batchnorm.moving_variance

# get the inversion factor so that we replace division by multiplication
inv = math_ops.rsqrt(moving_variance + self.batchnorm.epsilon)
if gamma is not None:
inv *= gamma

# wrap conv kernel and bias with bn parameters
folded_kernel = inv * kernel
folded_bias = inv * (bias - moving_mean) + beta

return [folded_kernel, folded_bias]
8 changes: 6 additions & 2 deletions qkeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .qlayers import Clip
from .qconv2d_batchnorm import QConv2DBatchnorm
from .qdepthwiseconv2d_batchnorm import QDepthwiseConv2DBatchnorm
from .qdense_batchnorm import QDenseBatchnorm
from .qlayers import QActivation
from .qlayers import QAdaptiveActivation
from .qpooling import QAveragePooling2D
Expand Down Expand Up @@ -96,6 +97,7 @@
"QDepthwiseConv2DBatchnorm",
"QAveragePooling2D",
"QGlobalAveragePooling2D",
"QDenseBatchnorm",
]


Expand Down Expand Up @@ -264,7 +266,7 @@ def model_save_quantized_weights(model, filename=None):
hw_weights = []

if any(isinstance(layer, t) for t in [
QConv2DBatchnorm, QDepthwiseConv2DBatchnorm]):
QConv2DBatchnorm, QDenseBatchnorm, QDepthwiseConv2DBatchnorm]):
qs = layer.get_quantizers()
ws = layer.get_folded_weights()
elif any(isinstance(layer, t) for t in [QSimpleRNN, QLSTM, QGRU]):
Expand Down Expand Up @@ -380,7 +382,7 @@ def model_save_quantized_weights(model, filename=None):
if has_scale:
saved_weights[layer.name]["scales"] = scales
if not any(isinstance(layer, t) for t in [
QConv2DBatchnorm, QDepthwiseConv2DBatchnorm]):
QConv2DBatchnorm, QDenseBatchnorm, QDepthwiseConv2DBatchnorm]):
# Set layer weights in the format that software inference uses
layer.set_weights(weights)
else:
Expand Down Expand Up @@ -1056,6 +1058,8 @@ def _add_supported_quantized_objects(custom_objects):
custom_objects["QConv2DBatchnorm"] = QConv2DBatchnorm
custom_objects["QDepthwiseConv2DBatchnorm"] = QDepthwiseConv2DBatchnorm

custom_objects["QDenseBatchnorm"] = QDenseBatchnorm

custom_objects["QAveragePooling2D"] = QAveragePooling2D
custom_objects["QGlobalAveragePooling2D"] = QGlobalAveragePooling2D
custom_objects["QScaleShift"] = QScaleShift
Expand Down
Loading
Loading