diff --git a/CHANGELOG b/CHANGELOG index 69c4cc86..398ad913 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,2 +1,3 @@ -v0.5, 07/18/2019 -- Initial release. -v0.6, 12/03/2019 -- Support tensorflow 2.0 and tf.keras +v0.5, 2019/07 -- Initial release. +v0.6, 2020/03 -- Support tensorflow 2.0, tf.keras and python3. +v0.7, 2020/03 -- Enhancemence of binary and ternary quantization. diff --git a/README.md b/README.md index 4864114b..060afb62 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,7 @@ that accept an alpha parameter, we need to specify a range of alpha, and for po2 type of quantizers, we need to specify the range of max_value. + ### Example Suppose you have the following network. diff --git a/examples/example_ternary.py b/examples/example_ternary.py new file mode 100644 index 00000000..9f11dc66 --- /dev/null +++ b/examples/example_ternary.py @@ -0,0 +1,113 @@ +# Copyright 2020 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import # Not necessary in a Python 3-only module +from __future__ import division # Not necessary in a Python 3-only module +from __future__ import google_type_annotations # Not necessary in a Python 3-only module +from __future__ import print_function # Not necessary in a Python 3-only module + +from absl import app +from absl import flags +import matplotlib +import numpy as np + +matplotlib.use('TkAgg') +import matplotlib.pyplot as plt + + +FLAGS = flags.FLAGS + + +def _stochastic_rounding(x, precision, resolution, delta): + """Stochastic_rounding for ternary. + + Args: + x: + precision: A float. The area we want to make this stochastic rounding. + [delta-precision, delta] [delta, delta+precision] + resolution: control the quantization resolution. + delta: the undiscountinued point (positive number) + + Return: + A tensor with stochastic rounding numbers. + """ + delta_left = delta - precision + delta_right = delta + precision + scale = 1 / resolution + scale_delta_left = delta_left * scale + scale_delta_right = delta_right * scale + scale_2_delta = scale_delta_right - scale_delta_left + scale_x = x * scale + fraction = scale_x - scale_delta_left + # print(precision, scale, x[0], np.floor(scale_x[0]), scale_x[0], fraction[0]) + + # we use uniform distribution + random_selector = np.random.uniform(0, 1, size=x.shape) * scale_2_delta + + # print(precision, scale, x[0], delta_left[0], delta_right[0]) + # print('x', scale_x[0], fraction[0], random_selector[0], scale_2_delta[0]) + # rounddown = fraction < random_selector + result = np.where(fraction < random_selector, + scale_delta_left / scale, + scale_delta_right / scale) + return result + + +def _ternary(x, sto=False): + m = np.amax(np.abs(x), keepdims=True) + scale = 2 * m / 3.0 + thres = scale / 2.0 + ratio = 0.1 + + if sto: + sign_bit = np.sign(x) + x = np.abs(x) + prec = x / scale + x = ( + sign_bit * scale * _stochastic_rounding( + x / scale, + precision=0.3, resolution=0.01, # those two are all normalized. + delta=thres / scale)) + # prec + prec *ratio) + # mm = np.amax(np.abs(x), keepdims=True) + return np.where(np.abs(x) < thres, np.zeros_like(x), np.sign(x)) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + # x = np.arange(-3.0, 3.0, 0.01) + # x = np.random.uniform(-0.01, 0.01, size=1000) + x = np.random.uniform(-10.0, 10.0, size=1000) + # x = np.random.uniform(-1, 1, size=1000) + x = np.sort(x) + tr = np.zeros_like(x) + t = np.zeros_like(x) + iter_count = 500 + for _ in range(iter_count): + y = _ternary(x) + yr = _ternary(x, sto=True) + t = t + y + tr = tr + yr + + plt.plot(x, t/iter_count) + plt.plot(x, tr/iter_count) + plt.ylabel('mean (%s samples)' % iter_count) + plt.show() + + +if __name__ == '__main__': + app.run(main) diff --git a/notebook/QKerasTutorial.ipynb b/notebook/QKerasTutorial.ipynb new file mode 100644 index 00000000..cc115ea7 --- /dev/null +++ b/notebook/QKerasTutorial.ipynb @@ -0,0 +1,869 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "##### Copyright 2020 Google LLC\n", + "#\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# QKeras Lab Book" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "__QKeras__ is a quantization extension to Keras that provides drop-in replacement for some of the Keras layers, especially the ones that creates parameters and activation layers, and perform arithmetic operations, so that we can quickly create a deep quantized version of Keras network.\n", + "\n", + "According to Tensorflow documentation, Keras is a high-level API to build and train deep learning models. It's used for fast prototyping, advanced research, and production, with three key advantages:\n", + "\n", + "- User friendly
\n", + "Keras has a simple, consistent interface optimized for common use cases. It provides clear and actionable feedback for user errors.\n", + "\n", + "- Modular and composable
\n", + "Keras models are made by connecting configurable building blocks together, with few restrictions.\n", + "\n", + "- Easy to extend
\n", + "Write custom building blocks to express new ideas for research. Create new layers, loss functions, and develop state-of-the-art models.\n", + "\n", + "__QKeras__ is being designed to extend the functionality of Keras using Keras' design principle, i.e. being user friendly, modular and extensible, adding to it being \"minimally intrusive\" of Keras native functionality.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Related Work\n", + "\n", + "__QKeras__ has been implemented based on the work of _\"B.Moons et al. - Minimum Energy Quantized Neural Networks\"_ , Asilomar Conference on Signals, Systems and Computers, 2017 and _“Zhou, S. et al. DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients,”_ but the framework should be easily extensible. The original code from QNN can be found below.\n", + "\n", + "https://github.com/BertMoons/QuantizedNeuralNetworks-Keras-Tensorflow\n", + "\n", + "__QKeras__ extends QNN by providing a richer set of layers (including SeparableConv2D, DepthwiseConv2D, ternary and stochastic ternary quantizations), besides some functions to aid the estimation for the accumulators and conversion between non-quantized to quantized networks. Finally, our main goal is easy of use, so we attempt to make QKeras layers a true drop-in replacement for Keras, so that users can easily exchange non-quantized layers by quantized ones." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Layers Implemented in QKeras\n", + "\n", + "The following layers have been implemented in __QKeras__.\n", + "\n", + "- __`QDense`__\n", + "\n", + "- __`QConv1D`__\n", + "\n", + "- __`QConv2D`__\n", + "\n", + "- __`QDepthwiseConv2D`__\n", + "\n", + "- __`QSeparableConv2D`__ (depthwise + pointwise expanded, extended from MobileNet SeparableConv2D implementation)\n", + "\n", + "- __`QActivation`__\n", + "\n", + "- __`QAveragePooling2D`__ (in fact, a AveragePooling2D stacked with a QActivation layer for quantization of the result, so this layer does not exist)\n", + "\n", + "- __`QBatchNormalization`__\n", + "\n", + "- __`QOctaveConv2D`__\n", + "\n", + "It is worth noting that not all functionality is safe at this time to be used with other high-level operations, such as with layer wrappers. For example, `Bidirectional` layer wrappers are used with RNNs. If this is required, we encourage users to use quantization functions invoked as strings instead of the actual functions as a way through this, but we may change that implementation in the future.\n", + "\n", + "__`QSeparableConv2D`__ is implemented as a depthwise + pointwise quantized expansions, which is extended from the `SeparableConv2D` implementation of MobileNet. With the exception of __`QBatchNormalization`__, if quantizers are not specified, no quantization is applied to the layer and it ends up behaving like the orgininal unquantized layers. On the other hand, __`QBatchNormalization`__ has been implemented differently as if the user does not specify any quantizers as parameters, it uses a set up that has worked best when attempting to implement quantization efficiently in hardware and software, i.e. `gamma` and `variance` with po2 quantizers (as they become shift registers in an implementation, and with further constraining variance po2 quantizer to use quadratic approximation as we take the square root of the variance to obtain the standard deviation), `beta` using po2 quantizer to maintain the dynamic range aspect of the center parameter, and `mean` remaining unquantized, as it inherits the properties of the previous layer.\n", + "\n", + "Activation has been migrated to __`QActivation`__ although it __QKeras__ also recognizes activation parameter used in convolutional and dense layers.\n", + "\n", + "We have improved the setup of quantization as convolution, dense and batch normalization layers now notify the quantizers when the quantizers are used as internal parameters, so the user does not need to worry about setting up options that work best in `weights` and `bias` like `alpha` and `use_stochastic_rounding` (although users may override the automatic setup).\n", + "\n", + "Finally, in the current version, we have eliminated the need to set up the range of the quantizers like `kernel_range` in __`QDense`__. This is automatically computed internally at this point. Although we kept the parameters for backward compatibility, these parameters will be removed in the future." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Activation Layers and Quantizers Implemented in __QKeras__\n", + "\n", + "Quantizers and activation layers are treated interchangingly in __QKeras__. \n", + "\n", + "The list of quantizers and its parameters is listed below.\n", + "\n", + "- __`smooth_sigmoid(x)`__\n", + "\n", + "- __`hard_sigmoid(x)`__\n", + "\n", + "- __`binary_sigmoid(x)`__\n", + "\n", + "- __`smooth_tanh(x)`__\n", + "\n", + "- __`hard_tanh(x)`__\n", + "\n", + "- __`binary_tanh(x)`__\n", + "\n", + "- __`quantized_bits(bits=8, integer=0, symmetric=0, keep_negative=1, alpha=None, use_stochastic_rouding=False)(x)`__\n", + "\n", + "- __`bernoulli(alpha=None, temperature=6.0, use_real_sigmoid=True)(x)`__\n", + "\n", + "- __`stochastic_ternary(alpha=None, threshold=None, temperature=8.0, use_real_sigmoid=True)(x)`__\n", + "\n", + "- __`ternary(alpha=None, threshold=None, use_stochastic_rounding=False)(x)`__\n", + "\n", + "- __`stochastic_binary(alpha=None, temperature=6.0, use_real_sigmoid=True)(x)`__\n", + "\n", + "- __`binary(use_01=False, alpha=None, use_stochastic_rounding=False)(x)`__\n", + "\n", + "- __`quantized_relu(bits=8, integer=0, use_sigmoid=0, use_stochastic_rounding=False)(x)`__\n", + "\n", + "- __`quantized_ulaw(bits=8, integer=0, symmetric=0, u=255.0)(x)`__\n", + "\n", + "- __`quantized_tanh(bits=8, integer=0, symmetric=0, use_stochastic_rounding=False)(x)`__\n", + "\n", + "- __`quantized_po2(bits=8, max_value=None, use_stochastic_rounding=False, quadratic_approximation=False)(x)`__\n", + "\n", + "- __`quantized_relu_po2(bits=8, max_value=None, use_stochastic_rounding=False, quadratic_approximation=False)(x)`__\n", + "\n", + "The __`stochastic_*`__ functions and __`bernoulli`__ rely on stochastic versions of the activation functions, so they are best suited for weights and biases. They draw a random number with uniform distribution from `sigmoid` of the input x, and result is based on the expected value of the activation function. Please refer to the papers if you want to understand the underlying theory, or the documentation in qkeras/quantizers.py. The parameter `temperature` determines how steep the sigmoid function will behave, and the default values seem to work fine.\n", + "\n", + "As we lower the number of bits, rounding becomes problematic as it adds bias to the number system. Numpy attempt to reduce the effects of bias by rounding to even instead of rounding to infinity. Recent results (_\"Suyog Gupta, Ankur Agrawal, Kailash Gopalakrishnan, Pritish Narayanan; Deep Learning with Limited Numerical Precision_ [https://arxiv.org/abs/1502.02551]) suggested using stochastic rounding, which uses the fracional part of the number as a probability to round up or down. We can turn on stochastic rounding in some quantizers by setting `use_stochastic_rounding` to `True` in __`quantized_bits`__, __`binary`__, __`ternary`__, __`quantized_relu`__ and __`quantized_tanh`__, __`quantized_po2`__, and __`quantized_relu_po2`__. Please note that if one is considering an efficient hardware or software implementation, we should avoid setting this flag to `True` in activations as it may affect the efficiency of an implementation. In addition, as mentioned before, we already set this flag to `True` in some quantized layers when the quantizers are used as weights/biases.\n", + "\n", + "The parameters `bits` specify the number of bits for the quantization, and `integer` specifies how many bits of `bits` are to the left of the decimal point. Finally, our experience in training networks with __`QSeparableConv2D`__, it is advisable to allocate more bits between the depthwise and the pointwise quantization, and both __`quantized_bits`__ and __`quantized_tanh`__ should use symmetric versions for weights and bias in order to properly converge and eliminate the bias.\n", + "\n", + "We have substantially improved stochastic rounding implementation in __QKeras__ $>= 0.7$, and added a symbolic way to compute alpha in __`binary`__, __`stochastic_binary`__, __`ternary`__, __`stochastic_ternary`__, __`bernoulli`__ and __`quantized_bits`__. Right now, a scale and the threshold (for ternary and stochastic_ternary) can be computed independently of the distribution of the inputs, which is required when using these quantizers in weights.\n", + "\n", + "The main problem in using very small bit widths in large deep learning networks stem from the fact that weights are initialized with variance roughly $\\propto \\sqrt{1/\\tt{fanin}}$, but during the training the variance shifts outwards. If the smallest quantization representation (threshold in ternary networks) is smaller than $\\sqrt{1/\\tt{fanin}}$, we run the risk of having the weights stuck at 0 during training. So, the weights need to dynamically adjust to the variance shift from initialization to the final training. This can be done by scaling the quantization. \n", + "\n", + "Scale is computed using the formula $\\sum(\\tt{dot}(Q,x))/\\sum(\\tt{dot}(Q,Q))$ which is described in several papers, including _Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, Ali Farhadi \"XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks\"_ [https://arxiv.org/abs/1603.05279]. Scale computation is computed for each output channel, making our implementation sometimes behaving like a mini-batch normalization adjustment. \n", + "\n", + "For __`ternary`__ and __`stochastic_ternary`__, we iterate between scale computation and threshold computation, as presented in _K. Hwang and W. Sung, \"Fixed-point feedforward deep neural network design using weights +1, 0, and −1,\" 2014 IEEE Workshop on Signal Processing Systems (SiPS), Belfast, 2014, pp. 1-6_ which makes the search for threshold and scale tolerant to different input distributions. This is especially important when we need to consider that the threshold shifts depending on the input distribution, affecting the scale as well, as pointed out by _Fengfu Li, Bo Zhang, Bin Liu, \"Ternary Weight Networks\"_ [https://arxiv.org/abs/1605.04711]. \n", + "\n", + "When computing the scale in these quantizers, if `alpha=\"auto\"`, we compute the scale as a floating point number. If `alpha=\"auto_po2\"`, we enforce the scale to be a power of 2, meaning that an actual hardware or software implementation can be performed by just shifting the result of the convolution or dense layer to the right or left by checking the sign of the scale (positive shifts left, negative shifts right), and taking the log2 of the scale. This behavior is compatible with shared exponent approaches, as it performs a shift adjustment to the channel.\n", + "\n", + "We have implemented a method for each quantizer called __`_set_trainable_parameter`__ that instructs __QKeras__ to set best options when this quantizer is used as a weight or for gamma, variance and beta in __`QBatchNormalization`__, so in principle, users should not worry about this.\n", + "\n", + "The following pictures show the behavior of __`binary`__ vs stochastic rounding in __`binary`__ vs __`stochastic_binary`__ (Figure 1) and __`ternary`__ vs stochastic rounding in __`ternary`__ and __`stochastic_ternary`__ (Figure 2). We generated a normally distributed input with mean 0.0 and standard deviation of 0.02, ordered the data, and ran the quantizer 1,000 times, averaging the result for each case. Note that because of scale, the output does not range from $[-1.0, +1.0]$, but from $[-\\tt{scale}, +\\tt{scale}]$.\n", + "\n", + "\n", + "\"Binary
Figure 1: Behavior of binary quantizers
\n", + "\"Ternary
Figure 2: Behavior of ternary quantizers
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using QKeras\n", + "\n", + "__QKeras__ works by tagging all variables and weights/bias created by Keras as well as output of arithmetic layers by quantized functions. Quantized functions can be instantiated directly in __`QDense`__/__`QConv2D`__/__`QSeparableConv2D`__ functions, and they can be passed to __`QActivation`__, which act as a merged quantization and activation function.\n", + "\n", + "In order to successfully quantize a model, users need to replace layers that create variables (trainable or not) (`Dense`, `Conv2D`, etc) by their equivalent ones in __Qkeras__ (__`QDense`__, __`QConv2D`__, etc), and any layers that perform math operations need to be quantized afterwards.\n", + "\n", + "Quantized values are clipped between their maximum and minimum quantized representation (which may be different than $[-1.0, 1.0]$), although for `po2` type of quantizers, we still recommend the users to specify the parameter for `max_value`.\n", + "\n", + "An example of a very simple network is given below in Keras." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", + " from ._conv import register_converters as _register_converters\n" + ] + } + ], + "source": [ + "import six\n", + "import numpy as np\n", + "import tensorflow.compat.v2 as tf\n", + "\n", + "from tensorflow.keras.layers import *\n", + "from tensorflow.keras.models import Model\n", + "from tensorflow.keras.datasets import mnist\n", + "from tensorflow.keras.utils import to_categorical" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def CreateModel(shape, nb_classes):\n", + " x = x_in = Input(shape)\n", + " x = Conv2D(18, (3, 3), name=\"conv2d_1\")(x)\n", + " x = Activation(\"relu\", name=\"act_1\")(x)\n", + " x = Conv2D(32, (3, 3), name=\"conv2d_2\")(x)\n", + " x = Activation(\"relu\", name=\"act_2\")(x)\n", + " x = Flatten(name=\"flatten\")(x)\n", + " x = Dense(nb_classes, name=\"dense\")(x)\n", + " x = Activation(\"softmax\", name=\"softmax\")(x)\n", + " \n", + " model = Model(inputs=x_in, outputs=x)\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def get_data():\n", + " (x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + " x_train = x_train.reshape(x_train.shape + (1,)).astype(\"float32\")\n", + " x_test = x_test.reshape(x_test.shape + (1,)).astype(\"float32\")\n", + "\n", + " x_train /= 256.0\n", + " x_test /= 256.0\n", + "\n", + " x_mean = np.mean(x_train, axis=0)\n", + "\n", + " x_train -= x_mean\n", + " x_test -= x_mean\n", + "\n", + " nb_classes = np.max(y_train)+1\n", + " y_train = to_categorical(y_train, nb_classes)\n", + " y_test = to_categorical(y_test, nb_classes)\n", + "\n", + " return (x_train, y_train), (x_test, y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "(x_train, y_train), (x_test, y_test) = get_data()\n", + "\n", + "model = CreateModel(x_train.shape[1:], y_train.shape[-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/3\n", + "60000/60000 [==============================] - 7s 122us/sample - loss: 0.2103 - accuracy: 0.9393 - val_loss: 0.0685 - val_accuracy: 0.9781\n", + "Epoch 2/3\n", + "60000/60000 [==============================] - 6s 108us/sample - loss: 0.0642 - accuracy: 0.9808 - val_loss: 0.0575 - val_accuracy: 0.9817\n", + "Epoch 3/3\n", + "60000/60000 [==============================] - 7s 112us/sample - loss: 0.0457 - accuracy: 0.9860 - val_loss: 0.0502 - val_accuracy: 0.9844\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(x_train, y_train, epochs=3, batch_size=128, validation_data=(x_test, y_test), verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great! it is relatively easy to create a network that converges in MNIST with very high test accuracy. The reader should note that we named all the layers as it will make it easier to automatically convert the network by name." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n", + "_________________________________________________________________\n", + "conv2d_1 (Conv2D) (None, 26, 26, 18) 180 \n", + "_________________________________________________________________\n", + "act_1 (Activation) (None, 26, 26, 18) 0 \n", + "_________________________________________________________________\n", + "conv2d_2 (Conv2D) (None, 24, 24, 32) 5216 \n", + "_________________________________________________________________\n", + "act_2 (Activation) (None, 24, 24, 32) 0 \n", + "_________________________________________________________________\n", + "flatten (Flatten) (None, 18432) 0 \n", + "_________________________________________________________________\n", + "dense (Dense) (None, 10) 184330 \n", + "_________________________________________________________________\n", + "softmax (Activation) (None, 10) 0 \n", + "=================================================================\n", + "Total params: 189,726\n", + "Trainable params: 189,726\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The corresponding quantized network is presented below." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from qkeras import *\n", + "\n", + "def CreateQModel(shape, nb_classes):\n", + " x = x_in = Input(shape)\n", + " x = QConv2D(18, (3, 3),\n", + " kernel_quantizer=\"stochastic_ternary\", \n", + " bias_quantizer=\"quantized_po2(4)\",\n", + " name=\"conv2d_1\")(x)\n", + " x = QActivation(\"quantized_relu(2)\", name=\"act_1\")(x)\n", + " x = QConv2D(32, (3, 3), \n", + " kernel_quantizer=\"stochastic_ternary\", \n", + " bias_quantizer=\"quantized_po2(4)\",\n", + " name=\"conv2d_2\")(x)\n", + " x = QActivation(\"quantized_relu(2)\", name=\"act_2\")(x)\n", + " x = Flatten(name=\"flatten\")(x)\n", + " x = QDense(nb_classes,\n", + " kernel_quantizer=\"quantized_bits(3,0,1)\",\n", + " bias_quantizer=\"quantized_bits(3)\",\n", + " name=\"dense\")(x)\n", + " x = Activation(\"softmax\", name=\"softmax\")(x)\n", + " \n", + " model = Model(inputs=x_in, outputs=x)\n", + " \n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "in converted code:\n\n build/bdist.linux-x86_64/egg/qkeras/qlayers.py:1150 call *\n outputs = tf.keras.backend.conv2d(\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/keras/backend.py:4889 conv2d\n data_format=tf_data_format)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:899 convolution\n name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:1010 convolution_internal\n name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/gen_nn_ops.py:969 conv2d\n data_format=data_format, dilations=dilations, name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/framework/op_def_library.py:477 _apply_op_helper\n repr(values), type(values).__name__, err))\n\n TypeError: Expected float32 passed to parameter 'filter' of op 'Conv2D', got of type 'stochastic_ternary' instead. Error: Expected float32, got of type 'stochastic_ternary' instead.\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mqmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCreateQModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mCreateQModel\u001b[0;34m(shape, nb_classes)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mkernel_quantizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"stochastic_ternary\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mbias_quantizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"quantized_po2(4)\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m name=\"conv2d_1\")(x)\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mQActivation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"quantized_relu(2)\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"act_1\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m x = QConv2D(32, (3, 3), \n", + "\u001b[0;32m/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/base_layer.pyc\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 771\u001b[0m not base_layer_utils.is_in_eager_or_tf_function()):\n\u001b[1;32m 772\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mauto_control_deps\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAutomaticControlDependencies\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0macd\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 773\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcast_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 774\u001b[0m \u001b[0;31m# Wrap Tensors in `outputs` in `tf.identity` to avoid\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 775\u001b[0m \u001b[0;31m# circular dependencies.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/autograph/impl/api.pyc\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint:disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'ag_error_metadata'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 237\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mag_error_metadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 238\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: in converted code:\n\n build/bdist.linux-x86_64/egg/qkeras/qlayers.py:1150 call *\n outputs = tf.keras.backend.conv2d(\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/keras/backend.py:4889 conv2d\n data_format=tf_data_format)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:899 convolution\n name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:1010 convolution_internal\n name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/gen_nn_ops.py:969 conv2d\n data_format=data_format, dilations=dilations, name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/framework/op_def_library.py:477 _apply_op_helper\n repr(values), type(values).__name__, err))\n\n TypeError: Expected float32 passed to parameter 'filter' of op 'Conv2D', got of type 'stochastic_ternary' instead. Error: Expected float32, got of type 'stochastic_ternary' instead.\n" + ] + } + ], + "source": [ + "qmodel = CreateQModel(x_train.shape[1:], y_train.shape[-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'qmodel' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAdam\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m qmodel.compile(\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"categorical_crossentropy\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.0005\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'qmodel' is not defined" + ] + } + ], + "source": [ + "from tensorflow.keras.optimizers import Adam\n", + "\n", + "qmodel.compile(\n", + " loss=\"categorical_crossentropy\",\n", + " optimizer=Adam(0.0005),\n", + " metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/10\n", + "60000/60000 [==============================] - 52s 869us/sample - loss: 0.5034 - accuracy: 0.8428 - val_loss: 0.2422 - val_accuracy: 0.9276\n", + "Epoch 2/10\n", + "60000/60000 [==============================] - 49s 813us/sample - loss: 0.2080 - accuracy: 0.9371 - val_loss: 0.1981 - val_accuracy: 0.9415\n", + "Epoch 3/10\n", + "60000/60000 [==============================] - 50s 832us/sample - loss: 0.1703 - accuracy: 0.9503 - val_loss: 0.1454 - val_accuracy: 0.9571\n", + "Epoch 4/10\n", + "60000/60000 [==============================] - 49s 813us/sample - loss: 0.1448 - accuracy: 0.9573 - val_loss: 0.1296 - val_accuracy: 0.9616\n", + "Epoch 5/10\n", + "60000/60000 [==============================] - 48s 806us/sample - loss: 0.1225 - accuracy: 0.9639 - val_loss: 0.1267 - val_accuracy: 0.9636\n", + "Epoch 6/10\n", + "60000/60000 [==============================] - 49s 818us/sample - loss: 0.1074 - accuracy: 0.9686 - val_loss: 0.1092 - val_accuracy: 0.9662\n", + "Epoch 7/10\n", + "60000/60000 [==============================] - 48s 801us/sample - loss: 0.1020 - accuracy: 0.9698 - val_loss: 0.1041 - val_accuracy: 0.9718\n", + "Epoch 8/10\n", + "60000/60000 [==============================] - 50s 825us/sample - loss: 0.0979 - accuracy: 0.9712 - val_loss: 0.1175 - val_accuracy: 0.9662\n", + "Epoch 9/10\n", + "60000/60000 [==============================] - 49s 823us/sample - loss: 0.0910 - accuracy: 0.9726 - val_loss: 0.1015 - val_accuracy: 0.9711\n", + "Epoch 10/10\n", + "60000/60000 [==============================] - 50s 836us/sample - loss: 0.0845 - accuracy: 0.9755 - val_loss: 0.1044 - val_accuracy: 0.9710\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qmodel.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test), verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should note that we had to lower the learning rate and train the network for longer time. On the other hand, the network should not involve in any multiplications in the convolution layers, and very small multipliers in the dense layers.\n", + "\n", + "Please note that the last `Activation` was not changed to __`QActivation`__ as during inference we usually perform the operation `argmax` on the result instead of `softmax`.\n", + "\n", + "It seems it is a lot of code to write besides the main network, but in fact, this additional code is only specifying the sizes of the weights and the sizes of the outputs in the case of the activations. Right now, we do not have a way to extract this information from the network structure or problem we are trying to solve, and if we quantize too much a layer, we may end up not been able to recover from that later on." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting a Model Automatically\n", + "\n", + "In addition to the drop-in replacement of Keras functions, we have written the following function to assist anyone who wants to quantize a network.\n", + "\n", + "__`model_quantize(model, quantizer_config, activation_bits, custom_objects=None, transfer_weights=False)`__\n", + "\n", + "This function converts an non-quantized model (such as the one from `model` in the previous example) into a quantized version, by applying a configuration specified by the dictionary `quantizer_config`, and `activation_bits` specified for unamed activation functions, with this parameter probably being removed in future versions.\n", + "\n", + "The parameter `custom_objects` specifies object dictionary unknown to Keras, required when you copy a model with lambda layers, or customized layer functions, for example, and if `transfer_weights` is `True`, the returned model will have as initial weights the weights from the original model, instead of using random initial weights.\n", + "\n", + "The dictionary specified in `quantizer_config` can be indexed by a layer name or layer class name. In the example below, conv2d_1 corresponds to the first convolutional layer of the example, while QConv2D corresponds to the default behavior of two dimensional convolutional layers. The reader should note that right now we recommend using __`QActivation`__ with a dictionary to avoid the conversion of activations such as `softmax` and `linear`. In addition, although we could use `activation` field in the layers, we do not recommend that. \n", + "\n", + "`{\n", + " \"conv2d_1\": {\n", + " \"kernel_quantizer\": \"stochastic_ternary\",\n", + " \"bias_quantizer\": \"quantized_po2(4)\"\n", + " },\n", + " \"QConv2D\": {\n", + " \"kernel_quantizer\": \"stochastic_ternary\",\n", + " \"bias_quantizer\": \"quantized_po2(4)\"\n", + " },\n", + " \"QDense\": {\n", + " \"kernel_quantizer\": \"quantized_bits(3,0,1)\",\n", + " \"bias_quantizer\": \"quantized_bits(3)\"\n", + " },\n", + " \"act_1\": \"quantized_relu(2)\",\n", + " \"QActivation\": { \"relu\": \"quantized_relu(2)\" }\n", + "}`\n", + "\n", + "In the following example, we will quantize the model using a different strategy.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " \"conv2d_1\": {\n", + " \"kernel_quantizer\": \"stochastic_binary\",\n", + " \"bias_quantizer\": \"quantized_po2(4)\"\n", + " },\n", + " \"QConv2D\": {\n", + " \"kernel_quantizer\": \"stochastic_ternary\",\n", + " \"bias_quantizer\": \"quantized_po2(4)\"\n", + " },\n", + " \"QDense\": {\n", + " \"kernel_quantizer\": \"quantized_bits(4,0,1)\",\n", + " \"bias_quantizer\": \"quantized_bits(4)\"\n", + " },\n", + " \"QActivation\": { \"relu\": \"binary\" },\n", + " \"act_2\": \"quantized_relu(3)\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "conv2d_1 kernel: stochastic_binary(alpha=auto_po2) bias: quantized_po2(4)\n", + "act_1 quantizer: binary()\n", + "conv2d_2 kernel: stochastic_ternary(alpha=auto_po2,threshold=0.33) bias: quantized_po2(4)\n", + "act_2 quantizer: quantized_relu(3,0)\n", + "dense kernel: quantized_bits(4,0,1,alpha=auto_po2,use_stochastic_rounding=1) bias: quantized_bits(4,0,1,alpha=auto_po2,use_stochastic_rounding=1)\n", + "\n", + "Model: \"model_1\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_2 (InputLayer) [(None, 28, 28, 1)] 0 \n", + "_________________________________________________________________\n", + "conv2d_1 (QConv2D) (None, 26, 26, 18) 180 \n", + "_________________________________________________________________\n", + "act_1 (QActivation) (None, 26, 26, 18) 0 \n", + "_________________________________________________________________\n", + "conv2d_2 (QConv2D) (None, 24, 24, 32) 5216 \n", + "_________________________________________________________________\n", + "act_2 (QActivation) (None, 24, 24, 32) 0 \n", + "_________________________________________________________________\n", + "flatten (Flatten) (None, 18432) 0 \n", + "_________________________________________________________________\n", + "dense (QDense) (None, 10) 184330 \n", + "_________________________________________________________________\n", + "softmax (Activation) (None, 10) 0 \n", + "=================================================================\n", + "Total params: 189,726\n", + "Trainable params: 189,726\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "from qkeras.utils import model_quantize\n", + "\n", + "qmodel = model_quantize(model, config, 4, transfer_weights=True)\n", + "\n", + "for layer in qmodel.layers:\n", + " if hasattr(layer, \"kernel_quantizer\"):\n", + " print(layer.name, \"kernel:\", str(layer.kernel_quantizer_internal), \"bias:\", str(layer.bias_quantizer_internal))\n", + " elif hasattr(layer, \"quantizer\"):\n", + " print(layer.name, \"quantizer:\", str(layer.quantizer))\n", + "\n", + "print()\n", + "qmodel.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "qmodel.compile(\n", + " loss=\"categorical_crossentropy\",\n", + " optimizer=Adam(0.001),\n", + " metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/10\n", + "60000/60000 [==============================] - 43s 722us/sample - loss: 0.0796 - accuracy: 0.9757 - val_loss: 0.0968 - val_accuracy: 0.9725\n", + "Epoch 2/10\n", + "60000/60000 [==============================] - 44s 726us/sample - loss: 0.0717 - accuracy: 0.9776 - val_loss: 0.0853 - val_accuracy: 0.9747\n", + "Epoch 3/10\n", + "60000/60000 [==============================] - 45s 748us/sample - loss: 0.0690 - accuracy: 0.9789 - val_loss: 0.0878 - val_accuracy: 0.9740\n", + "Epoch 4/10\n", + "60000/60000 [==============================] - 44s 733us/sample - loss: 0.0636 - accuracy: 0.9800 - val_loss: 0.0816 - val_accuracy: 0.9753\n", + "Epoch 5/10\n", + "60000/60000 [==============================] - 45s 758us/sample - loss: 0.0572 - accuracy: 0.9822 - val_loss: 0.0861 - val_accuracy: 0.9753\n", + "Epoch 6/10\n", + "60000/60000 [==============================] - 44s 734us/sample - loss: 0.0565 - accuracy: 0.9819 - val_loss: 0.0819 - val_accuracy: 0.9763\n", + "Epoch 7/10\n", + "60000/60000 [==============================] - 46s 765us/sample - loss: 0.0489 - accuracy: 0.9842 - val_loss: 0.0859 - val_accuracy: 0.9758\n", + "Epoch 8/10\n", + "60000/60000 [==============================] - 47s 783us/sample - loss: 0.0485 - accuracy: 0.9845 - val_loss: 0.0889 - val_accuracy: 0.9771\n", + "Epoch 9/10\n", + "60000/60000 [==============================] - 44s 737us/sample - loss: 0.0477 - accuracy: 0.9850 - val_loss: 0.0729 - val_accuracy: 0.9791\n", + "Epoch 10/10\n", + "60000/60000 [==============================] - 45s 742us/sample - loss: 0.0484 - accuracy: 0.9843 - val_loss: 0.0796 - val_accuracy: 0.9780\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qmodel.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test), verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "in addition to __`model_quantize`__, __QKeras__ offers the additional utility functions.\n", + "\n", + "__`BinaryToThermometer(x, classes, value_range, with_residue=False, merge_with_channels, use_two_hot_encoding=False)`__\n", + "\n", + "This function converts a dense binary encoding of inputs to one-hot (with scales).\n", + "\n", + "Given input matrix `x` with values (for example) 0, 1, 2, 3, 4, 5, 6, 7, create a number of classes as follows:\n", + "\n", + "If classes=2, value_range=8, with_residue=0, a true one-hot representation is created, and the remaining bits are truncated, using one bit representation.\n", + "\n", + "`\n", + "0 - [1,0] 1 - [1,0] 2 - [1,0] 3 - [1,0]\n", + "4 - [0,1] 5 - [0,1] 6 - [0,1] 7 - [0,1]\n", + "`\n", + "\n", + "If classes=2, value_range=8, with_residue=1, the residue is added to the one-hot class, and the class will use 2 bits (for the remainder) + 1 bit (for the one hot)\n", + "\n", + "`\n", + "0 - [1,0] 1 - [1.25,0] 2 - [1.5,0] 3 - [1.75,0]\n", + "4 - [0,1] 5 - [0,1.25] 6 - [0,1.5] 7 - [0,1.75]\n", + "`\n", + "\n", + "The arguments of this functions are as follows:\n", + "\n", + "`\n", + "x: the input vector we want to convert. typically its dimension will be\n", + " (B,H,W,C) for an image, or (B,T,C) or (B,C) for for a 1D signal, where\n", + " B=batch, H=height, W=width, C=channels or features, T=time for time\n", + " series.\n", + "classes: the number of classes to (or log2(classes) bits) to use of the\n", + " values.\n", + "value_range: max(x) - min(x) over all possible x values (e.g. for 8 bits,\n", + " we would use 256 here).\n", + "with_residue: if true, we split the value range into two sets and add\n", + " the decimal fraction of the set to the one-hot representation for partial\n", + " thermometer representation.\n", + "merge_with_channels: if True, we will not create a separate dimension\n", + " for the resulting matrix, but we will merge this dimension with\n", + " the last dimension.\n", + "use_two_hot_encoding: if true, we will distribute the weight between\n", + " the current value and the next one to make sure the numbers will always\n", + " be < 1.\n", + "`\n", + "\n", + "__`model_save_quantized_weights(model, filename)`__\n", + "\n", + "This function saves the quantized weights in the model or writes the quantized weights in the file `filename` for production, as the weights during training are maintained non-quantized because of training. Typically, you should call this function before productizing the final model. The saved model is compatible with Keras for inference, so for power-of-2 quantization, we will not return `(sign, round(log2(weights)))`, but rather `(-1)**sign*2**(round(log2(weights)))`. We also return a dictionary containing the name of the layer and the quantized weights, and for power-of-2 quantizations, we will return `sign` and `round(log2(weights))` so that other tools can properly process that.\n", + "\n", + "__`load_qmodel(filepath, custom_objects=None, compile=True)`__\n", + "\n", + "Load quantized model from Keras's model.save() h5 file, where filepath is the path to the filename, custom_objects is an optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization, and compile instructs __QKeras__ to compile the model after reading it. If an optimizer was found as part of the saved model, the model is already compiled. Otherwise, the model is uncompiled and a warning will be displayed. When compile is set to `False`, the compilation is omitted without any warning.\n", + "\n", + "__`print_model_sparsity(model)`__\n", + "\n", + "Prints sparsity for the pruned layers in the model.\n", + "\n", + "__`quantized_model_debug(model, X_test, plot=False)`__\n", + "\n", + "Debugs and plots model weights and activations. It is usually useful to print weights, biases and activations for inputs and outputs when debugging a model. model contains the mixed quantized/unquantized layers for a model. We only print/plot activations and weights/biases for quantized models with the exception of Activation. X_test is the set of inputs we will use to compute activations, and we recommend that the user uses a subsample from the entire set he/she wants to debug. if plot is True, we also plot weights and activations (inputs/outputs) for each layer.\n", + "\n", + "__`extract_model_operations(model)`__\n", + "\n", + "As each operation depends on the quantization method for the weights/bias and on the quantization of the inputs, we estimate which operations are required for each layer of the quantized model. For example, inputs of a __`QDense`__ layer are quantized using __`quantized_relu_po2`__ and weights are quantized using __`quantized_bits`__, the matrix multiplication can be implemented as a barrel shifter + accumulator without multiplication operations. Right now, we return for each layer one of the following operations: `mult`, `barrel`, `mux`, `adder`, `xor`, and the sizes of the operator.\n", + "\n", + "We are currently refactoring this function and it may be substantially changed in the future.\n", + "\n", + "__`print_qstats(model)`__\n", + "\n", + "Prints statistics of number of operations per operation type and layer so that user can see how big the model is. This function utilizes __`extract_model_operations`__.\n", + "\n", + "An example of the output is presented below.\n", + "\n", + "`Number of operations in model:\n", + " conv2d_0_m : 25088 (smult_4_8)\n", + " conv2d_1_m : 663552 (smult_4_4)\n", + " conv2d_2_m : 147456 (smult_4_4)\n", + " dense : 5760 (smult_4_4)\n", + "\n", + "Number of operation types in model:\n", + " smult_4_4 : 816768\n", + " smult_4_8 : 25088`\n", + "\n", + "In this example, smult_4_4 stands for 4x4 bit signed multiplication and smult_4_8 stands for 8x4 signed multiplication.\n", + "\n", + "We are currently refactoring this function and it may be substantially changed in the future.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the quantized network `qmodel`, let's print the statistics of the model and weights." + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Number of operations in model:\n", + " conv2d_1 : 109512 (smux_1_8)\n", + " conv2d_2 : 2985984 (smux_2_1)\n", + " dense : 184320 (smult_4_3)\n", + "\n", + "Number of operation types in model:\n", + " smult_4_3 : 184320\n", + " smux_1_8 : 109512\n", + " smux_2_1 : 2985984\n", + "\n", + "Weight profiling:\n", + " conv2d_1_weights : 162 (1-bit unit)\n", + " conv2d_1_bias : 18 (4-bit unit)\n", + " conv2d_2_weights : 5184 (2-bit unit)\n", + " conv2d_2_bias : 32 (4-bit unit)\n", + " dense_weights : 184320 (4-bit unit)\n", + " dense_bias : 10 (4-bit unit)\n" + ] + } + ], + "source": [ + "print_qstats(qmodel)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input -0.5451 0.9960\n", + "conv2d_1 -4.6218 4.0295 ( -1.0000 1.0000) ( -0.5000 0.5000) a( 0.125000 0.500000)\n", + "act_1 -1.0000 1.0000\n", + "conv2d_2 -21.2500 14.2500 ( -1.0000 1.0000) ( -0.2500 -0.1250) a( 0.125000 0.250000)\n", + "act_2 0.0000 0.8750\n", + "dense -52.1094 39.4062 ( -0.5000 0.3750) ( -0.1250 0.1250) a( 1.000000 1.000000)\n", + "softmax 0.0000 1.0000\n" + ] + } + ], + "source": [ + "from qkeras.utils import quantized_model_debug\n", + "\n", + "quantized_model_debug(qmodel, x_test, plot=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Where the values in `conv2d_1 -4.6218 4.0295 ( -1.0000 1.0000) ( -0.5000 0.5000) a( 0.125000 0.500000)` corresponde to min and max values of the output of the convolution layer, weight ranges (min and max), bias (min and max) and alpha (min and max)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebook/images/figure1.png b/notebook/images/figure1.png new file mode 100644 index 00000000..8ec06971 Binary files /dev/null and b/notebook/images/figure1.png differ diff --git a/notebook/images/figure2.png b/notebook/images/figure2.png new file mode 100644 index 00000000..465e75d9 Binary files /dev/null and b/notebook/images/figure2.png differ diff --git a/qkeras/__init__.py b/qkeras/__init__.py index c7e36239..bd0a86c5 100644 --- a/qkeras/__init__.py +++ b/qkeras/__init__.py @@ -27,4 +27,4 @@ from .qpooling import * # pylint: disable=wildcard-import from .safe_eval import * # pylint: disable=wildcard-import -__version__ = "0.6.0" +__version__ = "0.7.0" diff --git a/qkeras/qconvolutional.py b/qkeras/qconvolutional.py index 337b8d79..8e4c9f40 100644 --- a/qkeras/qconvolutional.py +++ b/qkeras/qconvolutional.py @@ -13,27 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import warnings import tensorflow as tf - -from tensorflow.keras import activations from tensorflow.keras import constraints from tensorflow.keras import initializers from tensorflow.keras import regularizers -from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Activation from tensorflow.keras.layers import Conv1D from tensorflow.keras.layers import Conv2D from tensorflow.keras.layers import DepthwiseConv2D from tensorflow.keras.layers import Dropout from tensorflow.keras.layers import InputSpec -from tensorflow.keras.layers import Layer -from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer - -from .qlayers import Clip +from .qlayers import get_auto_range_constraint_initializer from .qlayers import QActivation -from .quantizers import get_quantizer from .quantizers import get_quantized_initializer +from .quantizers import get_quantizer + +from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer class QConv1D(Conv1D, PrunableLayer): @@ -71,33 +70,45 @@ def __init__(self, bias_constraint=None, kernel_quantizer=None, bias_quantizer=None, - kernel_range=1.0, - bias_range=1.0, + kernel_range=None, + bias_range=None, **kwargs): - self.kernel_quantizer = kernel_quantizer - self.bias_quantizer = bias_quantizer + if kernel_range is not None: + warnings.warn("kernel_range is deprecated in QConv1D layer.") + + if bias_range is not None: + warnings.warn("bias_range is deprecated in QConv1D layer.") + self.kernel_range = kernel_range self.bias_range = bias_range + self.kernel_quantizer = kernel_quantizer + self.bias_quantizer = bias_quantizer + self.kernel_quantizer_internal = get_quantizer(self.kernel_quantizer) self.bias_quantizer_internal = get_quantizer(self.bias_quantizer) + # optimize parameter set to "auto" scaling mode if possible + if hasattr(self.kernel_quantizer_internal, "_set_trainable_parameter"): + self.kernel_quantizer_internal._set_trainable_parameter() + self.quantizers = [ self.kernel_quantizer_internal, self.bias_quantizer_internal ] - if kernel_quantizer: - if kernel_constraint: - kernel_constraint = constraints.get(kernel_constraint) - kernel_constraint = Clip(-kernel_range, kernel_range, kernel_constraint, - kernel_quantizer) + kernel_constraint, kernel_initializer = ( + get_auto_range_constraint_initializer(self.kernel_quantizer_internal, + kernel_constraint, + kernel_initializer)) - if bias_quantizer: - if bias_constraint: - bias_constraint = constraints.get(bias_constraint) - bias_constraint = Clip(-bias_range, bias_range, bias_constraint, - bias_quantizer) + if use_bias: + bias_constraint, bias_initializer = ( + get_auto_range_constraint_initializer(self.bias_quantizer_internal, + bias_constraint, + bias_initializer)) + if activation is not None: + activation = get_quantizer(activation) super(QConv1D, self).__init__( filters=filters, @@ -145,8 +156,10 @@ def call(self, inputs): def get_config(self): config = { - "kernel_quantizer": constraints.serialize(self.kernel_quantizer_internal), - "bias_quantizer": constraints.serialize(self.bias_quantizer_internal), + "kernel_quantizer": + constraints.serialize(self.kernel_quantizer_internal), + "bias_quantizer": + constraints.serialize(self.bias_quantizer_internal), "kernel_range": self.kernel_range, "bias_range": self.bias_range } @@ -194,37 +207,48 @@ def __init__(self, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, - kernel_range=1.0, - bias_range=1.0, + kernel_range=None, + bias_range=None, kernel_quantizer=None, bias_quantizer=None, **kwargs): - self.kernel_quantizer = kernel_quantizer - self.bias_quantizer = bias_quantizer + if kernel_range is not None: + warnings.warn("kernel_range is deprecated in QConv2D layer.") + + if bias_range is not None: + warnings.warn("bias_range is deprecated in QConv2D layer.") + self.kernel_range = kernel_range self.bias_range = bias_range - kernel_initializer = get_quantized_initializer(kernel_initializer, kernel_range) + self.kernel_quantizer = kernel_quantizer + self.bias_quantizer = bias_quantizer self.kernel_quantizer_internal = get_quantizer(self.kernel_quantizer) self.bias_quantizer_internal = get_quantizer(self.bias_quantizer) + # optimize parameter set to "auto" scaling mode if possible + if hasattr(self.kernel_quantizer_internal, "_set_trainable_parameter"): + self.kernel_quantizer_internal._set_trainable_parameter() + self.quantizers = [ self.kernel_quantizer_internal, self.bias_quantizer_internal ] - if kernel_quantizer: - if kernel_constraint: - kernel_constraint = constraints.get(kernel_constraint) - kernel_constraint = Clip(-kernel_range, kernel_range, kernel_constraint, - kernel_quantizer) + kernel_constraint, kernel_initializer = ( + get_auto_range_constraint_initializer(self.kernel_quantizer_internal, + kernel_constraint, + kernel_initializer)) - if bias_quantizer: - if bias_constraint: - bias_constraint = constraints.get(bias_constraint) - bias_constraint = Clip(-bias_range, bias_range, bias_constraint, - bias_quantizer) + if use_bias: + bias_constraint, bias_initializer = ( + get_auto_range_constraint_initializer(self.bias_quantizer_internal, + bias_constraint, + bias_initializer)) + + if activation is not None: + activation = get_quantizer(activation) super(QConv2D, self).__init__( filters=filters, @@ -277,10 +301,8 @@ def get_config(self): constraints.serialize(self.kernel_quantizer_internal), "bias_quantizer": constraints.serialize(self.bias_quantizer_internal), - "kernel_range": - self.kernel_range, - "bias_range": - self.bias_range + "kernel_range": self.kernel_range, + "bias_range": self.bias_range } base_config = super(QConv2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -328,21 +350,45 @@ def __init__(self, dilation_rate=(1, 1), depthwise_quantizer=None, bias_quantizer=None, - depthwise_range=1.0, - bias_range=1.0, + depthwise_range=None, + bias_range=None, **kwargs): - if depthwise_quantizer: - if depthwise_constraint: - depthwise_constraint = constraints.get(depthwise_constraint) - depthwise_constraint = Clip(-depthwise_range, depthwise_range, - depthwise_constraint, depthwise_quantizer) + if depthwise_range is not None: + warnings.warn("depthwise_range is deprecated in QDepthwiseConv2D layer.") - if bias_quantizer: - if bias_constraint: - bias_constraint = constraints.get(bias_constraint) - bias_constraint = Clip(-bias_range, bias_range, bias_constraint, - bias_quantizer) + if bias_range is not None: + warnings.warn("bias_range is deprecated in QDepthwiseConv2D layer.") + + self.depthwise_range = depthwise_range + self.bias_range = bias_range + + self.depthwise_quantizer = depthwise_quantizer + self.bias_quantizer = bias_quantizer + + self.depthwise_quantizer_internal = get_quantizer(self.depthwise_quantizer) + self.bias_quantizer_internal = get_quantizer(self.bias_quantizer) + + # optimize parameter set to "auto" scaling mode if possible + if hasattr(self.depthwise_quantizer_internal, "_set_trainable_parameter"): + self.depthwise_quantizer_internal._set_trainable_parameter() + + self.quantizers = [ + self.depthwise_quantizer_internal, self.bias_quantizer_internal + ] + + depthwise_constraint, depthwise_initializer = ( + get_auto_range_constraint_initializer(self.depthwise_quantizer_internal, + depthwise_constraint, + depthwise_initializer)) + + if use_bias: + bias_constraint, bias_initializer = ( + get_auto_range_constraint_initializer(self.bias_quantizer_internal, + bias_constraint, + bias_initializer)) + if activation is not None: + activation = get_quantizer(activation) super(QDepthwiseConv2D, self).__init__( kernel_size=kernel_size, @@ -361,20 +407,6 @@ def __init__(self, bias_constraint=bias_constraint, dilation_rate=dilation_rate, **kwargs) - self.bias_constraint = bias_constraint - - self.depthwise_quantizer = depthwise_quantizer - self.bias_quantizer = bias_quantizer - - self.depthwise_range = depthwise_range - self.bias_range = bias_range - - self.depthwise_quantizer_internal = get_quantizer(self.depthwise_quantizer) - self.bias_quantizer_internal = get_quantizer(self.bias_quantizer) - - self.quantizers = [ - self.depthwise_quantizer_internal, self.bias_quantizer_internal - ] def build(self, input_shape): if len(input_shape) < 4: @@ -468,34 +500,35 @@ def get_prunable_weights(self): return [] -def QSeparableConv2D(filters, # pylint: disable=invalid-name - kernel_size, - strides=(1, 1), - padding="VALID", - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer="he_normal", - pointwise_initializer="he_normal", - bias_initializer="zeros", - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - depthwise_quantizer=None, - pointwise_quantizer=None, - bias_quantizer=None, - depthwise_activation=None, - depthwise_range=1.0, - pointwise_range=1.0, - bias_range=1.0, - depthwise_dropout_rate=0.0, - pw_first=False, - name=""): +def QSeparableConv2D( + filters, # pylint: disable=invalid-name + kernel_size, + strides=(1, 1), + padding="VALID", + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="he_normal", + pointwise_initializer="he_normal", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + depthwise_quantizer=None, + pointwise_quantizer=None, + bias_quantizer=None, + depthwise_activation=None, + depthwise_range=None, + pointwise_range=None, + bias_range=None, + depthwise_dropout_rate=0.0, + pw_first=False, + name=""): """Adds a quantized separableconv2d.""" # we use here a modified version that appeared in mobilenet that adds @@ -523,7 +556,7 @@ def QSeparableConv2D(filters, # pylint: disable=invalid-name # SeparableConv2D. # - def _call(inputs): + def _call(inputs): # pylint: disable=invalid-name """Internally builds qseparableconv2d.""" x = inputs diff --git a/qkeras/qlayers.py b/qkeras/qlayers.py index 59fa709d..f08bf7eb 100644 --- a/qkeras/qlayers.py +++ b/qkeras/qlayers.py @@ -31,13 +31,12 @@ # https://ieeexplore.ieee.org/abstract/document/6986082 # https://ieeexplore.ieee.org/iel4/78/5934/00229903.pdf # - from __future__ import absolute_import from __future__ import division from __future__ import print_function - +import warnings +import six import tensorflow.compat.v2 as tf - from tensorflow.keras import activations from tensorflow.keras import constraints from tensorflow.keras import initializers @@ -45,14 +44,34 @@ from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Layer +from .quantizers import get_quantized_initializer +from .quantizers import get_quantizer from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer -import numpy as np -import six +def get_auto_range_constraint_initializer(quantizer, constraint, initializer): + """Get value range automatically for quantizer. -from .quantizers import get_quantized_initializer -from .quantizers import get_quantizer + Arguments: + quantizer: A quantizer class in quantizers.py. + constraint: A tf.keras constraint. + initializer: A tf.keras initializer. + + Returns: + a tuple (constraint, initializer), where + constraint is clipped by Clip class in this file, based on the + value range of quantizer. + initializer is initializer contraint by value range of quantizer. + """ + if quantizer is not None: + max_value = quantizer.max() if hasattr(quantizer, "max") else 1.0 + min_value = quantizer.min() if hasattr(quantizer, "min") else -1.0 + if constraint: + constraint = constraints.get(constraint) + constraint = Clip(min_value, max_value, constraint, quantizer) + initializer = get_quantized_initializer(initializer, + max(abs(min_value), abs(max_value))) + return constraint, initializer # @@ -60,7 +79,6 @@ # we may be replacing their instantiation by QActivation in the future. # - class QActivation(Layer, PrunableLayer): """Implements quantized activation layers.""" @@ -107,8 +125,6 @@ def get_prunable_weights(self): # 1. quantization approximation is symmetric (b = 0). # 2. max(x) and min(x) are 1 and -1 respectively. # - - class Clip(Constraint): """Clips weight constraint.""" @@ -145,7 +161,6 @@ def get_config(self): """Returns configuration of constraint class.""" return {"min_value": self.min_value, "max_value": self.max_value} - # # Definition of Quantized NN classes. These classes were copied # from the equivalent layers in Keras, and we modified to apply quantization. @@ -156,9 +171,9 @@ def get_config(self): class QDense(Dense, PrunableLayer): """Implements a quantized Dense layer.""" - # most of these parameters follow the implementation of Dense in - # Keras, # with the exception of kernel_range, bias_range, - # kernel_quantizer and bias_quantizer, and kernel_initializer. + # Most of these parameters follow the implementation of Dense in + # Keras, with the exception of kernel_range, bias_range, + # kernel_quantizer, bias_quantizer, and kernel_initializer. # # kernel_quantizer: quantizer function/class for kernel # bias_quantizer: quantizer function/class for bias @@ -169,7 +184,6 @@ class QDense(Dense, PrunableLayer): # # we refer the reader to the documentation of Dense in Keras for the # other parameters. - # def __init__(self, units, @@ -184,26 +198,18 @@ def __init__(self, bias_constraint=None, kernel_quantizer=None, bias_quantizer=None, - kernel_range=1.0, - bias_range=1.0, + kernel_range=None, + bias_range=None, **kwargs): - self.kernel_range = kernel_range - self.bias_range = bias_range + if kernel_range is not None: + warnings.warn("kernel_range is deprecated in QDense layer.") - kernel_initializer = get_quantized_initializer(kernel_initializer, - kernel_range) - if kernel_quantizer: - if kernel_constraint: - kernel_constraint = constraints.get(kernel_constraint) - kernel_constraint = Clip(-kernel_range, kernel_range, kernel_constraint, - kernel_quantizer) + if bias_range is not None: + warnings.warn("bias_range is deprecated in QDense layer.") - if bias_quantizer: - if bias_constraint: - bias_constraint = constraints.get(bias_constraint) - bias_constraint = Clip(-bias_range, bias_range, bias_constraint, - bias_quantizer) + self.kernel_range = kernel_range + self.bias_range = bias_range self.kernel_quantizer = kernel_quantizer self.bias_quantizer = bias_quantizer @@ -211,10 +217,27 @@ def __init__(self, self.kernel_quantizer_internal = get_quantizer(self.kernel_quantizer) self.bias_quantizer_internal = get_quantizer(self.bias_quantizer) + # optimize parameter set to "auto" scaling mode if possible + if hasattr(self.kernel_quantizer_internal, "_set_trainable_parameter"): + self.kernel_quantizer_internal._set_trainable_parameter() + self.quantizers = [ self.kernel_quantizer_internal, self.bias_quantizer_internal ] + kernel_constraint, kernel_initializer = ( + get_auto_range_constraint_initializer(self.kernel_quantizer_internal, + kernel_constraint, + kernel_initializer)) + + if use_bias: + bias_constraint, bias_initializer = ( + get_auto_range_constraint_initializer(self.bias_quantizer_internal, + bias_constraint, + bias_initializer)) + if activation is not None: + activation = get_quantizer(activation) + super(QDense, self).__init__( units=units, activation=activation, @@ -254,12 +277,9 @@ def compute_output_shape(self, input_shape): def get_config(self): config = { - "units": - self.units, - "activation": - activations.serialize(self.activation), - "use_bias": - self.use_bias, + "units": self.units, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, "kernel_quantizer": constraints.serialize(self.kernel_quantizer_internal), "bias_quantizer": @@ -278,10 +298,8 @@ def get_config(self): constraints.serialize(self.kernel_constraint), "bias_constraint": constraints.serialize(self.bias_constraint), - "kernel_range": - self.kernel_range, - "bias_range": - self.bias_range + "kernel_range": self.kernel_range, + "bias_range": self.bias_range } base_config = super(QDense, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/qkeras/qnormalization.py b/qkeras/qnormalization.py index 245edf2f..a40b2477 100644 --- a/qkeras/qnormalization.py +++ b/qkeras/qnormalization.py @@ -15,10 +15,12 @@ # # ============================================================================== """Definition of normalization quantization package.""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np +import six +import warnings import tensorflow.compat.v2 as tf @@ -31,14 +33,13 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer - -import numpy as np -import six - from .qlayers import Clip +from .qlayers import get_auto_range_constraint_initializer from .qlayers import get_quantizer +from .quantizers import quantized_relu_po2 +from .quantizers import quantized_po2 from .safe_eval import safe_eval +from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer class QBatchNormalization(BatchNormalization, PrunableLayer): @@ -68,25 +69,52 @@ def __init__( gamma_quantizer=None, mean_quantizer=None, variance_quantizer=None, + gamma_constraint=None, + beta_constraint=None, # use quantized_po2 and enforce quadratic approximation # to get an even exponent for sqrt beta_range=None, gamma_range=None, **kwargs): + if gamma_range is not None: + warnings.warn('gamma_range is deprecated in QBatchNormalization layer.') + + if beta_range is not None: + warnings.warn('beta_range is deprecated in QBatchNormalization layer.') + + self.gamma_range = gamma_range + self.beta_range = beta_range + self.activation = activation + + # We know the optimal settings for gamma and variance for now, so if the + # user has not specified them. + # If user really did not want quantizers, the user would have used + # BatchNormalization instead. + + if gamma_quantizer is None: + gamma_quantizer = quantized_relu_po2(6, 2048) + if variance_quantizer is None: + variance_quantizer = quantized_relu_po2( + 6, quadratic_approximation=True) + if beta_quantizer is None: + beta_quantizer = quantized_po2(5) + self.beta_quantizer = beta_quantizer self.gamma_quantizer = gamma_quantizer self.mean_quantizer = mean_quantizer self.variance_quantizer = variance_quantizer - self.activation = activation - self.beta_range = beta_range - self.gamma_range = gamma_range self.beta_quantizer_internal = get_quantizer(self.beta_quantizer) self.gamma_quantizer_internal = get_quantizer(self.gamma_quantizer) self.mean_quantizer_internal = get_quantizer(self.mean_quantizer) self.variance_quantizer_internal = get_quantizer(self.variance_quantizer) + if hasattr(self.gamma_quantizer_internal, '_set_trainable_parameter'): + self.gamma_quantizer_internal._set_trainable_parameter() + if hasattr(self.variance_quantizer_internal, '_set_trainable_parameter'): + self.variance_quantizer_internal._set_trainable_parameter() + self.quantizers = [ self.gamma_quantizer_internal, self.beta_quantizer_internal, @@ -94,36 +122,40 @@ def __init__( self.variance_quantizer_internal ] - if center and beta_quantizer and beta_range: - beta_constraint = Clip(-beta_range, beta_range) - else: - beta_constraint = None - kwargs.pop('beta_constraint', None) - - if scale and gamma_quantizer and gamma_range: - gamma_constraint = Clip(-gamma_range, gamma_range) - else: - gamma_constraint = None - kwargs.pop('gamma_constraint', None) + if scale and self.gamma_quantizer: + gamma_constraint, gamma_initializer = ( + get_auto_range_constraint_initializer( + self.gamma_quantizer_internal, + gamma_constraint, + gamma_initializer) + ) + + if center and self.beta_quantizer: + beta_constraint, beta_initializer = ( + get_auto_range_constraint_initializer( + self.beta_quantizer_internal, + beta_constraint, + beta_initializer) + ) if kwargs.get('fused', None): - warning.warn('batch normalization fused is disabled ' - 'in qkeras qnormalization.py.') + warnings.warn('batch normalization fused is disabled ' + 'in qkeras qnormalization.py.') del kwargs['fused'] if kwargs.get('renorm', None): - warning.warn('batch normalization renorm is disabled ' - 'in qkeras qnormalization.py.') + warnings.warn('batch normalization renorm is disabled ' + 'in qkeras qnormalization.py.') del kwargs['renorm'] if kwargs.get('virtual_batch_size', None): - warning.warn('batch normalization virtual_batch_size is disabled ' - 'in qkeras qnormalization.py.') + warnings.warn('batch normalization virtual_batch_size is disabled ' + 'in qkeras qnormalization.py.') del kwargs['virtual_batch_size'] if kwargs.get('adjustment', None): - warning.warn('batch normalization adjustment is disabled ' - 'in qkeras qnormalization.py.') + warnings.warn('batch normalization adjustment is disabled ' + 'in qkeras qnormalization.py.') del kwargs['adjustment'] super(QBatchNormalization, self).__init__( diff --git a/qkeras/qpooling.py b/qkeras/qpooling.py index 53fb75f0..b73e6333 100644 --- a/qkeras/qpooling.py +++ b/qkeras/qpooling.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function from tensorflow.keras.layers import AveragePooling2D - from .qlayers import QActivation @@ -42,3 +44,4 @@ def _call(x): return x return _call + diff --git a/qkeras/quantizers.py b/qkeras/quantizers.py index 55fee1d0..f3e8ec6d 100644 --- a/qkeras/quantizers.py +++ b/qkeras/quantizers.py @@ -13,21 +13,81 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -import tensorflow.compat.v2 as tf +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import warnings +import logging import numpy as np import six -import warnings - +import tensorflow.compat.v2 as tf from tensorflow.keras import initializers +import tensorflow.keras.backend as K from tensorflow.keras.utils import deserialize_keras_object - from .safe_eval import safe_eval # # Library of auxiliary functions # + +def get_weight_scale(quantizer, x): + """Gets the scales of weights for (stochastic_)binary and ternary quantizers. + + Arguments + quantizer: A binary or teneray quantizer class. + x: A weight tensor. + + Returns: + Weight scale per channel for binary and ternary + quantizers with auto or auto_po2 alpha/threshold. + """ + if hasattr(quantizer, "scale") and quantizer.scale is not None: + return K.eval(quantizer.scale) + return 1.0 + + +def _get_scale(alpha, x, q): + """Gets scaling factor for scaling the tensor per channel. + + Arguments: + alpha: A float or string. When it is string, it should be either "auto" or + "auto_po2", and + scale = sum(x * q, axis=all but last) / sum(q * q, axis=all but last) + x: A tensor object. Its elements are in float. + q: A tensor object. Its elements are in quantized format of x. + + Returns: + A scaling factor tensor or scala for scaling tensor per channel. + """ + + if isinstance(alpha, six.string_types) and "auto" in alpha: + assert alpha in ["auto", "auto_po2"] + x_shape = x.shape.as_list() + len_axis = len(x_shape) + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + qx = K.mean(tf.math.multiply(x, q), axis=axis, keepdims=True) + qq = K.mean(tf.math.multiply(q, q), axis=axis, keepdims=True) + else: + qx = K.mean(x * q, axis=0, keepdims=True) + qq = K.mean(q * q, axis=0, keepdims=True) + scale = qx / (qq + K.epsilon()) + if alpha == "auto_po2": + scale = K.pow(2.0, + tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0))) + elif alpha is None: + scale = 1.0 + elif isinstance(alpha, np.ndarray): + scale = alpha + else: + scale = float(alpha) + return scale + + def smooth_sigmoid(x): """Implements a linear approximation of a sigmoid function.""" @@ -69,7 +129,7 @@ def set_internal_sigmoid(mode): elif mode == "smooth": _sigmoid = smooth_sigmoid elif mode == "real": - _sigmoid = tf.sigmoid + _sigmoid = tf.keras.backend.sigmoid def binary_tanh(x): @@ -87,15 +147,15 @@ def smooth_tanh(x): return 2.0 * smooth_sigmoid(x) - 1.0 -def stochastic_round(x): +def stochastic_round(x, precision=0.5): """Performs stochastic rounding to the first decimal point.""" - s = tf.sign(x) - s += (1.0 - tf.abs(s)) * (2.0 * tf.round(tf.random.uniform(tf.shape(x))) - - 1.0) - t = tf.floor(x) - (s - 1.0) / 2.0 - p = tf.abs(x - t) - f = s * (tf.sign(p - tf.random.uniform(tf.shape(p))) + 1.0) / 2.0 - return t + f + scale = 1.0 / precision + scale_x = x * scale + fraction = scale_x - tf.floor(scale_x) + + result = tf.where(fraction < tf.random.uniform(tf.shape(x)), + tf.math.floor(scale_x), tf.math.ceil(scale_x)) + return result / scale def stochastic_round_po2(x): @@ -105,8 +165,8 @@ def stochastic_round_po2(x): y = tf.abs(x) eps = tf.keras.backend.epsilon() log2 = tf.keras.backend.log(2.0) + x_log2 = tf.round(tf.keras.backend.log(y + eps) / log2) - sign = tf.sign(x) po2 = tf.cast(pow(2.0, tf.cast(x_log2, dtype="float32")), dtype="float32") left_val = tf.where(po2 > y, x_log2 - 1, x_log2) right_val = tf.where(po2 > y, x_log2, x_log2 + 1) @@ -117,10 +177,18 @@ def stochastic_round_po2(x): # use y as a threshold to keep the probabliy [2**left_val, y, 2**right_val] # so that the mean value of the sample should be y x_po2 = tf.where(y < val, left_val, right_val) + """ + x_log2 = stochastic_round(tf.keras.backend.log(y + eps) / log2) + sign = tf.sign(x) + po2 = ( + tf.sign(x) * + tf.cast(pow(2.0, tf.cast(x_log2, dtype="float32")), dtype="float32") + ) + """ return x_po2 -def _round_through(x, use_stochastic_rounding=False): +def _round_through(x, use_stochastic_rounding=False, precision=0.5): """Rounds x but using straight through estimator. We use the trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182). @@ -139,12 +207,14 @@ def _round_through(x, use_stochastic_rounding=False): Arguments: x: tensor to perform round operation with straight through gradient. use_stochastic_rounding: if true, we perform stochastic rounding. + precision: by default we will use 0.5 as precision, but that can overriden + by the user. Returns: Rounded tensor. """ if use_stochastic_rounding: - return x + tf.stop_gradient(-x + stochastic_round(x)) + return x + tf.stop_gradient(-x + stochastic_round(x, precision)) else: return x + tf.stop_gradient(-x + tf.round(x)) @@ -166,7 +236,6 @@ def _ceil_through(x): return x + tf.stop_gradient(-x + tf.ceil(x)) - # # Activation functions for quantized networks. # @@ -181,8 +250,8 @@ class quantized_bits(object): # pylint: disable=invalid-name In general, we want to use a quantization function like: - a = (pow(2,bits)-1 - 0) / (max(x) - min(x)) - b = - min(x) * a + a = (pow(2,bits) - 1 - 0) / (max(x) - min(x)) + b = -min(x) * a in the equation: @@ -211,6 +280,8 @@ class quantized_bits(object): # pylint: disable=invalid-name integer: number of bits to the left of the decimal point. symmetric: if true, we will have the same number of values for positive and negative numbers. + alpha: a tensor or None, the scaling factor per channel. + If None, the scaling factor is 1 for all channels. keep_negative: if true, we do not clip negative numbers. use_stochastic_rounding: if true, we perform stochastic rounding. @@ -219,33 +290,123 @@ class quantized_bits(object): # pylint: disable=invalid-name """ def __init__(self, bits=8, integer=0, symmetric=0, keep_negative=1, - use_stochastic_rounding=False): + alpha=None, use_stochastic_rounding=False): self.bits = bits self.integer = integer self.symmetric = symmetric self.keep_negative = (keep_negative > 0) + self.alpha = alpha self.use_stochastic_rounding = use_stochastic_rounding + # "auto*" |-> symmetric + if isinstance(self.alpha, six.string_types): + self.symmetric = True + self.scale = None + + def __str__(self): + flags = [str(self.bits), str(self.integer), str(int(self.symmetric))] + if not self.keep_negative: + flags.append("keep_negative=" + str(int(self.keep_negative))) + if self.alpha: + flags.append("alpha=" + str(self.alpha)) + if self.use_stochastic_rounding: + flags.append("use_stochastic_rounding=" + + str(int(self.use_stochastic_rounding))) + return "quantized_bits(" + ",".join(flags) + ")" def __call__(self, x): """Computes fixedpoint quantization of x.""" - + # quantized_bits with "1" bit becomes a binary implementation. unsigned_bits = self.bits - self.keep_negative + m = pow(2, unsigned_bits) + m_i = pow(2, self.integer) - # quantized_bits with "1" bit becomes a binary implementation. + if self.alpha is None: + scale = 1.0 + elif isinstance(self.alpha, six.string_types): + # We only deal with the symmetric case right now. + assert self.symmetric + len_axis = len(x.shape) + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + else: + axis = [0] + + # we will use this implementation for the scale for QKeras 0.7 + levels = 2**self.bits - 1 + scale = (K.max(x, axis=axis, keepdims=True) - + K.min(x, axis=axis, keepdims=True)) / levels + if "po2" in self.alpha: + scale = K.pow(2.0, + tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0))) + for _ in range(5): + v = tf.floor(tf.abs(x) / scale + 0.5) + mask = v < (levels - 1) / 2 + z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * (levels - 1) / 2) + scale = _get_scale("auto_po2", x, z) + + # z is an integer number, so we must make the scale * m and z / m + scale = scale * m + + # we will not use "z" right now because of stochastic_rounding + # this is still under test. + + # if "new" in self.alpha: + # z = z / m + # self.scale = scale + # return x + tf.stop_gradient(-x + scale * z) + x = x / scale + else: + scale = self.alpha + # quantized_bits with "1" bit becomes a binary implementation. if unsigned_bits > 0: - m = pow(2, unsigned_bits) - m_i = pow(2, self.integer) p = x * m / m_i xq = m_i * tf.keras.backend.clip( - _round_through(p, self.use_stochastic_rounding), + _round_through(p, self.use_stochastic_rounding, precision=1.0), self.keep_negative * (-m + self.symmetric), m - 1) / m else: xq = tf.sign(x) xq += (1.0 - tf.abs(xq)) if not self.keep_negative: xq = (xq + 1.0) / 2.0 - return x + tf.stop_gradient(-x + xq) + + self.scale = scale + return x + tf.stop_gradient(-x + scale * xq) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.symmetric = True + self.use_stochastic_rounding = True + + def max(self): + """Get maximum value that quantized_bits class can represent.""" + unsigned_bits = self.bits - self.keep_negative + + if unsigned_bits > 0: + return ((1.0 - np.power(2.0, -unsigned_bits)) * + np.power(2.0, self.integer)) + else: + return 1.0 + + def min(self): + """Get minimum value that quantized_bits class can represent.""" + if not self.keep_negative: + return 0.0 + + unsigned_bits = self.bits - self.keep_negative + + if unsigned_bits > 0: + if self.symmetric: + return -( + (1.0 - np.power(2.0, -unsigned_bits)) * np.power(2.0, self.integer)) + else: + return -np.power(2.0, self.integer) + else: + return -1.0 @classmethod def from_config(cls, config): @@ -253,16 +414,12 @@ def from_config(cls, config): def get_config(self): config = { - "bits": - self.bits, - "integer": - self.integer, - "symmetric": - self.symmetric, - "keep_negative": - self.keep_negative, - "use_stochastic_rounding": - self.use_stochastic_rounding + "bits": self.bits, + "integer": self.integer, + "symmetric": self.symmetric, + "alpha": self.alpha, + "keep_negative": self.keep_negative, + "use_stochastic_rounding": self.use_stochastic_rounding } return config @@ -286,21 +443,88 @@ class bernoulli(object): # pylint: disable=invalid-name Remember that E[dL/dy] = E[dL/dx] once we add stochastic sampling. Attributes: - alpha: allows one to specify multiplicative factor for number generation. + alpha: allows one to specify multiplicative factor for number generation + of "auto" or "auto_po2". + temperature: amplifier factor for sigmoid function, making stochastic + less stochastic as it moves away from 0. + use_real_sigmoid: use real sigmoid for probability. Returns: Computation of round with stochastic sampling with straight through gradient. """ - def __init__(self, alpha=1.0): + def __init__(self, alpha=None, temperature=6.0, use_real_sigmoid=True): self.alpha = alpha + self.bits = 1 + self.temperature = temperature + self.use_real_sigmoid = use_real_sigmoid + self.default_alpha = 1.0 + self.scale = None + + def __str__(self): + flags = [] + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.temperature != 6.0: + flags.append("temperature=" + str(self.temperature)) + if not self.use_real_sigmoid: + flags.append("use_real_sigmoid=" + str(int(self.use_real_sigmoid))) + return "bernoulli(" + ",".join(flags) + ")" def __call__(self, x): - p = _sigmoid(x / self.alpha) - k_sign = tf.sign(p - tf.random.uniform(tf.shape(p))) - k_sign += (1.0 - tf.abs(k_sign)) - return x + tf.stop_gradient(-x + self.alpha * (k_sign + 1.0) / 2.0) + if isinstance(self.alpha, six.string_types): + assert self.alpha in ["auto", "auto_po2"] + + if isinstance(self.alpha, six.string_types): + len_axis = len(x.shape) + + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + else: + axis = [0] + + std = K.std(x, axis=axis, keepdims=True) + K.epsilon() + else: + std = 1.0 + + if self.use_real_sigmoid: + p = tf.keras.backend.sigmoid(self.temperature * x / std) + else: + p = _sigmoid(self.temperature * x/std) + r = tf.random.uniform(tf.shape(x)) + q = tf.sign(p - r) + q += (1.0 - tf.abs(q)) + q = (q + 1.0) / 2.0 + + q_non_stochastic = tf.sign(x) + q_non_stochastic += (1.0 - tf.abs(q_non_stochastic)) + q_non_stochastic = (q_non_stochastic + 1.0) / 2.0 + + # if we use non stochastic binary to compute alpha, + # this function seems to behave better + scale = _get_scale(self.alpha, x, q_non_stochastic) + self.scale = scale + return x + tf.stop_gradient(-x + scale * q) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + + def max(self): + """Get the maximum value bernoulli class can represent.""" + if self.alpha is None or (isinstance(self.alpha, six.string_types) and + "auto" in self.alpha): + return 1.0 + else: + return self.alpha + + def min(self): + """Get the minimum value bernoulli class can represent.""" + return 0.0 @classmethod def from_config(cls, config): @@ -320,54 +544,114 @@ class stochastic_ternary(object): # pylint: disable=invalid-name Attributes: x: tensor to perform sign opertion with stochastic sampling. bits: number of bits to perform quantization. - alpha: ternary is -alpha or +alpha.` + alpha: ternary is -alpha or +alpha, or "auto" or "auto_po2". threshold: (1-threshold) specifies the spread of the +1 and -1 values. + temperature: amplifier factor for sigmoid function, making stochastic + less stochastic as it moves away from 0. + use_real_sigmoid: use real sigmoid for probability. + number_of_unrolls: number of times we iterate between scale and threshold. Returns: Computation of sign with stochastic sampling with straight through gradient. """ - def __init__(self, alpha=1.0, threshold=0.25): + def __init__(self, alpha=None, threshold=None, temperature=8.0, + use_real_sigmoid=True, number_of_unrolls=5): self.bits = 2 self.alpha = alpha self.threshold = threshold assert threshold != 1.0 + self.default_alpha = 1.0 + self.default_threshold = 0.33 + self.temperature = temperature + self.use_real_sigmoid = use_real_sigmoid + self.number_of_unrolls = number_of_unrolls + self.scale = None + + def __str__(self): + flags = [] + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.threshold is not None: + flags.append("threshold=" + str(self.threshold)) + if self.temperature != 8.0: + flags.append("temperature=" + str(self.temperature)) + if not self.use_real_sigmoid: + flags.append("use_real_sigmoid=0") + if self.number_of_unrolls != 5: + flags.append("number_of_unrolls=" + str(self.number_of_unrolls)) + return "stochastic_ternary(" + ",".join(flags) + ")" def __call__(self, x): - # we right now use the following distributions for fm1, f0, fp1 - # - # fm1 = ((1-T)-p)/(1-T) for p <= (1-T) - # f0 = 2*p/clip(0.5+T,0.5,1.0) for p <= 0.5 - # 2*(1-p)/clip(0.5+T,0.5,1.0) for p > 0.5 - # fp1 = (p-T)/(1-T) for p >= T - # - # threshold (1-T) determines the spread of -1 and +1 - # for T < 0.5 we need to fix the distribution of f0 - # to make it bigger when compared to the other - # distributions. - - p = _sigmoid(x / self.alpha) # pylint: disable=invalid-name - - T = self.threshold # pylint: disable=invalid-name - - ones = tf.ones_like(p) - zeros = tf.zeros_like(p) + # right now we only accept stochastic_ternary in parameters + + assert isinstance(self.alpha, six.string_types) + assert self.alpha in ["auto", "auto_po2"] + if self.alpha is None: + scale = self.default_alpha + elif isinstance(self.alpha, six.string_types): + scale = 1.0 + else: + scale = float(self.alpha) - T0 = np.clip(0.5 + T, 0.5, 1.0) # pylint: disable=invalid-name + len_axis = len(x.shape) + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis-1) + else: + axis = range(1, len_axis) + else: + axis = [0] - fm1 = tf.where(p <= (1 - T), ((1 - T) - p) / (1 - T), zeros) - f0 = tf.where(p <= 0.5, 2 * p, 2 * (1 - p)) / T0 - fp1 = tf.where(p <= T, zeros, (p - T) / (1 - T)) + x_mean = K.mean(x, axis=axis, keepdims=True) + x_std = K.std(x, axis=axis, keepdims=True) - f_all = fm1 + f0 + fp1 + m = K.max(tf.abs(x), axis=axis, keepdims=True) + scale = 2.*m/3. + for _ in range(self.number_of_unrolls): + T = scale / 2.0 + q_ns = K.cast(tf.abs(x) >= T, K.floatx()) * K.sign(x) + scale = _get_scale(self.alpha, x, q_ns) - c_fm1 = fm1 / f_all - c_f0 = (fm1 + f0) / f_all + x_norm = (x - x_mean) / x_std + T = scale / (2.0 * x_std) - r = tf.random.uniform(tf.shape(p)) + if self.use_real_sigmoid: + p0 = tf.keras.backend.sigmoid(self.temperature * (x_norm - T)) + p1 = tf.keras.backend.sigmoid(self.temperature * (x_norm + T)) + else: + p0 = _sigmoid(self.temperature * (x_norm - T)) + p1 = _sigmoid(self.temperature * (x_norm + T)) + r0 = tf.random.uniform(tf.shape(p0)) + r1 = tf.random.uniform(tf.shape(p1)) + q0 = tf.sign(p0 - r0) + q0 += (1.0 - tf.abs(q0)) + q1 = tf.sign(p1 - r1) + q1 += (1.0 - tf.abs(q1)) + + q = (q0 + q1) / 2.0 + self.scale = scale + return x + tf.stop_gradient(-x + scale * q) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + if self.threshold is None: + self.threshold = self.default_threshold + + def max(self): + """Get the maximum value that stochastic_ternary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return 1.0 + else: + return self.alpha - return x + tf.stop_gradient(-x + self.alpha * tf.where( - r <= c_fm1, -1 * ones, tf.where(r <= c_f0, zeros, ones))) + def min(self): + """Get the minimum value that stochastic_ternary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return -1.0 + else: + return -self.alpha @classmethod def from_config(cls, config): @@ -375,10 +659,11 @@ def from_config(cls, config): def get_config(self): config = { - "alpha": - self.alpha, - "threshold": - self.threshold, + "alpha": self.alpha, + "threshold": self.threshold, + "temperature": self.temperature, + "use_real_sigmoid": self.use_real_sigmoid, + "number_of_unrolls": self.number_of_unrolls } return config @@ -386,30 +671,123 @@ def get_config(self): class ternary(object): # pylint: disable=invalid-name """Computes an activation function returning -alpha, 0 or +alpha. + Right now we assume two type of behavior. For parameters, we should + have alpha, threshold and stochastic rounding on. For activations, + alpha and threshold should be floating point numbers, and stochastic + rounding should be off. + Attributes: x: tensor to perform sign opertion with stochastic sampling. bits: number of bits to perform quantization. - alpha: ternary is -alpha or +alpha. Threshold is also scaled by alpha. - threshold: threshold to apply "dropout" or dead band (0 value). + alpha: ternary is -alpha or +alpha. Alpha can be "auto" or "auto_po2". + threshold: threshold to apply "dropout" or dead band (0 value). If "auto" + is specified, we will compute it per output layer. use_stochastic_rounding: if true, we perform stochastic rounding. Returns: Computation of sign within the threshold. """ - def __init__(self, alpha=1.0, threshold=0.33, use_stochastic_rounding=False): + def __init__(self, alpha=None, threshold=None, use_stochastic_rounding=False, + number_of_unrolls=5): self.alpha = alpha self.bits = 2 self.threshold = threshold self.use_stochastic_rounding = use_stochastic_rounding + self.default_alpha = 1.0 + self.default_threshold = 0.33 + self.number_of_unrolls = number_of_unrolls + self.scale = None + + def __str__(self): + flags = [] + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.threshold is not None: + flags.append("threshold=" + str(self.threshold)) + if self.use_stochastic_rounding: + flags.append( + "use_stochastic_rounding=" + str(int(self.use_stochastic_rounding))) + if self.number_of_unrolls != 5: + flags.append( + "number_of_unrolls=" + str(int(self.number_of_unrolls))) + return "ternary(" + ",".join(flags) + ")" def __call__(self, x): - if self.use_stochastic_rounding: - x = _round_through( - x, use_stochastic_rounding=self.use_stochastic_rounding) - return x + tf.stop_gradient( - -x + self.alpha * tf.where(tf.abs(x) < self.threshold, - tf.zeros_like(x), tf.sign(x))) + if isinstance(self.alpha, six.string_types): + # parameters + assert self.alpha in ["auto", "auto_po2"] + assert self.threshold is None + else: + # activations + assert not self.use_stochastic_rounding + assert not isinstance(self.threshold, six.string_types) + + if self.alpha is None or isinstance(self.alpha, six.string_types): + scale = 1.0 + elif isinstance(self.alpha, np.ndarray): + scale = self.alpha + else: + scale = float(self.alpha) + + # This is an approximiation from https://arxiv.org/abs/1605.04711 + # We consider channels_last only for now. + if isinstance(self.alpha, six.string_types): + # It is for parameters + # first, compute which asix corresponds to the channels. + # TODO(hzhuang): support channels_first + len_axis = len(x.shape.as_list()) + if len_axis == 1: + axis = None + elif K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + + # This approximation is exact if x ~ U[-m, m]. For x ~ N(0, m) + # we need to iterate a few times before we can coverge + m = K.max(tf.abs(x), axis=axis, keepdims=True) + scale = 2 * m / 3.0 + x_orig = x + for _ in range(self.number_of_unrolls): + thres = scale / 2.0 + if self.use_stochastic_rounding: + # once we scale the number precision == 0.33 works + # well for Uniform and Normal distribution of input + x = scale * _round_through( + x_orig / scale, + use_stochastic_rounding=self.use_stochastic_rounding, + precision=0.33) + q = K.cast(tf.abs(x) >= thres, K.floatx()) * tf.sign(x) + scale = _get_scale(self.alpha, x, q) + else: + if self.threshold is None: + thres = self.default_threshold + else: + thres = self.threshold + q = K.cast(tf.abs(x) >= thres, K.floatx()) * tf.sign(x) + + self.scale = scale + return x + tf.stop_gradient(-x + scale * q) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.use_stochastic_rounding = True + + def max(self): + """Get the maximum value that ternary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return 1.0 + else: + return self.alpha + + def min(self): + """Get the minimum value that ternary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return -1.0 + else: + return -self.alpha @classmethod def from_config(cls, config): @@ -417,12 +795,10 @@ def from_config(cls, config): def get_config(self): config = { - "alpha": - self.alpha, - "threshold": - self.threshold, - "use_stochastic_rounding": - self.use_stochastic_rounding + "alpha": self.alpha, + "threshold": self.threshold, + "use_stochastic_rounding": self.use_stochastic_rounding, + "number_of_unrolls": self.number_of_unrolls } return config @@ -435,32 +811,93 @@ class stochastic_binary(object): # pylint: disable=invalid-name Attributes: x: tensor to perform sign opertion with stochastic sampling. - alpha: binary is -alpha or +alpha.` + alpha: binary is -alpha or +alpha, or "auto" or "auto_po2". bits: number of bits to perform quantization. + temperature: amplifier factor for sigmoid function, making stochastic + behavior less stochastic as it moves away from 0. + use_real_sigmoid: use real sigmoid from tensorflow for probablity. Returns: Computation of sign with stochastic sampling with straight through gradient. """ - def __init__(self, alpha=1.0): + def __init__(self, alpha=None, temperature=6.0, use_real_sigmoid=True): self.alpha = alpha self.bits = 1 + self.temperature = temperature + self.use_real_sigmoid = use_real_sigmoid + self.default_alpha = 1.0 + self.scale = None + + def __str__(self): + flags = [] + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.temperature != 6.0: + flags.append("temperature=" + str(self.temperature)) + if not self.use_real_sigmoid: + flags.append("use_real_sigmoid=" + str(int(self.use_real_sigmoid))) + return "stochastic_binary(" + ",".join(flags) + ")" def __call__(self, x): - assert self.alpha != 0 - p = _sigmoid(x / self.alpha) - k_sign = tf.sign(p - tf.random.uniform(tf.shape(x))) - # we should not need this, but if tf.sign is not safe if input is - # exactly 0.0 - k_sign += (1.0 - tf.abs(k_sign)) - return x + tf.stop_gradient(-x + self.alpha * k_sign) + if isinstance(self.alpha, six.string_types): + assert self.alpha in ["auto", "auto_po2"] + if isinstance(self.alpha, six.string_types): + len_axis = len(x.shape) + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis-1) + else: + axis = range(1, len_axis) + else: + axis = [0] + std = K.std(x, axis=axis, keepdims=True) + K.epsilon() + else: + std = 1.0 + + if self.use_real_sigmoid: + p = tf.keras.backend.sigmoid(self.temperature * x / std) + else: + p = _sigmoid(self.temperature * x / std) + + r = tf.random.uniform(tf.shape(x)) + q = tf.sign(p - r) + q += (1.0 - tf.abs(q)) + q_non_stochastic = tf.sign(x) + q_non_stochastic += (1.0 - tf.abs(q_non_stochastic)) + scale = _get_scale(self.alpha, x, q_non_stochastic) + self.scale = scale + return x + tf.stop_gradient(-x + scale * q) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.use_stochastic_rounding = True + + def max(self): + """Get the maximum value that stochastic_binary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return 1.0 + else: + return self.alpha + + def min(self): + """Get the minimum value that stochastic_binary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return -1.0 + else: + return -self.alpha @classmethod def from_config(cls, config): return cls(**config) def get_config(self): - config = {"alpha": self.alpha} + config = { + "alpha": self.alpha, + "temperature": self.temperature, + "use_real_sigmoid": self.use_real_sigmoid, + } return config @@ -476,34 +913,100 @@ class binary(object): # pylint: disable=invalid-name x: tensor to perform sign_through. bits: number of bits to perform quantization. use_01: if True, return {0,1} instead of {-1,+1}. - alpha: binary is -alpha or +alpha. + alpha: binary is -alpha or +alpha, or "auto", "auto_po2" to compute + automatically. use_stochastic_rounding: if true, we perform stochastic rounding. Returns: Computation of sign operation with straight through gradient. """ - def __init__(self, use_01=False, alpha=1.0, use_stochastic_rounding=False): + def __init__(self, use_01=False, alpha=None, use_stochastic_rounding=False): self.use_01 = use_01 self.bits = 1 self.alpha = alpha self.use_stochastic_rounding = use_stochastic_rounding + self.default_alpha = 1.0 + self.scale = None + + def __str__(self): + flags = [] + if self.use_01: + flags.append("use_01=" + str(int(self.use_01))) + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.use_stochastic_rounding: + flags.append( + "use_stochastic_rounding=" + str(self.use_stochastic_rounding)) + return "binary(" + ",".join(flags) + ")" def __call__(self, x): - assert self.alpha != 0 + if isinstance(self.alpha, six.string_types): + assert self.alpha in ["auto", "auto_po2"] + if self.alpha is None: + scale = self.default_alpha + elif isinstance(self.alpha, six.string_types): + scale = 1.0 + elif isinstance(self.alpha, np.ndarray): + scale = self.alpha + else: + scale = float(self.alpha) + if self.use_stochastic_rounding: - x = self.alpha * _round_through( - x / self.alpha, use_stochastic_rounding=self.use_stochastic_rounding) + len_axis = len(x.shape.as_list()) + if len_axis == 1: + axis = None + elif K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + + # if stochastic_round is through, we need to scale + # number so that the precision is small enough. + # This is especially important if range of x is very + # small, which occurs during initialization of weights. + m = K.max(tf.abs(x), axis=axis, keepdims=True) + m = tf.where(m > 1.0, tf.ones_like(m), m) + f = 2 * m + + x = f * _round_through( + x / f, use_stochastic_rounding=self.use_stochastic_rounding, + precision=0.125 + ) k_sign = tf.sign(x) if self.use_stochastic_rounding: k_sign += (1.0 - tf.abs(k_sign)) * ( 2.0 * tf.round(tf.random.uniform(tf.shape(x))) - 1.0) - else: - k_sign += (1.0 - tf.abs(k_sign)) + # if something still remains, just make it positive for now. + k_sign += (1.0 - tf.abs(k_sign)) if self.use_01: k_sign = (k_sign + 1.0) / 2.0 - return x + tf.stop_gradient(-x + self.alpha * k_sign) + + scale = _get_scale(self.alpha, x, k_sign) + self.scale = scale + return x + tf.stop_gradient(-x + scale * k_sign) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.use_stochastic_rounding = True + + def max(self): + """Get maximum value that binary class can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return 1.0 + else: + return self.alpha + + def min(self): + """Get minimum value that binary class can respresent.""" + if self.use_01: + return 0.0 + elif self.alpha is None or isinstance(self.alpha, six.string_types): + return -1.0 + else: + return -self.alpha @classmethod def from_config(cls, config): @@ -511,12 +1014,9 @@ def from_config(cls, config): def get_config(self): config = { - "use_01": - self.use_01, - "alpha": - self.alpha, - "use_stochastic_rounding": - self.use_stochastic_rounding + "use_01": self.use_01, + "alpha": self.alpha, + "use_stochastic_rounding": self.use_stochastic_rounding } return config @@ -557,6 +1057,14 @@ def __init__(self, bits=8, integer=0, use_sigmoid=0, self.use_sigmoid = use_sigmoid self.use_stochastic_rounding = use_stochastic_rounding + def __str__(self): + flags = [str(self.bits), str(self.integer)] + if self.use_sigmoid or self.use_stochastic_rounding: + flags.append(str(int(self.use_sigmoid))) + if self.use_stochastic_rounding: + flags.append(str(int(self.use_stochastic_rounding))) + return "quantized_relu(" + ",".join(flags) + ")" + def __call__(self, x): m = pow(2, self.bits) m_i = pow(2, self.integer) @@ -573,20 +1081,33 @@ def __call__(self, x): 0.0, 1.0 - 1.0 / m) return xq + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_relu can represent.""" + unsigned_bits = self.bits + + if unsigned_bits > 0: + return ((1.0 - np.power(2.0, -unsigned_bits)) * + np.power(2.0, self.integer)) + else: + return 1.0 + + def min(self): + """Get the minimum value that quantized_relu can represent.""" + return 0.0 + @classmethod def from_config(cls, config): return cls(**config) def get_config(self): config = { - "bits": - self.bits, - "integer": - self.integer, - "use_sigmoid": - self.use_sigmoid, - "use_stochastic_rounding": - self.use_stochastic_rounding + "bits": self.bits, + "integer": self.integer, + "use_sigmoid": self.use_sigmoid, + "use_stochastic_rounding": self.use_stochastic_rounding } return config @@ -611,6 +1132,14 @@ def __init__(self, bits=8, integer=0, symmetric=0, u=255.0): self.symmetric = symmetric self.u = u + def __str__(self): + flags = [str(self.bits), str(self.integer)] + if self.symmetric or self.u != 255.0: + flags.append(str(int(self.symmetric))) + if self.u != 255.0: + flags.append(str(self.u)) + return "quantized_ulaw(" + ",".join(flags) + ")" + def __call__(self, x): non_sign_bits = self.bits - 1 m = pow(2, non_sign_bits) @@ -623,20 +1152,42 @@ def __call__(self, x): (1.0 * self.symmetric) / m, 1.0 - 1.0 / m) return xq + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_ulaw can represent.""" + unsigned_bits = self.bits - 1 + + if unsigned_bits > 0: + return ((1.0 - np.power(2.0, -unsigned_bits)) * + np.power(2.0, self.integer)) + else: + return 1.0 + + def min(self): + """Get the minimum value that quantized_ulaw can represent.""" + unsigned_bits = self.bits - 1 + + if unsigned_bits > 0: + if self.symmetric: + return -(1.0 - np.power(2.0, -unsigned_bits)) * np.power( + 2.0, self.integer) + else: + return -np.power(2.0, self.integer) + else: + return -1.0 + @classmethod def from_config(cls, config): return cls(**config) def get_config(self): config = { - "bits": - self.bits, - "integer": - self.integer, - "symmetric": - self.symmetric, - "u": - self.u + "bits": self.bits, + "integer": self.integer, + "symmetric": self.symmetric, + "u": self.u } return config @@ -666,6 +1217,14 @@ def __init__(self, bits=8, integer=0, symmetric=0, self.symmetric = symmetric self.use_stochastic_rounding = use_stochastic_rounding + def __str__(self): + flags = [str(self.bits), str(self.integer)] + if self.symmetric or self.use_stochastic_rounding: + flags.append(str(int(self.symmetric))) + if self.use_stochastic_rounding: + flags.append(str(int(self.use_stochastic_rounding))) + return "quantized_tanh(" + ",".join(flags) + ")" + def __call__(self, x): non_sign_bits = self.bits - 1 m = pow(2, non_sign_bits) @@ -677,6 +1236,33 @@ def __call__(self, x): (1.0 * self.symmetric) / m, 1.0 - 1.0 / m) return xq + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_tanh can represent.""" + unsigned_bits = self.bits - 1 + if unsigned_bits > 0: + return ((1.0 - np.power(2.0, -unsigned_bits)) * + np.power(2.0, self.integer)) + else: + return 1.0 + + def min(self): + """Get the minimum value that quantized_tanh can represent.""" + if not self.keep_negative: + return 0.0 + + unsigned_bits = self.bits - 1 + if unsigned_bits > 0: + if self.symmetric: + return -(1.0 - np.power(2.0, -unsigned_bits)) * np.power( + 2.0, self.integer) + else: + return -np.power(2.0, self.integer) + else: + return -1.0 + @classmethod def from_config(cls, config): return cls(**config) @@ -694,6 +1280,7 @@ def get_config(self): def _clip_power_of_two(x_abs, min_exp, max_exp, + max_value, quadratic_approximation=False, use_stochastic_rounding=False): """Clips a tensor using power-of-two quantizer. @@ -703,6 +1290,7 @@ def _clip_power_of_two(x_abs, x_abs: A tensor object. Its elements should be non-negative. min_exp: An integer representing the smallest exponent. max_exp: An integer representing the largest exponent. + max_value: A float or None. If it is None, we clip the value to max_value. quadratic_approximation: An boolean representing whether the quadratic approximation is applied. use_stochastic_rounding: An boolean representing whether the stochastic @@ -712,11 +1300,18 @@ def _clip_power_of_two(x_abs, A tensor object, the values are clipped by min_exp and max_exp. """ - eps = tf.keras.backend.epsilon() - log2 = np.log(2.0) # if quadratic_approximation is True, round to the exponent for sqrt(x), # so that the return value can be divided by two without remainder. - x_eps_pad = tf.where(x_abs < eps, eps, x_abs) + log2 = np.log(2.0) + + # When the elements of x_abs are small than the keras epsilon, + # we just overwrite x_abs with eps + eps = tf.keras.backend.epsilon() + x_filter = tf.where(x_abs < eps, eps, x_abs) + if max_value is not None: + # If the elements of x_filter has value larger than x_value, clip it. + x_filter = tf.where(x_filter >= max_value, + tf.ones_like(x_filter) * max_value, x_filter) def power_of_two_clip(x_abs, min_exp, max_exp, quadratic_approximation, use_stochastic_rounding): @@ -740,7 +1335,7 @@ def power_of_two_clip(x_abs, min_exp, max_exp, quadratic_approximation, x_clipped = tf.where( x_abs < eps, tf.ones_like(x_abs) * min_exp, - power_of_two_clip(x_eps_pad, min_exp, max_exp, quadratic_approximation, + power_of_two_clip(x_filter, min_exp, max_exp, quadratic_approximation, use_stochastic_rounding)) return x_clipped @@ -791,14 +1386,12 @@ def _get_min_max_exponents(non_sign_bits, need_exponent_sign_bit, """ effect_bits = non_sign_bits - need_exponent_sign_bit + min_exp = -2**(effect_bits) if quadratic_approximation: - if effect_bits % 2: - effect_bits = effect_bits - 1 - min_exp = -2**(effect_bits) max_exp = 2**(effect_bits) else: - min_exp = -2**(effect_bits) max_exp = 2**(effect_bits) - 1 + return min_exp, max_exp @@ -808,8 +1401,8 @@ class quantized_po2(object): # pylint: disable=invalid-name Attributes: bits: An integer, the bits allocated for the exponent, its sign and the sign of x. - max_value: default is None, or a non-negative value to put a constraint for - the max value. + max_value: An float or None. If None, no max_value is specified. + Otherwise, the maximum value of quantized_po2 <= max_value use_stochastic_rounding: A boolean, default is False, if True, it uses stochastic rounding and forces the mean of x to be x statstically. quadratic_approximation: A boolean, default is False if True, it forces the @@ -824,44 +1417,68 @@ def __init__(self, self.bits = bits self.max_value = max_value self.use_stochastic_rounding = use_stochastic_rounding + # if True, round to the exponent for sqrt(x), + # so that the return value can be divided by two without remainder. self.quadratic_approximation = quadratic_approximation - - def __call__(self, x): need_exponent_sign_bit = _need_exponent_sign_bit_check(self.max_value) non_sign_bits = self.bits - 1 - min_exp, max_exp = _get_min_max_exponents(non_sign_bits, - need_exponent_sign_bit, - self.quadratic_approximation) - eps = tf.keras.backend.epsilon() - if min_exp < np.log2(eps): - warnings.warn( - "QKeras: min_exp in po2 quantizer is smaller than tf.epsilon().") - if self.max_value: - max_exp = np.minimum(max_exp, np.round(np.log2(self.max_value + eps))) + self._min_exp, self._max_exp = _get_min_max_exponents( + non_sign_bits, need_exponent_sign_bit, self.quadratic_approximation) + def __str__(self): + flags = [str(self.bits)] + if self.max_value is not None or self.use_stochastic_rounding: + flags.append(str(int(self.max_value))) + if self.use_stochastic_rounding: + flags.append(str(int(self.use_stochastic_rounding))) + if self.quadratic_approximation: + flags.append( + "quadratic_approximation=" + str(int(self.quadratic_approximation))) + return "quantized_po2(" + ",".join(flags) + ")" + + def __call__(self, x): x_sign = tf.sign(x) x_sign += (1.0 - tf.abs(x_sign)) x_abs = tf.abs(x) - x_clipped = _clip_power_of_two(x_abs, min_exp, max_exp, + x_clipped = _clip_power_of_two(x_abs, self._min_exp, self._max_exp, + self.max_value, self.quadratic_approximation, self.use_stochastic_rounding) return x + tf.stop_gradient(-x + x_sign * pow(2.0, x_clipped)) + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_po2 can represent.""" + return self._max_exp + + def min(self): + """Get the minimum value that quantized_po2 can represent.""" + return self._min_exp + @classmethod def from_config(cls, config): return cls(**config) def get_config(self): - """Gets config.""" + """Gets configugration of the quantizer. + + Returns: + A dict mapping quantization configuration, including + bits: bitwidth for exponents. + max_value: the maximum value of this quantized_po2 can represent. + use_stochastic_rounding: + if True, stochastic rounding is used. + quadratic_approximation: + if True, the exponent is enforced to be even number, which is + the closest one to x. + """ config = { - "bits": - self.bits, - "max_value": - self.max_value, - "use_stochastic_rounding": - self.use_stochastic_rounding, - "quadratic_approximation": - self.quadratic_approximation + "bits": self.bits, + "max_value": self.max_value, + "use_stochastic_rounding": self.use_stochastic_rounding, + "quadratic_approximation": self.quadratic_approximation } return config @@ -890,23 +1507,40 @@ def __init__(self, # if True, round to the exponent for sqrt(x), # so that the return value can be divided by two without remainder. self.quadratic_approximation = quadratic_approximation + need_exponent_sign_bit = _need_exponent_sign_bit_check(self.max_value) + self._min_exp = -2**(self.bits - need_exponent_sign_bit) + self._max_exp = 2**(self.bits - need_exponent_sign_bit) - 1 + + def __str__(self): + flags = [str(self.bits)] + if self.max_value is not None or self.use_stochastic_rounding: + flags.append(str(int(self.max_value))) + if self.use_stochastic_rounding: + flags.append(str(int(self.use_stochastic_rounding))) + if self.quadratic_approximation: + flags.append( + "quadratic_approximation=" + str(int(self.quadratic_approximation))) + return "quantized_relu_po2(" + ",".join(flags) + ")" def __call__(self, x): - need_exponent_sign_bit = _need_exponent_sign_bit_check(self.max_value) - min_exp = -2**(self.bits - need_exponent_sign_bit) - max_exp = 2**(self.bits - need_exponent_sign_bit) - 1 - eps = tf.keras.backend.epsilon() - if min_exp < np.log2(eps): - warnings.warn("QKeras: min_exp in quantized_relu_po2 quantizer " - "is smaller than tf.epsilon().") - if self.max_value: - max_exp = np.minimum(max_exp, np.round(np.log2(self.max_value + eps))) x = tf.maximum(x, 0) - x_clipped = _clip_power_of_two(x, min_exp, max_exp, + x_clipped = _clip_power_of_two(x, self._min_exp, self._max_exp, + self.max_value, self.quadratic_approximation, self.use_stochastic_rounding) return x + tf.stop_gradient(-x + pow(2.0, x_clipped)) + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_relu_po2 can represent.""" + return self._max_exp + + def min(self): + """Get the minimum value that quantized_relu_po2 can represent.""" + return self._min_exp + @classmethod def from_config(cls, config): return cls(**config) diff --git a/qkeras/safe_eval.py b/qkeras/safe_eval.py index 9a964652..da4042cd 100644 --- a/qkeras/safe_eval.py +++ b/qkeras/safe_eval.py @@ -24,6 +24,9 @@ from pyparsing import Regex from pyparsing import Suppress +import logging +from tensorflow import keras + def Num(s): """Tries to convert string to either int or float.""" @@ -68,7 +71,7 @@ def safe_eval(eval_str, op_dict, *params, **kwparams): # pylint: disable=invali """Replaces eval by a safe eval mechanism.""" function_split = eval_str.split("(") - quantizer = op_dict[function_split[0]] + quantizer = op_dict.get(function_split[0], None) if len(function_split) == 2: args, kwargs = GetParams("(" + function_split[1]) @@ -80,6 +83,11 @@ def safe_eval(eval_str, op_dict, *params, **kwparams): # pylint: disable=invali for k in kwparams: kwargs[k] = kwparams[k] + # must be Keras activation object if None + if quantizer is None: + logging.info("keras dict %s", function_split[0]) + quantizer = keras.activations.get(function_split[0]) + if len(function_split) == 2 or args or kwargs: return quantizer(*args, **kwargs) else: diff --git a/qkeras/utils.py b/qkeras/utils.py index 3596c29d..20d7c9f6 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -13,14 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import copy import json import six +import types import tensorflow as tf import tensorflow.keras.backend as K from tensorflow.keras import initializers + +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Conv1D from tensorflow.keras.models import model_from_json from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper @@ -28,6 +35,8 @@ from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer import numpy as np +import matplotlib.pyplot as plt + from .qlayers import QActivation from .qlayers import QDense @@ -35,26 +44,25 @@ from .qconvolutional import QConv1D from .qconvolutional import QConv2D from .qconvolutional import QDepthwiseConv2D +from .qnormalization import QBatchNormalization from .qpooling import QAveragePooling2D -from .quantizers import quantized_bits -from .quantizers import bernoulli -from .quantizers import stochastic_ternary -from .quantizers import ternary -from .quantizers import stochastic_binary from .quantizers import binary +from .quantizers import bernoulli +from .quantizers import get_weight_scale +from .quantizers import quantized_bits from .quantizers import quantized_relu from .quantizers import quantized_ulaw from .quantizers import quantized_tanh from .quantizers import quantized_po2 from .quantizers import quantized_relu_po2 -from .qnormalization import QBatchNormalization - +from .quantizers import stochastic_binary +from .quantizers import stochastic_ternary +from .quantizers import ternary from .safe_eval import safe_eval # # Model utilities: before saving the weights, we want to apply the quantizers # - def model_save_quantized_weights(model, filename=None): """Quantizes model for inference and save it. @@ -144,25 +152,39 @@ def model_save_quantized_weights(model, filename=None): return saved_weights -def quantize_activation(layer_config, custom_objects, activation_bits): +def quantize_activation(layer_config, activation_bits): """Replaces activation by quantized activation functions.""" - str_act_bits = str(activation_bits) - # relu -> quantized_relu(bits) # tanh -> quantized_tanh(bits) # # more to come later - - if layer_config["activation"] == "relu": + if layer_config.get("activation", None) is None: + return + if isinstance(layer_config["activation"], six.string_types): + a_name = layer_config["activation"] + elif isinstance(layer_config["activation"], types.FunctionType): + a_name = layer_config["activation"].__name__ + else: + a_name = layer_config["activation"].__class__.__name__ + + if a_name == "linear": + return + if a_name == "relu": layer_config["activation"] = "quantized_relu(" + str_act_bits + ")" - custom_objects["quantized_relu(" + str_act_bits + ")"] = ( - quantized_relu(activation_bits)) - - elif layer_config["activation"] == "tanh": + elif a_name == "tanh": layer_config["activation"] = "quantized_tanh(" + str_act_bits + ")" - custom_objects["quantized_tanh(" + str_act_bits + ")"] = ( - quantized_tanh(activation_bits)) + + +def get_config(quantizer_config, layer, layer_class, parameter=None): + """Returns search of quantizer on quantizer_config.""" + quantizer = quantizer_config.get(layer["name"], + quantizer_config.get(layer_class, None)) + + if quantizer is not None and parameter is not None: + quantizer = quantizer.get(parameter, None) + + return quantizer def model_quantize(model, @@ -239,32 +261,9 @@ def model_quantize(model, # let's make a deep copy to make sure our objects are not shared elsewhere jm = copy.deepcopy(json.loads(model.to_json())) custom_objects = copy.deepcopy(custom_objects) - config = jm["config"] - layers = config["layers"] - custom_objects["QDense"] = QDense - custom_objects["QConv1D"] = QConv1D - custom_objects["QConv2D"] = QConv2D - custom_objects["QDepthwiseConv2D"] = QDepthwiseConv2D - custom_objects["QAveragePooling2D"] = QAveragePooling2D - custom_objects["QActivation"] = QActivation - - # just add all the objects from quantizer_config to - # custom_objects first. - - for layer_name in quantizer_config.keys(): - - if isinstance(quantizer_config[layer_name], six.string_types): - name = quantizer_config[layer_name] - custom_objects[name] = safe_eval(name, globals()) - - else: - for name in quantizer_config[layer_name].keys(): - custom_objects[quantizer_config[layer_name][name]] = ( - safe_eval(quantizer_config[layer_name][name], globals())) - for layer in layers: layer_config = layer["config"] @@ -273,92 +272,69 @@ def model_quantize(model, if layer["class_name"] == "Dense": layer["class_name"] = "QDense" - # needs to add kernel/bias quantizers - - kernel_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDense", - None))["kernel_quantizer"] - - bias_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDense", None))["bias_quantizer"] - + kernel_quantizer = get_config( + quantizer_config, layer, "QDense", "kernel_quantizer") + bias_quantizer = get_config( + quantizer_config, layer, "QDense", "bias_quantizer") layer_config["kernel_quantizer"] = kernel_quantizer layer_config["bias_quantizer"] = bias_quantizer # if activation is present, add activation here - - quantizer = quantizer_config.get(layer["name"], - quantizer_config.get( - "QDense", None)).get( - "activation_quantizer", None) + quantizer = get_config( + quantizer_config, layer, "QDense", "activation_quantizer") if quantizer: layer_config["activation"] = quantizer custom_objects[quantizer] = safe_eval(quantizer, globals()) else: - quantize_activation(layer_config, custom_objects, activation_bits) - - elif layer["class_name"] == "Conv2D": - layer["class_name"] = "QConv2D" + quantize_activation(layer_config, activation_bits) + elif layer["class_name"] in ["Conv1D", "Conv2D"]: + q_name = "Q" + layer["class_name"] + layer["class_name"] = q_name # needs to add kernel/bias quantizers + kernel_quantizer = get_config( + quantizer_config, layer, q_name, "kernel_quantizer") - kernel_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QConv2D", None)).get( - "kernel_quantizer", None) - - bias_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QConv2D", None)).get( - "bias_quantizer", None) - + bias_quantizer = get_config( + quantizer_config, layer, q_name, "bias_quantizer") layer_config["kernel_quantizer"] = kernel_quantizer layer_config["bias_quantizer"] = bias_quantizer # if activation is present, add activation here - - quantizer = quantizer_config.get(layer["name"], - quantizer_config.get( - "QConv2D", None)).get( - "activation_quantizer", None) + quantizer = get_config( + quantizer_config, layer, q_name, "activation_quantizer") if quantizer: layer_config["activation"] = quantizer - custom_objects[quantizer] = safe_eval(quantizer, globals()) else: - quantize_activation(layer_config, custom_objects, activation_bits) + quantize_activation(layer_config, activation_bits) elif layer["class_name"] == "DepthwiseConv2D": layer["class_name"] = "QDepthwiseConv2D" - # needs to add kernel/bias quantizers - - depthwise_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDepthwiseConv2D", None)).get( - "depthwise_quantizer", None) - - bias_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDepthwiseConv2D", None)).get( - "bias_quantizer", None) + depthwise_quantizer = get_config(quantizer_config, layer, + "QDepthwiseConv2D", "depthwise_quantizer") + bias_quantizer = get_config(quantizer_config, layer, + "QDepthwiseConv2D", "bias_quantizer") layer_config["depthwise_quantizer"] = depthwise_quantizer layer_config["bias_quantizer"] = bias_quantizer - # if activation is present, add activation here - - quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDepthwiseConv2D", None)).get( - "activation_quantizer", None) + quantizer = get_config(quantizer_config, layer, + "QDepthwiseConv2D", "activation_quantizer",) if quantizer: layer_config["activation"] = quantizer - custom_objects[quantizer] = safe_eval(quantizer, globals()) else: - quantize_activation(layer_config, custom_objects, activation_bits) + quantize_activation(layer_config, activation_bits) elif layer["class_name"] == "Activation": - quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QActivation", None)) + quantizer = get_config(quantizer_config, layer, "QActivation") + # this is to avoid softmax from quantizing in autoq + if quantizer is None: + continue # if quantizer exists in dictionary related to this name, # use it, otherwise, use normal transformations @@ -372,28 +348,34 @@ def model_quantize(model, quantizer = quantizer[layer_config["activation"]] if quantizer: layer_config["activation"] = quantizer - custom_objects[quantizer] = safe_eval(quantizer, globals()) else: - quantize_activation(layer_config, custom_objects, activation_bits) - - elif layer["class_name"] == "AveragePooling2D": - layer["class_name"] = "QAveragePooling2D" - - quantizer = quantizer_config.get(layer["name"], None) - - # if quantizer exists in dictionary related to this name, - # use it, otherwise, use normal transformations + quantize_activation(layer_config, activation_bits) - if quantizer: - layer_config["activation"] = quantizer - custom_objects[quantizer] = safe_eval(quantizer, globals()) - else: - quantize_activation(layer_config, custom_objects, activation_bits) + elif layer["class_name"] == "BatchNormalization": + layer["class_name"] = "QBatchNormalization" + # needs to add kernel/bias quantizers + gamma_quantizer = get_config( + quantizer_config, layer, "QBatchNormalization", + "gamma_quantizer") + beta_quantizer = get_config( + quantizer_config, layer, "QBatchNormalization", + "beta_quantizer") + mean_quantizer = get_config( + quantizer_config, layer, "QBatchNormalization", + "mean_quantizer") + variance_quantizer = get_config( + quantizer_config, layer, "QBatchNormalization", + "variance_quantizer") + + layer_config["gamma_quantizer"] = gamma_quantizer + layer_config["beta_quantizer"] = beta_quantizer + layer_config["mean_quantizer"] = mean_quantizer + layer_config["variance_quantizer"] = variance_quantizer # we need to keep a dictionary of custom objects as our quantized library # is not recognized by keras. - qmodel = model_from_json(json.dumps(jm), custom_objects=custom_objects) + qmodel = quantized_model_from_json(json.dumps(jm), custom_objects) # if transfer_weights is true, we load the weights from model to qmodel @@ -402,7 +384,8 @@ def model_quantize(model, if layer.get_weights(): qlayer.set_weights(copy.deepcopy(layer.get_weights())) - return qmodel, custom_objects + return qmodel + def _add_supported_quantized_objects(custom_objects): @@ -411,7 +394,6 @@ def _add_supported_quantized_objects(custom_objects): custom_objects["QConv1D"] = QConv1D custom_objects["QConv2D"] = QConv2D custom_objects["QDepthwiseConv2D"] = QDepthwiseConv2D - custom_objects["QAveragePooling2D"] = QAveragePooling2D custom_objects["QActivation"] = QActivation custom_objects["QBatchNormalization"] = QBatchNormalization custom_objects["Clip"] = Clip @@ -441,6 +423,7 @@ def quantized_model_from_json(json_string, custom_objects=None): return qmodel + def load_qmodel(filepath, custom_objects=None, compile=True): """ Load quantized model from Keras's model.save() h5 file. @@ -456,7 +439,7 @@ def load_qmodel(filepath, custom_objects=None, compile=True): considered during deserialization. compile: Boolean, whether to compile the model after loading. - + # Returns A Keras model instance. If an optimizer was found as part of the saved model, the model is already @@ -471,37 +454,90 @@ def load_qmodel(filepath, custom_objects=None, compile=True): # let's make a deep copy to make sure our objects are not shared elsewhere custom_objects = copy.deepcopy(custom_objects) - + _add_supported_quantized_objects(custom_objects) - - qmodel = tf.keras.models.load_model(filepath, custom_objects=custom_objects, compile=compile) - + + qmodel = tf.keras.models.load_model(filepath, custom_objects=custom_objects, + compile=compile) return qmodel def print_model_sparsity(model): - """Prints sparsity for the pruned layers in the model.""" - - def _get_sparsity(weights): - return 1.0 - np.count_nonzero(weights) / float(weights.size) - - print("Model Sparsity Summary ({})".format(model.name)) - print("--") - for layer in model.layers: - if isinstance(layer, pruning_wrapper.PruneLowMagnitude): - prunable_weights = layer.layer.get_prunable_weights() - elif isinstance(layer, prunable_layer.PrunableLayer): - prunable_weights = layer.get_prunable_weights() - elif prune_registry.PruneRegistry.supports(layer): - weight_names = prune_registry.PruneRegistry._weight_names(layer) - prunable_weights = [getattr(layer, weight) for weight in weight_names] - else: - prunable_weights = None - if prunable_weights: - print("{}: {}".format( - layer.name, ", ".join([ - "({}, {})".format(weight.name, - str(_get_sparsity(K.get_value(weight)))) - for weight in prunable_weights - ]))) - print("\n") \ No newline at end of file + """Prints sparsity for the pruned layers in the model.""" + + def _get_sparsity(weights): + return 1.0 - np.count_nonzero(weights) / float(weights.size) + + print("Model Sparsity Summary ({})".format(model.name)) + print("--") + for layer in model.layers: + if isinstance(layer, pruning_wrapper.PruneLowMagnitude): + prunable_weights = layer.layer.get_prunable_weights() + elif isinstance(layer, prunable_layer.PrunableLayer): + prunable_weights = layer.get_prunable_weights() + elif prune_registry.PruneRegistry.supports(layer): + weight_names = prune_registry.PruneRegistry._weight_names(layer) + prunable_weights = [getattr(layer, weight) for weight in weight_names] + else: + prunable_weights = None + if prunable_weights: + print("{}: {}".format( + layer.name, ", ".join([ + "({}, {})".format(weight.name, + str(_get_sparsity(K.get_value(weight)))) + for weight in prunable_weights + ]))) + print("\n") + + +def quantized_model_debug(model, X_test, plot=False): + """Debugs and plots model weights and activations.""" + outputs = [] + output_names = [] + + for layer in model.layers: + if layer.__class__.__name__ in [ + "QActivation", "QBatchNormalization", "Activation", "QDense", + "QConv2D", "QDepthwiseConv2D" + ]: + output_names.append(layer.name) + outputs.append(layer.output) + + model_debug = Model(inputs=model.inputs, outputs=outputs) + + y_pred = model_debug.predict(X_test) + + print("{:30} {: 8.4f} {: 8.4f}".format( + "input", np.min(X_test), np.max(X_test))) + + for n, p in zip(output_names, y_pred): + layer = model.get_layer(n) + print("{:30} {: 8.4f} {: 8.4f}".format(n, np.min(p), np.max(p)), end="") + if plot and layer.__class__.__name__ in ["QConv2D", "QDense", "QActivation"]: + plt.hist(p.flatten(), bins=25) + plt.title(layer.name + "(output)") + plt.show() + alpha = None + for i, weights in enumerate(layer.get_weights()): + if hasattr(layer, "get_quantizers") and layer.get_quantizers()[i]: + weights = K.eval(layer.get_quantizers()[i](K.constant(weights))) + if i == 0 and layer.__class__.__name__ in [ + "QConv1D", "QConv2D", "QDense"]: + alpha = get_weight_scale(layer.get_quantizers()[i], weights) + # if alpha is 0, let's remove all weights. + alpha_mask = (alpha == 0.0) + weights = np.where( + alpha_mask, + weights * alpha, + weights / alpha + ) + if plot: + plt.hist(weights.flatten(), bins=25) + plt.title(layer.name + "(weights)") + plt.show() + print(" ({: 8.4f} {: 8.4f})".format(np.min(weights), np.max(weights)), + end="") + if alpha is not None and isinstance(alpha, np.ndarray): + print(" a({: 10.6f} {: 10.6f})".format( + np.min(alpha), np.max(alpha)), end="") + print("") diff --git a/requirements.txt b/requirements.txt index 538b777c..01c4ceff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ pyasn1<0.5.0,>=0.4.6 requests<3,>=2.21.0 pyparsing pytest>=4.6.9 -tensorflow-model-optimization>=0.2.1 \ No newline at end of file +tensorflow-model-optimization>=0.2.1 +matplotlib>=3.1.2 diff --git a/setup.py b/setup.py index 593600d6..1bf91245 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ setuptools.setup( name="QKeras", - version="0.6.0", + version="0.7.0", author="Claudionor N. Coelho", author_email="nunescoelho@google.com", maintainer="Hao Zhuang", @@ -43,6 +43,7 @@ "scipy>=1.4.1", "pyparser", "setuptools>=41.0.0", + "tensorflow-model-optimization>=0.2.1", ], setup_requires=[ "pytest-runner", diff --git a/tests/automatic_conversion_test.py b/tests/automatic_conversion_test.py new file mode 100644 index 00000000..d1e1c710 --- /dev/null +++ b/tests/automatic_conversion_test.py @@ -0,0 +1,95 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import pytest +from tensorflow.keras.layers import * +from tensorflow.keras.models import * + +from qkeras import * +from qkeras.utils import model_quantize + + +def create_network(): + xi = Input((28,28,1)) + x = Conv2D(32, (3, 3))(xi) + x = Activation("relu")(x) + x = Conv2D(32, (3, 3), activation="relu")(x) + x = Activation("softmax")(x) + x = QConv2D(32, (3, 3), activation="quantized_relu(4)")(x) + return Model(inputs=xi, outputs=x) + + +def test_linear_activation(): + m = create_network() + + assert m.layers[1].activation.__name__ == "linear", "test failed" + + +def test_linear_activation_conversion(): + m = create_network() + + d = { + "QConv2D": { + "kernel_quantizer": "binary", + "bias_quantizer": "binary", + "activation_quantizer": "binary" + } + } + qq = model_quantize(m, d, 4) + + assert str(qq.layers[1].activation) == "binary()" + + +def test_no_activation_conversion_to_quantized(): + m = create_network() + d = {"QConv2D": {"kernel_quantizer": "binary", "bias_quantizer": "binary"}} + qq = model_quantize(m, d, 4) + assert qq.layers[2].__class__.__name__ == "Activation" + assert qq.layers[4].__class__.__name__ == "Activation" + + +def test_automatic_conversion_from_relu_to_qr(): + m = create_network() + d = { + "QConv2D": { + "kernel_quantizer": "binary", + "bias_quantizer": "binary" + }} + qq = model_quantize(m, d, 4) + assert str(qq.layers[3].activation) == "quantized_relu(4,0)" + + +def test_conversion_from_relu_activation_to_qr_qactivation(): + m = create_network() + d = { + "QConv2D": { + "kernel_quantizer": "binary", + "bias_quantizer": "binary" + }, + "QActivation": { + "relu": "ternary" + } + } + qq = model_quantize(m, d, 4) + assert qq.layers[2].__class__.__name__ == "QActivation" + assert str(qq.layers[2].quantizer) == "ternary()" + assert qq.layers[4].__class__.__name__ == "Activation" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/qactivation_test.py b/tests/qactivation_test.py index 72e51a8a..d2d82349 100644 --- a/tests/qactivation_test.py +++ b/tests/qactivation_test.py @@ -14,6 +14,9 @@ # limitations under the License. # ============================================================================== """Test activation from qlayers.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import numpy as np from numpy.testing import assert_allclose @@ -27,10 +30,12 @@ from qkeras import quantized_relu from qkeras import quantized_relu_po2 from qkeras import smooth_sigmoid +from qkeras import stochastic_binary +from qkeras import stochastic_ternary from qkeras import ternary @pytest.mark.parametrize( - 'bits, max_value, use_stochastic_rounding, quadratic_approximation, ' + + 'bits, max_value, use_stochastic_rounding, quadratic_approximation, ' 'test_values, expected_values', [ # bits=4 without max_value. Therefore the max exponent is 4 when # quadratic approximiation is enabled. The max and min values from this @@ -382,5 +387,66 @@ def test_stochastic_round_quantized_relu_po2(test_values, expected_values): assert_allclose(res, expected_values, rtol=1e-01, atol=1e-6) +def test_stochastic_binary(): + np.random.seed(42) + + x = np.random.uniform(-0.01, 0.01, size=10) + x = np.sort(x) + + s = stochastic_binary(alpha="auto_po2") + ty = np.zeros_like(s) + ts = 0.0 + + n = 1000 + + for _ in range(n): + y = K.eval(s(K.constant(x))) + scale = K.eval(s.scale)[0] + ts = ts + scale + ty = ty + (y / scale) + + result = (ty/n).astype(np.float32) + scale = np.array([ts/n]) + + expected = np.array( + [-1., -1., -1., -0.852, 0.782, 0.768, 0.97, 0.978, 1.0, 1.0] + ).astype(np.float32) + expected_scale = np.array([0.003906]) + + assert_allclose(result, expected, atol=0.1) + assert_allclose(scale, expected_scale, rtol=0.1) + + +def test_stochastic_ternary(): + np.random.seed(42) + + x = np.random.uniform(-0.01, 0.01, size=10) + x = np.sort(x) + + s = stochastic_ternary(alpha="auto_po2", temperature=8) + ty = np.zeros_like(s) + ts = 0.0 + + n = 1000 + + for _ in range(n): + y = K.eval(s(K.constant(x))) + scale = K.eval(s.scale)[0] + ts = ts + scale + ty = ty + (y / scale) + + result = (ty/n).astype(np.float32) + scale = np.array([ts/n]) + + expected = np.array( + [-0.998, -0.992, -0.992, -0.208, 0.048, 0.04 , 0.448, 0.606, + 0.987, 0.998] + ).astype(np.float32) + expected_scale = np.array([0.007812]) + + assert_allclose(result, expected, atol=0.1) + assert_allclose(scale, expected_scale, rtol=0.1) + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/qalpha_test.py b/tests/qalpha_test.py new file mode 100644 index 00000000..85ba0df3 --- /dev/null +++ b/tests/qalpha_test.py @@ -0,0 +1,127 @@ +# Copyright 2020 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test get_weight_scale function with auto and auto_po2 modes of quantizers.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import logging +from numpy.testing import assert_allclose +import pytest +from tensorflow.keras import backend as K +from qkeras import binary +from qkeras import get_weight_scale +from qkeras import ternary + + +# expected value if input is uniform distribution is: +# - alpha = m/2.0 for binary +# - alpha = (m+d)/2.0 for ternary + + +def test_binary_auto(): + """Test binary auto scale quantizer.""" + + np.random.seed(42) + N = 1000000 + m_list = [1.0, 0.1, 0.01, 0.001] + + for m in m_list: + x = np.random.uniform(-m, m, (N, 10)).astype(K.floatx()) + x = K.constant(x) + + quantizer = binary(alpha="auto") + q = K.eval(quantizer(x)) + + result = get_weight_scale(quantizer, q) + expected = m / 2.0 + logging.info("expect %s", expected) + logging.info("result %s", result) + assert_allclose(result, expected, rtol=0.02) + + +def test_binary_auto_po2(): + """Test binary auto_po2 scale quantizer.""" + + np.random.seed(42) + N = 1000000 + m_list = [1.0, 0.1, 0.01, 0.001] + + for m in m_list: + x = np.random.uniform(-m, m, (N, 10)).astype(K.floatx()) + x = K.constant(x) + + quantizer_ref = binary(alpha="auto") + quantizer = binary(alpha="auto_po2") + + q_ref = K.eval(quantizer_ref(x)) + q = K.eval(quantizer(x)) + + ref = get_weight_scale(quantizer_ref, q_ref) + + expected = np.power(2.0, np.round(np.log2(ref))) + result = get_weight_scale(quantizer, q) + + assert_allclose(result, expected, rtol=0.0001) + + +def test_ternary_auto(): + """Test ternary auto scale quantizer.""" + + np.random.seed(42) + N = 1000000 + m_list = [1.0, 0.1, 0.01, 0.001] + + for m in m_list: + x = np.random.uniform(-m, m, (N, 10)).astype(K.floatx()) + x = K.constant(x) + + quantizer = ternary(alpha="auto") + q = K.eval(quantizer(x)) + + d = m/3.0 + result = np.mean(get_weight_scale(quantizer, q)) + expected = (m + d) / 2.0 + assert_allclose(result, expected, rtol=0.02) + + +def test_ternary_auto_po2(): + """Test ternary auto_po2 scale quantizer.""" + + np.random.seed(42) + N = 1000000 + m_list = [1.0, 0.1, 0.01, 0.001] + + for m in m_list: + x = np.random.uniform(-m, m, (N, 10)).astype(K.floatx()) + x = K.constant(x) + + quantizer_ref = ternary(alpha="auto") + quantizer = ternary(alpha="auto_po2") + + q_ref = K.eval(quantizer_ref(x)) + q = K.eval(quantizer(x)) + + ref = get_weight_scale(quantizer_ref, q_ref) + + expected = np.power(2.0, np.round(np.log2(ref))) + result = get_weight_scale(quantizer, q) + + assert_allclose(result, expected, rtol=0.0001) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/qconvolutional_test.py b/tests/qconvolutional_test.py index 90d1a0a9..075fc7c7 100644 --- a/tests/qconvolutional_test.py +++ b/tests/qconvolutional_test.py @@ -14,12 +14,15 @@ # limitations under the License. # ============================================================================== """Test layers from qconvolutional.py.""" - +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import os -import tempfile import numpy as np from numpy.testing import assert_allclose import pytest +import tempfile + from tensorflow.keras import backend as K from tensorflow.keras.layers import Activation from tensorflow.keras.layers import Flatten @@ -27,6 +30,8 @@ from tensorflow.keras.models import Model from tensorflow.keras.backend import clear_session +from qkeras import binary +from qkeras import ternary from qkeras import QActivation from qkeras import QDense from qkeras import QConv1D @@ -41,7 +46,7 @@ from qkeras import extract_model_operations -# TODO: +# TODO(hzhuang): # qoctave_conv test # qbatchnorm test @@ -50,9 +55,9 @@ def test_qnetwork(): x = QSeparableConv2D( 32, (2, 2), strides=(2, 2), - depthwise_quantizer="binary", - pointwise_quantizer=quantized_bits(4, 0, 1), - depthwise_activation=quantized_bits(6, 2, 1), + depthwise_quantizer=binary(alpha=1.0), + pointwise_quantizer=quantized_bits(4, 0, 1, alpha=1.0), + depthwise_activation=quantized_bits(6, 2, 1, alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='conv2d_0_m')( x) @@ -60,7 +65,7 @@ def test_qnetwork(): x = QConv2D( 64, (3, 3), strides=(2, 2), - kernel_quantizer="ternary", + kernel_quantizer=ternary(alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='conv2d_1_m', activation=quantized_relu(6, 3, 1))( @@ -68,7 +73,7 @@ def test_qnetwork(): x = QConv2D( 64, (2, 2), strides=(2, 2), - kernel_quantizer=quantized_bits(6, 2, 1), + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='conv2d_2_m')( x) @@ -76,7 +81,7 @@ def test_qnetwork(): x = Flatten(name='flatten')(x) x = QDense( 10, - kernel_quantizer=quantized_bits(6, 2, 1), + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='dense')( x) @@ -90,7 +95,6 @@ def test_qnetwork(): model = quantized_model_from_json(json_string) # generate same output for weights - np.random.seed(42) for layer in model.layers: all_weights = [] @@ -126,27 +130,27 @@ def test_qnetwork(): assert np.all(all_weights == all_weights_signature) # test_qnetwork_forward: - expected_output = np.array([[0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 0.e+00, 0.e+00, 6.e-08, 1.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 0.e+00, 0.e+00, 5.e-07, 1.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00 ,1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 1.e+00, - 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 1.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00]]).astype(np.float16) - + expected_output = np.array( + [[0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 0.e+00, 0.e+00, 6.e-08, 1.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [ 0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 0.e+00, 0.e+00, 5.e-07, 1.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00 ,1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 1.e+00, + 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 1.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00]]).astype(np.float16) inputs = 2 * np.random.rand(10, 28, 28, 1) actual_output = model.predict(inputs).astype(np.float16) assert_allclose(actual_output, expected_output, rtol=1e-4) @@ -157,25 +161,25 @@ def test_qconv1d(): x = Input((4, 4,)) y = QConv1D( 2, 1, - kernel_quantizer=quantized_bits(6, 2, 1), + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='qconv1d')( x) model = Model(inputs=x, outputs=y) - #Extract model operations + # Extract model operations model_ops = extract_model_operations(model) # Assertion about the number of operations for this Conv1D layer - assert model_ops['qconv1d']["number_of_operations"] == 32 + assert model_ops['qconv1d']['number_of_operations'] == 32 # Print qstats to make sure it works with Conv1D layer - print_qstats(model) + print_qstats(model) # reload the model to ensure saving/loading works - json_string = model.to_json() - clear_session() - model = quantized_model_from_json(json_string) + # json_string = model.to_json() + # clear_session() + # model = quantized_model_from_json(json_string) for layer in model.layers: all_weights = [] @@ -189,16 +193,15 @@ def test_qconv1d(): 10.0 * np.random.normal(0.0, np.sqrt(2.0 / input_size), shape)) if all_weights: layer.set_weights(all_weights) - # Save the model as an h5 file using Keras's model.save() fd, fname = tempfile.mkstemp('.h5') model.save(fname) del model # Delete the existing model - # Returns a compiled model identical to the previous one + # Return a compiled model identical to the previous one model = load_qmodel(fname) - #Clean the created h5 file after loading the model + # Clean the created h5 file after loading the model os.close(fd) os.remove(fname) @@ -207,12 +210,6 @@ def test_qconv1d(): inputs = np.random.rand(2, 4, 4) p = model.predict(inputs).astype(np.float16) - ''' - y = np.array([[[0.1309, -1.229], [-0.4165, -2.639], [-0.08105, -2.299], - [1.981, -2.195]], - [[-0.3174, -3.94], [-0.3352, -2.316], [0.105, -0.833], - [0.2115, -2.89]]]).astype(np.float16) - ''' y = np.array([[[-2.441, 3.816], [-3.807, -1.426], [-2.684, -1.317], [-1.659, 0.9834]], [[-4.99, 1.139], [-2.559, -1.216], [-2.285, 1.905], diff --git a/tests/qlayers_test.py b/tests/qlayers_test.py index 1fc60ff8..5e1f0f84 100644 --- a/tests/qlayers_test.py +++ b/tests/qlayers_test.py @@ -14,10 +14,13 @@ # limitations under the License. # ============================================================================== """Test layers from qlayers.py.""" - +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import numpy as np from numpy.testing import assert_allclose import pytest +import logging from tensorflow.keras import backend as K from tensorflow.keras.layers import Activation from tensorflow.keras.layers import Flatten @@ -51,27 +54,34 @@ def qdense_util(layer_cls, @pytest.mark.parametrize( - 'layer_kwargs, input_data, weight_data, bias_data, expected_output', [ - ({ - 'units': 2, - 'use_bias': True, - 'kernel_initializer': 'glorot_uniform', - 'bias_initializer': 'zeros' - }, np.array([[1, 1, 1, 1]], dtype=K.floatx()), - np.array([[10, 20], [10, 20], [10, 20], [10, 20]], - dtype=K.floatx()), np.array([0, 0], dtype=K.floatx()), - np.array([[40, 80]], dtype=K.floatx())), - ({ - 'units': 2, - 'use_bias': True, - 'kernel_initializer': 'glorot_uniform', - 'bias_initializer': 'zeros', - 'kernel_quantizer': 'quantized_bits(2,0)', - 'bias_quantizer': 'quantized_bits(2,0)', - }, np.array([[1, 1, 1, 1]], dtype=K.floatx()), - np.array([[10, 20], [10, 20], [10, 20], [10, 20]], dtype=K.floatx()), - np.array([0, 0], dtype=K.floatx()), np.array([[2, 2]], - dtype=K.floatx())), + 'layer_kwargs, input_data, weight_data, bias_data, expected_output', + [ + ( + { + 'units': 2, + 'use_bias': True, + 'kernel_initializer': 'glorot_uniform', + 'bias_initializer': 'zeros' + }, + np.array([[1, 1, 1, 1]], dtype=K.floatx()), + np.array([[10, 20], [10, 20], [10, 20], [10, 20]], + dtype=K.floatx()), # weight_data + np.array([0, 0], dtype=K.floatx()), # bias + np.array([[40, 80]], dtype=K.floatx())), # expected_output + ( + { + 'units': 2, + 'use_bias': True, + 'kernel_initializer': 'glorot_uniform', + 'bias_initializer': 'zeros', + 'kernel_quantizer': 'quantized_bits(2,0,alpha=1.0)', + 'bias_quantizer': 'quantized_bits(2,0)', + }, + np.array([[1, 1, 1, 1]], dtype=K.floatx()), + np.array([[10, 20], [10, 20], [10, 20], [10, 20]], + dtype=K.floatx()), # weight_data + np.array([0, 0], dtype=K.floatx()), # bias + np.array([[2, 2]], dtype=K.floatx())), #expected_output ]) def test_qdense(layer_kwargs, input_data, weight_data, bias_data, expected_output):