From 244ebe82e09b8ce6743a15c399cc8b0107a4e3d7 Mon Sep 17 00:00:00 2001 From: Hao Zhuang Date: Tue, 24 Mar 2020 13:37:18 -0700 Subject: [PATCH] This patch is for better numerical behaviors: - Add alpha, get_scale function for binary and ternary quantizers. - Control the initial weight distribution with respect to fan-in of layers. PiperOrigin-RevId: 302736657 Change-Id: I296d8fe267f9fb47e45ad3bdddac1d7a332c155c --- CHANGELOG | 5 +- README.md | 1 + examples/example_ternary.py | 113 ++++ notebook/QKerasTutorial.ipynb | 869 +++++++++++++++++++++++++ notebook/images/figure1.png | Bin 0 -> 46857 bytes notebook/images/figure2.png | Bin 0 -> 42218 bytes qkeras/__init__.py | 2 +- qkeras/qconvolutional.py | 231 ++++--- qkeras/qlayers.py | 100 +-- qkeras/qnormalization.py | 88 ++- qkeras/qpooling.py | 5 +- qkeras/quantizers.py | 980 ++++++++++++++++++++++++----- qkeras/safe_eval.py | 10 +- qkeras/utils.py | 318 +++++----- requirements.txt | 3 +- setup.py | 3 +- tests/automatic_conversion_test.py | 95 +++ tests/qactivation_test.py | 68 +- tests/qalpha_test.py | 127 ++++ tests/qconvolutional_test.py | 91 ++- tests/qlayers_test.py | 54 +- 21 files changed, 2604 insertions(+), 559 deletions(-) create mode 100644 examples/example_ternary.py create mode 100644 notebook/QKerasTutorial.ipynb create mode 100644 notebook/images/figure1.png create mode 100644 notebook/images/figure2.png create mode 100644 tests/automatic_conversion_test.py create mode 100644 tests/qalpha_test.py diff --git a/CHANGELOG b/CHANGELOG index 69c4cc86..398ad913 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,2 +1,3 @@ -v0.5, 07/18/2019 -- Initial release. -v0.6, 12/03/2019 -- Support tensorflow 2.0 and tf.keras +v0.5, 2019/07 -- Initial release. +v0.6, 2020/03 -- Support tensorflow 2.0, tf.keras and python3. +v0.7, 2020/03 -- Enhancemence of binary and ternary quantization. diff --git a/README.md b/README.md index 4864114b..060afb62 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,7 @@ that accept an alpha parameter, we need to specify a range of alpha, and for po2 type of quantizers, we need to specify the range of max_value. + ### Example Suppose you have the following network. diff --git a/examples/example_ternary.py b/examples/example_ternary.py new file mode 100644 index 00000000..9f11dc66 --- /dev/null +++ b/examples/example_ternary.py @@ -0,0 +1,113 @@ +# Copyright 2020 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import # Not necessary in a Python 3-only module +from __future__ import division # Not necessary in a Python 3-only module +from __future__ import google_type_annotations # Not necessary in a Python 3-only module +from __future__ import print_function # Not necessary in a Python 3-only module + +from absl import app +from absl import flags +import matplotlib +import numpy as np + +matplotlib.use('TkAgg') +import matplotlib.pyplot as plt + + +FLAGS = flags.FLAGS + + +def _stochastic_rounding(x, precision, resolution, delta): + """Stochastic_rounding for ternary. + + Args: + x: + precision: A float. The area we want to make this stochastic rounding. + [delta-precision, delta] [delta, delta+precision] + resolution: control the quantization resolution. + delta: the undiscountinued point (positive number) + + Return: + A tensor with stochastic rounding numbers. + """ + delta_left = delta - precision + delta_right = delta + precision + scale = 1 / resolution + scale_delta_left = delta_left * scale + scale_delta_right = delta_right * scale + scale_2_delta = scale_delta_right - scale_delta_left + scale_x = x * scale + fraction = scale_x - scale_delta_left + # print(precision, scale, x[0], np.floor(scale_x[0]), scale_x[0], fraction[0]) + + # we use uniform distribution + random_selector = np.random.uniform(0, 1, size=x.shape) * scale_2_delta + + # print(precision, scale, x[0], delta_left[0], delta_right[0]) + # print('x', scale_x[0], fraction[0], random_selector[0], scale_2_delta[0]) + # rounddown = fraction < random_selector + result = np.where(fraction < random_selector, + scale_delta_left / scale, + scale_delta_right / scale) + return result + + +def _ternary(x, sto=False): + m = np.amax(np.abs(x), keepdims=True) + scale = 2 * m / 3.0 + thres = scale / 2.0 + ratio = 0.1 + + if sto: + sign_bit = np.sign(x) + x = np.abs(x) + prec = x / scale + x = ( + sign_bit * scale * _stochastic_rounding( + x / scale, + precision=0.3, resolution=0.01, # those two are all normalized. + delta=thres / scale)) + # prec + prec *ratio) + # mm = np.amax(np.abs(x), keepdims=True) + return np.where(np.abs(x) < thres, np.zeros_like(x), np.sign(x)) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + # x = np.arange(-3.0, 3.0, 0.01) + # x = np.random.uniform(-0.01, 0.01, size=1000) + x = np.random.uniform(-10.0, 10.0, size=1000) + # x = np.random.uniform(-1, 1, size=1000) + x = np.sort(x) + tr = np.zeros_like(x) + t = np.zeros_like(x) + iter_count = 500 + for _ in range(iter_count): + y = _ternary(x) + yr = _ternary(x, sto=True) + t = t + y + tr = tr + yr + + plt.plot(x, t/iter_count) + plt.plot(x, tr/iter_count) + plt.ylabel('mean (%s samples)' % iter_count) + plt.show() + + +if __name__ == '__main__': + app.run(main) diff --git a/notebook/QKerasTutorial.ipynb b/notebook/QKerasTutorial.ipynb new file mode 100644 index 00000000..cc115ea7 --- /dev/null +++ b/notebook/QKerasTutorial.ipynb @@ -0,0 +1,869 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "##### Copyright 2020 Google LLC\n", + "#\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# QKeras Lab Book" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "__QKeras__ is a quantization extension to Keras that provides drop-in replacement for some of the Keras layers, especially the ones that creates parameters and activation layers, and perform arithmetic operations, so that we can quickly create a deep quantized version of Keras network.\n", + "\n", + "According to Tensorflow documentation, Keras is a high-level API to build and train deep learning models. It's used for fast prototyping, advanced research, and production, with three key advantages:\n", + "\n", + "- User friendly
\n", + "Keras has a simple, consistent interface optimized for common use cases. It provides clear and actionable feedback for user errors.\n", + "\n", + "- Modular and composable
\n", + "Keras models are made by connecting configurable building blocks together, with few restrictions.\n", + "\n", + "- Easy to extend
\n", + "Write custom building blocks to express new ideas for research. Create new layers, loss functions, and develop state-of-the-art models.\n", + "\n", + "__QKeras__ is being designed to extend the functionality of Keras using Keras' design principle, i.e. being user friendly, modular and extensible, adding to it being \"minimally intrusive\" of Keras native functionality.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Related Work\n", + "\n", + "__QKeras__ has been implemented based on the work of _\"B.Moons et al. - Minimum Energy Quantized Neural Networks\"_ , Asilomar Conference on Signals, Systems and Computers, 2017 and _“Zhou, S. et al. DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients,”_ but the framework should be easily extensible. The original code from QNN can be found below.\n", + "\n", + "https://github.com/BertMoons/QuantizedNeuralNetworks-Keras-Tensorflow\n", + "\n", + "__QKeras__ extends QNN by providing a richer set of layers (including SeparableConv2D, DepthwiseConv2D, ternary and stochastic ternary quantizations), besides some functions to aid the estimation for the accumulators and conversion between non-quantized to quantized networks. Finally, our main goal is easy of use, so we attempt to make QKeras layers a true drop-in replacement for Keras, so that users can easily exchange non-quantized layers by quantized ones." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Layers Implemented in QKeras\n", + "\n", + "The following layers have been implemented in __QKeras__.\n", + "\n", + "- __`QDense`__\n", + "\n", + "- __`QConv1D`__\n", + "\n", + "- __`QConv2D`__\n", + "\n", + "- __`QDepthwiseConv2D`__\n", + "\n", + "- __`QSeparableConv2D`__ (depthwise + pointwise expanded, extended from MobileNet SeparableConv2D implementation)\n", + "\n", + "- __`QActivation`__\n", + "\n", + "- __`QAveragePooling2D`__ (in fact, a AveragePooling2D stacked with a QActivation layer for quantization of the result, so this layer does not exist)\n", + "\n", + "- __`QBatchNormalization`__\n", + "\n", + "- __`QOctaveConv2D`__\n", + "\n", + "It is worth noting that not all functionality is safe at this time to be used with other high-level operations, such as with layer wrappers. For example, `Bidirectional` layer wrappers are used with RNNs. If this is required, we encourage users to use quantization functions invoked as strings instead of the actual functions as a way through this, but we may change that implementation in the future.\n", + "\n", + "__`QSeparableConv2D`__ is implemented as a depthwise + pointwise quantized expansions, which is extended from the `SeparableConv2D` implementation of MobileNet. With the exception of __`QBatchNormalization`__, if quantizers are not specified, no quantization is applied to the layer and it ends up behaving like the orgininal unquantized layers. On the other hand, __`QBatchNormalization`__ has been implemented differently as if the user does not specify any quantizers as parameters, it uses a set up that has worked best when attempting to implement quantization efficiently in hardware and software, i.e. `gamma` and `variance` with po2 quantizers (as they become shift registers in an implementation, and with further constraining variance po2 quantizer to use quadratic approximation as we take the square root of the variance to obtain the standard deviation), `beta` using po2 quantizer to maintain the dynamic range aspect of the center parameter, and `mean` remaining unquantized, as it inherits the properties of the previous layer.\n", + "\n", + "Activation has been migrated to __`QActivation`__ although it __QKeras__ also recognizes activation parameter used in convolutional and dense layers.\n", + "\n", + "We have improved the setup of quantization as convolution, dense and batch normalization layers now notify the quantizers when the quantizers are used as internal parameters, so the user does not need to worry about setting up options that work best in `weights` and `bias` like `alpha` and `use_stochastic_rounding` (although users may override the automatic setup).\n", + "\n", + "Finally, in the current version, we have eliminated the need to set up the range of the quantizers like `kernel_range` in __`QDense`__. This is automatically computed internally at this point. Although we kept the parameters for backward compatibility, these parameters will be removed in the future." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Activation Layers and Quantizers Implemented in __QKeras__\n", + "\n", + "Quantizers and activation layers are treated interchangingly in __QKeras__. \n", + "\n", + "The list of quantizers and its parameters is listed below.\n", + "\n", + "- __`smooth_sigmoid(x)`__\n", + "\n", + "- __`hard_sigmoid(x)`__\n", + "\n", + "- __`binary_sigmoid(x)`__\n", + "\n", + "- __`smooth_tanh(x)`__\n", + "\n", + "- __`hard_tanh(x)`__\n", + "\n", + "- __`binary_tanh(x)`__\n", + "\n", + "- __`quantized_bits(bits=8, integer=0, symmetric=0, keep_negative=1, alpha=None, use_stochastic_rouding=False)(x)`__\n", + "\n", + "- __`bernoulli(alpha=None, temperature=6.0, use_real_sigmoid=True)(x)`__\n", + "\n", + "- __`stochastic_ternary(alpha=None, threshold=None, temperature=8.0, use_real_sigmoid=True)(x)`__\n", + "\n", + "- __`ternary(alpha=None, threshold=None, use_stochastic_rounding=False)(x)`__\n", + "\n", + "- __`stochastic_binary(alpha=None, temperature=6.0, use_real_sigmoid=True)(x)`__\n", + "\n", + "- __`binary(use_01=False, alpha=None, use_stochastic_rounding=False)(x)`__\n", + "\n", + "- __`quantized_relu(bits=8, integer=0, use_sigmoid=0, use_stochastic_rounding=False)(x)`__\n", + "\n", + "- __`quantized_ulaw(bits=8, integer=0, symmetric=0, u=255.0)(x)`__\n", + "\n", + "- __`quantized_tanh(bits=8, integer=0, symmetric=0, use_stochastic_rounding=False)(x)`__\n", + "\n", + "- __`quantized_po2(bits=8, max_value=None, use_stochastic_rounding=False, quadratic_approximation=False)(x)`__\n", + "\n", + "- __`quantized_relu_po2(bits=8, max_value=None, use_stochastic_rounding=False, quadratic_approximation=False)(x)`__\n", + "\n", + "The __`stochastic_*`__ functions and __`bernoulli`__ rely on stochastic versions of the activation functions, so they are best suited for weights and biases. They draw a random number with uniform distribution from `sigmoid` of the input x, and result is based on the expected value of the activation function. Please refer to the papers if you want to understand the underlying theory, or the documentation in qkeras/quantizers.py. The parameter `temperature` determines how steep the sigmoid function will behave, and the default values seem to work fine.\n", + "\n", + "As we lower the number of bits, rounding becomes problematic as it adds bias to the number system. Numpy attempt to reduce the effects of bias by rounding to even instead of rounding to infinity. Recent results (_\"Suyog Gupta, Ankur Agrawal, Kailash Gopalakrishnan, Pritish Narayanan; Deep Learning with Limited Numerical Precision_ [https://arxiv.org/abs/1502.02551]) suggested using stochastic rounding, which uses the fracional part of the number as a probability to round up or down. We can turn on stochastic rounding in some quantizers by setting `use_stochastic_rounding` to `True` in __`quantized_bits`__, __`binary`__, __`ternary`__, __`quantized_relu`__ and __`quantized_tanh`__, __`quantized_po2`__, and __`quantized_relu_po2`__. Please note that if one is considering an efficient hardware or software implementation, we should avoid setting this flag to `True` in activations as it may affect the efficiency of an implementation. In addition, as mentioned before, we already set this flag to `True` in some quantized layers when the quantizers are used as weights/biases.\n", + "\n", + "The parameters `bits` specify the number of bits for the quantization, and `integer` specifies how many bits of `bits` are to the left of the decimal point. Finally, our experience in training networks with __`QSeparableConv2D`__, it is advisable to allocate more bits between the depthwise and the pointwise quantization, and both __`quantized_bits`__ and __`quantized_tanh`__ should use symmetric versions for weights and bias in order to properly converge and eliminate the bias.\n", + "\n", + "We have substantially improved stochastic rounding implementation in __QKeras__ $>= 0.7$, and added a symbolic way to compute alpha in __`binary`__, __`stochastic_binary`__, __`ternary`__, __`stochastic_ternary`__, __`bernoulli`__ and __`quantized_bits`__. Right now, a scale and the threshold (for ternary and stochastic_ternary) can be computed independently of the distribution of the inputs, which is required when using these quantizers in weights.\n", + "\n", + "The main problem in using very small bit widths in large deep learning networks stem from the fact that weights are initialized with variance roughly $\\propto \\sqrt{1/\\tt{fanin}}$, but during the training the variance shifts outwards. If the smallest quantization representation (threshold in ternary networks) is smaller than $\\sqrt{1/\\tt{fanin}}$, we run the risk of having the weights stuck at 0 during training. So, the weights need to dynamically adjust to the variance shift from initialization to the final training. This can be done by scaling the quantization. \n", + "\n", + "Scale is computed using the formula $\\sum(\\tt{dot}(Q,x))/\\sum(\\tt{dot}(Q,Q))$ which is described in several papers, including _Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, Ali Farhadi \"XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks\"_ [https://arxiv.org/abs/1603.05279]. Scale computation is computed for each output channel, making our implementation sometimes behaving like a mini-batch normalization adjustment. \n", + "\n", + "For __`ternary`__ and __`stochastic_ternary`__, we iterate between scale computation and threshold computation, as presented in _K. Hwang and W. Sung, \"Fixed-point feedforward deep neural network design using weights +1, 0, and −1,\" 2014 IEEE Workshop on Signal Processing Systems (SiPS), Belfast, 2014, pp. 1-6_ which makes the search for threshold and scale tolerant to different input distributions. This is especially important when we need to consider that the threshold shifts depending on the input distribution, affecting the scale as well, as pointed out by _Fengfu Li, Bo Zhang, Bin Liu, \"Ternary Weight Networks\"_ [https://arxiv.org/abs/1605.04711]. \n", + "\n", + "When computing the scale in these quantizers, if `alpha=\"auto\"`, we compute the scale as a floating point number. If `alpha=\"auto_po2\"`, we enforce the scale to be a power of 2, meaning that an actual hardware or software implementation can be performed by just shifting the result of the convolution or dense layer to the right or left by checking the sign of the scale (positive shifts left, negative shifts right), and taking the log2 of the scale. This behavior is compatible with shared exponent approaches, as it performs a shift adjustment to the channel.\n", + "\n", + "We have implemented a method for each quantizer called __`_set_trainable_parameter`__ that instructs __QKeras__ to set best options when this quantizer is used as a weight or for gamma, variance and beta in __`QBatchNormalization`__, so in principle, users should not worry about this.\n", + "\n", + "The following pictures show the behavior of __`binary`__ vs stochastic rounding in __`binary`__ vs __`stochastic_binary`__ (Figure 1) and __`ternary`__ vs stochastic rounding in __`ternary`__ and __`stochastic_ternary`__ (Figure 2). We generated a normally distributed input with mean 0.0 and standard deviation of 0.02, ordered the data, and ran the quantizer 1,000 times, averaging the result for each case. Note that because of scale, the output does not range from $[-1.0, +1.0]$, but from $[-\\tt{scale}, +\\tt{scale}]$.\n", + "\n", + "\n", + "\"Binary
Figure 1: Behavior of binary quantizers
\n", + "\"Ternary
Figure 2: Behavior of ternary quantizers
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using QKeras\n", + "\n", + "__QKeras__ works by tagging all variables and weights/bias created by Keras as well as output of arithmetic layers by quantized functions. Quantized functions can be instantiated directly in __`QDense`__/__`QConv2D`__/__`QSeparableConv2D`__ functions, and they can be passed to __`QActivation`__, which act as a merged quantization and activation function.\n", + "\n", + "In order to successfully quantize a model, users need to replace layers that create variables (trainable or not) (`Dense`, `Conv2D`, etc) by their equivalent ones in __Qkeras__ (__`QDense`__, __`QConv2D`__, etc), and any layers that perform math operations need to be quantized afterwards.\n", + "\n", + "Quantized values are clipped between their maximum and minimum quantized representation (which may be different than $[-1.0, 1.0]$), although for `po2` type of quantizers, we still recommend the users to specify the parameter for `max_value`.\n", + "\n", + "An example of a very simple network is given below in Keras." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", + " from ._conv import register_converters as _register_converters\n" + ] + } + ], + "source": [ + "import six\n", + "import numpy as np\n", + "import tensorflow.compat.v2 as tf\n", + "\n", + "from tensorflow.keras.layers import *\n", + "from tensorflow.keras.models import Model\n", + "from tensorflow.keras.datasets import mnist\n", + "from tensorflow.keras.utils import to_categorical" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def CreateModel(shape, nb_classes):\n", + " x = x_in = Input(shape)\n", + " x = Conv2D(18, (3, 3), name=\"conv2d_1\")(x)\n", + " x = Activation(\"relu\", name=\"act_1\")(x)\n", + " x = Conv2D(32, (3, 3), name=\"conv2d_2\")(x)\n", + " x = Activation(\"relu\", name=\"act_2\")(x)\n", + " x = Flatten(name=\"flatten\")(x)\n", + " x = Dense(nb_classes, name=\"dense\")(x)\n", + " x = Activation(\"softmax\", name=\"softmax\")(x)\n", + " \n", + " model = Model(inputs=x_in, outputs=x)\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def get_data():\n", + " (x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + " x_train = x_train.reshape(x_train.shape + (1,)).astype(\"float32\")\n", + " x_test = x_test.reshape(x_test.shape + (1,)).astype(\"float32\")\n", + "\n", + " x_train /= 256.0\n", + " x_test /= 256.0\n", + "\n", + " x_mean = np.mean(x_train, axis=0)\n", + "\n", + " x_train -= x_mean\n", + " x_test -= x_mean\n", + "\n", + " nb_classes = np.max(y_train)+1\n", + " y_train = to_categorical(y_train, nb_classes)\n", + " y_test = to_categorical(y_test, nb_classes)\n", + "\n", + " return (x_train, y_train), (x_test, y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "(x_train, y_train), (x_test, y_test) = get_data()\n", + "\n", + "model = CreateModel(x_train.shape[1:], y_train.shape[-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/3\n", + "60000/60000 [==============================] - 7s 122us/sample - loss: 0.2103 - accuracy: 0.9393 - val_loss: 0.0685 - val_accuracy: 0.9781\n", + "Epoch 2/3\n", + "60000/60000 [==============================] - 6s 108us/sample - loss: 0.0642 - accuracy: 0.9808 - val_loss: 0.0575 - val_accuracy: 0.9817\n", + "Epoch 3/3\n", + "60000/60000 [==============================] - 7s 112us/sample - loss: 0.0457 - accuracy: 0.9860 - val_loss: 0.0502 - val_accuracy: 0.9844\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(x_train, y_train, epochs=3, batch_size=128, validation_data=(x_test, y_test), verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great! it is relatively easy to create a network that converges in MNIST with very high test accuracy. The reader should note that we named all the layers as it will make it easier to automatically convert the network by name." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n", + "_________________________________________________________________\n", + "conv2d_1 (Conv2D) (None, 26, 26, 18) 180 \n", + "_________________________________________________________________\n", + "act_1 (Activation) (None, 26, 26, 18) 0 \n", + "_________________________________________________________________\n", + "conv2d_2 (Conv2D) (None, 24, 24, 32) 5216 \n", + "_________________________________________________________________\n", + "act_2 (Activation) (None, 24, 24, 32) 0 \n", + "_________________________________________________________________\n", + "flatten (Flatten) (None, 18432) 0 \n", + "_________________________________________________________________\n", + "dense (Dense) (None, 10) 184330 \n", + "_________________________________________________________________\n", + "softmax (Activation) (None, 10) 0 \n", + "=================================================================\n", + "Total params: 189,726\n", + "Trainable params: 189,726\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The corresponding quantized network is presented below." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from qkeras import *\n", + "\n", + "def CreateQModel(shape, nb_classes):\n", + " x = x_in = Input(shape)\n", + " x = QConv2D(18, (3, 3),\n", + " kernel_quantizer=\"stochastic_ternary\", \n", + " bias_quantizer=\"quantized_po2(4)\",\n", + " name=\"conv2d_1\")(x)\n", + " x = QActivation(\"quantized_relu(2)\", name=\"act_1\")(x)\n", + " x = QConv2D(32, (3, 3), \n", + " kernel_quantizer=\"stochastic_ternary\", \n", + " bias_quantizer=\"quantized_po2(4)\",\n", + " name=\"conv2d_2\")(x)\n", + " x = QActivation(\"quantized_relu(2)\", name=\"act_2\")(x)\n", + " x = Flatten(name=\"flatten\")(x)\n", + " x = QDense(nb_classes,\n", + " kernel_quantizer=\"quantized_bits(3,0,1)\",\n", + " bias_quantizer=\"quantized_bits(3)\",\n", + " name=\"dense\")(x)\n", + " x = Activation(\"softmax\", name=\"softmax\")(x)\n", + " \n", + " model = Model(inputs=x_in, outputs=x)\n", + " \n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "in converted code:\n\n build/bdist.linux-x86_64/egg/qkeras/qlayers.py:1150 call *\n outputs = tf.keras.backend.conv2d(\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/keras/backend.py:4889 conv2d\n data_format=tf_data_format)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:899 convolution\n name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:1010 convolution_internal\n name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/gen_nn_ops.py:969 conv2d\n data_format=data_format, dilations=dilations, name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/framework/op_def_library.py:477 _apply_op_helper\n repr(values), type(values).__name__, err))\n\n TypeError: Expected float32 passed to parameter 'filter' of op 'Conv2D', got of type 'stochastic_ternary' instead. Error: Expected float32, got of type 'stochastic_ternary' instead.\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mqmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCreateQModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mCreateQModel\u001b[0;34m(shape, nb_classes)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mkernel_quantizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"stochastic_ternary\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mbias_quantizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"quantized_po2(4)\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m name=\"conv2d_1\")(x)\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mQActivation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"quantized_relu(2)\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"act_1\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m x = QConv2D(32, (3, 3), \n", + "\u001b[0;32m/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/base_layer.pyc\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 771\u001b[0m not base_layer_utils.is_in_eager_or_tf_function()):\n\u001b[1;32m 772\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mauto_control_deps\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAutomaticControlDependencies\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0macd\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 773\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcast_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 774\u001b[0m \u001b[0;31m# Wrap Tensors in `outputs` in `tf.identity` to avoid\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 775\u001b[0m \u001b[0;31m# circular dependencies.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/autograph/impl/api.pyc\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint:disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'ag_error_metadata'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 237\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mag_error_metadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 238\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: in converted code:\n\n build/bdist.linux-x86_64/egg/qkeras/qlayers.py:1150 call *\n outputs = tf.keras.backend.conv2d(\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/keras/backend.py:4889 conv2d\n data_format=tf_data_format)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:899 convolution\n name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:1010 convolution_internal\n name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/gen_nn_ops.py:969 conv2d\n data_format=data_format, dilations=dilations, name=name)\n /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/framework/op_def_library.py:477 _apply_op_helper\n repr(values), type(values).__name__, err))\n\n TypeError: Expected float32 passed to parameter 'filter' of op 'Conv2D', got of type 'stochastic_ternary' instead. Error: Expected float32, got of type 'stochastic_ternary' instead.\n" + ] + } + ], + "source": [ + "qmodel = CreateQModel(x_train.shape[1:], y_train.shape[-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'qmodel' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAdam\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m qmodel.compile(\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"categorical_crossentropy\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.0005\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'qmodel' is not defined" + ] + } + ], + "source": [ + "from tensorflow.keras.optimizers import Adam\n", + "\n", + "qmodel.compile(\n", + " loss=\"categorical_crossentropy\",\n", + " optimizer=Adam(0.0005),\n", + " metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/10\n", + "60000/60000 [==============================] - 52s 869us/sample - loss: 0.5034 - accuracy: 0.8428 - val_loss: 0.2422 - val_accuracy: 0.9276\n", + "Epoch 2/10\n", + "60000/60000 [==============================] - 49s 813us/sample - loss: 0.2080 - accuracy: 0.9371 - val_loss: 0.1981 - val_accuracy: 0.9415\n", + "Epoch 3/10\n", + "60000/60000 [==============================] - 50s 832us/sample - loss: 0.1703 - accuracy: 0.9503 - val_loss: 0.1454 - val_accuracy: 0.9571\n", + "Epoch 4/10\n", + "60000/60000 [==============================] - 49s 813us/sample - loss: 0.1448 - accuracy: 0.9573 - val_loss: 0.1296 - val_accuracy: 0.9616\n", + "Epoch 5/10\n", + "60000/60000 [==============================] - 48s 806us/sample - loss: 0.1225 - accuracy: 0.9639 - val_loss: 0.1267 - val_accuracy: 0.9636\n", + "Epoch 6/10\n", + "60000/60000 [==============================] - 49s 818us/sample - loss: 0.1074 - accuracy: 0.9686 - val_loss: 0.1092 - val_accuracy: 0.9662\n", + "Epoch 7/10\n", + "60000/60000 [==============================] - 48s 801us/sample - loss: 0.1020 - accuracy: 0.9698 - val_loss: 0.1041 - val_accuracy: 0.9718\n", + "Epoch 8/10\n", + "60000/60000 [==============================] - 50s 825us/sample - loss: 0.0979 - accuracy: 0.9712 - val_loss: 0.1175 - val_accuracy: 0.9662\n", + "Epoch 9/10\n", + "60000/60000 [==============================] - 49s 823us/sample - loss: 0.0910 - accuracy: 0.9726 - val_loss: 0.1015 - val_accuracy: 0.9711\n", + "Epoch 10/10\n", + "60000/60000 [==============================] - 50s 836us/sample - loss: 0.0845 - accuracy: 0.9755 - val_loss: 0.1044 - val_accuracy: 0.9710\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qmodel.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test), verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should note that we had to lower the learning rate and train the network for longer time. On the other hand, the network should not involve in any multiplications in the convolution layers, and very small multipliers in the dense layers.\n", + "\n", + "Please note that the last `Activation` was not changed to __`QActivation`__ as during inference we usually perform the operation `argmax` on the result instead of `softmax`.\n", + "\n", + "It seems it is a lot of code to write besides the main network, but in fact, this additional code is only specifying the sizes of the weights and the sizes of the outputs in the case of the activations. Right now, we do not have a way to extract this information from the network structure or problem we are trying to solve, and if we quantize too much a layer, we may end up not been able to recover from that later on." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting a Model Automatically\n", + "\n", + "In addition to the drop-in replacement of Keras functions, we have written the following function to assist anyone who wants to quantize a network.\n", + "\n", + "__`model_quantize(model, quantizer_config, activation_bits, custom_objects=None, transfer_weights=False)`__\n", + "\n", + "This function converts an non-quantized model (such as the one from `model` in the previous example) into a quantized version, by applying a configuration specified by the dictionary `quantizer_config`, and `activation_bits` specified for unamed activation functions, with this parameter probably being removed in future versions.\n", + "\n", + "The parameter `custom_objects` specifies object dictionary unknown to Keras, required when you copy a model with lambda layers, or customized layer functions, for example, and if `transfer_weights` is `True`, the returned model will have as initial weights the weights from the original model, instead of using random initial weights.\n", + "\n", + "The dictionary specified in `quantizer_config` can be indexed by a layer name or layer class name. In the example below, conv2d_1 corresponds to the first convolutional layer of the example, while QConv2D corresponds to the default behavior of two dimensional convolutional layers. The reader should note that right now we recommend using __`QActivation`__ with a dictionary to avoid the conversion of activations such as `softmax` and `linear`. In addition, although we could use `activation` field in the layers, we do not recommend that. \n", + "\n", + "`{\n", + " \"conv2d_1\": {\n", + " \"kernel_quantizer\": \"stochastic_ternary\",\n", + " \"bias_quantizer\": \"quantized_po2(4)\"\n", + " },\n", + " \"QConv2D\": {\n", + " \"kernel_quantizer\": \"stochastic_ternary\",\n", + " \"bias_quantizer\": \"quantized_po2(4)\"\n", + " },\n", + " \"QDense\": {\n", + " \"kernel_quantizer\": \"quantized_bits(3,0,1)\",\n", + " \"bias_quantizer\": \"quantized_bits(3)\"\n", + " },\n", + " \"act_1\": \"quantized_relu(2)\",\n", + " \"QActivation\": { \"relu\": \"quantized_relu(2)\" }\n", + "}`\n", + "\n", + "In the following example, we will quantize the model using a different strategy.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " \"conv2d_1\": {\n", + " \"kernel_quantizer\": \"stochastic_binary\",\n", + " \"bias_quantizer\": \"quantized_po2(4)\"\n", + " },\n", + " \"QConv2D\": {\n", + " \"kernel_quantizer\": \"stochastic_ternary\",\n", + " \"bias_quantizer\": \"quantized_po2(4)\"\n", + " },\n", + " \"QDense\": {\n", + " \"kernel_quantizer\": \"quantized_bits(4,0,1)\",\n", + " \"bias_quantizer\": \"quantized_bits(4)\"\n", + " },\n", + " \"QActivation\": { \"relu\": \"binary\" },\n", + " \"act_2\": \"quantized_relu(3)\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "conv2d_1 kernel: stochastic_binary(alpha=auto_po2) bias: quantized_po2(4)\n", + "act_1 quantizer: binary()\n", + "conv2d_2 kernel: stochastic_ternary(alpha=auto_po2,threshold=0.33) bias: quantized_po2(4)\n", + "act_2 quantizer: quantized_relu(3,0)\n", + "dense kernel: quantized_bits(4,0,1,alpha=auto_po2,use_stochastic_rounding=1) bias: quantized_bits(4,0,1,alpha=auto_po2,use_stochastic_rounding=1)\n", + "\n", + "Model: \"model_1\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_2 (InputLayer) [(None, 28, 28, 1)] 0 \n", + "_________________________________________________________________\n", + "conv2d_1 (QConv2D) (None, 26, 26, 18) 180 \n", + "_________________________________________________________________\n", + "act_1 (QActivation) (None, 26, 26, 18) 0 \n", + "_________________________________________________________________\n", + "conv2d_2 (QConv2D) (None, 24, 24, 32) 5216 \n", + "_________________________________________________________________\n", + "act_2 (QActivation) (None, 24, 24, 32) 0 \n", + "_________________________________________________________________\n", + "flatten (Flatten) (None, 18432) 0 \n", + "_________________________________________________________________\n", + "dense (QDense) (None, 10) 184330 \n", + "_________________________________________________________________\n", + "softmax (Activation) (None, 10) 0 \n", + "=================================================================\n", + "Total params: 189,726\n", + "Trainable params: 189,726\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "from qkeras.utils import model_quantize\n", + "\n", + "qmodel = model_quantize(model, config, 4, transfer_weights=True)\n", + "\n", + "for layer in qmodel.layers:\n", + " if hasattr(layer, \"kernel_quantizer\"):\n", + " print(layer.name, \"kernel:\", str(layer.kernel_quantizer_internal), \"bias:\", str(layer.bias_quantizer_internal))\n", + " elif hasattr(layer, \"quantizer\"):\n", + " print(layer.name, \"quantizer:\", str(layer.quantizer))\n", + "\n", + "print()\n", + "qmodel.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "qmodel.compile(\n", + " loss=\"categorical_crossentropy\",\n", + " optimizer=Adam(0.001),\n", + " metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/10\n", + "60000/60000 [==============================] - 43s 722us/sample - loss: 0.0796 - accuracy: 0.9757 - val_loss: 0.0968 - val_accuracy: 0.9725\n", + "Epoch 2/10\n", + "60000/60000 [==============================] - 44s 726us/sample - loss: 0.0717 - accuracy: 0.9776 - val_loss: 0.0853 - val_accuracy: 0.9747\n", + "Epoch 3/10\n", + "60000/60000 [==============================] - 45s 748us/sample - loss: 0.0690 - accuracy: 0.9789 - val_loss: 0.0878 - val_accuracy: 0.9740\n", + "Epoch 4/10\n", + "60000/60000 [==============================] - 44s 733us/sample - loss: 0.0636 - accuracy: 0.9800 - val_loss: 0.0816 - val_accuracy: 0.9753\n", + "Epoch 5/10\n", + "60000/60000 [==============================] - 45s 758us/sample - loss: 0.0572 - accuracy: 0.9822 - val_loss: 0.0861 - val_accuracy: 0.9753\n", + "Epoch 6/10\n", + "60000/60000 [==============================] - 44s 734us/sample - loss: 0.0565 - accuracy: 0.9819 - val_loss: 0.0819 - val_accuracy: 0.9763\n", + "Epoch 7/10\n", + "60000/60000 [==============================] - 46s 765us/sample - loss: 0.0489 - accuracy: 0.9842 - val_loss: 0.0859 - val_accuracy: 0.9758\n", + "Epoch 8/10\n", + "60000/60000 [==============================] - 47s 783us/sample - loss: 0.0485 - accuracy: 0.9845 - val_loss: 0.0889 - val_accuracy: 0.9771\n", + "Epoch 9/10\n", + "60000/60000 [==============================] - 44s 737us/sample - loss: 0.0477 - accuracy: 0.9850 - val_loss: 0.0729 - val_accuracy: 0.9791\n", + "Epoch 10/10\n", + "60000/60000 [==============================] - 45s 742us/sample - loss: 0.0484 - accuracy: 0.9843 - val_loss: 0.0796 - val_accuracy: 0.9780\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qmodel.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test), verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "in addition to __`model_quantize`__, __QKeras__ offers the additional utility functions.\n", + "\n", + "__`BinaryToThermometer(x, classes, value_range, with_residue=False, merge_with_channels, use_two_hot_encoding=False)`__\n", + "\n", + "This function converts a dense binary encoding of inputs to one-hot (with scales).\n", + "\n", + "Given input matrix `x` with values (for example) 0, 1, 2, 3, 4, 5, 6, 7, create a number of classes as follows:\n", + "\n", + "If classes=2, value_range=8, with_residue=0, a true one-hot representation is created, and the remaining bits are truncated, using one bit representation.\n", + "\n", + "`\n", + "0 - [1,0] 1 - [1,0] 2 - [1,0] 3 - [1,0]\n", + "4 - [0,1] 5 - [0,1] 6 - [0,1] 7 - [0,1]\n", + "`\n", + "\n", + "If classes=2, value_range=8, with_residue=1, the residue is added to the one-hot class, and the class will use 2 bits (for the remainder) + 1 bit (for the one hot)\n", + "\n", + "`\n", + "0 - [1,0] 1 - [1.25,0] 2 - [1.5,0] 3 - [1.75,0]\n", + "4 - [0,1] 5 - [0,1.25] 6 - [0,1.5] 7 - [0,1.75]\n", + "`\n", + "\n", + "The arguments of this functions are as follows:\n", + "\n", + "`\n", + "x: the input vector we want to convert. typically its dimension will be\n", + " (B,H,W,C) for an image, or (B,T,C) or (B,C) for for a 1D signal, where\n", + " B=batch, H=height, W=width, C=channels or features, T=time for time\n", + " series.\n", + "classes: the number of classes to (or log2(classes) bits) to use of the\n", + " values.\n", + "value_range: max(x) - min(x) over all possible x values (e.g. for 8 bits,\n", + " we would use 256 here).\n", + "with_residue: if true, we split the value range into two sets and add\n", + " the decimal fraction of the set to the one-hot representation for partial\n", + " thermometer representation.\n", + "merge_with_channels: if True, we will not create a separate dimension\n", + " for the resulting matrix, but we will merge this dimension with\n", + " the last dimension.\n", + "use_two_hot_encoding: if true, we will distribute the weight between\n", + " the current value and the next one to make sure the numbers will always\n", + " be < 1.\n", + "`\n", + "\n", + "__`model_save_quantized_weights(model, filename)`__\n", + "\n", + "This function saves the quantized weights in the model or writes the quantized weights in the file `filename` for production, as the weights during training are maintained non-quantized because of training. Typically, you should call this function before productizing the final model. The saved model is compatible with Keras for inference, so for power-of-2 quantization, we will not return `(sign, round(log2(weights)))`, but rather `(-1)**sign*2**(round(log2(weights)))`. We also return a dictionary containing the name of the layer and the quantized weights, and for power-of-2 quantizations, we will return `sign` and `round(log2(weights))` so that other tools can properly process that.\n", + "\n", + "__`load_qmodel(filepath, custom_objects=None, compile=True)`__\n", + "\n", + "Load quantized model from Keras's model.save() h5 file, where filepath is the path to the filename, custom_objects is an optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization, and compile instructs __QKeras__ to compile the model after reading it. If an optimizer was found as part of the saved model, the model is already compiled. Otherwise, the model is uncompiled and a warning will be displayed. When compile is set to `False`, the compilation is omitted without any warning.\n", + "\n", + "__`print_model_sparsity(model)`__\n", + "\n", + "Prints sparsity for the pruned layers in the model.\n", + "\n", + "__`quantized_model_debug(model, X_test, plot=False)`__\n", + "\n", + "Debugs and plots model weights and activations. It is usually useful to print weights, biases and activations for inputs and outputs when debugging a model. model contains the mixed quantized/unquantized layers for a model. We only print/plot activations and weights/biases for quantized models with the exception of Activation. X_test is the set of inputs we will use to compute activations, and we recommend that the user uses a subsample from the entire set he/she wants to debug. if plot is True, we also plot weights and activations (inputs/outputs) for each layer.\n", + "\n", + "__`extract_model_operations(model)`__\n", + "\n", + "As each operation depends on the quantization method for the weights/bias and on the quantization of the inputs, we estimate which operations are required for each layer of the quantized model. For example, inputs of a __`QDense`__ layer are quantized using __`quantized_relu_po2`__ and weights are quantized using __`quantized_bits`__, the matrix multiplication can be implemented as a barrel shifter + accumulator without multiplication operations. Right now, we return for each layer one of the following operations: `mult`, `barrel`, `mux`, `adder`, `xor`, and the sizes of the operator.\n", + "\n", + "We are currently refactoring this function and it may be substantially changed in the future.\n", + "\n", + "__`print_qstats(model)`__\n", + "\n", + "Prints statistics of number of operations per operation type and layer so that user can see how big the model is. This function utilizes __`extract_model_operations`__.\n", + "\n", + "An example of the output is presented below.\n", + "\n", + "`Number of operations in model:\n", + " conv2d_0_m : 25088 (smult_4_8)\n", + " conv2d_1_m : 663552 (smult_4_4)\n", + " conv2d_2_m : 147456 (smult_4_4)\n", + " dense : 5760 (smult_4_4)\n", + "\n", + "Number of operation types in model:\n", + " smult_4_4 : 816768\n", + " smult_4_8 : 25088`\n", + "\n", + "In this example, smult_4_4 stands for 4x4 bit signed multiplication and smult_4_8 stands for 8x4 signed multiplication.\n", + "\n", + "We are currently refactoring this function and it may be substantially changed in the future.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the quantized network `qmodel`, let's print the statistics of the model and weights." + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Number of operations in model:\n", + " conv2d_1 : 109512 (smux_1_8)\n", + " conv2d_2 : 2985984 (smux_2_1)\n", + " dense : 184320 (smult_4_3)\n", + "\n", + "Number of operation types in model:\n", + " smult_4_3 : 184320\n", + " smux_1_8 : 109512\n", + " smux_2_1 : 2985984\n", + "\n", + "Weight profiling:\n", + " conv2d_1_weights : 162 (1-bit unit)\n", + " conv2d_1_bias : 18 (4-bit unit)\n", + " conv2d_2_weights : 5184 (2-bit unit)\n", + " conv2d_2_bias : 32 (4-bit unit)\n", + " dense_weights : 184320 (4-bit unit)\n", + " dense_bias : 10 (4-bit unit)\n" + ] + } + ], + "source": [ + "print_qstats(qmodel)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input -0.5451 0.9960\n", + "conv2d_1 -4.6218 4.0295 ( -1.0000 1.0000) ( -0.5000 0.5000) a( 0.125000 0.500000)\n", + "act_1 -1.0000 1.0000\n", + "conv2d_2 -21.2500 14.2500 ( -1.0000 1.0000) ( -0.2500 -0.1250) a( 0.125000 0.250000)\n", + "act_2 0.0000 0.8750\n", + "dense -52.1094 39.4062 ( -0.5000 0.3750) ( -0.1250 0.1250) a( 1.000000 1.000000)\n", + "softmax 0.0000 1.0000\n" + ] + } + ], + "source": [ + "from qkeras.utils import quantized_model_debug\n", + "\n", + "quantized_model_debug(qmodel, x_test, plot=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Where the values in `conv2d_1 -4.6218 4.0295 ( -1.0000 1.0000) ( -0.5000 0.5000) a( 0.125000 0.500000)` corresponde to min and max values of the output of the convolution layer, weight ranges (min and max), bias (min and max) and alpha (min and max)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebook/images/figure1.png b/notebook/images/figure1.png new file mode 100644 index 0000000000000000000000000000000000000000..8ec06971d65a0307becbd83ae089db544b707c2e GIT binary patch literal 46857 zcmeFZWmr{h+dsGnk!}g4Q)yIC=~5{Lk&+G(>F$&U2@z0HI;2yPZV*MKJETNPO1kG< zKJPm-#~d^NnUC|~<#@cG`(|&~Uh9hU{MEI?Rqx*=!oP@*APA9yyqp?>U^62K<{T~# z{Dvoh(F^_$+et=29T$GQaUX}lf8*K9KXO745@YlaMy_Oz1^n`+^KEVChjve#-HaX0 z5F2A>duuyqYYP)5S2IT^3p?9ue1d!eJWNlWo$W>V`TvjC`Rp9c`72gKsS$(;QIL~X zcTf5=>7jS;;OzXyaQ);b-Ac7S87@V&MHS zEpl(a{`e7`oSeL`EX&UiZ&4S z#}Bw(hKqO0-MxGMYGUMENO)M7B}N=+R#{7n%wsb%W=YA$PqOeBo_vngbi%^IJ^lR! ztE%u3KQAlGytosu$X={pi$hOOKQcBZqoNX%%0x!i%YBy=o05uZc6W8W%60W(QgSkL zf}=AR8CmQn0X*r8S9o}W)bilz>3=zT*o4Z$(|;!u41ckusVVJ~ER3C} zslblPV2jH0u&_m!rYUo{)r!EzXK<6~4_Xe>h7A_Fk_A2f{=)x8Pe-@$m6VD~^LZ@!zw0uf*EP4WIB4M<=i%WQ$W39w)vmC-a3_Z0 zL}o7i?d!zE@bBN1c?=uq7P?dK!pFdn90Y7o! z7r1zM&+_uFEDseuzvc8>iWD1JZWUxB#~Ug!jPdtJzRWw`6cuIV;JCxXf8|Q^VsGZm z;vz2GYgJX1QG3J%uGMqr&u99bood&)3&M(=nVk&?CZKMaoYd$2^om#4q$}wb;-8hp zrdRKI!RK)MLEbgf&xF@4`mkXEL|qhWNxJ2vs;;g+bMUI7(Wn_yTSup*tqptbbJRnn zkjn{*%Bd3czkmNO{q*Vc@xhixXKPE#^vnzfD=X{B=&14W{^pl&-wJ*yYS~UT_&l9z z5QlL!GB$qJ;Bz#)JXm0I&>VQ4IadBfTpX*dsMBvkc(A^{KKqSn=}5X;izG zDp_+Dpfl^9fcAC%v5NZ2V3i;fQY@ZlOEA>s7GLg>nHNx^__ zy{FK{8@7bv;^OEy^uNDnl2}_S>U*-IwSRP!=`iyKm0>My zzvh5*xP*jqj*eG7_tv6}TSJ&dMK4J$1$qVI5S98Ix$4!qlQ}JPwHH46P--(SoM774 z-p(u}bm3@ky#scWfaSpDtgNi*Uti-)lKXR1Sa02;WxX4XjnGMW)7*P|{h3U<>aABY z=~A#<_zdd8E^uoaP1d;VAFh;ei!W4F-9-QLbaUWHjSIKypOI(h2&i)PsbK=x*w~Cl z%gwZ_?3s7hr}$qfWvdnHRg+Wm={1EAUD`NWZ!|sJUNZ5xa|b6fDvH~O_fffA;n~KF zta9SjbK&9PVb7ljs&n0zl0xDY&9WENK6iASpYq-!QdU-ek&qBvSjeNCEOJ50VwPP{ zkW5fe@TTh@8f?hr{QO4Hw6lnjjxLf}HUvRl z@jMDh;MSUL3#V8d{lMw$;-bOu8w#1jpAkXp@lQn0o;{m&g;6UsY-I5C^qduzl#~n# z4Q=i3r$9&<7^3edi=;ouyo2~d&4S{W-+s+&*Lu_H{Q2|fD8$Cba^XG8Sz20}$w~DM z(Blhz_wEuEpI(?n|NEJ(Zb|Cb@2!5;-TM9G-Fd{`)%A0}Mgb``b?Dk;EtY?0TLLjL zvG&KOWH45t>y2j&h>?}m^K0f^>NHr$n|JS?8K)eRcI9jRMVdoAFPGBE|LM z;^LbQv)FZYb*0`1j@akVo5DUt{I5KGKfS*>r`_Ov>!Iv63wMcrmJHWU%q@%y%iD~8mfqiiF}Tv4TSg2_GaF_%)`?LB?;XS zWp!z32o4U8<>vA6R`0((QELgl?@@3&b$+4 zWMYB|lZfR0)yseR@+D@XG6x4K($vy|$;Zbhub|MoNOLt&S(z9^3aW_llP4T&sqijC zL&F3gZy%p;;=Z=Fe6WoIpZs{6N92HkK!Gzu3p5-=C-(QJ5fcI!HNQ&e0^n^ zo@;UPY=aC`4|G%RlhR50(sgxp_4MZX%48OP^Lte3jc~hd!78hRm8lK&8e5vw%a%^ulX7==cmL0a?l#Z2&hwI9yFB>#_-%8HaLv!3KbL;^ z(0_Rd+7+~(ppcM=ug`K8>T|p7BWvB(S)iuE2+4*L(<8{;Xu1b^Ru&d`$d}GenPd?c zL~3;(_O+$Bww}KJ13kUSpdf6An|B_0X|OAFQc_d1LJ>1CFhKkV2GoE3`qeZzsA0`2 zA;G|UCz5n8JSIl#7CO_%$H!?fFK^zwarCM{ngD8Oyj1xI*X$y+=I~xF-ZFdq*r+r9 z3R>8pWCYtap6n)ikwcF*g`T@!z3UW5 z$L$)U>gvSws_dncSer*jY2haCX!2_~}lU-fta_G%V9&ZFHCYR69yv==>Lky6l)OA&7WMqVil@$xEFIN3s zUBp`p-5mguHai&oa!(|0-Mq=f&rjmu;4u4w!ra_^du5o!^mAnMTOq47)yo1-zm*3{ z3>i-jcls|&zgOAxi^_ZE8M6tc6NZx#l3dpMd>oc0CHV>^1Z1iU3O+=g}~I*6fJ%x9%rW~WhU+TC~UN~ zwblOcgs7#pwYS>I9BG0Y0#(Bwib6BsUjXd0D_=csezKj6U=a6g`uX$p5g#ksJA1kB zv50;14Bgi01X7WgmseFd9S3g-vKlUK?(HpW;rpFMD@{86=TGG3T!$$vV9WWhU%xK? ztQRddsK+-ooin#&y?K-7)~#Crw|sxvPvOE<--vl!b#!zzcq$qk9E`R#sNMq{)od)Q zv+F`IeM)BCSGOH!> zUhd1gv-gs^>kP+@Rn4&3-%fkoM418mZ>|w@J)^bf=6Fk#lNzCq$-T%sV4Gj#^w$3RFxu8czLNR z5c7=`W4o#hW|Mwt8?SO{u%-4kq9n%pe%&YC}zsw`CnVzYy6Gu4eUgYRr>D4!(sv_!Ls(EtKx3<_<`KtZTt5`Q)t|mZk za#sh_=t-Ysm>g#poaQ(a7OQ*m;eEkIzF=dttMD)y#47n!>i*9E)xJB%njdYvXLl7( zdhjzFb}W^J)%HUN#Tqv`KmwXOfH|3y35r}%70p> zABUgdQT5P~?6aCK<>hsmaG5tNRcZ2H?rLr-EJ`sV_G`WDh<}}NRaHrxzOZ6Y#}_V% z@G0c7GAoYHY<#o?@9j3jW=@VfHF z%<&iXf>!l<>h1h*VrQ0W^&atKqaS5uZ_8RAnOKzUi zJr<#@2#u$R=S{x_F%jX#u*1(gWhXJKH{W_t+EUqSkNfZvnuWZOXB%;LpvWS>f|-QL zG@S6<=E^yC#h||VY?tgBch!C5;P$}_1+!dJ)$=h*46Uk~HCe9V<=7%>^m2ybAH>!RK5164i5hu%k?T}^o>DdIT8Tbo8Ftm7c<;QSkx$5ay5 z2h4Y>W7Mp2KIof^art5=s^cVXr%L8bS?wBBZz{5@$+Mqr7pg8^N$YtSw*HNr6lwb< z<(G46+0iR)k+{8XaIyQPzkB)r`4PhIw#NwU^r(w*c?1O&Vrzdb+Rz_f-)KBKYu!G( zA6ih&Z5?TCWc>7*lUtX@$_V|kThmM zkvoySUyCV3MqS8b9t@X$T-lE!ISIKco3MvWPmM52e3F0>D-lRBzlV^L)i3==o$d3- zal&ngEuL(|v0}B#T6k~jcj!xZxb;HY_1YN+&Ke!AZe|~z|HJZ!FJ30A&6GVRr#-yB zj1?{yu*iMJb8T{8aXt7>Y3N^HZcXyPW@Ec#AzTT3oqc|b16ZSIh>5|rdSNCbc zGm+$)kjT&Gv_S0vUowfELrR~EJI~5&m&8aID^InHioaH{3TSW{_lS`aE!PY^#RMVS zfR2ONz#sSm?ZfxV&2=7I_R9lTK0fWm^!4?va#_)=-iculj{q)a1h;n6=kWUB-g?aG z$|WDU8K)76x3!PDTUR&ETGlJy)t;F| zQxM))P044^;SvzE0yz`_)#bX?5FvmhtY40|d=Hm$99K%T%P^rGw|%tepI%;$=$1UA zv#Z~s(k#+pAM#kQ-xF|K({l!t+4pJQ@m66`QTyv~Jjn}-6V=x^R8sVes{;eE8V-J6 z;=Q9@4RHJBty?ny)h@1I^E>taXgyl4v#Rc*A7&6VRt;PJftWVm$|gs|&BG(7ruI^? zT#73pL$}fv__}uW5uK>3GH*2~CtAiQoz_3=XkAu@fnb) zKb$->7&jPN(Mt$CPa?w-%f!if*X$^oUUcJNA$3?yeXIhY4b1R>&e1#(k@w%hYNrM2 zv4-??W*E@O#B1h5BZBAb#c+^>pw<_YvibFXeP`@5Fy*rYUO8<+vh;x>BflGo9*!yJ z{cN^9t{=Q?6G_cu*9eE@Y;jCtt;erwne#wQ?1&&km; z6PG3}wH)N)(W}}B?vNAW1|2vtF%f0Y0F^^PC5XIqgBX6ZKmXd$QncgDiDY zY@o?f99|y2$#37j9kM)v&yq4WW=5tDcUCt3{*|$|=9!q7K+;-LgKyuy-7`2?6if5d zYhBzM)@PODJbqGAlI5>Ac3je=Xl1vZpOSfMvgkywq;j`p;wSON+SkNvr??+aNUIL^ zUTAPvE<}KQ`Y}eF*?^VNdjkUq4Cr&uK&3JUJ`x6M%wU1mx1zIlYQrXAKHj6F3H_6s zLMTCc@q+Zt+qW6ElQnOynG**!nqTztycr_L(hN6BMoNlQ>k$J#XlrW=n`dWdrxEwO zPDf8~{PZa5fNNsQ{}tVJlY6|2f(?EVB%&2rD?O(!k7AU z0Bz8L_VGs2uR+m6Tbo8qOw8!gB-Q7w|5uDxHnsg8YSYQZfkPDs)B8l*vLVDFR;D_ zD@O7wAM8qbd0cC2Ya=5g1emz@W;?aXB;S~R2@a{Y;SFn;#=xogOZ9$Xk(+~@<1LIC|gz4)yCZ^ z;_^yLtw1x05R?*v-Wm`Xh>=Uu;eNa?2y=h|!?(Dy(h6)s@^s%cVYu)Q!(^=+KdRvP z`1sJ>a>SOJ-@B`$^L8w&1=s-EU)(p+I?Lu}XM5h?y95FQ)CSl?8tftbjmPA$=Rn;+ z7YDG7EVa{Z`^c_jQB#oN5)8JKUh}0x(TCP0_2J16(o2_g*Xr}a%Z1Up%2ooLqUs0C zF|-YGzoFf=$xFbG($mxTXF?f`Y;1}g>%M%Er=iK^3*P_6-@r6tsuS=fK)h9=kd_&2)N|nX(^Yfjg5ZM;af}S!^m`J z0{8qYKXilooi*%x9zH8j3_vI)22NLK-(Q(3HO2yZ4Bh$}Xtzu}JcNKQOG-}z&JmE( z(1iW^^>`betu=&5*41@M=a*tOIobul8F1K~z5P8|lp4mS<&(|0dNdhGhOkNhOf@os z^U#v}AoIn`mjQ8{DiJRCQzfF{N~mg%*za!`qY6HLu*|$RcvArPs@ykOGz+z#L2*a5 z(cia3xAMajb~bdM;iSawfPbh%N*wwL6pF_ zBy8OT3S~5GzvuIyO$Z1J7jIWQ{z_hU+8RnCXx2%|#=*hD$49J|M~Q)f0di78iV`g& zBcu1$j~gHfl5rY>7Q^M=cfO>pO%7_uG<-5CHb%YIUh(xuTUp1)-soPrbi-B-RyA-- z0rw4~GBGYLe4za;oT=UcpjMdv87ZT=<;X@%cT;snaDA$ggn}Y&;V|BNzB3^#JiHmk z>c!i)5in7$&CSw0jMtx%LJ{ab4y5S&~32Y)y0V`cr=x zP4pE|M}mrr*PQeh;NQ}KB|jZ1q=l{(0;Qs9YRYgPW7A{37LS5s8u;@5y5AWp%o#8B zWy6S|+AC@`AX1i=T;=8E1LQ$wsCNM}rwQC5h>ED1Vq|SyxOw&sD&Zz{2o`qsb8y3c z?~gKBQhj!5Wn^Vz-@L(1R1Se^7;ZNH$)*)>QW=ai3~vR~dJBSQGgvx%0X;=tL=Aaxz$~U1kEgaV{WYI!-hGD@JnJ9n zj&T0SKAw{2S;-XAad?YEGs551J&+gyU*FF4~D-z z{IjAwAOLd)bSfEX>1N4Gj0<=zVic>53=0;{5@t~M*s5x3@D;PwCWOJiY6kU=6kEEq zvXZ*t^w6q|&|d&m>wOFiQV^+;S~t638<3KuX#|Y?_kBT<3`j^w*nn1_;0eP;j!LJS zCwCcr*qqlU9(hYz^s%KwsnSU4*D~~B|7bg@<1K0TQ!E`S4%h0tckc*D7$qnxiwx>@ zVDhOH6&10iFW2)y*U`FFoRgCy0P-3B_+*1mEjVAy4g6eObkR5W z^r()2`W3wT8jpm88o}*|VWh6i&dS0?f<=(tOB?MY;(FW6D!;85w$R0~;; zsJ;=h(#*EOBF)IiQ2&(p_AUO*%uE1N{7+W>8s`Vqe&4`-(ndEdTDHsULAPmvHOpnA z1;d)Iii$t;uf$he7_L4Jn_gMbHGU`LFy-Oyjtkq9ztYxYPxVnP>{xQR1GR?_&mk=U z^tr7+u$XnkL`Q3m8VsI1_}RVYwYP?c1c9`jOD8HUtXcTMYtI(LfB(bq__!{Oq;H*~ zC^!}xN2lwdF4Jk%Y*2|c7B45hktl@IO_$nJu?rFa&{L8FxUWN6Kr5!FT54D$*1a%6U;p^ zC}^q}_Xs$N)DIY+@X;9Z(SJ<+n1F176?+eNPtmZpWUY9E#>~Z4e^TR`wD`$h(}HGz zs+!)l4vq{hgMVU|hX+ns!=3lai2-D+chc^?69!1va%0Ln5C~30sw=D%I&>e@s;j)c zjt8Ys=IPVRsC@vf3^q21LViPIphkj$f^+qn((owXN{a4KJv>RlNVypSOdk;e0RbB) z=iFS3&4zpg&tKgucY@dIi(Acc(bwq>^|pU_m_ zLp&(imoHxiyqP|>Ov%8&U}>7#0R{}z7#J5qDB8N8tRF1(s;a7{4Xx5#*57r3f>q_Y z!}a&?UsV8c^PQZ$Iv@S@__S^~AQ&xwPO#!#&s-fG zq!kq-?sUBP`gmh|n{w)g%^2mACr_YsK70L|T1`!DfJ0zsXNPw$Zg*lrZ+of#JjiAb zS9T5bfghu)CERPz2XmF})p2TZ&uv6iB#8x7Z6L&4N0gM5>Ng2j6i}K9CTI*yQaaEf z8fX(JIX3pWSG)zhLjbHZ#c`jXN45Y~_BUEdHUTyLd!vAnRE2L>@+~JE5fPEafxNiS zM`A^A6|VIHCqMQ1uTEuUWf;grjSB(LAUQ=ve9(V$)TYMAYdaaJ6ZE&Bg@ZYsx81Py zSBBN7~e!bth}%BNQI+l zU*+q)df}Nnr$iOmH?5S#wWRX5Orr0}$?+4$k@8Eas*-};9tA{d7C!muXgNF77|#0= ziLtR5dE2S60umAu=xTb>oe~935Nb2@_V#j-n1P9FLXPD>a?cw~1;Eu3+uPfK%D6ei z{`~nPVAOp-gW(eSG~l)5i*y4@M<ptdE))`gFdYIlm<6gi*utoX;d?Mo{cC={`Sa&HNe`zL-WFO50;4CS7oh~>>GRKe zFF|Msa(C{qFfmDeh~o`SPd83*N!e@gIy>2$fm8xO=()-%-&Sfv?`JSTsB0obWvbD6 zbXu(JvbM@uqVnjkdOx$$g}gs+tw5qD0w?*xE4RfS46i*+yNR>@;Rjp^nujN8>Rg-P z&<6zv=lw{=LPSK)#9llxHy_?_Yqjf)B#nt+w21ddxKw057HZXJXL7q(W*!Q6otKu| zRx|t30Avu55%4{BcX-E1H@qt zKER~>_-nOQCm^JpzAU4AY;HWfys^p2QSe4tLqkUYWtW$V+33DM-Jc^-etSKy&oo)m z&!BFod*z;Pn$uA1ql?3X2}9o_I5Z+WxV~PF#k^;rNg zS^;Xdr3gST1G`@eYJ-@#cp3C}x!bojM(@MEMIGtU59X}AyhPxa$;!)zDe@#|{mk?99mGe-*wQF5@fi1s?jeQ1HI={jKRxbFkrkk_v`_L2z zPE+aV=w7^j9SYGHhcDr9*PvH~3Yva4f))W>^vx|Nj^*K!7~hja3(C~uVm?&1LW41o z4`60yemY#7{3ia5u+2qGEUeyq4Q^2Th=9Ad0r-bnNRW0N1ELBf(9`A&&km^_ot)au z?09*3xs}~rmio>CkPfJ{+EN)ZI2j7&`fK|qw|N-#OvU4^Jr2&fUT5^%u% z+g_hy1T#_=*3t_|he4fSWnxO}D_#Rn1-%IzVme#^D#yTKZ+iS?!y%;E&>KF0r4$Te z@8=lC7+6`-Dk>!BiRqfsqytqeZ6~?T7_V2aqv{!WO6dx59958=0!1F3S>Y z1fX(pvi1gq@CG@OobBzK!pKU(B7pyJw;H$)YR0(Wy+(Xu?OXhd91=J6eaZ0 z(b1sh4qe+(e|O!29K?fOVEy}thbAzIRrWKO7lkahl^w980IKy?Ihew}!-VH+1)kY9 zl0x@+mnFnh><1)?V6>qkdTjFaNj$NGQEvWoR zUV`sKVdQw#Nc--Q4Fy@R)7&!6_8l4t@0*Y*n4X^x2L5K$m!)7iS1BX}YG~cjS}iU< zz6`J}2LQ?m24orzuC5{A;Gz)m>L-TF^oET83p+WXtmnegSOqs!%-%B7yV&Q>%_O`1 z%o{q!Mn@GSs#kpa_@EAf%=*a8;>$AZQh+}o%92k^PHsXenS++8mX`_1NHh}Dm#y>^ zLPPMEB7kb-4M4c3xyoS<3KKR|EFN7=5iBov_nRQiL_waeWqkZn0Fx0YJQ~EImswwb zgC$RhcBqDi2H+){kA9Y*F~BmDbFcUfiErGv0UZ+uiZ`lVBh#?CpfP~eTJ(lFyA36^;V8bLaWaJRSRbaXC)^MIA7jt`<5 z=;vvs9nmO7e|y~`09q}pxVZkYTrCUWB{YtMl3~#Lm;?k!QNHCRAhzOS+n0ssd$f8Z zN!*JHau^%CyPq4+PVAzzaC?6>a6_~)2#*3)z_x;k@(X#@z!4nKb@?@I=5NU6Xin&AVvvR>7QKiw8sEIg$t<+{HnIaWj+3)VM(Nss((N|2 zq((1>=j2rB+K(bzUdmJ)O7ltI6g}oCj5%H6{+F7Bagd=g~e_`Ms92}IEmNo&B(O|eAZXgV64UKCS>s6~QJ8a^w2bc7c&#yjcW+beFvSU$vHwMQiWXkekWRrVWM-*e~bCJg7LzF<6D&n0Lp%_Ac; z&~8wT92OV0q2ZL)RSYyZqN37`@6}AI%5r?oT(PWN``B}`+GX)^PL1XFm&8iyKb4rC zAs9b+uD`R(`5b>w(Z;9~8ygZ6(U7#}&DSbnK(n_n$`KU&uX4&26wo%^EZt zj;h^}k%R+9y3~+&DTuw`ZW)-KK0}~OyDUT-^^AO8QX!0!0;D|AGuP~+&m`d;b1ypt z25LSi58=i%e$Z~d*MU4J;B+0^(Upk`r~Qo?E^gZw`NK97mSr0%>!F-pGKiFptxon4 z*EXf9)Fw-d+G`go?fx?7)6f8~?D5BkV;1kg?W2KXDvK3bjlL&i#Oaz(O@#u7@4;qP zEitGcBMw{)24FKg{D|a(dM6A2?Fhr}Z8iO7`IY-;;tEcWz|SZAVZ})W*)c(78B(}u zi|McIN`_`CyFu<_KaSra{_Lx-n;Od+Kgn_R$dl2qRf7BE5ZO2W zPJXxAXt}$`W6f>SO1KyaiguPK(v=r(I;TS@Yqj!I7GjF zkLcb-%WH*y#eLHnR)vZw1-HkjO`c}RS}c}X`X*8Lip z>^oxK%C3zk+1#0RW;;|3D~I&UUkzV+Q2$Mlm$~VRt!L#CjSpi$25Cz5R*>Q$uxU8z z3eoTw2-0ZkaKe6tsJY+TC(Ays_4-hDrB_e&&H!kG=>FyF*IswVd&NY9g{aR6p)End zGKlVbvL`pVTtrCEw8S{t?x}LynfDXZUvj3Lah9_9oEh&?r#HABC_RU8 zr{Pm&^(|t05$FaSXj&BnYDqAPdH|A%dGB}kd4rDL8q0du9>i;%$oQX+OU9RJ_$$6& zYuce`YUV#16~027yQ1eM2=f(XLLOk?)p-1u0JV<5xB_VfMSBqRgsWWI09E%Atrv6C6;cvT|}DPl*Yc&L+Uc?N#4A0W;Ih zfd$?6f?3uGb~2ba;7j8Gi@Ru8DP%R&1nE8aAmFH|@1VWw%>VMg%#}A1J|f?L{4ln$ zi2`g~K*ta24#+(MSFZ+uI1LT5#8DqU3_dyJ%^Moj1qXc$&51%T+F>qa`PXC6!S6uF zZ(BJ4bERp@(4}w#dMax1K(1C>SJ!m0Cmpr*gBNs%e4$qY<}n5X3sTb{3h5WE!JDV2 zrxBoGko%?~VK}}Pcmyb=r@_bTO=4m*;_M537i>HsW~gExR7Fbx4k&;wh*s&iQ$X#N_ z+CG2IXQS0DF^EFfb|8raipEGOD6lv=ITKvL1@&)gk^<>3({24H4Y(EMA3k8jOY93h z?Lo#X?K;|@Q=`HfBzg$|zhWVfX+_U{fIda>GA<5{xud6DpmYGjMV%B(fPjK)5KFAs z0-0Vm=YI0Wu%RC_gvinD4w9PC+9Uh<&I@Qb9_2qUdG?3f7T~Qm^!W%|k07ub|GynE z;OamFpMYl3wuHJ!He(;r=yj<@KS$YpNURFm{fvY$!$8m*1+5w22AJ?CdO4f+Bd=X^yLt6BFRT=*_HrH5LM|80ACoGWvwue?}=zPfvli=5Fhy zIM24>AyQzvdbiAm|rvmgU@g7os9@fDO`p2#`P^ zunv2+OE=fhSS#pBA&}zjN)dk!i+&Emwrv@=pL2;BJSzVN5~f)#)B<8gi9!cGIsSlO z7jzg*|D2Xpy>!a{)Y_xhEIw_icVH$`GZqyLz4vl)v~+i2S`}hb`Ek&*7EhJ}qAynJ zH@LY80CKwTO}ZJ|*^wX{8yg5hbMfLVC@xT3saEXZ1PcloL6I;9ryN1gCc#ajSye#L z#=RN0_kYEy2!P+=AgQ9FlHdyBkl^Dl5Ql%su2c#ZVA4$oG#vxZ7$3NWOhYc91At0^ z9&3TXFPi0rBRhBDs1Hj05dVUL0_`b4WuWV_ujSZ`mY=J%oB9mPy*;!#hk`xfN-Yk6 z`17hF6X}ovz$GFoDxCy*!xYp<1OcPHP^SVPEOS)P_3!&hU z6E0~nz62``9g49}y!9nFU}b}7ggz-62Um=fhb9QJ0i<@@BA-8R1qTWP!NMU_z4co~ z@&qMfsDB0#VFZowra?k`g2unk>x1%hDqi*Cj)hdeXh`vc)0w-|lotlaHxM{mATxhv zckp)}47iXtPaoAdUjcgVZ`W{u1BWG&?{pA@W!CKhY-Xg=j?wpc^DcVI3Vl$JVbRQq zJRA9S5X2!|0$0HK-IYv`r!GWNtj+}o;~oPSj~LF6gwO>%xKKoR?)&-vYe=Hc95NAkEyHoc5?~|%6$&mW;6_igSzAX!qQ`P4=&}k}9#^kjgY0%e z`Q-4h*1UJYEEt@=ym4RAzUgolA5RI=c){@$%n@ow_?;f)jFoXHzwrm#?6v40zz3@H z2Fw|qSDY_gc(t^=Xp+C6?p#rF>^7<-+u8jcwm#=^-u>?Hs!zg#(%*Iqs>ZEaD#BS# z^e7pqQm8{Q1Ap4hQ++gxs>#&Hga-<3^z{deGvU<9pJ>MayHh9XCio;I9T3(+TPPYh zeRu~1M$Nkt{Y_xibmWU_KD zFeOy7V_+Z#!RpUstyWSa7J_moF!j;*fl7t684E!i5KYsgk|p>NfEVDkhEh1c1B+s0 zeB2cHP-#U)Fq{~zt*uq8h3r*=Z~j!e?0L8(H6q}${4f6i=bb_!>H!j#ASgx&a|&$a z6~>AQJ0IO?gM)%lh3lUw4<~GM6QGhj7Z(Z3wZs=c$3Q85xl)7y_Iodr1N zwtF+*v)c4ZI@?g`7F{45KmGgigepTTT#lvJU0crPr92yY2&trmAKh(`(gD7S%5D&R z;&OrHKh%h4{~W%L9UUFATnR1h?dKqQ0ZRu?(V=mh6Xngej*eC^W6h$2|8nqI59VJ$ z#d4@2=m{Xyjsb;D%XUHqJtu}9?u5N=zDxvT^bC%S0?EWgq8TOm3Fa<5pet$~Y&Td}_d)V1!qBkB84ut>`)fmAJV^E+QXqLJtp4Y? z%L6)3bQ?p`C}Bn*p&E1<0PbbBlOhhY-y!9YJ6HhcYCt*%pv20Q<{VHBh94X_$s4Hr z_)&d{*9ne%RWNWM>T8!MC@A0tp&J|gVOJY67(K9>1!#_*sPOMw0qdWQn8LxurM;x_ z6knRO?Ux!j36%K3Z=R9{otdw9WoG8DpT#sjd8sVp&xVbOAt#0V&Ki?Epf*dcOcVd| zbc@=Ajubmrgyds~D^;aoX_Q4e+_>{AA+K0)D{|#rb91LLucrm^a$v}ON<26rqrIBQ z9`|@xjB(-n*NN_v39qTv%{RXix=jXQjqgC|SOWF5f;Z#QC{P;buk6y-o#|ZE77KA7 zFvu)MZaI-PU$39pck6puusTkzqeR55d}7LmjRakIZ^~3?1InG>=~ik5bUZV2^A!l> z4nom^(4u(-fR?k9b-xPOM#(oFWU{lfYgRvvQ&Lc5!q&6`UG8wN@vJDEf^#`qbWIgP zJ!}fVj(XsL?MJY^d|(pC%1lEIyf<&hSSiB6Gl;QlgC@8F5V9YH?P8Dw$GRAA+DieN zd7!EJ>J%<#0jCHKPmk7Z8gVP(^FalMBRI7vX1PF3Eh=XE>r1TS{rf|}8FwJz_w?-a z#1jOo_=JS{2T=>BL~y34s9WMl84yFNZS_1=<%()mKOl5Ci?a;x9{|^W2R=m~*3$0& zAqW-#FP8FhQu;u)UG0L!p9cqVT^s)xy3YhdlmE0(q*I{?2K){*26S33mIzW1s6IEr9(#k-!eTWpPK6&z9wsBdx13hmCk6Qq>-U0;Vnw=qC zdMcXw-7%{QNDr<~R6hi0+h4!;)9Qu^8hQqGCS|DbQ6^Y{V^n_LkPO@Tey#f{oO>Rt zv`c01*?s`y^UbzCE+0LM1M%Gyh>;b*gss4aC@bpN%#l1%0FLnQPTkfGd;lT8mGG&7 zDsE#u0Baa-z|I0v$d)6e>5xW>kBiIvV7|KW!Eble{NhdfvY%z$i=~8PWEEya(8UgW0bQ(B5d#z-S>Y{{q73^jgFlcHYkO7$*t92_0A*Rm<>b$sCyFLWZ z<^u1P=Q7;is97oQ7aG?``H3c96g0B2B#cSLKPcNos@4rD1%D5wGLtbq6@;M&`eG3bG1SpaT~1$dO7Rt$X) z%;pysR^V8tCO9KDB@t(90CzgOyN5v0z^Z}$XEjX6!zh0GB+1jA2^f)kpF3K zD>b}vYzJzii0^R~-1@xz8rXx&BW0N|cE**eCN6cGpD%e%KRX})`nA@$+DZ^OGWig% z8LM{6vAGyoLdYnghdzaAjQH>Vg*4l|uDuGMTRHK9ewlBt3s{T~Ozr2jEPHtjfAh4c z+-j=wQgb`m+xKHv*mG{YOTo8k?M9#aj%Z#C`-sJ=q!f>V^R*hR1sJ-bJAEpvV5;U^u2e{<45om~~(E4Y!hSPD1`81CBDMYqh$eFQKJ0LtJ(!b+vtb>=L%#pO{@PiMZ~S`F>wRRr>aT8G(BF!&VpF&hrFN{B7 z%{M>yZKq{%U^Fa>6 zd`BGX4f@=~)3YTFv8Kdt1xa()NIu^j?k05FI-ouy(vLPkFgZJQ; zKh2P_wm+9cv&m+m;TA(f5ls3=T=+sTWTKR7KE*a@?I4Aczh1Y{qQCl-a^{e$zvw+(DUbF zI||=P%>HhHFe}*U?cI?Y&MVz&_qJH!p653g(Oa1E{NJBhepT#Ke>!g`O?E}=5RMA` zg@aK2)JqkLHH%zlgO~`eWg#B=6wjoNyUqr~OX$WJ*7vI34p_~enM^oTDxJr`M$$0R z2}EubVoWfY8m>u%erjGle?ILc*P8(A*RV!jrg5ME+>sIE!5NG=w47Pmd5y0mZ5;q*%FY#4^c^WU&A2z>F| z5`^9A>!lfG#`{vj_xsMld;McA0)*)?6nq3}ijuO9exR-S)Xn?B@$FOo46ZdyT!l{e zF-h3=GWh>C!;4W5>zv6=n`FB9_udp9)Sc;tsOhfdn9_=! zK}p-2wx$ahh(9&yzb!t0c0fZ_(7rZeB#>)i{jhNuUaBNRU&2oOU}nA=eNuh-NmP^$ z`|TSqgn8q~3}_cEVXmOyz|CQuMq73&-7&O3Bg=4?hIqV*WQT z!S@8@K~}(OW4d|l!@@C(ig{O3FGS;u%~L!+&pUD6dz%FjbyWb213#xNn-&Hh_q)4y2>PY2p6PpV9uI zl+@G`0Lr|&pOnEQiep;^p##i+IOp;T91Jb9NeJ#lz}H3i`Zj=`Y`Hy(f#BZICH}Si zs`2qs2FFgtOnQOOZc;`$i+;_vTF!z*{rp3r;^7hVZr3G#EycU8uNxUpUwxc)T{wC2 z>!^eSu@9IslR3}ITvrGv350$vPtTg7A5X7yL8bx_<^7Vcot@p77#>m@T}M5YN##g|ZBzCO@hReGDX~zo`Hi06=NEOD4f_7%1Ejm3 z4F4>AbPXP8cu5Fk_4`K)zkazNPk=xGUy1^tz}-6);&wnR$D;}G@x^=@VBDi)a^(I< zG>q?LVpSZ;wPSqX?!bPJgGgx`v&qq;_IP_o$G8E|LVRZh*hBPwKlDxx+lGBTh!NzO z9PW%61nC~YQS#PeO88<34M)fFek;Hc1%p$A;=qFdtS)T&Ko~6p0Oa@zd@}-rBnj-` z;k=g*?~j|Zm>laV+~%RjK{ohgaQUw#DlbpgR)TtBXZ1mwN2k%Z9==T~$sq34t0%*T zFF}x8C~v$5O1!4JdiPE*dq7pWEn9=zj3P-CkcsiUCx!H;f&=Tpv}S}AfPL9d{~mY?kmPZa|$@XxsY-JjHn3NUAc!3)!@096yq|1 z0vqbw0-H$Zp4Dy98N%P|AEFn+>O*(u=Au5fr#3}|qrZJf4tMN0ju5U@vh=Dto|11W0;3rNs&lLsS99tbs2VPVP92JSxy63I}gevLO( zUzho+>Sz17>z20`(wg&K7Tz;#ku&ntXjzVY6X7|9YI^x)z5xA!aC1l(lsRP#mqoIO zN2QU)%5SD(r8 z#M@IG?K>#Ugb{Wu@a}ZIvU=G=fbu0jacHNb2K{j6VGHynH&H-5Ai=JQ%jF>S=Zc;B zyaZVxbxYR$-2dquB@Ko^_fyC|W)2wqI$yHFZRWc*c%#?$NpX zDPOwIXjIfO<$3%nsqgt&aURuw8S8d@58v;+k{B&5wU`el^W`DMKWr`WrjM52y5gw| z{>vJBHP=ScJA@C3m};hRB>ZfbpOx&;{CU%4D7es?Ploc;d1l8NB^rMc!<>3Wu-%$TDSq`+Q0>v>&}-aR zPNOD0uuk>V6)GaCe^)wlN~JPYJF55U!5?eUBaY=+b2Kr-=Xc7&wKW9qw8#;+Ko^1y zv_vo04?S;={XCuKJeVM2+t2?D{Cg8N*7UvL8Af{6k*@8n+@<&m3I;q(PQU&8yC$Cu z<-y=GzKgcmj99xpIt#B!#5^5`iX5jA(U_7NyLQ861?(RewzIS z^PhJ?geJSX*g#_&8)zI_TXkyaukXlj?kOdXj*v_8bcavezW7k=_09MS-lfF}hzX5D z=fK`jwpw0=A}ut+N3R$xmW~HTM)oWI73}_VEf8tXGo6sr#!bUeU=H+Pq@<8&kMQh1 z!_L~unXaQ7b^TpidrVYFbdl6)bP^?d*2Vwrpt9Yb!%!VADDTYTI$%^Uppq1^j`&$K zO^x~Z$ee~0IfhTPj8k)eWu0ZeQY-fgOdT!ORTt7ROB^0FhW8zL!R`pC2%d1&w}8pi z(YjxmihR!vbI!ry8r9`qJO_-*B~+5FF?PmzJ%uQymnQ0$Hnx}+5eJV6J`fJY-2P;O zQ6+%e8$$vTZ5_z4furWJ;fJL> z84NRv_CEDaKF9qKr(sR!*G8}BJDLse{s0m0Ogo2$CD!z($rzE9;-joY!+puh`1qkjsO z6oR#b)jJJ&7p6kEXaxJ$e2}0RgE{y)%!&G2zm9_kC`}+ zi@b#OJn19GNk;_u>1_YA&`uMFqG(+Em9sg$>F>5d)ENC&Jc89SIV=L@>tVI?e@hGk9$kn?E=-0 zxcqkAqU{n+p+woRnU6-9OLCVsn?gtrk0(thc8@=S+xJDH78G^n5B`uYJ1>KUl0lCK zGP5?8#iq5bn*4*YZ6B3s=lX&RDAdgpV<;Rj@pw1Z>f%up&f>Z8!VIJ0J&gd$XRolE zhwv4P&Atf_xh}aX5sfjs-}n;%SbZr@mXN4|bL*-u&RV%!(t*_H}flCUH_hiAH zlHkeU_AEW74=A6gIIgQ`3oGV@6_=VOF}8C@+HD11wEsVghG&e@bb2Zn-$p{kzrQ%f z5L2t_Zj)r@l|tb@i;Tyq|J==*=V1Bou%Y{S#W%+KoHHn2%aff))M+Q`R!QXj(r=P2 zt{G%GY4gH9TC7eq2KB6v6cBrPvV)?nh5btiKbxk6l(D79-K6{aI%|dl@Ei_3*cQsu zXd|ox-v+J!tCvv;dG(is*cD9>sU^1{|4!6}5ANz3}W; zv2Xk zp~7tf4n7s9L{^%73@D$ko_vfqmuHB^w9QP_dWG2Zik0ElA1h;tE6*byzs)=O)Btvy z4^Nw5N4qq5nhg%zN>^U8J4DOv%;FK1`M~sgpX>^!UR~JWdTRG+se6Dgy@((?3vx>^ zTX+v!SHlomIUumNZ{8>aJSw*y{#||y^ksCG>dkM!SU@u2z#SfQ%Ry&D_?QJmCZys5 zyn-#v&bGvz(NYWgS6_^FVuVK50$}(jt74Wz>_arA#nsYszwdOxh~whWc_KQrpnu#C zoT4&d&}{wdkpQs5@4Wc2?o0|#UP=mHEfE4onW(iF0&b$nI2P5%T>EP%(vNY)Y2%vdh(`C#l zx+R1vhh4wg9BS8ZgM-ml1$BQ*k8wF_lNR!86aiBff_y?QP@+_Tz}nmN8kaFD^d%xI zfkecyw$-xGl`$A=SDyu4rrIkFb`5r}0R5uk;;cQ;1uFqGp0MuETt<)3q%=7t{G8^? zxMiqwQ0e&CWs6$ndAd7L&v6I=odAiw2OvjOLV~T;9r9o~B_&M&jeS-?ELz-`2L&|$ zykkh`LygIF5b0C`NRGb)bog=r$GP(CRgzGziEC?X(;Lvy)60c!XefuH3N>Sg$6u^!O`<4Tlq+M=Z8a*zF7=%#ZY_6fi^$qCFp?oSjqPVq2Lu!#{np@R$JBIbg;yhb|$&8XthyO0}=_iTm4f3_Nz3 zCAy8H4Dzy;*>`e&$2~CR@#)X5y!3b6r+P(~qevX*3uM$p&5ii|qj@ZPB_Y?$@4lI- ze*jff81u!=7|=PWz;t=#bqX;ugV4m2Cr|i6f}(E&niv6&lz6}n;4z?M9ZGbKNPR3L z-shl`Z%5yhxU1nNc3A!wK=G1-$i{zh0sgO`(QIJ?N-T^wZ;oy;-@4Te8EXA;)o`Fk zFhX$)c#Z>q>0xNd-*Ri*@yk@4BpTg4^eF!aC1ad)(_K-hZ9=P49uUMc?I!ADz^tIN zB?VYha^XNQNdu;@C6qA3Q=x^*P*dM23K|kpC~PamfZ;F*CE!wM`Inu^u@qTtJmWr} zl@S>p{u%yy9xB6{Kry-t`UtW!LAth3m>$$~2qwR0oe30PK-IF)Z&*Xlv)+942aKUA z=3#55Gmdin6V3%jxrXDQ%`dKVA9!>2DB(ngSAnA5S~T zN1MVy0-Kw>i$HvCqM-pEHXRs}^F^)Hs?eW14jKebz!DiSTmSQeA7rMM>kn*aW@hHr z+c^ed;-U6UpXF)!ZGEPl2bxhJbC3hr7pdKi#Q=(LY{(Y5@|oPrlsJu=J~b%dL0I~{ z&;^t)ic8DtP7jfBT2hiI^dUga%*+e?j`!A={6xp&b^YpGFH^R6H%s7=)!?{j)wPfRkx73oX?EwlFK5dNck(N2@K9z{{;|w+;(5lb zCqGq+y&TU99!hlbKj4gZHb^{DM^4{>*!W%MxP;*KB>v5U-Y$W(PF{_baoEF8Vn^QW zwsUf~8OJ?!(E8PQtUjkugfE9a6GPBOBKOnaGM!d7E{gE)iTj$g4C+)V*_b1H!}W&8 z70FQ2)}A;pXz(Ua;tu3FdVzWyV~@litjF5+KB+4y2ZdYFqMey2?5_>@FU z)Cdnw@n-Ix%S`tZIuwSM&YAxkpV_eSzLKoBtP)W(Yex-B3FOU~a>vyZo z!Dsy22t1wN6~qJi>O8wgm6Dv5Y({!qDYm~~x4x4mwS0geM#A$q&gg>=7rn(0!`;ofui?1<&??BFcRpZj=LzD1FQ2&6U@)r6wa*keJMUFxQcVqOR|| zGH92@$m58+4Y8YgZF~1zrdkRIYt^8<_2o4(+>K}0U0O5mUNIRfbP9<_qKKHFu8J!1 zi#^2^{Z8gvxG*wju_&cdd1jd`N?QExU7pRwmNc{Y3YivHXCVks0(rnsqAZguoBsK` zhi%2VWb z`OCiK*4O>Tt>U)qb{&KCZ*xKME)bz^n3dJ*&2I$Eqc5ODaj+wrnOyr@P*z1w{Oy7w zd%8Uz&vs&2E)L3fvSk@T-qg=X7~@!&4M)E$g^->1{slmTvYr*jsnXc#4oyl@swg&! zP551N15669it?iU{>M!W{|^&;8ga6BDmmW5xrA`FrWkBc|E#>Uxr*>c=T$Qw6JwgJ zTZZI^#Rr5pn8dizJ0?+J0=y?hIAbR+3R9mW@1^#!o3fU!)&-mUc)dSp+LcIVqJ1#S zrgL9i@*%GucJ=zwH%l-rTdXyrBQ+y4V#ga~pA2xzsx>W(GtGpJi#0CfMwSG;F$o*7_sp5ny7+wJy0g<~U;V|@=%yY>rTDYfCk8b{F9~%34X_Q(1lbXiwYSS94v?Jm=Z4xk^6D!2itTC-D z-MTh7d;3XQHlVUcGO&bDk8pd4qSDn5pLaXWI#|iTU&#FDYG~`r0fbFd$neecVS`ZK z+^Fr=U9K)~k50?iF#7%WCxf(#q+6(BZ#>+lD!*;-Wxwx8qjnZ0`VYH|RvW9vxZbwC zT^vcF|IDZ)Bf((sXC@omKHCXWRRe_o)3PC7!cnsH>hjaelzBh-~1J zKAQvvC{A$l*{h%0EB+no*G$YEZyor|1Z&8nz-DP3Pu8ufMQCSJuehwr4%66^yJYd( zFtDB2VN{ba)^@`oqPXT!ABH#*R|f8GsC1zGUJ?(m?T-$6Sku#)1S^MEJ$xG&nX5>W zlWYE!TSN^uq1BUQ+w?EVpt=vyr7oj;+wLnE<+p*+%MDDQa{IQUIjCgkCb<57 zgM!IIiHGF`bC`4mdviwb7Qi|c*BkODM0$0g%aLjXT3mP>{{F;%vgwTtgQT*z5$-9V z8Ti@Y7?G- zJD$c1d+7k@!Vbj>_fP&h|5(_+$5MT>YFiw@WLc;@MZ%{26Bw$5LR`APU5Ok3CqxdX z2x9I#r2n3twwZm%6aisJN@-~+N&{{quAhr$8BM$&$0p}CBcU=&ktXGp7L>njXbihR z6ZF8bB${_YrA1og@H~IQEBbA;Z?&raTZ6#toIv z!qw4N%IUrWt9&U07~e=Sr~|MCkT)HDvrnT?Q;pMzO*AUPPy>nPnfXFrMktV}lV=k= z6hU|Cq9|m8X#4qwDW?y&<2vtJPt_okXb+QE*$bV*HXsMx4I||H$ha&!`O;DhjC1b4 z&MnAhc;7;n&{cRcF|ik8MQBUpE$%`z1tLFZGa>s#Xqi89N6IHqJH?j2A=BwyNf=?g zH>-YkYiJZ~IqpsO{GhS-nbyZ%%Q9dtY3S5@x?!I=Ien~$2Vdav;+;>F0tUS899wX$ zHu(51F#cd0yGtGg6fFQOtlZqzBKH89rW343Jw=jLt~MTnvlrVa}dwtjuMK8BPAE0wkbn@e;z{?wDQDo z8ovbfrmK@};3FOh><6>?zE3@l?jKq%?Hl8|5%;hM)|m6V7wnCih<-Ojo9ajn47wQl zq2+67s7L}K4hgK(68~F=D{!(3T6kE6m2wTdQpS&xmWMCPz#t?2zW-rO0;tJ>z8Xtt z@^W7+YP}1L`9u*PXnlac2?lm-WK2vpRBtMm8lXi0x^rYf6a`4kvCx-e<1%hpW(k&R zIbK@@9J34%c|>TN&R(Db;52Ge!3@F=NOKkxGeM7SX{!qgVM7N-wM$Dyfq^KJ(@We4n?~dc}?E!Sk?J)sU72q|5mn^l)lRH?Mg&?HDMw zkqbEI>K`C_aiH3>036iq;Sy-*a+*tvwuJt=aE737pMh&;0RlPhplCw15AsHONYxhZ z_t*-=DiYnd7qg%-sU2(&at-c4uiWaE010k5kO2*~#ya4nA*H2ykgNyw1Sv>3s3!`~ zctjRxmqG(?P`t3f27FfJ$53rF`9A~nNIYVgL*U=*HI>rOD6h2 zW#ZnM)KuzT&Er)w*eCp67vj+!ZDxqbh8wugG1hJN$}4ushJP`74|~#xgHjwHUCD4R z$lMH<@mdW$1Q9JKcLf0ReMeBVE#H_O`L$&+Qkpfk2bxlhOiUyDYDgUx-m|Y+2@I0T zew#EpW4XN}g1{7t7 zH}ifG*)PUTo~BH*TLjLxNdh)I=vhf0Q+wqArP|HEv3C25RsX5y0u{h=&>NH;05%KM zG2OpEY;){jm@}{HHq~BS%nd#4)tHrBSGP%mfFxVDy$5@-$MJJL zNu6+Vo!h+%K@6{^JX5N^DoMfpshgY?aB|Qr65nX9csqf#B2=^wHiw}+OD^PI3MJ7L z=!~;GI^075Fxw#+tZb;JWO33hz99@|Xs2R`D$reP zJVWpg=+r8JrANwba-b!LRL3ZNP!*tOovQngE{Akw3O-cvzUXF3g(j$2+OV})EbH8q4 zbm!PDdh>`TZS%THz-1rxi>ORLpiamP1^YoeY&8Rt=|p(ZV-FyxLsi^pJAR=2_pzK= zZb`#v=|5%+bNiq=Jj^{9K0Y(hsD1%kBFc~4bT3r8Gu@Oje{i&+3|yo?*dr%-up()? zV5{}LgCcVHJMuOefyn!NM4zAh9P`eHP-2xB!UbTmTak)Gx3nef+xw#vM^ANmVSXAz z$~-UoNo_Mr{w!iuFK~(wMPh6%#54RFh`I)vVGqGY76ETJ8TR~hNr;b+3|VA9-nYUr z&cA!*^H065u5<9D3GdWhE^+`{r4M$rGx}f?I3!OPdkyn%`&^?8WB9}x?02iA<$K3% zNGePw^IsE(WFnGKeFQYsGy0d1kH5az)zF1e4xfGyXo8C*zgPVW<3oBfmFa7E=yEl3 z;0Tj6V1d>F5B^jxP`J*X3)1GVRd{i^ziU&LYQ`0 zSy?#+1$F4NJFNw|Se#1~iP6HClB>i%^?9Wa!l#=ei>1{O0u&D+3={OQ0nH<$;b^O1 zAEbsDWueJhSw~05mlX(KV#XWoGloIyT@_Iza?j&JjLSYkG-k^B`WPl4jTP+B~W^U$VZ+*by=UQR>O73PZQW^5)r?r+52T>$hV4tTL zxXoOXlK$ZH6ZWJrx7m1vAYBThjsgMz>ZRhGwruaj@9D*uytLpUgn!0?GfHRzE7!DkLU$FG48|8W`I;JDtr%K<0`3-k%E) zCC!%_$_T1Cmv818m|wLU)F|-g@Ih=K9vKdhd_uC}d7z+-bS5NA1e8LrDe`dS_n$vb zYKNzIq4JTn!*XV`-_Ru2o#%_NUm;lvM)?|4BE^RkTV0R#W+9Hh3xYnNekdCb0%x18 zB1g&~w}b!831T#(acX95x%h;A_NJl^ql3232m;iA-JB^xmZNs{CAgL}8lOYu@cw*# zfp4HSTFDJ#$r(@~A5iaT9aspFSrJRP$CpfV1iZHBog?9y*P+!Sj0oaTTp! zWzXfaBlxaS7ui5z#a)ml_3`22dFJ-UhX1@a%wm}|0u1NzHR&WrS%2wC?n3}vf#5LT z>gIosXTXTAXZA(-bhiw-ALSu4=Z~Pfeku(!TqU`;oO2U0aksUyHfl@|#O^zSg;`F` zYhJwpI|3r4*Yo`Toh2m=2WU0PX51hE%fok_Kt5={T9}3?@4+1r5{P)Lu$8H)mG8DY zv(q}?S!PvvOz6om`yoV7FYPBm5E+ioUzMF%g`(#l#Ps^+$+zB8OvT1P#{U(PO%dJY zj+T906`q#cmLDY}t~cQXpEZCm9VcRp zs|J{ekFX34T2HFO+{1=hn^sW!c7B(Bx90q__|bM;dEBLtPr=~Jmk_w>yQ->2+p#jg@44@oweYAWy$x8c zuL<+Zm*I`j*nB}#m*(;NOWx1DgtHSQRA&m43v;`ZAuF3ng`@}wy|Z-`xLF30f0&dd z)+^ulGeQ`gmK-Ob*EPO|@J4q1)j<`OQp4Qh^#_5=3e}o!K(Txzggjf4?VMZARpOgv z%N%%~a&q%(<40CQ@5v%ElbI!)WTFL5YtjRO1529pqvLHRCBMsdY{~fmb5o=8^I^<@ zeV&tH!2u=vok9_qwG%UGYW+mqn?8cRP{&<5{PEqw#{l{3#StSCpWw|&1>(Yi zMl{V1`Al^|){DNsO2r~!3|kN8jR!l!;h|iy82}uw&Kw29TX-?UR7alsW_d#u;^F;F`N1QMr4H`b zbsZsRb25jei!?H|9N(LH1guu~=;RJD2N zk)yu9`eI(+(1JlSYvvFW_nPow3oJ7a>;~oc;Vs?Q->j)AhUy#Vw@u0ASl;DwN`16c zMqca^vcrp-^HqpGejSF3CZPSpdF8kDu=c1&MIYA?H7p_>>>~PNw=~^*xU6EFKOIA; zZsAphu}|ZiGM`c&^~mVs`WSi&_OUVbWMv`YwE~4^-|nI( zZ)R&ue;wsJ{#fcQQ*94Tp?DD1Q{D%S83FS7wUcrI_ne`&-ROZW;EM-k(9@0RvzF0J zt~-8HCXi6|pC6~*VX5+dZUf0%lr1(n?nj~|5yIUs3@vs3>}!*J=wArCG=jQx?s`ZkUs%>}$1lXLwbaS+ag8 zlK#XlDXdYh%9<}}Hz+nlv3G6mmJe7irou|D4JB-9|6jxY?HCX@4WKzeZj#3Iov zm5_UXzLK+ZIf4Sfr6T2YcW75s*U)$>a6Ffh*k;7ZR|_46p6mA^RKCUt zjBu9o`iXa2EgrS4Q84tulmU@719;01K##H?I#8FkV?YsYu9-ljABy)|iN}XB2&EKS z*i8iN?!`Jau`(m2wkZcU(ABGGdGTz!ZgJn2-`)Swyo^kwn zX(VP8S`3+zLQ*wIk0oF#q&YD>70KTq*c^J{3V=b*rxyW~Xe4?=NTm67+d8tnNXh_` zWeWfe)S&CL6i%J<0s;7aP?rE&>>VIdt3sD(Dz)1mqK=PTpbY%Oxd95PEOc~VL}*~A zGAxSR|5M&6E$(R7J9e`R%prvA;*=E7o#h@pC_G0P9+{B94#ihMbpmxyp(T-pjZGcl zMd$s8^$dZ}lNvh+Ns!y=CIa33+SRN7;IYbU4574%2GwNDU*3iuc$cY;P>3yjD32BS zz54g6{tC4ZzIgECz5BSvlCGU#ZBO#j%tYlTWz+@QEIx%DAVb9`^$ibLGS);R#+qex-lz zbLcfg0RMRA!4U8E3+?<$f#QA+05|ugvcdiaA0HTn&n^1*YlTIWkJvf4SE;L>xZB|7x~G}Tek$t{VyQ}ZlwKF zW%<<3tW~2m*TOcsLSVepa&4FT{9hypvhpLO1~FKWU{Hy|0ACvc14x@Eq@lx80S$dO zDZ0hOgP!C04{;ml-31U8%dt#M-2nD9oOfA%0H@XTA%qf}>Q-o4fd2Lj;Coc&nt=`g zlB{dTasxs_^4A?8=;L)84zJRB8~A~J@g~)7C}ywNWOh(}s>ZOsFCoH%_Uj#1%BQ)rTMsMiUb%K`Jno{Tw0II!Zlix)ohh<~-4ivU!K^mkR z;fg>~fCW-R$U(`F*aJ;ksD5VvSp_Ur0U%oh10Kj91JdcCa)FbbfawBO6~ph84J{RR zxI~a8FF>mN(Auj3**pWUEp?lrLNVxk#T#n403n?7dcb-?@Te^5ru}8QpNs}xzr5;} zH|83dLpXYf2TuE*C%udkE!=&+wzm_k z%Ao2sH%&(A%BQFTG#99m6_&J(tVaa9ebDgm_EDr3FTU214zm9)o=-vtjepW zDi&c^2YYiWFn|B>I=2$De_A~rt=ZVkZ1V`6p7^v#7!LotLHp?34VGItU-MW>A1%nK zNoJ})QWLd}T0fnar?cY3tW)u_olVY0tWGCLz?PQfyx)VNH%Vu&;^5HXzV}Kvka=wM zoq^%!eY2zhEzx%j7 z6wtM;C^Ap6d>G6&2o}(B^6lA4gX5Wdn=jv|S|ln&FgVHoBia3?g+$)IT$B(y@@3S% zNFXE{xJievLDRpD8D&JXweKAT?x5ML_k`HBkk2f<2ULyfv@qwBeq1_Jk-e?vQ4z%Q zw!cc>tbGL@ASPf6F*bv1X#o`!ix-W(u6%BBaAgl%pI`SaXf^upR~ecux!&8XpX#)+ zu-$$MQs>5e{;jUHqc0vRT4>S0joLYp+YOmb`3zz_9aQvBo2=0vEmYCpZg9%WDmw%a zOv}kY2dixhfbBq+NUfnd6Wd3P27?LZ-ulzbOIO<^3i&U7`3+BVHUwxRsKW6$`P6Wm z)_j-YiBueq@sNJ^Tv_o#gnB<`%b@FTOnK$IAYyvd$n`~6hxl5)Oy3ve zdZPDOuW`_!NEp6x3Oijosr0jct~249FUcVAPqn4<$97Dp6*0sX|7?WJm*jrRXW2nY;MEGC# zQnq>Ne&A!bU*RG^FH9d81_MPOhD_bSMn82;Nbr98_LlAO*2@N+=L*uMu=@EHNYjkKZ?q7^uU;M zT?~0`f9s;LloAf|fkuApoH}LWKaiT(6Cqi62s_3tLmImStt-udUvPL%er}mVE=)pY zcd(?1sg%>K^NMnr7?~n_huq>3cGf0Y+EWV*!`qzi>9xKfSTh}`m@esh216~5ON@l? zjb6Jl%sW!dOk6`@yGDF4rOB@oV3uf~P2A2rfaCvk$Cm_WP4AzU14(TLn)HQ#k19Ba>InOMg%T<&mU&ace-Tth3TNZ%n(gLDo2G+7RJr$h7lzHH!*OfAhKV$I4Az_NY+=G?}0J}Z31^vXh%lkMPexK#;a ztpnHxkIAUj8rc*pOj#u&y-tdJOOzP9eRSOM5$~H%+bZG>p_i;l@!Ae>qIkL|^bz(r z-^L>30_F*aE>QK{KzU5_i93FsV)(jb_A5VDsb;v?ih^VgzykLsV$BL0V;Q|(TeAk{ z?_agcQhxF?Cmxn&|4Bj5 zQuQTY1-&aw$iOTod(HUK#mt`LS`=B*K<`H$6llNi#D83)#Ip^xnSE^7X zoSt%M|AR(-I;&`l`dtfR_t@7t%5&3iJBZ#)T;u;Mvo;?6V7O$Vyu@s9$xNY`r`9L; zAYgoS-Jov0T-nB^jsq;f`{Zj;q&(h&RRNJhQv9>^9q)zR_E=*WehYoYEt zq>nkeb8EF(a#;Eph1iQ`60r6)iOeDD*pda`?tmgh+2FQBEA`T=~cFwb%^buA-OWtOO)kar72E zN#V}qgF}*`dACpYeXR_qx}ayG5W*rHSg4Fbjnx8FN=5*lq|B#n@?Dc;VbSp-f9iU7K@_=0{P7n&_ZFCAKA#rDKS`Z@LqP49Dil9%&DDY^ko*>J` zh-zF)%NoiyAcr20M6q@7jeR&r{;v1ou5?$`bcnzvR2LS?1`YZULOBr8gUPr)8`J^P zv4jp}?*gbfWE;Y@`GHmw1+WLsg8T=>=sv)?c0T|&=!E^Y>8*chBt!I^>>bR;q!v}P zIH(tOLxl@T(IYjl@XgIWr3FNgCnZ!k8?;^BXF(j!9JJAq+5~&`zZa6(Qdw2vYZ zIY4*+z#z2MuNt0;B-}y5GT$X0*a_L$*>XrG9*EQE3?&+bQd{7IS>TGG1;DnuX*_{Z z&3})0+=(`KaZd6x5EPW=yAa)L2j(su@yUCNP8LdUQKL%(#ZTG~6C&Vk=I1T(sQCMJ z2cT?uAE{x!>IBULBR0Ec?YguNi7Bt>M~H^&Hy82kjj{)S1#OS$Ry z8cx0hyz^I;1jJDR6J8@Bx#y|~HC#9f-Lnmju^gJbkEfpjkaR2G{+|}0@`ZR+ZEYxx z%dL0=N{fI{)cnhc6Myz(<_jFU?PL+gd*8wi3Xz@v1#X@T5gpPY{=7;rxE{`#AY0^& zT>{A-6*$0X2c0G#Z4_`8h4UetCN=w&$`fsjsf^D)ei6uHRN{ZZhbO#CBt#cu1kCm7 zDSry`ygp!uq-A7Sfj*J&3=a6{ zz&1s7C!)dxWRkubLmeHt8>dk^DZf|JZYo9p9+6kJNey3F7<^-UQymkTgj2+RirId9 z+mACBGG4w<9*Ex+gBA)XCx>ct!0}X2=?C2=br9TQl*J`wk8GXpf#Vqt6Z*jLLDXa} zLP=N|(m@Q&A|i%RPf~^c@d(qOKq84H%qVpx!s5+1Ar{U!v7uN+-Bi=&erDSi@!$QA&) z^n(0a44?gHWGew7C(N9ZkBOTis<8;I!RuYSK+9a!I)$NSq67I}uk8!Y_mK&^P_`#^ZcR01) zpU|t#D@&vt^4h%=!sbfMyHTn--aTl=KP{sKd#FxqG<$?)PO6*cfylT+Wu)uN%~qOM zG0#?KzcYM&tzI)^t@K`*dN7UMVAir(AR{e9)_{X6pMOMz3b-h%ZesXz z_j$lTz=2D>hm4Qz9iH_P3r2*;LB+W|e2*zj4TcJx_%l=lbrd3WXausxwP zPHbm#W#RCrd>WqJpw02`Q**hoalt=A&(F7AXli@utdH|u{DN_WD>Ad3nIFyuBN;^cidu_krJ_HLF z5KpUPBjZ=9Msp~(-Y;inz(MtEA>*`c;9sLRkf^}IW#12(&l@5)Dyo;B zbnO-piwn^Swxrl?drj{Q4VuY#oun}J$rP8FkDStcjg%bo+;L{urMu8 zIgzNhi#@W%hox+VdeNM`w4!N3k32u{Qo3&h5(NlS5x;A}OvNHI| z!fd{JeDIvWRhR_2lND9E!X)+WoWNV@U_LiL={Ug~L&m#CPFNg`T{Yoj6lrmvR7@S1 zDQ`b#PnaeoCL&%W^gO|}$=HYfLBT8kIwx%euZ=au@_DRLzJGUI(U#`5C*OJIkRU4$Z2NK zGr|WccVsRwUpXd&`me7O)=1@lKbLP&ZK&0+j(lEx7vu4RN@qeWa*$F{yw-beVLs{? zYhVuqo&v82YT&AbAN`q}-YMlATsOHG_%i&oS2^a+ig%;{OJEP>nH1Q1aL9z{kOeT% zz{ut%p3S`DBgph+*fr)~?_Tu&VvT-s6hjz8GqNU6f_Anq$C&iB>HO7QauEQ(L-1u3 z1v1IBde`aecw|RP?hl3CjpdTwxn0$oVJ$v!@oh*TzKUck3AF`uB#BcXhSr28Xitke zzT|whHEmTid3=AHEwLEip6caQh3S)hTHJ1u5gse%l}D_`$}fS%;AuWJx9s_-Z&EGi zkhM6;fQ*@6P8$Yb|J38Tmo-bFY1c_?bW`Pz9AxX(#p2JOql-y*c>SXmN^XAKC!fwf zA5Lhans$%bX|$Op{Dy>8!s{PFOb2XARuO-nF+c=tfh2lFD=lwb3>(p{b`L3?rI`C# zLWMV{Whuz>1wVOMD;e5tuz4C&DOtNKAtozUC;`JRCJ9@p@3W0L`ra2b7#e{CXERd^ ze;$$i9@f2rvjXp?TP}f@I!Vw>wTykQ?Onb#``i4y>AQT775;|S&W8ehQ|2+Pg>VKt z>Vq+2_e~0Fdd;mjOyh!5E+T{3+Tg9%*A-*1GUU+Us>te6(Er_l{uLoY{ehu`fsOrl z5056uVpfk*fpclYf~A0Ykd0>Ix9So4U%^$<>lr0?I`LH=ceEnefz#MOn>xa8QV5g8 z{(-MNQdXU+5M(b78wb-!`t}*SqQ3kQWUf#7gP*rTIwD!t^J4zY7yL7DtH5se4A>aS zHC5*t1gkuzo9|MV8n2vJFyPakAo0l=3LJ88HQaM@&3`|21#zu7R}ss`xp^(^OUUHM zmReqE^#m5SeQ)JD)4tA_bYW(%mxSmwQ+s%_Tm)13K2vd%yz<$LH>>rWO4CKLODzp7 zO6sp#*)hXu_|kA^aXhT;%1W(GEyb|R^-42kUFOQjh|341DWd@UU;jIvr;V%i9GhAi z@%NzQvvHexKC}3}XHxMMMQ+4Z6IPeIQAEh|6<)$#P6{n%;WIVm<+6HAQLxqDMjzVZ z&8ZQ-*7kL)trm1_5y_-TQK0F20PB-?j`q{zyv z8R8{Zv2KaVr4R-g)moTybiqqjY~`Y|w(4BPVVS()#xUY!`00d8HCEvhU&N5DQ9NFt z$wh=oIXo=ill|5drHZodR7I_gn(oRrSO^bzvkdX#sGynu-eGW@@>9YIhmTnmMp>Mw z3FBND(5{x7+ZA>PLVurE-Ob*)1sSoNqKptHKJEo<#&iJ*HGCiKd^7ag@xaGbq!i(5 z`8QutF!1+x2i`!HNB39^zftz*8=ooP!4&%OBO5JY_WEDh(Cu|LfrY=5U!VIG&lgJ* z69$x)o;~{tc@#>lNl2+*NZqSIb0{q=dES;`zQv5=46j7{5W9_CUD|8$*}Fa{Kmp=F zixbHQ4$GPo{q^oo(Rak&&}qH1dc7sb4)rsd&x zyYxcv56bCBlJFtQldGaN_Lg0!Z#nILnl_Fcj9fU1J1E-?_~so7Qdjb3&Vs5QZ9C>Z zTi#o?G>20+oQ+@qx2JvN=%^HEHG7`%-8o&7e`71-HmkdC@m&!NFklO64?;{4r<~h7 z4_ma~XZi@jGj?IW`zQ&F#}b_O`=Wg;S9hy!#w{3uZ6q<$LM)EdutDGm=JhjWPQ<~#|=DKR;a%LhFZ_M12p=rl?g+HEI(LaSZ4*N55ZCJXb(i|wRiRm zif^(;Da}ojX1rmSBV@cjwODen-XRY%H6!O21fP%7yZj$3Gi0Q~<2 zmJc>S$m#=7fvzh63F8Z!f+$~F*SY61b`_4ZDs_e1=RPI2P`z`rdpEG8KdWl?m}BDD zhke<;=*nke${?iB1PX2zE$2zKjKlfk2t_$LsOAWgad51@U>RB=1IDirZLg?981`eyJzUT&@$ zkchjWH}$c|?VYg|)dn*?S{6%|&91HE{_6O_h;kiD@OBtfd7uD@ZJTLEJV@QyUXog za)2XrPVnwg126Du{?=H-KWcf7;~gHKL5G~gK2ug|)$uL2{l$Q;+SUYJOx!Q_iItl$ zJC0_@0#4s#HE%UhCBNIYGX@IhG zV}(dO`5tWg4}p?aqk{k(xcYzwB&lDunqr|makmf(MIul*=pvG?U&2vj zF~NlKUR#NZ3QLpsan9Oz{FkyKSKl|jUZgqb#b1{$}vmH5Mx!h7!Z8fRiQST%F z?LU>N&le%{5EbLTX!jY=I>I%8GeuUium0bs9>#bTkbDntXyFT|-wFVn2h~3)WuZ_3 zQAdw^*%glH{LA?6K+!w~LRxON1?NWuFn-n49bU?ggbP1KtGp6#lQy z?}MJu9{wJLR-{S27P5UZah0*JkucUQ%*@K-;^KwPr+i7KXM#D996L4t3Y}CHW{3Tj z=Pe0zS=79WYRcgY-ptLPT$aRW=@znJcbCW~MVoWoGA^495<;Up{d)xiTcvy}ej9hK z{^gp+4ZG`a{fyQra!O2~a!B`2Y~z$c;YfLYxJZK}%^RN*uzNfTzsu>Blt1qlbyWSk z0J48E-F7dbyt27e-$5QVZV+W#VSAe$9VEgYDXe8~ae3@Y_1qHc#5S|8>y|hI{J;7U zv(AriKlhdSvD$BBP}D0qPq8x z4T|uVziA`u)f&z`t(~mx56d{tWT?}RwDwYjISi5=!aLQZAIjM61$Z<(Qd}Ek$q`C& zh}d6vTWaXO&``agFIDsTZfWy<{wjrv^ZC|}N#%M6+4Aaos@Yy>#t}``?5=IEG(|28 z@^H^`#b6O_kL(=_i#55ZfyF`I3%F`iS^XZ`mnz~=8CLJ>p%Rqz`4xlO0Vxx?<&fMU1Z5; z-WmvYYRzm?)#lj8sZwX{vd1W1aH)MD?}FZn{OCAe*Jj?QzGSMRX5jHCe#G&9zSU=A zRYSu(JOTNw#g9ki-j+t~W@k`~L~j$HOw!2WxW4{Pvy9h;r)F#IR@jX4eJWd!)w7DZ zuw@a;cTMsj@yze|6_dry3ooh8`=SoX zw0VEh)(_lT4y>(^S4q4SS}s+kMs8mjdmt7qf5Cf1DVyi5C4=Vp0DV~jkKEB5I)A|2 zDOcCco`^rCR^wcc?Kij7+*dq0G!)4^MNSgJH{K~9=%fgR*;WNB2aHsEY_b-h(Tr%8 zkY^TRHMeHIRjV&N&-#<$>aN+J&DWIHbgP2bKTAZTjnP8BRM*Pn7ypV;UjFSYjnW%; z-3_A#TKmmCB6Z_)gtDK@`qCm-IYc)d8Of6 zYaQySow`z46vEahf`CwDQR@f{2rcU%djlGFgc^YWQpFvi6p=*`6l;VS+yXJOD4-z@B7~OeU=}eqNA|hO!^Q)$nc|qR4fa z&f!!n9A0TcXOH|%Gtejg8)xTqc**_3&>7sWj!IMs3fWK>k~_y#qfp`|O#KYvfB)X% zz*P!#*2oT*V~HVjYUK8~#kRZ1-Tvp$DS@(C_Wp?cM|*itW_JrIUvxU&`@65N<#H+g~=46ZpzAzx%VxNh7QqH zz}!TWYdJ@<0#z)~4srnfsW+65tR$dXnfp0|3n(I+&{GKIlh9+J4TXM)?^G_=07P|r z6x>&?lZp!WBLMnw=ONg171%~9oFQ0Qj0||V7-#oXttAzGTlm%yxE}z4gwo3}*~#|u zaX|xx2-z48bDM}ef{(Yi)MEqw8bdMhj^zXQ*~%S}gB(Pg$nH~W*OwJL+|K2|(=Mk* zv)BsJkZ{Xu4F@-ZK`1OR+isu}D?|$=Nc?WtA8yQkCX(5(utp2DbXb|_Mb>h7oqi;h zBzU+x&^p$ECQld)LYRrEDJkjbl!XC-n+7y+23%w_Fr-bOeA_uV_)IAmnlre`+EqsH zT4}1v%9RkTdIodoHW;5hDM6aVF!{%s{Z20Op)*Ra#|W4NW@_Bsx^y;wlqQ3 z$UF92A1 z?xG#$5V=Jn;kBt96wV}XfTrWMt2TO+JvZNiPTF{jGbGEByn!n8ZffGXG79a`eLi*O zj1D@iw@R1SpKfh-Jb`DdnT3(PesV+?ASUyM>={(hwc{@ zwRn4br=3uU{5g2ezkpvgD}HwWoyIy{`Vaix^Dek_xRF9KG`(Qy`_kAQ_lHKE0Kn&f z$`xD%M_!64X{{xB7X@|r?WyOSa|@gVP&;jfwL$rHx8jye$TmoCmNNRTGJn#x3GPv$rLH#=_=#4kXh_*pPrVA~4U>3aBn zkPheASg7!#u?u*rguc$Hd-_>hvo+2iNk$5ZMZsMKXh9!*^IR$#RVJBdG8L~ay2IvB zd;v}sk7rA+L(e++8P|^(FR9B04UPO84p}G6pNnpjP~i@pE*l3 z)}W2mx+027GjoYBQYZzjbfE1H3l6MTp?XY=X5fy-d%NS)mhe;=7o_}v&BpT+H*j%8+{R|EI378~7$e{7!3E%P|KNj{> zYlTF!huM9vPCU@{`u_Rzm;!MO(O0O`N+xIm|T*XI z_a=DqzyZjjSgM9 zF_E4&4Bq$_3{q9kX!xMLlhXyj#WT=A7ftAHN_87Pmz^CXoQhO*Tol!?B-hef3f%kw!56?9XP6ngIsk4u`_Pf^oD+%m zhE>R@=^ARZtE9Si+nrBIaR~nQ2_YgdORijrJEgJ7^Pr0?%?huJk6I5LE29ON(!(+( zt8U-p007f7E@{MNt)O$5j^)sH_dss+^2SeEM-PFlL4Fst;BI#sb?2pnN-_)Dmnxb( z&pc2iyAS4DuYj1*f;{8xI#YAg*^#6Duy=9>*Mv6&cP>4=1wE8GbRnUk1%ikwm0H6j zF3_r~s#VV!?ScESq3uHwgZ)`sRa?7ZsP$H5!)t$yA-WoE;5>5(N-fTagy8F+<%2aZ z&|qa&*yb;|$>t<*VK(YmeKS@*kZ9p0XJq(cE1J5rTc~`DHjF;cm zLQzqXW;e((#2Y!MUm<)X?E0*NBzVN4jnd4_LUdbZ;yle1_p|nZi?IjE2uTrOd73xX zT6cIpY#g-s7S_CEu1k^*Kfa8^=7Jc0X0JY8q_%1)7?!@8W1Z1UW*D-3j}S35!4An{z8tFA+FX?S$bPk z(@XGMFf1IlQ^OK3mInxJ!NQN}Sm_Cm@N>TbghO1B9Y(!#(@Zr|ut}~vRNP@0nN$Gm zo8`CYTV0Z(A8raEQc>CtP#8gI0Ubs8w~A=Ui&$Q5gXMwUW@c>Ued(Skz7I=J!R-Rb z#K1r7FvJ0N6Wb6(ohEo>saSvsML43RI1(+6P+Jd>FX zbR=;psaUev9zo4LW~O(&(IU*E(=~F&Yte3Aza1npLd$V1aLt1d5aKY7^-6tzJhK0O sGL#}$G2Qr8Q?DRbHE5tM9@Bjb+ literal 0 HcmV?d00001 diff --git a/notebook/images/figure2.png b/notebook/images/figure2.png new file mode 100644 index 0000000000000000000000000000000000000000..465e75d98f63d4e1fb341a71107b783db2907d47 GIT binary patch literal 42218 zcmeFZXH-;M(=NK2AgCZ90*Yi&a!|>!6%iT)Bw9yvSOIyl=}nKHSWKXbBju)ocFlb4^H$cR-)*FyK)mpi6xV&ywI2HT#+4_*wRvC+4eMrS|QqoA^l(KzyiL!lvjd{6EkF9RU*l{f7Mi{{Fvu`2WHV zc1?sK`gv!uGGDCE_3PJLO+ORVh=Q#^23+<}E-t=hZf@RcogT{ho*4be7JG8s*41^v z^)|d$>MN5q+TI7-XI*bs3*ebtR#Z?3)4*ccTU%T2*pq@AdTyqtr-S=lvGuY)$)Tnw zn-Pm=&u$^e!L>)BR!T2gAbohC|Nk04tQ-h(rZ;D^-kBPf`{PLa!u!J`_T!!>qc-h$ zhY89v{Bgy)OITiG3z6U9>RP(y6(bYoF8(rQIsK|~(K*3-gS0M7=Ia@jipSkm1d4gw z)IlL43S|vzZ!YG2zD3gdK6$q5S>k$Gdxm9@VJ~$}3awNskLoj8|J_0QFy`z?Nkg`G zv=Jg)5teT?b9hwKyUs<3acSC!i2cDzEl507JzKptQ}b$)gn%5@YSi3r1mi-PN+AA8 z0++4BvWhB}asyKuq2Bi>R*CQW^{h722t$80dgXXE{amLUW%;oBz`(vO4wPOjd0~B- z%fX+&7gtNh%Zl`4SUGCSwX7~=O1%sIz^Y3k=t@~u5#dWHxW8=TEk9l58*Dzlf;o%frj3-)k)g{f2ZAy{Z zJ+OJIob8amt}9*k>A+WzT1HSwYMFH>yQGanM%GoPRXHe(o|?Ym3TFka!c(2_nNoJK zM*6x1FVuanqF6L^4Z;IbUM79emD_Rt z>fOG(wMHX5FY9@Z%Lfk%HZ00?T<%PtcWy>=m7dh>c&v^OJL@f3d8{{jhNZP18W)s~ z2psREHYkLBGOjj=C9@sUZfWQXede2H*y6fa+|R3Ss*#dM+9XTgb-)~T$hEuNDe17Z zz1=UEcgWu?*WAue;_b8I^e62%`*-U8f_2Nw{@kR9(#6k(PSCH{5HSQjFw0^_oVmqV zi-|nw{xy7t_2`O!^3LHJU5C3dcjALN6)??t%JEZ-$+ckHuvbwuOfh zH#t>bRgZ@IQ%Bqvr3ELk02xkp$r_HRnKYbnLF<;}h-Q~8v5Tfb^ z#7KRrJA26{Xz-`~E5XWf6O5DGlaOue+)it!ibn@O7a4|u6)USq*U4aoEonEDFL?S% zGu2W+2ncQU4msbDW7bDgILH`U-@Dx+eRR^~vDvjFOJi-Y{&dY|IR0t8 zGD=EzUw7Z9>@0*A>; zgWASBl|Iwn7^|O+FBIhEAL{AF+)$3o>Wmd8+T?;D@`nY2ckqVrb=*r&Zu54o%T1o_ zu8wKft2haD^wtJ%m2UT^cMcg!JpG(dvnOiTU|(~tfO~`3H|%l6*iK#jw0fLm*SaeL z$wUE-TQ&z$#R%Gbn(K_s{QC9opFe-B`qEHh&o`2~53UN^4+st$q+YeJyfa?mFdWXI zE@0kv$@}PU%bFPe8IqolDpJ{@QU~%^g=`;!HSyMN+@%+D&MK*=hMpZ%VfDC$yWWGC8;L@?Dux^33=G&J>bM%D1x8ArI!M+v*YK}uUKlZxmoCY$ zZ+36WO|$Qo16isbo2_iuQOcR>qobqMJ45;|n8_*6^at@`np9L&dL>VSYiw6WKGRBh zRZ{aAt!&W)*ROy=^7w7wA{HIs2W2-}Qt6_Hw7TbJ66 z3n|5k^cp@X(-V9;bI)$9v~^xGR>*EMix2hjhLUx;kCIaGCvM%8@bF7OpZj_0hKBil zeQKe;5abv3T@;P*c#%uQUob0)##U@$orlHU_G`->kM4eDMDqsAnvDJ2_|^1Ggs$z! zj}7+jA3u7!;>EqE)HO7Im-wivevlgl`{ledsGVC}tXAi-(^u`XJWyA&-VnvBQ-s`_ z@5ao=0vpjOuw%lLA}V?%gG|J(xF7qpuYeZRfiZ!&+_5hW~Qc>C@E*>D-aNqqk-ai9elwvrEO*BEXO?{#T>;; zj8uij^(TrE!=F42JZA%i!#})FMp)r%Xz6Fvw|aNl^FV%d9?&?Gn3$eLyvwI^=gv)- z1A`$XBy8)()J?1S_|(gJqx3554NG=nr7Ehc$FTSDA#wu5GUP0lmi{qa1*_{j<?h z`^N`bNsm%$9=)`{z;m5ZH)X;E?8XFk)+Y7OQ**y(?TiyupKS>p{`&g7OY8MW%l*wi zW7RIXz&=<|IaygfGGvlztsmKPwF`~Jy!MI`JeD|C$17B{joi(9>e(R3=j_1!4P&g_ z!QAS>EW@j`I48Y4BO~>Kg^BIxkF+$F3nup z&%-C<+)9*{BsJY?y#rXYeXV^zUL}Tpz3_&bE+3{t(A`eIfa$&pInJFKrG47yS?rTl z`_Ky!3cK?C!}7;ncAu9t3k}k{Lwx2`ceh`w*w_|+KKQ-C)~3**re}Ld{Q*I)b5yV@ zv5gC-1bIY7GY^j~tX%7_HHHs8dWVa0vs1)V9Sn_uO*@I@mpQ~fk6M=ud)g;& zuaG4~44Mvhx(jo0h_!w@1C(q^%nm{P!MQ{|0~ARSqzL|G&yv>S@2aT|iVZG_Dl6iu z7P}Ja;sxW6P6E>)^6$iG2x=6#&s6@xaR(~rKPuTUNj}_rQNq~syZ=V)^aW@E4|j`m z#`l6(i9@~_iA!70f>ZAodLPEHT5Vc=6uj;h80|oW<~Xp&Rb* zOEK98n|qOAX;LWf2VOTJXGv~6d=R*2r_OxskZSY>5-vX|RJ3W(UB2KO7JjMOln?5; zh{pskqJIUq=Btg0YK6ZTB8Pk{jNEFS0_1eHOh2Lc&Y%n4;YA7A^8CIt;J0%2M%`to`o`&ai!#Ut&xVt&v|ou-WmFz+(=`m{e9AtmTIZ$h zprq@kf|*&9{rtt3))}-Ms?zF4Gc0B;XmL!3*)I-?v_c+0%_F7{o^wFXc5qH=!t!wr zqV?bY{p0&qoYQ0gmur4ctmm|IBH8NAf$m5I^0!GG0lX4mmD9{gE^HFrHGitfV(kwu zK$eWBLU~?DYLG+^ng_S2{$(JWaCgp3BqKuqkcGg58BZMu*4JO{+FDB_oCF?jnEbi3 z$}_$!Wic%co}a^6TGSB!-q+Iavr@rMh)A!>dtm}8^Td;`Lv~9H zZ~ks`mJlMZI<=JGt6vU(EeuNq39l9=(umwLOG~j;J*h=NEYvuZ-BL$U%SchQc;f3m ztPm!&Hl6i)oLV7lvt-kIk`S`Q|7U0OG?9j94A@Z~3v~e#vSyyag3F5#`P0)m3Ye3$ z8X^&Pn-7|L5{f0m(tb($a|sJpM2JIB=jm!LuQ4axkmR42pcj4m&HJE3wtG*w_7cQG zHx9N1pxHI4#tYBHB)ciAt(MNyRH~%rai8EpCrLzbCh&cop|P#WNR}w~f@|#^K%U_C zk`^OGPOpScKa^0mp)XsTZC9?#dzFs|Rg%N#&D}r}O<&jG+|aV<_Gh{(lu>Lx`J)X! zG&mDNLExPYoe(ZC4wpC6wcAkL*B)Mr5Cp zwtSfO-qj(q%fHB=2lvR}^Q9Cp2T9#-oj>snyJQzEl&lFsG#>C;ISiVl(_Om|wl|JY zYF-+(j68Ty+^Hw(WyQK!G9sJgbZg^*AWv9e+fo-GbsV~|GOGP&HUU`7f8GvsLduyp zC47_M64Y^O)0SiSBOii)a&b>dD=r`iASc|qIY_z*M(kr}P3zi#5t#7e?Bq2w)*mf3 zA~~6Bd#{_~MBpsschtbj32veF z;~h5;5bNlxNnjK`g3eh)MZYo%Q&J{mr5O*LI_Ylq*ZL!%30Mj!!xBT+SZ!@JI&k0_ zJFx!aTndCP$BnA~d>1Nb{#J(-UJM#$MV0wUK?|2~CUY*-kZCMz{=%v=BSVGY39Hrm z(H#gm{?CF!gQ4-H%{Q!8)1noNKm#NovaF z)Nit5?=&z*P27Ue=W$0V6$O=+Gb5(V==wf|KVDoQ3QJq=s#Napjw*QP_YPS`4zV!E zAmROz&_dYB9McO0w77fNih&)*JKe2cpU+v^NTQVkZY?iOWPUE!F%;hzgUU#tpLcQY zWkZLsw5+X6vn|hN1<}#nrtPKS#S{aM=}VlE^*U4aXk|iT%a0w@qzmPY(0brfAnwwy zDl(I08DwtWj*F@7&4bpYM1q1x5G0J-HrbkTgaBPm@c@=Y*ZG8{dCUyIpB+LIo$jd( zRZsZ}g54s%hKPTvYeE`Y;KnI41c00J-_nCx*~QEp-$!ThhFBpqo6t)lus4puke$oP zbOI)g@4SY2K2}@+Xh@63u#iex5>h{b5sEmqAxj}Bz|~;iSLOQqo&u0F^-0iqaKL`G`Dz&>dUV-P9zxf3!%F!jHNzg4 zj65#ez*vHH9nk=MNgXO{Rl1jgy)Aj;^k5jF4T-?)o(2 zLLmY|A_agoOA%;<-QT85R##Ucb4r6~ zR-3yR|vpszy)lzUJDJMyIRLBX>y~^s6mnhBA!U>@2%_QJ-4DS``x>Xpg$}Jv)(loSOlys6 z=;`HvGzR1X|CK;Fd3k3UOzZEJhncP(EJ$@%dAsD=_jO;Ovh=jgW6jogE}p9Pc(p!t z7orH5E{hLfXcWrI&h9O@8y5iCl3Vnk_29R+lx0@^HxeW@b!7H^C3{Cr=J)*>2x1QQ9&5<$zh6 z^w{bY)~DN|qohoYii*161B@H5eG-3RHb%2t+3q~lDEz-mUb->v|DC+V%Xw?>9nJ*t zZGXS&kfmQ?|6(7+2WLDC9g@)y?cm(KDI(3YQ>tu90YJw#$OrOtO9Z?RoP+5^)!Qy=>^H%-7$YWe@07f4SyROJO z7|_5~P(_9mO?M%{)5pZaqdFxu_0^Uv0$ND=+SP$``K8_z=uBSag((`edkZV@RRbl1!PdxpSn2A(T zr@1U&?1^5TN0~p#6&BQ3g+ro|&%QaT4{2 zW|1*bB3Jn8*F@IA{SiuXs7G{!i2vPZ<8u&GxL0+9290#Rkz|t=f{2^j7le|8878E0A; zs2EPH&vaRlLyf0lq;1CON|c9m$8oZ)R=FR9Hu)#0?La<<1={gzPbb9-0SMwdjcA3U zXkQG&lKVco{Xaeq{LX-!mO=Uo9%%;xCutf4L2NkQQ2qG|m!|frOP#MZR@nx&lT#=E zYl_&Y@2%QuV8~|WlyoqN8DH+>LJr6>^ti)v6kQshU@E8BKC5I_oP~a#f~a2!`0PQ> zDfig&5}XjDCkW_eC@$E{=n)RQKXDIIE-k$#0Wq=_O^J#O$bUnT}rR{VR)K? zZl2~^arE0)I?x&DE-@V2_SQ!#Xm`4|w%4O-sBDf6!UR_CmH9~hRn?7IAtHu|c|CEZ zvTC%7@h8@Fz0ifkY+&Sm1L|qvQ(tVUz}d%Tqu2;9Hq;nB)gNBk@Mn5Odiap#8Qk!N z%#fDu)c{ZB*8MyG`7uD1*#l#H{VCJgl&=P9A>?hGoW7)pjps6u1YbeLrFilVY6RrP z1mi#KWq%Eq2Vu^okMQ>fe%F7orGS#Iz~=NbvDzB*>!$v(O3=15H5&1c8cwopCiW%; z2?TiR<#5v?aMEO=$dYV`E%hFpiulim<@H28@3)njq|lQy)r6n#Kh1`BM!2%(ja3Ijf-5e`^kXU*NQ>CXnSEPbdOr8e; zmHNY`aisQW=ScKX$=J(*ixAQV)@_m*iHu;_rwrDbAnTC0O!n)-zGM(n*6F3CoSnoy z5zJy)uSNF3&*owsy-Fg8TmvV?&Gf@(75$ocf@F|>5EliQ9@T^P`(hR!?jI-xd_Dv&$yhD*Mivt{k1q_WQWyDYq zEqsUk!6_;q%>^zX{sXi6^bW7)=6NF$2hv6^s zP=hXvHA46-I;hL{f&Y@{nYX?Wc{Of5-Y4s{D7EM^siotaoN|C#9T&V6E{cZ8zn)&c zu&(xYFz4&Yfvky|>9RIfV1h~>T+P8B-i6&*0#0m%x}Hg2|>WAU3`yyHg9 z-m|t^@x$`b=bYp59kb1ToAwFfA{JtD#;GYcM8^)0p)s^|Lc7eP`L3a2XL1`M z3&Ov+$520&hQv~TK({An+F7)_5xsDup#i^(@@^4t|FNp7f09p{kt_l zaAU>hia|Tt1qRtK@QFBugtTM=E)3=t)-OHv9?a3K-uQYRrs;IKsaHonUwhw41SQS3 zky(2;GhnhM^AiVbqO=8v$scmTtD!~|k!t1XBmGKe zv(|L4qUyYcYzzo8ZiDn?ff&!^s>J4xYX{AIJ}}N$cz3(2XQnxbijlFOR z=K~r;=MQU@vlmEd1?0&jHqiIqT~208^y)L5uU;e$p{|}=j7P)C@-;N#D zr|Q1gUxb*xh7SoyAAECU`~Z%R#;AVb#nJyAIyBrF=e%*_{>_^=eI)TA$Txg2T+lO^ z0k{%x0Nz;sS6uf01P1|^jev-0crcJq2Vx};_D{h8Ye#`%PKHWi!G~*d*4qocdZm^h zI5{~5%)XPL`bw?(2Ws3LctD*XOu%XGK`^7V(Rmu)f$Nb5DSz4{ZI*}f^@<J z2J|N|P43tyc_f*iX;+lbCYoEAq{P9@=;fCnCTLLXDYyvo&>X zpzM@1dH|ACkX_*XvyqhaZq@g~{m*iv>)pq7{>~2d7Tt?e#0mf9T59urY(kbu0_37Q zO>O=%0OkqZBM0M-1QMA^w}yF5+p^N5k{7Zj(5c%*XOd|z6moDFH7kqi*PiYx>}!7h z`qt%}-7&XM$;`#9XkigQIpb|hKn6UFH*oFffRsS|K2-nliRhoiz_ONduWLX^@=aU- zWmw`Q^NdTR9xmD0$h_eMq5gH3jZFb%*8!>4##!n-P-j=&&3_s19F&Y>vnF!r2V)`= zK%!QVtEn(%Wmi3H!uI?h$M(ti@sR??G?M~NIRx~g&*%z`Q#Bdtj9WN{8;@}wzG74KdgJM;30=%4WMTOe-C+2o^fw33HO2+h6b5PSgPTDC+M`3+1+DWwmbLg_ws#W1J;DLoYZFRdW` zZjs2eyYUPU9bdhUvQp0E?nU*P2Z`-UU~cjMV|z?^c06M_&5Ybbb19u`#Uh!^gIeW| z+U>Y4gjI_qi1|B4Qdq3h^qxd6AhA`593nr!RXOCtW=JxPG5!Ba4zZ#>D`V$2uD}8? zpKgmF31a8fvMZP2zs7UOuH);&nuJ^YZV zt8iY}9>_qf3G@ds-3=f9RAKOiVR?8f`d?X*epxyB)>U}GynFt@0rU;I!>V_CuC15( z$A=3RWJz$?We$Z8?LBnxq43hLd@|4QNr%n>!hk6U?xL|Hj{whW&o>z%KW=!F7iQ>` z1Lhp!)>U`Ert>Yx8QZBAm$c52S{)yGgnCjvY)HU^h*@DQU^r6T{mJU+{A^i+X+zCQ z4no`NeTpR4iE{)9na1n4CO$%MfwFSo4M5Vkq^$W-+v@(*^f~^yoIYQ;%iU4exQ?j` zKpz~l0)Ggb4U$IFKzF4?LR?gz^JHWjBWO?HaFB|CZot`Kk1ViV_lWrMLxb!L7N&Kl z>Uk&(0?`3p^g@akVIvoj0Y^lb9(fq6dZtX2eOB-?qqfOV2j9im%YYe!h0vSG?LzuyF;wM)UBEZVGasbe%!c4 z_c%iS14(}4)DKkzS1M^KHepvHg$R(+m}7BWGk~9% zG(N+GwjsLTajDlXlyi(afBB5z{L|H`&zPm0{z7YRkagbJxm89Q>~yL(7{ECr2tChj zV^mIfQ63IhPln+FI>&U2KS;txf0Y@qWd%jFcplhe45(rE<<{opC;;Zkmt2nkLONdY zXU_aJFoz#IY0k)S(&K%!J5V9%`8!V8sypMrtltw_zQ#lk_NV$QC)AZFZ4656Mt>qA zB4VU{eNIQ1G;l20Oz7+nG~}u+m<)IO*k`7q_xn#tO>TdG|Bfv)1jVoa zT@Ce&pfb?x*;Go*E5LTp+Ix6{nhAvNp#mcuY%hkT`5YfuO;kPwFy`Ddz#tGO1Rn}n zJM}&(+>G6_x8MKXZS=vKYbliSwvq$QIPA~Q2SE#cg~rnKB92L;%if2(tKWtoh@4Ki z{%NZAXc0q`v$_#iKeLcj@AJ6FK9F;@7DDj26@p6qG@*{xR;#~-ffoW&DfRKeSbNK_ zeIo|@zK+>zwxrc5d5}_i#!Q_VCVt2Tgd{HpJ7qszKRH_qC^oo}(# zAZ!Ojm$w@|$(~0P)_LX6tUTqJKH6;xZo%h^xabizAy!4n1S42&9m2-iT8YA90Bkt9 zxD?5xkBWSbW*Ld#(h7kV2unssN3sz^;nssRxwQot<(0_;8IKFr67>h}l={Oz>dsjr z6D~mfGGLs!IG=+>*!(~P8-R#2G=V#E_}w0xoK}ty|mr*U)Q^$6v2iKaIz{L+&&+iT@XY6;i%wrO2ZI|MXY1eG8I+1?#-{ZZF~ zC={HXQE6pc0s8NYJs#*b7ITx~19+~ahU|sS-c`>%mVv=gvd0yNt$5^3Z zw_)QaTX4%Y-=UMk@i)OgL81s26rKi{>BYBf2|X5xgDU=v0+#RXtNu`v(uwf@;K7HI z%v+S{`X}7xm!zkU<-)dWoP6f%cQsqJ4|0BUC1R9X<5N2P9!MU(vJl^$nNhhKdg$9= z)@_N8%R|eYNKlS-+kJ|(w;sO%OmYUW1aRWAi;5n?Iguh?-K7nT^v-~KTWhPTt!>`- zcUMw*^vg?ieE+mX&;kR8+pU}z9@ZT!KM?a|1?1yw0OmlI_Y_JosX5PP7+LZ$G(TCNDsQ!n%>@(mW7x5F)0pafw3X^wQrJ#^Jtn z#X)d*Cq38eVg#(hq>i`MLDqO?r3R2=k1qSSIJOKocpF~bwd*+vJRb8hDVeJmAFdNN z@;qLPo3?sbaVKBf_WW>LqGUJ!ynCl32D%9B^06*S>c>xKZr}0ps+Bt0kohw==MnDv zx0$A`y6 zB?lCrX((dL3>BmYYb_Fg7qba1qy>35d0)aTfv z5UPCCbSB<&tqS4k09vu6qb9>LP z(^+d4i^oml0vRoq$4?gWtG%}-6+#q-;+QLPXZCJl8-Al~Tk5q^QuX{LwbF0arMKzz zPyYf_=HpQoX#-O1k`}@d6pcWFT?6;81=EXbg>8ePk`*`~AaK^hRC{Rk6H(BC1W10c zITCHRZ`~?&npeivRQB_XLHk+#8mCP`S8EO^%@`W4O7ZdV?C6pqAL}0eMNAgkdd`Nf zLp5J^bJmKni77lG%I?OB3i;ID^i=7bL9&YKVJ;}y*B8cZ4<6~hPEc0TQnrh~7^mz_ zN0~aW9wxCy|8&OhTr5cib|Gs1%;FyxDi8DSh9mn~by*@8`@ysEMLL%(F82w0{(k@X z_(>?YzR^H~Ayq}g9HvHZ({*RD`s9T$-oXVL7!MW{okKSq3>LDYd{4F#cjl$Bqw|8< zKKoVy*Y>0zbd-bazB5r;dZZH~H!b(^%Fxfo$=F?HL_9U)J~pDaX?CpU z@6)g}xxRiqZ{oregiJ>_8M4vF6BG zF;o-DQ_%YEeC+UGV~&97wseT1_#eY7tv6}=#M9NMK$W)v$fs<mQrKP1Bpbv5tbU#;L z($1^=TdVHUlGcxZz}w#C+ge_%d|yvm)Ms#0;Rk)$)`D_f`Jv0P?FLEyc!{!5km2lM zTxNdr_%`smdPlw3sNI6nC;{^ukFAaOr9k9dYVe=VsYiL6myDEV+E)0xFlNO4rc3{x zmIGRxUTFey>~rXv>%dZk!;s~JIm;d8m3*5&RZI`!t#}0FN}F?hqeKgf2(^M_wYa=t z01d*v)Zt1=4B!Y%O-&`*Y(Vp**Zwnbo_LtngF`}wbS>i3U%teH*(Npb%RKDtJvsKi6@9)GO?Jf_iL`O$+va>r9)q{3QyQx}Ta0E-bExv#I z_T1YIa&+O=&8Hr|dh4D6 zq3QDwGK)AD5&?_~+%ai1&Kn*PVFe-%;ygM`z-EYBDN)Msnq0`hLT_rBWiN}ShK41$ z8kT0&pONjhKBZTb#gx(wHx-1fGZImA%Hq%|lq2k#NX@|78L+A*^%xjKMmtmocDi#_x~MxTu? z*5Rg&l6r{WPZTrd13_Q!00i+7qyKClurA?Y7f1Hb0RKG*D%%2XYX+dV)@pC#mtKu) zz9&Vbr04DrYrCWN!~Zp#Q-guRmoS2Ppq-J=cVCfS+Q(o%&M5?RAGR}~p|R$&Zi7aQ z(QOhdm8P~9(fO#|5w>#aE~m zQKu0;#1!V&EQJRAVFqabmfHJ?Ps$*nIj*gtKk*1xlgpn;L1!A4v~XB(>azqdB>Sq4OHHB-~s zC7a&fUI03J;qaa}hd{~~5qOYW{>|J;oFsA6ds_#=G!IRHd zAc#d(hoBK)160_vXU_y&mNfk3@z7)lMm&#=j|bhZ9oSr_p~We;^}(Js8JAJ3$TA%_ z(50ggsuVeCJ?tQNGL-lU(8^EzWFhBhDQ}M=(DKWiseXTRcu<}YGt6i~V&Bg*PG&z` z>O0KWRW`RQ$XUR>o72yio(OlOM`9WbD2v&m2K4}&`wqAxt=z$R7JgpCr(152WiHwk zviB(dTWK+lm&A2cC9(DQwvmfb!e+5?C%}lSlpvS@0-6C)^bei~U_{XOu&Y6Zj8JZ8$0xwOxt_~~qCS0^y9xSQNI*u(AthzhEr~H8*FOV&q1*r+ zX!t)DBL(#BW>12vpH2pcrKM%D7B2e%n)xMP9^TK?P^l#fC|@hrV|yVPbg#4X@bJ9h z!AJjm`lkVHM_~N4^6#FUuT}Y2Kcm#D6mKq%I`BgEk;jsCn>$=DtUx=9Wd+!aVPoIA z5Ju?&xIg+}rKFn`WfAw(e{*Z=6=+`~u1Aw1c%*f((TCm~Le0yY;Ju4#t4XDU@@@Ul z&Ba%l%NdV^VSGnBk#DCj2SMn@;xjdf)z2==S^&;1^~y00_NL0YY%i!v9xSAat(Nv{ z=;&Nt0jB$qz>C;&@#LWUTY?my${|Jo?yQw=JN^KUU4i(lSn!X_6pd$q)s>&iWoVo} zniH;GDQ*XEl&G}Z7dm(t8piv8aJvk=cE9*~E#+@&JS4vWHy-61xk!fhCtEcKvwmw* zxpa9F3z{HcYjfRZzD*8Fhma4jk4hk3D45Q*#?WobJ@bAD-IZ^j5XcR}$HWUm4T218 z+Z)xjd#{>}%&!(qC52yJPfmd#racg0Zm_Z0fc|C}E1sP+28 z0=PZ_7ewyZ(?bMu$azS}aQiy1W|&4%CX$iv0RCn|^PK^$B)~Q|RD2IU0wuvs71h;Z zC^Zd@aQpiGC&wq42C`?%P*PBK?xzc@wh|RB@?oI*#V`n}W*S=RH&9FQBQ^aopb2Y? zFHZ{rheWeb$qZoWRxKg)W1r0eJg@hFmXRSwK)}5J1=7V8ilkg0W;?9h>^{f5d>8|l zS4XOzu2S?TFWsQDuO(u`ddqj%|cJ|7yC1ipC2%rX4_XykO?#(`>pXte z5#epz0-i<}(Aa4Hc_V`Fc|=+a5kp5ACzUOWVa+M2Y#~cCnl!OHKz4Fg0wlD+^581o zL2J(US%yr-u_yI2$B}6cwtBsBBX`27D zOk`YG%isx+d(kF3rS zJ%06k-S-W))+`Q8&Sw7|=|RvHeu!ZJDiDBahp5|j1X(r<7O!&n&ZpT2`?7IY79%BR zJ!7_kF3Cx7x7ZjE4yaHCNGK~Vb4qlbh2Jod26Ue@g-;=#o=?wgDYsCscdhki4%|+Q z<_C=uK>wiQBafkx7FBrTn@zWS1uE>|#*QYo^yEZ}9y|}h96gwust45h=hectEj?nk zznB4Mx4>F51p#084iu>tSdNScOH-tv`l~eas&qD?JL^^8L_P-X156Ww&WHnPq9Pw+ zOM9h!xu?p{LK4Zq&L+e4^G99$HsqEvWF`*`VrsH{^?WBqmz13n$iTR8z2Z?<8CjAs zJ!q@caKBd;HYJ}sy~cd~hAf}K!ONM{nPWQ^z->{$-Q#DCwl)BdPUr`-i zNA1^bnpbQ}Qv&|+`l$vxE+Yf?L6!GwDqAP-_qiPOvEyZ?Q ziR}FDnlEMM8BC2Elm?`%2cR%}WSyM}7Uqs)cbr#!IBw{L4FL4Zx>#XxJd~FKf(7{< zoG#bNdveT-J>#S$P9BjiN_}ex01oo~H#)!+Iu#sQEdFHbi7~`(iZH$xYU~D09k_Z1 znm0v82FVqYbu^BBeo1ybMac@RpI8#+1o?-1nUmV3^fOe1q%#b-e9r^^Tn*Ohii<5( z51~G_Jhj!syDw|F#W}{mxxxcBC3+*0f)*AylM)5aztdwBNufbC(pZUb`=I@7_K0*Jqh1R7Eki49}ZtAUKq}hDw91yt4B5q&u&EE-VmRSXlVIe$6Z?Y3LA}m&g5%L6Wwh zprELtLM~Ct+Yz&7IC_)OYXbo|lQUdgTms;&LIH;tZ#YG)Xy35hYBZkpX`WFJN_X+( z%2a!YlaGS3k%fd-fS88Ypb^CJ&wzUXFQ-XAU$4V#mtT?gk#JoZc>!L5oCWVARXWXI znDSWWiJ$c+z3P8*d>E;jb5fla@~Ljc57#VhFI)A6iYxk{AFT`?{p0 zB>V`0badqJ=;$D)r#Cc-jf%Sb`}c1q5fPm+E$}WHc()NB+{v>O^>%^kvB}wmrZ~2~ z;z%uoj!vZML$KYC14r<)@Mu`QhAA^Rnd242JJ z$1D6U9r5Bc%F4=pk>KmI3=0b-6A3$B=g{MJVfIx84q6r#mRsh9kI8<5eq9P$+8j(2 z*v+q#lk}jc^`o73?jSG+|C$=l7poJKlgVn?@591amI-S;G4{!JB!VEPc@V{aM^8mj z(d?r9y)(gh9&E=UrMAoj8^0fM$jQR?yb$pKnw*&UIX9>E`{|uKcjQs1I}aX^z(xb@ z9j-Sd4_f!;Xy)RLgS?vL)2B~w-o3jA%nO2^K7A@LFOPiqut%}k&fJ`kiz_!$jEF+> z@oNk8K4oKawiQzOzAgcn7f8VAQ& zKt+gn9|$NXN|8n=n1_2_Pua}P!s`)7W%56pmbaQzAa1O&TRa0V^TuMt3W7VtWqF8M zLV_+{%;jZt^c7IFxu*0Bf<#3{-vkB8-2YzEkl_m&h5J4Oeuwz%*)M?leLPPu=nO#K1;z~*7#S1$>Nr?(9)9bm?fI&JlsTc})~_fSBVIU>u?oOV_vWk#g@J)V z5hw}W#T@SK$urh(>gem6E%v4BR6AQGuYw}RLpA1$6bN>YOFa_7F&Bv9aw4p`fvAsT zbnq1JX>=`eAv$~Kehhg^D$AMtxu1R&_z8+=ZrF|enJ7_3x*S65=0D_aF+}ZN=GXot zs{=p#G77p#5#i)6i6)Fd&oLfTLJ|YCjcv&TX26nI`o9spf0Ko$=oQAUU-)~0a72XW znE)77)A<4P_D!3(WGZ^f-C0Sm;KN6J4F@AFRrolah5fBY7wUQvAuHM{vRtu!O4K@u z7PvB&6zwZGhqq!srex(l_v&%ut7~ZBZR^@j#pq*C>L$f^y1jbiq__kDaVu2^aMF_g z94*ajx(39d{0fT`NBRC%qs6!%5|clf?!v{4UfV*X+g*{91MA7~gS{=`hX(Ar$C%5g zloM3s!qe^uJ41q-2v!!(`Ar4nmiUL~cgyXW?|n)A26WWRbs7iQB=Mv8zKAC{8@Rr! zg5Jj)`Qw&@3m+PHey>j-8ZTiVA9gUI-vyYTDwhtYTmuAEQV(&k?WAfuknhbK+*&+m ztVk`52(#|ze%aC3XW}0mv_@JP4<*AIdix00(4F6%_?5r<&W!G!86ADNSBkIK;NX`9 ziyE$TSsY?v;nZ#TcEX}xxIRR%P}eGd8q6M%oyM$750PA@mGM6Meox)#8 z>usF)(j;;R78G^5m6Gk)OT`RGSeZBP%fn`q0A~>-lwehOgcc0j{zKbl64Z#`rX)MJYFDh=g|-%@!G|6UxM;{*n8>Fd&pU2!Z)Lpu1rG2i+}+H?X;}{gZ~*A=icLg zh5UdIZD?nkoP`Est_m{1(;MQZ7bPr4Kw9qIcte)vwI;YpPVKm(lMxSEDbq_*kZf%G z^5b+_r`wes{ug9*qoa-3!^z73&``*v_cIz%Y7Tx%Jkm!*A|ABjtnfsSa1(uYBI`mx zM;f9w^46dO91s>=~m~il|DZosPv(E30WNDW@WW@{qEBPv`AOIP8tbkT5AC z-@bX1#g4l%`p7-uE|c+DcG%wjdG>Wr+mJ5s)nJO#oswv)WaHt2;{1KrZq)#f`X-AQ z_wa{zaU`|Aj$cbi6~yCSP4`nh0Z#m08YAui{5Gnel7{#l7@nlbgH>cn;KmAhoOS#0 z@}S`{vmexW@gWe76Lz^N&sYV;HT8CHmfR_F(EkZCh>;JXqi>A-Ho)@pqcoYE&!tbxN4qT3bIalF$aX!c!rc^_kPmKf&(%Lo&`% zWXi`ZK7Mf`zYHe7bqP25kG>HEIHIamuxI-5ZBoAGML~nh2Rm`gE!&Mqi|a`%c(4r~ z!_P-of_KUN$5k$xj|=QSlVm2vQ}>6x5d3fvG5AA4mjx`z>FDWyt*pE&D&kMe%v`qn z7!mOrP-9JlgH*e_y9@o9WNTA(T&vPfb0lXj&=K?T@r{g(FflO!s7&F}C4aB5_<>7C zm(a4;t426FUI_$Y_Y5hmALytw0sWvORZeUt`}5$>r%ZWzjK5=)W)OE>X#{^~b?U2Eg8(THBQwjXT2aX4?@aT;RGs)-cOsdJib}jtc6Rn2s1tt%^-brM z5eldXlq?Ad2!4)yHnW_p76h>QhK+j36D2xHk4yUHws)RCf4(?jk)z+9MyJ z>B`E=aD$C|8jlQ=syOu-V?VvMsuQ}3gOr@z|7WOH^x1D>pW_|vm9a8K2MGX*espzx z8tU)sBdH<gy&x0T^~+YjBH~C>~eQ1 zsGUXKG&@u6yqJOk+cN*PwH0-+xw*;t@$%ar&U?Af&>oV1u5KwH+okAWW4ADBWoQI1 zk+BE}Xsq89v>EcNsSyY7#IcPRXyty0iz9J!!{Ciq)zrwOt1*Mh@6h?@AmND@qTw^T z2VSn*0EN{@0N=@em;9eW@OGf;&>BhUe&fg^*pNkc6mFpH?d=_%oVYeClaiRy($biK zJ#@wh;zLKrnCTwtGD|?Tzp#S=v7(~lz5Dl5R1&2ibmjl3?#rXGY~OYtgvt;lg-X#N ziAshtW-KK`hNw)DxrE9*kCn_CWGIvpB9tHLVe)zx9HRXdQs*P+fqodb9d-ja+ za7++?0R7!Hrh_rL5;VOr*|yHOPt;N5K8>rycyaP$DZ{c+5yZ-CH$QV(K~a&wQ#hK! zlqiMnbE3Z*({EsM1r6o?{rfEJ>;bFR?KHw1uAt!HQnYokva$I!rk|my+B-M&=>ZMy ztq*=-E0FQAW5-^3E)~^}L^R}T!c%G}PE3dS4(TR~8LDhkur%!%pj61vFc&kf*1UXYae-@vF&DBqGat;z) zFCw!3)~#EW%z;5FM*>who0=u|1>FC9=uu;iOk(|N1+_uHr?(o-IJV%~1b6neWS5iJ zw{CLlAMaXWFWUEW5?dJlGb*3>qPwf;K)?8to4y=dvH#b5*O9=JT66ydMK;c6J*WQR zu8@nzLzc5u^feNTPelDaHX@lk`bbJCv*d^@6|WcdCR*}97~he1IyJ$!<(>o>5TKyC z5}sJQEd;;t6GnjNRFo6ly~k26F%~$YM=#`^H7cfW)ES-=Z zq99N_d`R-tfg`)P@4wqU*m`IoN{*t#OuOZ_=T6=K8hTHB3S7f{?88a@bXHyJ4Az^6+DLOx)g~)l!0n^#3NKqsyN`ojrc?uXeV*z=U|Suvg80W||Y~i}6x{wGKDg z>2W)>1(tui=9+m?WDHlccA%bYs9b3D?SHrG1i@LRCLSSk+WH3U6Z!4uX}L-GerWQ1 z-&Kl&LeyaAN8jhVEkBO*!eDk&@uDfQYdyQ&)ZSRZ_DH%6(cnaqn+U17-SjjPk}C-A>&h3Qjav%wK7 zrcSMo%a~Y0v10lA+og_0r0agQpeWaw#L1U(F0Wf^@*vjPNlm`CDSWEWB+lMOQ}s`F zWF{$a=T_;x26+bBwx`z^Eo#U%Q<40yNb_e{-cqTR0B>=B|e+dvl6z&mmoRq+Yb)Rax@RT8Zhc)s^0(8cda{Ev*QG&TRn zlbLg``VPvzmj_SEk{=io3~?YkdBXgZEMG45hmm%gvsRp6bb2zEAqKyXe-_PF9+$5> z%zAcul_P@oYFLe|xj-{`=pQIiQA(B^@0+>p*C=X}ot4W$)PYR##@Eua`_flQaO>+=q+co<^hH1qg0=N086Uc=< zEEJpS_ULt|PS`I zI_#4@n_Y?VvMq_ykTX(TBz zElMBYA&*iiEIjhTo2-FOlLTc$s#K+<@CX@Q-6oTSSiISG%rNT0OeXcTug?)!u8L<0 zuy+B9W_=mba&zf9)wX%E!{x~z>7|!n3d|R6Y;6sH72x6Q*p@kQDpKs12nQ{JGY9f! z$1lX$eiVK|X=4OchB+P8#mPUYk8XO?w8tl`mL&jWMk4#i)(%kxb z-92M|N8)@1fqsh16�z`%+HM{JvTKc2!qKb8|y4`SZF*`J-3oS~pd~F3EUN7|kAp zf;-)0*-<1_8N#L8I@Z7AY=ILiTyXCqcd6anyqYzJp@0&Es{=KbFFc6%F4_4Uu+fE$~en`tV6L+2sxW?zr$sEgz3zgGc&WZ#xTli zN`xr+0yfV4{aXgp-QwaRmhPuud{Bcs!sg79Qc_)`qoJS*laA1|IdjDxwDf^tZUuz} zY?5_TMpuYfd)Q`S=W$aYJ^H-7yi~mxJEdoP)>(dOEHj8SD0DH#jRh+3O-ka06#67G z$SPOgZK7V;?($`CK#>cz_ji9Uw(F8RX!DH)P3>dj<7KHDNn!Ev8GUaabb#s{Yq8EO zZEofOCF=xI>GE)kh~*Nx$P(502uo)x;r=&gLQ2Z6lM=wZ`8ojGfNR`_KY(xhoW_i_ zo@%YaRrNQfb7h>(H&H+9TzDNFSH1}e+|Sa}S=X-hi50t4ed|n-+ad6ev?-daJUu<< zLDn`D0snTw$_-jmNd-gV;Qo5Cb!U`WBpM558S8m_AlV|^a zR%T{q0(_`xXfBWU9K)KRhr33x*ki%e*7NVcF1~0LK$KJyznwQhFr$NO zIXHs9eLKBL+>U}#c_lG>&+LOwC~ofMix={ajt7hW{@PO$FU18jSoBov-MhxHguk-u zqJ~L%7lDB@{YmD94XakI0=wQcH(9a4@aQ`k=kXPUqr~ViRt}C;1VN3oap|Y)j28C5 zXv9>sFm*=!@L`pXLCmD1s#0}PiV6h)?C9%jQ;;mfP{hWub&!98Y81Om$FL+1Dju zaaOL$g=2$xV2O_9x6Dnuc9F-0s9e@2_442pXn$oRqeOw()45jb3knMCX2;FQR-ttE z^#x`Y{Z;9^|KP!cjpBA&DV;Zjib}xF69nrx+_D}dakfG=QUR7=1q6B1RgdE0HtS@Z z-Tm=AErsAnQ2S_#qgv+p4LldFe0ip~j$-$tlF7-%<=0zb3{0?We)ona1eJ%Z|_=NK=j(X*iLVZBguXImlr@fZwnHr;;s+6W^h zyN}=9&24P9FX22U!~wi=P)a_3UiVlnAwXJCP_V%sG`8s+d08#Ib=oX;son48%a^t8)LwysfkV>M!`}l@+I8c7j~+kntDmlSR*bW!HWO@Q@H~CPO+v{3l&(OHqyIOU&O{HWSTdYbANjW4~%j`Tbxc+uX z%Ixu66z(K&?z_}38Rx!b%NBA&&v_5SMPxRJRiUCIda*SK+u0aXJ65hU=6=jK zM4(iC?Rn|arDvzJsN>_~NvZlN(R;$i#GuY}ajlX3adFiOtjYRZYh7`1aU)cFhcoww zNtjrq>*wBN+03f(dztSttJ~#C>HOQjb4+~yHAOwU29u!fThve0`4|N&4(>#12~pxr z7qe|)D_&ZZaC37rtuu`FD>zVIpeakq_C7I@_568B*fx1WZy@yUF}@XXF!5mZPawH0v~fQ>o4>ZVSE#+bAF&d9sJl*P$uBcg$YXKV{J!{Q z>Y}0|x?FC|hK$leY~*8&H9w9!!1?s-;(*$5j&0XRB$T(+<-1Htl{){+__lBw$Reg40e*P1BAZ~8lPu=0UB(-VN zrhVvq&`Q%`0?~DlPa6EkO%7o?ga#jfe_Ft_`qx*icW;{?D_>KOPKXtBbUa|E2*-x} z`utLdG$sGN;cdZlVd3H4h90xDxYyXMG-0vddJFq-2QNeQTR~6H#kPSpJnJ%>Iv}}b z#iQ#vL;Ueqo1H~z?=iU#IP75*q~6Pow(=DY4h~<^bk~5cQGcQZ`Orrx=?Gr^({p3_ zn_ObdwY9Z$yJY>5DHnaw5WxVFLOD$*eQc~tUhk5T#ymb8Ff&v}-KiSRu(?)!Foanb z+YVcPFJ4->MA*Tut}cs)U}NRo|J@6ayz17bLp-Wn57fXZk<910FlgA{Q7p5#FzUHs zMpECHQ|#hR1W^-o^kef;fsP>|kh68s86@4eSC@HVyT*>EAuo@xnme=+H__(Xb1Gc8 zT#`}3jfPE}PLxzz(W(X0pjJWdWmT<$!v?K!*OE9Q*rNnb-&Jbbh2XK7q zxi_Mcl9I`p^Pi8AOF8%2F+9E;#o?9J8uzpE-zO)hq?Cf6TRdjPqS!CIns*j!E`lB5Y_Z2? ztaSgFjQS7j6*pRH|=%lcW(ez3L%py2C=9bV4-I#d?38mL)GMItjl=6kT}*b!v-4}z(> zvM(>Tt&VeY$J3*&Ej-O>g^c{nQWz=`6)7<_?S6MD#@S{xX3V_l=y&{|0B~Y;T0n<@6(MC=PwMg1O-*^m4@7*RR-7>en0xzi*uG4D;f*7b=-O2 z+Il0^B8f2j8Usc|xLfsJ)7-esdzI6jN%g$8clGqcg)K?hs%D*2$DEUt77TygTc(~L zF0&0XJ@XH#bAYn^Gdtd4x}Dtj z1itqgzAyfg8QVlvHBsk?YV=FSIWFZ}oA$r@h{#os-XD{5kI^f+QS9Q`%jLW-yDMEf zw0oxK{wyuY%gfiJrHog}^v>>6kBl(CqmtQ|k#KaWE8MR2N?7Ix$TKlo?_=!p?i%+U zZ?Z0Wldh0)M(59`=-lfQU1aouFXA+w<@q7?6Af60okK%5L;D=J@7i@9uQoP49SC$m zX!bSIGZYZ~mCVEbcYY0?jiRAvmVKbU6-;!VI|gdaz%+c5O}iPkntgxICI$xwJN3m~ zQ-6GNGiZ|aGk=^h;d^G_al#lvPu)sZ&CAP*7S-0br5mZnNVDZr!?m1DsDi z!{IBVY-Nj#<+V%0Z|}OAClBv!J;T;&r2W$TL*1@Gp;Vq2uYiDnZrFfOb$iJ@z<;uC zZj$H)4!zj+y0pE0i*}k$1#agJ`ZZgQ4r7v7=-8MYg7nG1Td8}bx%2IRs*qv{fo$y0 zpYo~U?k{Sa%BTH>5m5VNll5L4d!KFCCL(_a&4(q+F-S z0yH2;2B4g;TN3}D>v_0+`rTCl4gE29rS#)RCa}*N57{T5Nl*|r4<#lBO@bs z4mmusSptz?k8JX-xw-Vi2L{9|IZRrM9}V!pT`jfFETjX|^6K^L?vG*}ZxOtb7icHj z?@PPu|Bc^*!V-OuidI%nCkktC%pw`BNKq{e!h%?_6*zac_gMa zHZ~HE3G7s6?=bIeQ?AulA24rVZwOAXG{ol|Lj(t8+>)xD#tu_d5X1!|xRPDgt?-xg zw*B5CHIJO9x#pHi3USgPq#N1VG6847RHS=Y6nF>2KyDY+ZI-zD7R@e3LBi&>iMU8R zJ3BAX8YH)XcPg3~+_ZkZKb}E(Vh!Y2W+`t6@_#4K=mq5qxjt`m=U+;nOUpz{59c6s)dZ!fIdpEvN$hfA{P zk4^J4Q;&J}H`H1yWhaw|im$U;{6BLb@zFnP|2G_nOaa==HiOZrM62d+gas3Xpl=we zvwSz;O)CG_jmO9(%9xdvC4RZv&#?x>{p_ZO^cXiEW~10)=s3!A<;Rc9N=vD%8=sZp zm9T@n*YDBkjJ;G;Kld8S6%WcMdge_fsQjLrePFM-Cvv=Rad>D5-nbCd^s7l-jdw~} z0CqW@W5FtT_LWzH$KO?+OABJsfBShYJ>}%6hlYnIJ>1ce+VC5MMky(9!TsRXNg-M2 zIvs#U7FQu>0%tB2FR)|3HB7r8+)?ahCgOp|(S807yJH&+KO=3EJF92SOihhY?3{jc zbuXYfxG^6p8oHx%NUYbh&c8E3b1)$b(buqK1S|cwVo%QpeQ4eG#}XD0FmPi}SG>Va zC_&t?b=}?F&jIOX;JBK@w{C!BF~-|KFgt?o(fTK+b`2~%t%HA3z0l>jH8MZRkUo*-hoK$j~ij~GUdSb)K(Q$)%6qwgr_X|Z7ycX zrYS)?9>)av<*fK!86&3PE5ekV`SbDZ4ek4cwSzcr)fK5!MZXF4pYDl!SYIP^Md%6>r8aYoV$ruR{bSpJOzMS(|4@}&Zj6WVD2~}*Sp>L`SWL5N){f7T&9f6ghl;hHEv6H%piIQ0+R;)Ls%&X z)tEjvI+#H)p4WJ}hm_eErzvddj!wRZmw!}d559TzDrlEl{L$|*LWY7^mi|gfta3f= zO9<3EB`%>#dE4_olA&+}g1Fj%*~z4I?M4MI~0)KDc_DpRcd)Y1jq4=L{yUAl)RA5nKg`K|cy9b-BCd;l4j#@4s%zkE>|CVw6q0I+Nw7Z7y~}484}XAV;Nz}@GfP!; z?o$b zB{e+9#@w79L{#gWYiccdjsu#{^mA7O#aGqT_(65z@P7e{Se-vQ+m=`9X32Bo^5>xE z0k@1nJ7>oi)YP1S19R~0pTW#x6GHD^9p-HK{P}@ST33I6>a>VO>#O}pIw_OH1@QRs z53HK8sj0?%HOko7SXf(Mk^9^}bY-4>?YIhcER?6J>S{S9B|7{KaEmvtcQAyv80W}z zepZt~=fXjX`;oahh+sS2Hmwg3Bi~8PE^GBV_MUdW!0> z&ANX5;Wz2kk^H~uX0sVutw9vh=j@G#bhSbdR2 zK!Z%OTSA(TFSROO$vxBg?Ac8Sjh!PSg2R_^#7^y@agU2p+qt;7DEU%jrmu|R@M%WK!AGm`aM*xhe# z0bPOMOizx>c_B1Pvpy`gb-k1G@#D>}4t<nR4|L zZRS9^?QVZm3zin=h|UC$uiea;F>LsxQU!j(dQ8>LjBmFlG6qskU+u3N0@$mv;#ekdddOx_iiBCt(9%)M3CjMT8W74k8K=`}_0Q zy|V4}F75_X?qXO6etv#N*yuCZ1Y|V}oPJ)?xto?TIE&AJVTkX<<@WEhh6x3S&M@tb z6ueo~)T)lR$C)K!tJ^v<&Pfcl7rr^&SF=a^D*Q4!68vX~!GLk!!kD9EjZ5kaQcA|U z*zhM-I7B2c-x%&bQx^X7z$+T{2b8EC!W30YP3%tdo0!*q9Z_www6wI)I#ui{q#x*Jt^^_F#Zse(AKT8 zua&g4*vP2l+~DiiBjRNC?q!0yL6#T5G{ewG_xRXYF38#`lZXoa69Eu|2)jf0Bf<+4 ztpsX=MN7`^h=>Ry5Z$hV&&7XzdEyHUVuW!pL|p+e3;`r79vf^;)!+rn`XzbEich}% z8m05+Pna~7O`vUH1Ry_=y3p`?v~Z;EUA!SO0*w2_<_3PPSp9z8o?i>$ONE;^-kDP_ ze7HEub~O$JJb>}(7W8yv-CngB3hFE8KRe;KyFu*!UPF36B|YbV2e!*p6&y8jM|Fw$ zI}hP!uef*|;Ms_Eyq5~?k@RS-_=HprwXNU%x%i$v4C{8Cd<)ekJT5K-I|#7l#9uYE zvC9+14=%6@z<c%vA;0~G62#EXdEo=uCHCGID6{6hnFIt&g78BPc6VAB?~ zWc##QhvpZcJgyjpM>(oXAhAk3!hLp~vX(t>aBzs-sV8zT&+>u)c*0Yz$gsIRXJDe;z^d zCv2>5_lumIw=mBUG>L4n6%H{>WQbW?26C5$gajj_tq>xBm6>gtvuOc)!LZI>8EBw| zgKGv|@nJDB0dPrB^0^loT2Y0dY=-FX0^_psUZS!hPxs$h9(OSieWW_R^6N_}p6YXq zyBt`Gdihd}Z8a@XePQ`4udGz+6TTA~x`r6+kmHFsOS1LF)>q!ZlLL%wW~mR`xkW`q ziLbI~e*8Fj>ew}Rq}^SzGKDc;-!CBGAzJ+bnf4Nj7BVv6oVb9X_pkThgD>y?~9ISho1ML5p23KC)y}o0^#g zfBK}})zh>6eq3Ayf!kz-USkR91EFye>S}R`I~LVCtiW*FIjmy>ouPVfz-`#j;r1uQ z#_N_3o%aAx;SBoDQ?3GRg)y;fep;9wYCgrtaX>-e5L4{2`dRQmc(!j>fV7=gLM6(W z3(UyM!m<_4NCi+UZ{dY-G#EuVxPek>wY9Z%JtA_J^PdaNTldNJEVesdxw4+}7OBjS zCBJwNY%k}?RadwIiTid}bB*=D<*)(VDVvUt4s?gp7`-NcwG}-(3UF`wr=+A{x!yt? zNY4?vL<#jN8hfvzDzJm}W(A{BR#&Hn0*pI*`a8DMRt$3D85$b$OG--W z|MDbqX4G?OPcGW0*1=fA3XQD{Q%@jP1t0OH3x-NlIE|aQm$Wbe>iB6sVZ*QRNJ^fa zT3S*x>;oqlgrT@A(R>G!vjAAOw_ZN?;p=R`U zo#SI$Jy1|*C?+g?3bKI#e8DSIGFPc$)>{&Yy*)oF?84m zmZzQ98Yt!Xpp{(P> zGIqJAczPpiv)xf4z4BN%3@;QD|GA_MynbGA$0&tx`ojd)@1O_+VwTc7?r z557J;B4WJE!Ku&+)*fF}knelB1(*;ciBN`50i9${vyi?11F}OIYRCZOm0z1T_L~9q zaFMJA>k7l(uT9gg!3owMn9)QI z@Hx`s5z$zfxMFNzOm9LpTc34ad8enTeI7Ik zsH(6^bwU4uASciQt=#H8cbJ4wKfZH({*W!d^Zk_up~=bHwVr4y=6x(Hqedjb%k$@i zXg>MEk)DRK{ID%?-GgZ$5+H~iWZ0B;r{3B)20Xff51;6lf>WIuD>CFG6Jm|VMoP1X z4HpLY+_V7us`i}XQ6J}|rINZacx z!WIwzHca6_4FXKfu1dMDyOxP|x|(i2HIx*3{nk_S6AbI$L&@F`-hgxoLNBtYe@MGH zy0qFOrZBl+Vd{JKZpA9c!ie2}aYE+mO)8Qy@_igG9Eev;BxirPEp#5h>dro6YABOB z{lA^x^&DGniqtBD;i5=#F(s$Hx^Vr&Nc+b$`*aIBWqGd&^(>xA?>zL&pz_6vuybCO z>N;u~AlM(B$T(#{6kh1pv?C61g8^3(F-SD^K;oq@l?C3tM=xSu+&O5e={?79f42)P zxf$<`wLhm}HODd%7~dd|0M&)UD7K2NYO~JaLx=Q4^DlTHFs^1h6JuWmU$KFL7$-D?sJX0{m5g^ zxH0qjgzSaS`fBC=^c(DW238#Y(yu5mouN zZT&fbV8^HVK?YYcO{x2PdoO%i5&d6#LO(2gSO9WoO4m7KvEdrwb1f}{fnhBBfBuIb zKJ}CJPu2FjOB@i^IOZ-6mCw-p_5IMZ?Ct zN2gT3I6e*DlIv3NtzykjuJQKDpB1$a0w)-af)nmSWZ-|Odw}ZTe{CU!x*!jS1u+X6 zl~62Kd;+nUgz-rP`RV7`tilpSS!N8yky!RX7NRQhMX0o-+>u$Fp&>>!P((%$@k!GM zFpL5sqM=i_E>FZ~3lXQNg_ZD5omZRRLMh1>{_yFE(ylL#8NsRFH8!q4v|9f`bc&!Uy^Yv89AG?#|P8ql+Ipg`w9eXl*({xj~3Xzra9pO$`cTc(4DJvwlfo%o$|T zbCZhIL?VVdL!5+XE)}K-dkeA_zIqb)*raY!YUaszHSsIcY!uma?H`|c6Ux5Vb4g?j zhq=^^YrV|8b!Olth4QZbkdD!Ni>y>_*@pbqpVvIx^BvZEFWZ$(AtNdhHwj2WhPKfk zKaz0bMQ*N0jNrNJ=+R*R_6TggxP8yj;L2kY{2w&UyOdg(X2fSpVLbV>t`b9I^w`#jQ)Raw<}9F7vZQJ%#2 zbkE|bZr>B~l$K%>pk8jVB0^W>gT|+q2Y>IK5&yYptQeko=6KTA&hJYe4C$Xt++O@9 zDY5*d+O7!3dOk!iwOtj-#zrSS+`uR&BLWsH0fVOIQ3t9=%Y<=48k+~!ugO`#{({0xPr@X)@ayZTdOrax$UeaaasJUQ!87Vff zoXerE81B^KhkeK&MP(<*4tdblA)A@&_XMxnS{OFVrQF^5gYSIc5i#KFqP6&svSymL(Sd`c~`c}2uWZ%-+W zM7;}~0z!Q;7pDWgW+$~@+}MpFOL0h?A08%a{0{a{tZE zDp_nX7I2gPij!S5t1~o+gEI(aCp{?Li_l;-QFI}}U;g!FH_X#-zjalShIw>Lm15Z% zoj>mjd}Rdh1(aEO3K05dV>rD(dg6eqTemNtL`57Z3i; z{vUT+3&8)aVm2^u*JN=J@)iKXp7C$cWs)v+6ai5SMW0Tc*d*!1LUaX$heVw;V(W55 z;RgGvU);Ho7hfR)mH}#!8WqsIql$H1q01dg1H(YuflX9mYMc^`{PG#SCsR#kpFNi2 z%{kh1VXJDSw}7oEzn*wmdkx+6m(VeV*j38H5?ZHz)xNZR@Z71zfQG6>EdN^SH+NX)+IbP4IJPqm*msq=YN8U#caL~kfwnBpR=7(h`O zcrI^m@0Flr)MS3f^!FCZz>pXU(u2552;5lobiKAkM>}qbYTjK>@P-W!-&X2-ToIc2 zJfGL7uDvgJ;KGD)b@;PQw7H7M9FE0me6HS`RI`KQ_!oQ6ogbO_6crSY+*H((Rb`<% zVn>so(80}YUXcCjEp1PcVVwhaexF<2tVhwbLuTumzWbxW5EAqA51~#a6>>UtH)_!e zD7lwG48hNe!2v5_K1#_hN=&2x&z2LY2RI~f_7w}F9!A#m(8PAW*&sXGsLrN z7dfl~%f zU9Uti_ zVy3Bhaq@H~FY5x8oZ7R9h_zHObL`#|y^zj!gDr?XNB_6L>>m;+@~^poO*nr>8%Fos8f&Ifp8gF?7lK!A-k ze0Yrefdlp!RD`I@0J6uymu*9>Xb=1Mv6Cm+Sy{QcxwEpfBgmW<Nq5hpHNe@TcU5@ zz7ZD}=QcO|l81-96V+B{XJ`E%XK`5F5YEg-Gr%60XS;v?+<~LCH5yTZt)%3;Og`0F z!fuPeK7S8>EE%1f4bYFxVj(I8<3zpj_E{L`D$qGV2FK+zLvjcCr3aPsC!+3JkQHrj zuHD7~kSVGVxsm}GbERQTiOxfVWj5xh;-7F(gC9Uryu|`RFK$kIsP^a$ewaJ5u~k`F z(Lx(U3B%exhJ_GOhgS$j*-r{&P>(+`_?`u|_#rsn_m0Ldy|1p;l%^-m4|p->!St7J zv5;_-x`ILv?wiTQi-Jg<5iv14p%bi^kdU}5Vy%-t_pzpijRFSwVXhlIp0HK^DS4IZ z{A6HSbG=8Vt+)Us!Dv2_MdO^+Uq#o#`PY)0XB~ZIGHx4s_P<)UZXKS#EtWHn376qN0+&UW947 zdt@X-vZKKH1z1tO+3#4$8#y_p$3p2K@~o$5>g)6Kh-~np^x*iyYEcMi)Us>Wu8sN~ z8youpTvwDR?4Avvg7oJKVMogK*@>zgPu|*{^a9jofp)0^nH1HxhM$slJr|qLBKNYr zeFQ?a)!xCjH!qPV8k?62=inDE%6|6np&X$Ca9QYk2KtQG~*K!Z(&aOga(=U$xP1*(4?0q ze|^>e{uqwT&!11_7pPHZb_W4a)-Xi15WDCN>>j8HZjz5`2+ZLUYd9Mnpzx=;#E1 zZz{#*JA*asLU)gf1i0;Q-*siq&9yhtP!P;=4MMTDjp1-5)GAyQ`%y%&v9otW0Z`$} z)jww|14%?2Eep)t(u|!WyM4g->5gCf)2oW1idqs(l2BgY$|fZm7(07?6K!TI6H`+^ z6u^7xVEJAR787>V-5*+8TcbWpS#z3SxUds!cRj!^Oh#~YrRx`?#frED(nH}O<+sJs z+y^5VQVp16Iw;og9H%*KTJr-L?3LrdTHL_g7|7F1tplx06EJqGIAY0O1#Hk%XAMZ zNlDT^YRwI!yCm3gS_=7eQzLCR@C}lk{V;6_2f7wp3}^n*?{~#K=FVdWJm#I|F&MlQ z^6YAJ^FFJ=8`rN}&6m@#4xq=Ihezt6<`biftrO;q+E28g#0NkWASNXE?@S=gK!8tu z_Uwz*AaDT^FyO$s-4|fNF)}jp!fdUQhK98;iW9v79M;U~-*7}cfcUnOiAg)A&6O=( z<*TU(x5vyjtN<3+-ym##-8j7fyN@?u*Mrp5{_s%h#+;?Tcvog*ntcZjXx`HQdi(8L z1;7$ONAJ$wJaV&G+H z(`MN*64yqyVxNIQqqm!Wkq3%BJ@_@bS2&@K&6$Tx4Itein9-)a^I3iea9q?eo1j$V z->~HKSZN|^5}W!aiK~oQ67RsEMR71;W*G5(0jwkfR|9Y?`|p5SA@-2{LS_^Y?*ZZg z&ODYqfot_@YTLWErLyqY}?J@%kFmb{Lwzor*LnI{{f4kEUrt~W? z**lHo_vNhKixL6ia7r=Owd9wfaBOVCE8-@>{S^esb7HV*^Z~;|e7uKPN~g0-XyI-n z;w3Q+d~K1c-kpd6WM%j{`4l}prt`y>JU<`qUP%!u7^5cg6)~ofdxDsjdrJ#!E+^v7 z46xZ{Xdmdf#{9_gaw2e(Ac0jx^j`pgoC|60n8C5qrtS55m?U5lE+?>%hXr|_Du=i+ z@hngkoQ?_lrP)+5rK$&93641!);0w5qX28ss}5jD@=vr!5}-XKLyOQGhv}etol^9$cvx1g3I3y&5qCVy7Squ_wEgRc*JU#-Mfd4di z3OqS(J6C*l_#T;uhh6d(^##(wuFS%zd&u7g3RVgCE8o3SCMJq>GZeS_!fApy83ZY= z0>&Z4MbqJy+^;*0v-0vXM0(+-y$&db?E9qUHR*-~ul6`sfS+FlY3eP41fq&5o}m7^ z2VV}Zx|y4sOJL<;$9~Kf+1Wifm?rZ5g6a=L;0#1F{yNyh_JI9GWyu3AISf)BXNS-l z$VSHzyZG?zaXN{_EKIOzh3&;li*JZ-Pf$0O5!z(0Sru|CDv;?cK(yNd#naBAn~rFPU#xOru&n}to(wRf zwQT2g)b^Bt~fzZ_nXwvl_hEUJ#5&ZUh}eKQcn{`qMcI`-Tlm_3D#yU)?Ec z_r)LFcS+j=J&|~22q6FLkvi~XaC&WVljtUcm zS5I5Nf#DQ}jzlV6>vq_RbO&86U`%9U4}|Fh(EgWu;~5?oCrR3gRh9BT6|^?QQ@6Om zSpMbk2UCFcoRRa&T9G)01Sg+=eN~I)V*QYl=Umrr3@cY7Lm#ZVu~@kzR=AWkx3ox) z4=VFMby>etELs5CJ$q9BAV?#~X-!juxLN-h5NSms2c)4#+ktDx|O<%^XFBcl_o{`vD5 zppi_^$i(FQkrr;^%WKIjD15n>Qj=jnK|x2=o|8!@_-a^JKsUBZ=Zf%0C~hLlSFNBB zn)z!^;1YYLD=9*N?=avUZSep8(DSUNfn{9^6sV|)B& zj&&0=vW6Fs&&j!m@5(oc*w(0hez!@)y+&jCA5{@$9v>7*FlznQVLe>Q>!r32I{@;x zm_-5tfa)(9hNrghGPT<$vy(+;Zy4G=5&N^WSS1fRBlMP52RK~gA1}Is zL6(|=)zOvPWZ;W5jSV%1QxhIS39q#!HeC!dAbrHZs%R=LgF_@O_Xg%{{>uYl*~003 zLbu$a(eW-Xg-<}h4@N;ew(B=;T-on;L2bh!)b_-yfsjTI0|%flnKb!&i^>Z;H7Y2w zh`kTNIueRMZW#QkGBF>ajht*3*essp8YqmBW(dGbgc~$rD1reH4E!|o)%+IDE+uV% zWLhga61@QY7BN-*2Wt1%Ta;etbR;CX;xHR_ke#rVJQ;mH86Y?gc{kvJd?4D1Iq^5L zDCJ0h$~h))&U-AuLB++zXy~X0C-k$#o2CJ(5Xj%U0+w#L*+Ds zb9^lq*B+RTqCtV3gj311Ios^@mfCff0KgyWr2Pa_7!7t$mDsTWtHYraHhlxuI*BDu zh=*v@;dl3+57@}XbsVRjHe6VopCyS2)OsW?g5$7bDxq)?*qYVLt#p1#DC0OQ#P>be zxlt1nj-V!Q!&H70)A70NPk+2FD6OoWXVc6Bn?N_b3=r=dC)MxmIM>q$rppAa)hI~` zyMacJn>HOqg^lhFk%6>Zk0H@g!ax)Af586Pflh7@oSeq_a<`x$JJRPY=*VcZ)f5_ z2M-KVULrTeod{P-#mh!f0qv;ag^UkcPXuyM6Gh6jBnQ0w$*Wa!kTS z3(Q7oH*uhw#P3uh@~^8|DL~)d2}|?4aQX=Y4b9^1u~&|FljcCWl{7ZyE&5yvz+qb9 zuSTV&)2dU2WC@0l-f|Zpc=3G=#Sh0ld$JBt>5jr`hTOCs? zfvsRu%56Nr%gamXX(6G^M%yn|S8@$g;SSQsJ9KSi6M{JMP#{LI7=f+jkjNVV*5nOdgMWU=@FPCbv?acNF)|} zGY4}&e)y2~!u@M;XtvtZI6-^IV&rRM5G=3-2Si6}FRaK=c8pHKv1cV!RVzVvJX=bf z48m4UNC?f&K0qcIfaUx4!h8)G9DqZ^)`M!jaglr~S|edmj9E+7dU}`}!1M5+VN!kK z!-v$QeSx&yqBj_f*B0P@R0Z{V#cC0iwe_f;R$jYyjW~-B5B@mT^zG2=CR85qpap@Q zGal#mD#T~grlxQeKjX>TH#3420 z5SY*6|;BLGQ7ztL1~BsX}N9t_T*Yo3}Wy?<35mTZcGOXGDSs2)DNkbC9txi zZr@IN_3HFse+B+H+7^kCiUNxCRO5HhD7Ce>gpV!gp6}y3e3%m`gmj#Ey%F9-?S-)m zIPv)w8gyqa2R7Oqyn5>?79r{U-3V;~7KGnUzV*ED8zHGY6CJcw&9q%lPak;ryY2$< zGyPn8`LwT2^3MtPfXoe!;Ys}36^78Zp-DJTjP4)cQ3ei|_3$Wa)Z2wr2y4OG7y&&W z7?u8RRuF!4W1cjup$Hy)!$lzsI}rQiMk$ehexc2Hc^$J>n8wmqzQ&=-;ubMF+%tKCh8^lJx>kheziD= zlDxRl+7-NqQ2JC!oAuo^vmTy*M4wwjn~8|WHxm18gbq*QuhQP{M!#7j$0R_ z8tCro+J=S})L1;^z-$;6RS`nCY-~RIXV^PtVybv8^FlS`L`Kxw=N>(3vG z)lawid!z;O4GEQyAtLVG<0ob^@LKObe8AifQ2o`Ym!qff24KFnLruvC#}lx0zkqMjZLF-25~2`CFu{QZ)rK<~43a_A>Lo-+@1Q`G zRnof(qHYxvQ#6Xz>-qWlq&dmn!6A3*jHV_J@)2fWa%JV_M#Gf(9^@jI6>=PIh!DG% z=S@L-auOGn`Y_rNTD*8&?QuOP;Fl+*2P%+dR`xKY0uB+YD4jV_F-Kt65e141SaXoKPCFnGU3G(y8E}gI8JbC$F`% zv{qoBp~x0Joj{uHz*m9;_+sRa3z7=bI8P7wooo#bHe@q3{xq`W6q>42o8L^csRP8- z#EPxKO%F&(*+FKDgQ!ojML!yVps39^FZjR8zkFd+lUX?EG5d2viYnF;fM(8LGO|Vs z#`{Ik=t}nX2k`s&kB67{iVWzThxE&z*u}i7%)kF@+J=+&#l2w=B?A?}Ea=-3e~o3s zc0x%Yd_Z)RH8$>&ANGV>nZzMHtgzVF8&IAwH9~HKJZ`efb^?5lj?QBegA!I8>0CwM zTS;v#GajS9Z6)+h_;5&|LP!sQ4Fx9m1@$@j9epuY;^xhpot>Q&#A5XB44ClY#)bN! zBw`(iTjsdawD?IMsEmwZS_s^=<1fO)S>Z_Z$KZXs8|;HwDrg>RK`laDWhb1Txk@e^mNF7sn3!6_85Y7oP^QP8Wz&Rv>@D6gTX z*SxQkJfZE&>qP`Gh*&Yg;ok;uhXl1APlb*$G5mc)Z*MOdDnMparKYA@y>NS-pHFIH z2t|B+d?hH2xs2KHiCUHy*HJ*bgL&T4gSZL54E*QFo0jiF$ 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + qx = K.mean(tf.math.multiply(x, q), axis=axis, keepdims=True) + qq = K.mean(tf.math.multiply(q, q), axis=axis, keepdims=True) + else: + qx = K.mean(x * q, axis=0, keepdims=True) + qq = K.mean(q * q, axis=0, keepdims=True) + scale = qx / (qq + K.epsilon()) + if alpha == "auto_po2": + scale = K.pow(2.0, + tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0))) + elif alpha is None: + scale = 1.0 + elif isinstance(alpha, np.ndarray): + scale = alpha + else: + scale = float(alpha) + return scale + + def smooth_sigmoid(x): """Implements a linear approximation of a sigmoid function.""" @@ -69,7 +129,7 @@ def set_internal_sigmoid(mode): elif mode == "smooth": _sigmoid = smooth_sigmoid elif mode == "real": - _sigmoid = tf.sigmoid + _sigmoid = tf.keras.backend.sigmoid def binary_tanh(x): @@ -87,15 +147,15 @@ def smooth_tanh(x): return 2.0 * smooth_sigmoid(x) - 1.0 -def stochastic_round(x): +def stochastic_round(x, precision=0.5): """Performs stochastic rounding to the first decimal point.""" - s = tf.sign(x) - s += (1.0 - tf.abs(s)) * (2.0 * tf.round(tf.random.uniform(tf.shape(x))) - - 1.0) - t = tf.floor(x) - (s - 1.0) / 2.0 - p = tf.abs(x - t) - f = s * (tf.sign(p - tf.random.uniform(tf.shape(p))) + 1.0) / 2.0 - return t + f + scale = 1.0 / precision + scale_x = x * scale + fraction = scale_x - tf.floor(scale_x) + + result = tf.where(fraction < tf.random.uniform(tf.shape(x)), + tf.math.floor(scale_x), tf.math.ceil(scale_x)) + return result / scale def stochastic_round_po2(x): @@ -105,8 +165,8 @@ def stochastic_round_po2(x): y = tf.abs(x) eps = tf.keras.backend.epsilon() log2 = tf.keras.backend.log(2.0) + x_log2 = tf.round(tf.keras.backend.log(y + eps) / log2) - sign = tf.sign(x) po2 = tf.cast(pow(2.0, tf.cast(x_log2, dtype="float32")), dtype="float32") left_val = tf.where(po2 > y, x_log2 - 1, x_log2) right_val = tf.where(po2 > y, x_log2, x_log2 + 1) @@ -117,10 +177,18 @@ def stochastic_round_po2(x): # use y as a threshold to keep the probabliy [2**left_val, y, 2**right_val] # so that the mean value of the sample should be y x_po2 = tf.where(y < val, left_val, right_val) + """ + x_log2 = stochastic_round(tf.keras.backend.log(y + eps) / log2) + sign = tf.sign(x) + po2 = ( + tf.sign(x) * + tf.cast(pow(2.0, tf.cast(x_log2, dtype="float32")), dtype="float32") + ) + """ return x_po2 -def _round_through(x, use_stochastic_rounding=False): +def _round_through(x, use_stochastic_rounding=False, precision=0.5): """Rounds x but using straight through estimator. We use the trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182). @@ -139,12 +207,14 @@ def _round_through(x, use_stochastic_rounding=False): Arguments: x: tensor to perform round operation with straight through gradient. use_stochastic_rounding: if true, we perform stochastic rounding. + precision: by default we will use 0.5 as precision, but that can overriden + by the user. Returns: Rounded tensor. """ if use_stochastic_rounding: - return x + tf.stop_gradient(-x + stochastic_round(x)) + return x + tf.stop_gradient(-x + stochastic_round(x, precision)) else: return x + tf.stop_gradient(-x + tf.round(x)) @@ -166,7 +236,6 @@ def _ceil_through(x): return x + tf.stop_gradient(-x + tf.ceil(x)) - # # Activation functions for quantized networks. # @@ -181,8 +250,8 @@ class quantized_bits(object): # pylint: disable=invalid-name In general, we want to use a quantization function like: - a = (pow(2,bits)-1 - 0) / (max(x) - min(x)) - b = - min(x) * a + a = (pow(2,bits) - 1 - 0) / (max(x) - min(x)) + b = -min(x) * a in the equation: @@ -211,6 +280,8 @@ class quantized_bits(object): # pylint: disable=invalid-name integer: number of bits to the left of the decimal point. symmetric: if true, we will have the same number of values for positive and negative numbers. + alpha: a tensor or None, the scaling factor per channel. + If None, the scaling factor is 1 for all channels. keep_negative: if true, we do not clip negative numbers. use_stochastic_rounding: if true, we perform stochastic rounding. @@ -219,33 +290,123 @@ class quantized_bits(object): # pylint: disable=invalid-name """ def __init__(self, bits=8, integer=0, symmetric=0, keep_negative=1, - use_stochastic_rounding=False): + alpha=None, use_stochastic_rounding=False): self.bits = bits self.integer = integer self.symmetric = symmetric self.keep_negative = (keep_negative > 0) + self.alpha = alpha self.use_stochastic_rounding = use_stochastic_rounding + # "auto*" |-> symmetric + if isinstance(self.alpha, six.string_types): + self.symmetric = True + self.scale = None + + def __str__(self): + flags = [str(self.bits), str(self.integer), str(int(self.symmetric))] + if not self.keep_negative: + flags.append("keep_negative=" + str(int(self.keep_negative))) + if self.alpha: + flags.append("alpha=" + str(self.alpha)) + if self.use_stochastic_rounding: + flags.append("use_stochastic_rounding=" + + str(int(self.use_stochastic_rounding))) + return "quantized_bits(" + ",".join(flags) + ")" def __call__(self, x): """Computes fixedpoint quantization of x.""" - + # quantized_bits with "1" bit becomes a binary implementation. unsigned_bits = self.bits - self.keep_negative + m = pow(2, unsigned_bits) + m_i = pow(2, self.integer) - # quantized_bits with "1" bit becomes a binary implementation. + if self.alpha is None: + scale = 1.0 + elif isinstance(self.alpha, six.string_types): + # We only deal with the symmetric case right now. + assert self.symmetric + len_axis = len(x.shape) + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + else: + axis = [0] + + # we will use this implementation for the scale for QKeras 0.7 + levels = 2**self.bits - 1 + scale = (K.max(x, axis=axis, keepdims=True) - + K.min(x, axis=axis, keepdims=True)) / levels + if "po2" in self.alpha: + scale = K.pow(2.0, + tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0))) + for _ in range(5): + v = tf.floor(tf.abs(x) / scale + 0.5) + mask = v < (levels - 1) / 2 + z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * (levels - 1) / 2) + scale = _get_scale("auto_po2", x, z) + + # z is an integer number, so we must make the scale * m and z / m + scale = scale * m + + # we will not use "z" right now because of stochastic_rounding + # this is still under test. + + # if "new" in self.alpha: + # z = z / m + # self.scale = scale + # return x + tf.stop_gradient(-x + scale * z) + x = x / scale + else: + scale = self.alpha + # quantized_bits with "1" bit becomes a binary implementation. if unsigned_bits > 0: - m = pow(2, unsigned_bits) - m_i = pow(2, self.integer) p = x * m / m_i xq = m_i * tf.keras.backend.clip( - _round_through(p, self.use_stochastic_rounding), + _round_through(p, self.use_stochastic_rounding, precision=1.0), self.keep_negative * (-m + self.symmetric), m - 1) / m else: xq = tf.sign(x) xq += (1.0 - tf.abs(xq)) if not self.keep_negative: xq = (xq + 1.0) / 2.0 - return x + tf.stop_gradient(-x + xq) + + self.scale = scale + return x + tf.stop_gradient(-x + scale * xq) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.symmetric = True + self.use_stochastic_rounding = True + + def max(self): + """Get maximum value that quantized_bits class can represent.""" + unsigned_bits = self.bits - self.keep_negative + + if unsigned_bits > 0: + return ((1.0 - np.power(2.0, -unsigned_bits)) * + np.power(2.0, self.integer)) + else: + return 1.0 + + def min(self): + """Get minimum value that quantized_bits class can represent.""" + if not self.keep_negative: + return 0.0 + + unsigned_bits = self.bits - self.keep_negative + + if unsigned_bits > 0: + if self.symmetric: + return -( + (1.0 - np.power(2.0, -unsigned_bits)) * np.power(2.0, self.integer)) + else: + return -np.power(2.0, self.integer) + else: + return -1.0 @classmethod def from_config(cls, config): @@ -253,16 +414,12 @@ def from_config(cls, config): def get_config(self): config = { - "bits": - self.bits, - "integer": - self.integer, - "symmetric": - self.symmetric, - "keep_negative": - self.keep_negative, - "use_stochastic_rounding": - self.use_stochastic_rounding + "bits": self.bits, + "integer": self.integer, + "symmetric": self.symmetric, + "alpha": self.alpha, + "keep_negative": self.keep_negative, + "use_stochastic_rounding": self.use_stochastic_rounding } return config @@ -286,21 +443,88 @@ class bernoulli(object): # pylint: disable=invalid-name Remember that E[dL/dy] = E[dL/dx] once we add stochastic sampling. Attributes: - alpha: allows one to specify multiplicative factor for number generation. + alpha: allows one to specify multiplicative factor for number generation + of "auto" or "auto_po2". + temperature: amplifier factor for sigmoid function, making stochastic + less stochastic as it moves away from 0. + use_real_sigmoid: use real sigmoid for probability. Returns: Computation of round with stochastic sampling with straight through gradient. """ - def __init__(self, alpha=1.0): + def __init__(self, alpha=None, temperature=6.0, use_real_sigmoid=True): self.alpha = alpha + self.bits = 1 + self.temperature = temperature + self.use_real_sigmoid = use_real_sigmoid + self.default_alpha = 1.0 + self.scale = None + + def __str__(self): + flags = [] + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.temperature != 6.0: + flags.append("temperature=" + str(self.temperature)) + if not self.use_real_sigmoid: + flags.append("use_real_sigmoid=" + str(int(self.use_real_sigmoid))) + return "bernoulli(" + ",".join(flags) + ")" def __call__(self, x): - p = _sigmoid(x / self.alpha) - k_sign = tf.sign(p - tf.random.uniform(tf.shape(p))) - k_sign += (1.0 - tf.abs(k_sign)) - return x + tf.stop_gradient(-x + self.alpha * (k_sign + 1.0) / 2.0) + if isinstance(self.alpha, six.string_types): + assert self.alpha in ["auto", "auto_po2"] + + if isinstance(self.alpha, six.string_types): + len_axis = len(x.shape) + + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + else: + axis = [0] + + std = K.std(x, axis=axis, keepdims=True) + K.epsilon() + else: + std = 1.0 + + if self.use_real_sigmoid: + p = tf.keras.backend.sigmoid(self.temperature * x / std) + else: + p = _sigmoid(self.temperature * x/std) + r = tf.random.uniform(tf.shape(x)) + q = tf.sign(p - r) + q += (1.0 - tf.abs(q)) + q = (q + 1.0) / 2.0 + + q_non_stochastic = tf.sign(x) + q_non_stochastic += (1.0 - tf.abs(q_non_stochastic)) + q_non_stochastic = (q_non_stochastic + 1.0) / 2.0 + + # if we use non stochastic binary to compute alpha, + # this function seems to behave better + scale = _get_scale(self.alpha, x, q_non_stochastic) + self.scale = scale + return x + tf.stop_gradient(-x + scale * q) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + + def max(self): + """Get the maximum value bernoulli class can represent.""" + if self.alpha is None or (isinstance(self.alpha, six.string_types) and + "auto" in self.alpha): + return 1.0 + else: + return self.alpha + + def min(self): + """Get the minimum value bernoulli class can represent.""" + return 0.0 @classmethod def from_config(cls, config): @@ -320,54 +544,114 @@ class stochastic_ternary(object): # pylint: disable=invalid-name Attributes: x: tensor to perform sign opertion with stochastic sampling. bits: number of bits to perform quantization. - alpha: ternary is -alpha or +alpha.` + alpha: ternary is -alpha or +alpha, or "auto" or "auto_po2". threshold: (1-threshold) specifies the spread of the +1 and -1 values. + temperature: amplifier factor for sigmoid function, making stochastic + less stochastic as it moves away from 0. + use_real_sigmoid: use real sigmoid for probability. + number_of_unrolls: number of times we iterate between scale and threshold. Returns: Computation of sign with stochastic sampling with straight through gradient. """ - def __init__(self, alpha=1.0, threshold=0.25): + def __init__(self, alpha=None, threshold=None, temperature=8.0, + use_real_sigmoid=True, number_of_unrolls=5): self.bits = 2 self.alpha = alpha self.threshold = threshold assert threshold != 1.0 + self.default_alpha = 1.0 + self.default_threshold = 0.33 + self.temperature = temperature + self.use_real_sigmoid = use_real_sigmoid + self.number_of_unrolls = number_of_unrolls + self.scale = None + + def __str__(self): + flags = [] + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.threshold is not None: + flags.append("threshold=" + str(self.threshold)) + if self.temperature != 8.0: + flags.append("temperature=" + str(self.temperature)) + if not self.use_real_sigmoid: + flags.append("use_real_sigmoid=0") + if self.number_of_unrolls != 5: + flags.append("number_of_unrolls=" + str(self.number_of_unrolls)) + return "stochastic_ternary(" + ",".join(flags) + ")" def __call__(self, x): - # we right now use the following distributions for fm1, f0, fp1 - # - # fm1 = ((1-T)-p)/(1-T) for p <= (1-T) - # f0 = 2*p/clip(0.5+T,0.5,1.0) for p <= 0.5 - # 2*(1-p)/clip(0.5+T,0.5,1.0) for p > 0.5 - # fp1 = (p-T)/(1-T) for p >= T - # - # threshold (1-T) determines the spread of -1 and +1 - # for T < 0.5 we need to fix the distribution of f0 - # to make it bigger when compared to the other - # distributions. - - p = _sigmoid(x / self.alpha) # pylint: disable=invalid-name - - T = self.threshold # pylint: disable=invalid-name - - ones = tf.ones_like(p) - zeros = tf.zeros_like(p) + # right now we only accept stochastic_ternary in parameters + + assert isinstance(self.alpha, six.string_types) + assert self.alpha in ["auto", "auto_po2"] + if self.alpha is None: + scale = self.default_alpha + elif isinstance(self.alpha, six.string_types): + scale = 1.0 + else: + scale = float(self.alpha) - T0 = np.clip(0.5 + T, 0.5, 1.0) # pylint: disable=invalid-name + len_axis = len(x.shape) + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis-1) + else: + axis = range(1, len_axis) + else: + axis = [0] - fm1 = tf.where(p <= (1 - T), ((1 - T) - p) / (1 - T), zeros) - f0 = tf.where(p <= 0.5, 2 * p, 2 * (1 - p)) / T0 - fp1 = tf.where(p <= T, zeros, (p - T) / (1 - T)) + x_mean = K.mean(x, axis=axis, keepdims=True) + x_std = K.std(x, axis=axis, keepdims=True) - f_all = fm1 + f0 + fp1 + m = K.max(tf.abs(x), axis=axis, keepdims=True) + scale = 2.*m/3. + for _ in range(self.number_of_unrolls): + T = scale / 2.0 + q_ns = K.cast(tf.abs(x) >= T, K.floatx()) * K.sign(x) + scale = _get_scale(self.alpha, x, q_ns) - c_fm1 = fm1 / f_all - c_f0 = (fm1 + f0) / f_all + x_norm = (x - x_mean) / x_std + T = scale / (2.0 * x_std) - r = tf.random.uniform(tf.shape(p)) + if self.use_real_sigmoid: + p0 = tf.keras.backend.sigmoid(self.temperature * (x_norm - T)) + p1 = tf.keras.backend.sigmoid(self.temperature * (x_norm + T)) + else: + p0 = _sigmoid(self.temperature * (x_norm - T)) + p1 = _sigmoid(self.temperature * (x_norm + T)) + r0 = tf.random.uniform(tf.shape(p0)) + r1 = tf.random.uniform(tf.shape(p1)) + q0 = tf.sign(p0 - r0) + q0 += (1.0 - tf.abs(q0)) + q1 = tf.sign(p1 - r1) + q1 += (1.0 - tf.abs(q1)) + + q = (q0 + q1) / 2.0 + self.scale = scale + return x + tf.stop_gradient(-x + scale * q) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + if self.threshold is None: + self.threshold = self.default_threshold + + def max(self): + """Get the maximum value that stochastic_ternary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return 1.0 + else: + return self.alpha - return x + tf.stop_gradient(-x + self.alpha * tf.where( - r <= c_fm1, -1 * ones, tf.where(r <= c_f0, zeros, ones))) + def min(self): + """Get the minimum value that stochastic_ternary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return -1.0 + else: + return -self.alpha @classmethod def from_config(cls, config): @@ -375,10 +659,11 @@ def from_config(cls, config): def get_config(self): config = { - "alpha": - self.alpha, - "threshold": - self.threshold, + "alpha": self.alpha, + "threshold": self.threshold, + "temperature": self.temperature, + "use_real_sigmoid": self.use_real_sigmoid, + "number_of_unrolls": self.number_of_unrolls } return config @@ -386,30 +671,123 @@ def get_config(self): class ternary(object): # pylint: disable=invalid-name """Computes an activation function returning -alpha, 0 or +alpha. + Right now we assume two type of behavior. For parameters, we should + have alpha, threshold and stochastic rounding on. For activations, + alpha and threshold should be floating point numbers, and stochastic + rounding should be off. + Attributes: x: tensor to perform sign opertion with stochastic sampling. bits: number of bits to perform quantization. - alpha: ternary is -alpha or +alpha. Threshold is also scaled by alpha. - threshold: threshold to apply "dropout" or dead band (0 value). + alpha: ternary is -alpha or +alpha. Alpha can be "auto" or "auto_po2". + threshold: threshold to apply "dropout" or dead band (0 value). If "auto" + is specified, we will compute it per output layer. use_stochastic_rounding: if true, we perform stochastic rounding. Returns: Computation of sign within the threshold. """ - def __init__(self, alpha=1.0, threshold=0.33, use_stochastic_rounding=False): + def __init__(self, alpha=None, threshold=None, use_stochastic_rounding=False, + number_of_unrolls=5): self.alpha = alpha self.bits = 2 self.threshold = threshold self.use_stochastic_rounding = use_stochastic_rounding + self.default_alpha = 1.0 + self.default_threshold = 0.33 + self.number_of_unrolls = number_of_unrolls + self.scale = None + + def __str__(self): + flags = [] + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.threshold is not None: + flags.append("threshold=" + str(self.threshold)) + if self.use_stochastic_rounding: + flags.append( + "use_stochastic_rounding=" + str(int(self.use_stochastic_rounding))) + if self.number_of_unrolls != 5: + flags.append( + "number_of_unrolls=" + str(int(self.number_of_unrolls))) + return "ternary(" + ",".join(flags) + ")" def __call__(self, x): - if self.use_stochastic_rounding: - x = _round_through( - x, use_stochastic_rounding=self.use_stochastic_rounding) - return x + tf.stop_gradient( - -x + self.alpha * tf.where(tf.abs(x) < self.threshold, - tf.zeros_like(x), tf.sign(x))) + if isinstance(self.alpha, six.string_types): + # parameters + assert self.alpha in ["auto", "auto_po2"] + assert self.threshold is None + else: + # activations + assert not self.use_stochastic_rounding + assert not isinstance(self.threshold, six.string_types) + + if self.alpha is None or isinstance(self.alpha, six.string_types): + scale = 1.0 + elif isinstance(self.alpha, np.ndarray): + scale = self.alpha + else: + scale = float(self.alpha) + + # This is an approximiation from https://arxiv.org/abs/1605.04711 + # We consider channels_last only for now. + if isinstance(self.alpha, six.string_types): + # It is for parameters + # first, compute which asix corresponds to the channels. + # TODO(hzhuang): support channels_first + len_axis = len(x.shape.as_list()) + if len_axis == 1: + axis = None + elif K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + + # This approximation is exact if x ~ U[-m, m]. For x ~ N(0, m) + # we need to iterate a few times before we can coverge + m = K.max(tf.abs(x), axis=axis, keepdims=True) + scale = 2 * m / 3.0 + x_orig = x + for _ in range(self.number_of_unrolls): + thres = scale / 2.0 + if self.use_stochastic_rounding: + # once we scale the number precision == 0.33 works + # well for Uniform and Normal distribution of input + x = scale * _round_through( + x_orig / scale, + use_stochastic_rounding=self.use_stochastic_rounding, + precision=0.33) + q = K.cast(tf.abs(x) >= thres, K.floatx()) * tf.sign(x) + scale = _get_scale(self.alpha, x, q) + else: + if self.threshold is None: + thres = self.default_threshold + else: + thres = self.threshold + q = K.cast(tf.abs(x) >= thres, K.floatx()) * tf.sign(x) + + self.scale = scale + return x + tf.stop_gradient(-x + scale * q) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.use_stochastic_rounding = True + + def max(self): + """Get the maximum value that ternary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return 1.0 + else: + return self.alpha + + def min(self): + """Get the minimum value that ternary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return -1.0 + else: + return -self.alpha @classmethod def from_config(cls, config): @@ -417,12 +795,10 @@ def from_config(cls, config): def get_config(self): config = { - "alpha": - self.alpha, - "threshold": - self.threshold, - "use_stochastic_rounding": - self.use_stochastic_rounding + "alpha": self.alpha, + "threshold": self.threshold, + "use_stochastic_rounding": self.use_stochastic_rounding, + "number_of_unrolls": self.number_of_unrolls } return config @@ -435,32 +811,93 @@ class stochastic_binary(object): # pylint: disable=invalid-name Attributes: x: tensor to perform sign opertion with stochastic sampling. - alpha: binary is -alpha or +alpha.` + alpha: binary is -alpha or +alpha, or "auto" or "auto_po2". bits: number of bits to perform quantization. + temperature: amplifier factor for sigmoid function, making stochastic + behavior less stochastic as it moves away from 0. + use_real_sigmoid: use real sigmoid from tensorflow for probablity. Returns: Computation of sign with stochastic sampling with straight through gradient. """ - def __init__(self, alpha=1.0): + def __init__(self, alpha=None, temperature=6.0, use_real_sigmoid=True): self.alpha = alpha self.bits = 1 + self.temperature = temperature + self.use_real_sigmoid = use_real_sigmoid + self.default_alpha = 1.0 + self.scale = None + + def __str__(self): + flags = [] + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.temperature != 6.0: + flags.append("temperature=" + str(self.temperature)) + if not self.use_real_sigmoid: + flags.append("use_real_sigmoid=" + str(int(self.use_real_sigmoid))) + return "stochastic_binary(" + ",".join(flags) + ")" def __call__(self, x): - assert self.alpha != 0 - p = _sigmoid(x / self.alpha) - k_sign = tf.sign(p - tf.random.uniform(tf.shape(x))) - # we should not need this, but if tf.sign is not safe if input is - # exactly 0.0 - k_sign += (1.0 - tf.abs(k_sign)) - return x + tf.stop_gradient(-x + self.alpha * k_sign) + if isinstance(self.alpha, six.string_types): + assert self.alpha in ["auto", "auto_po2"] + if isinstance(self.alpha, six.string_types): + len_axis = len(x.shape) + if len_axis > 1: + if K.image_data_format() == "channels_last": + axis = range(len_axis-1) + else: + axis = range(1, len_axis) + else: + axis = [0] + std = K.std(x, axis=axis, keepdims=True) + K.epsilon() + else: + std = 1.0 + + if self.use_real_sigmoid: + p = tf.keras.backend.sigmoid(self.temperature * x / std) + else: + p = _sigmoid(self.temperature * x / std) + + r = tf.random.uniform(tf.shape(x)) + q = tf.sign(p - r) + q += (1.0 - tf.abs(q)) + q_non_stochastic = tf.sign(x) + q_non_stochastic += (1.0 - tf.abs(q_non_stochastic)) + scale = _get_scale(self.alpha, x, q_non_stochastic) + self.scale = scale + return x + tf.stop_gradient(-x + scale * q) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.use_stochastic_rounding = True + + def max(self): + """Get the maximum value that stochastic_binary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return 1.0 + else: + return self.alpha + + def min(self): + """Get the minimum value that stochastic_binary can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return -1.0 + else: + return -self.alpha @classmethod def from_config(cls, config): return cls(**config) def get_config(self): - config = {"alpha": self.alpha} + config = { + "alpha": self.alpha, + "temperature": self.temperature, + "use_real_sigmoid": self.use_real_sigmoid, + } return config @@ -476,34 +913,100 @@ class binary(object): # pylint: disable=invalid-name x: tensor to perform sign_through. bits: number of bits to perform quantization. use_01: if True, return {0,1} instead of {-1,+1}. - alpha: binary is -alpha or +alpha. + alpha: binary is -alpha or +alpha, or "auto", "auto_po2" to compute + automatically. use_stochastic_rounding: if true, we perform stochastic rounding. Returns: Computation of sign operation with straight through gradient. """ - def __init__(self, use_01=False, alpha=1.0, use_stochastic_rounding=False): + def __init__(self, use_01=False, alpha=None, use_stochastic_rounding=False): self.use_01 = use_01 self.bits = 1 self.alpha = alpha self.use_stochastic_rounding = use_stochastic_rounding + self.default_alpha = 1.0 + self.scale = None + + def __str__(self): + flags = [] + if self.use_01: + flags.append("use_01=" + str(int(self.use_01))) + if self.alpha is not None: + flags.append("alpha=" + str(self.alpha)) + if self.use_stochastic_rounding: + flags.append( + "use_stochastic_rounding=" + str(self.use_stochastic_rounding)) + return "binary(" + ",".join(flags) + ")" def __call__(self, x): - assert self.alpha != 0 + if isinstance(self.alpha, six.string_types): + assert self.alpha in ["auto", "auto_po2"] + if self.alpha is None: + scale = self.default_alpha + elif isinstance(self.alpha, six.string_types): + scale = 1.0 + elif isinstance(self.alpha, np.ndarray): + scale = self.alpha + else: + scale = float(self.alpha) + if self.use_stochastic_rounding: - x = self.alpha * _round_through( - x / self.alpha, use_stochastic_rounding=self.use_stochastic_rounding) + len_axis = len(x.shape.as_list()) + if len_axis == 1: + axis = None + elif K.image_data_format() == "channels_last": + axis = range(len_axis - 1) + else: + axis = range(1, len_axis) + + # if stochastic_round is through, we need to scale + # number so that the precision is small enough. + # This is especially important if range of x is very + # small, which occurs during initialization of weights. + m = K.max(tf.abs(x), axis=axis, keepdims=True) + m = tf.where(m > 1.0, tf.ones_like(m), m) + f = 2 * m + + x = f * _round_through( + x / f, use_stochastic_rounding=self.use_stochastic_rounding, + precision=0.125 + ) k_sign = tf.sign(x) if self.use_stochastic_rounding: k_sign += (1.0 - tf.abs(k_sign)) * ( 2.0 * tf.round(tf.random.uniform(tf.shape(x))) - 1.0) - else: - k_sign += (1.0 - tf.abs(k_sign)) + # if something still remains, just make it positive for now. + k_sign += (1.0 - tf.abs(k_sign)) if self.use_01: k_sign = (k_sign + 1.0) / 2.0 - return x + tf.stop_gradient(-x + self.alpha * k_sign) + + scale = _get_scale(self.alpha, x, k_sign) + self.scale = scale + return x + tf.stop_gradient(-x + scale * k_sign) + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.use_stochastic_rounding = True + + def max(self): + """Get maximum value that binary class can respresent.""" + if self.alpha is None or isinstance(self.alpha, six.string_types): + return 1.0 + else: + return self.alpha + + def min(self): + """Get minimum value that binary class can respresent.""" + if self.use_01: + return 0.0 + elif self.alpha is None or isinstance(self.alpha, six.string_types): + return -1.0 + else: + return -self.alpha @classmethod def from_config(cls, config): @@ -511,12 +1014,9 @@ def from_config(cls, config): def get_config(self): config = { - "use_01": - self.use_01, - "alpha": - self.alpha, - "use_stochastic_rounding": - self.use_stochastic_rounding + "use_01": self.use_01, + "alpha": self.alpha, + "use_stochastic_rounding": self.use_stochastic_rounding } return config @@ -557,6 +1057,14 @@ def __init__(self, bits=8, integer=0, use_sigmoid=0, self.use_sigmoid = use_sigmoid self.use_stochastic_rounding = use_stochastic_rounding + def __str__(self): + flags = [str(self.bits), str(self.integer)] + if self.use_sigmoid or self.use_stochastic_rounding: + flags.append(str(int(self.use_sigmoid))) + if self.use_stochastic_rounding: + flags.append(str(int(self.use_stochastic_rounding))) + return "quantized_relu(" + ",".join(flags) + ")" + def __call__(self, x): m = pow(2, self.bits) m_i = pow(2, self.integer) @@ -573,20 +1081,33 @@ def __call__(self, x): 0.0, 1.0 - 1.0 / m) return xq + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_relu can represent.""" + unsigned_bits = self.bits + + if unsigned_bits > 0: + return ((1.0 - np.power(2.0, -unsigned_bits)) * + np.power(2.0, self.integer)) + else: + return 1.0 + + def min(self): + """Get the minimum value that quantized_relu can represent.""" + return 0.0 + @classmethod def from_config(cls, config): return cls(**config) def get_config(self): config = { - "bits": - self.bits, - "integer": - self.integer, - "use_sigmoid": - self.use_sigmoid, - "use_stochastic_rounding": - self.use_stochastic_rounding + "bits": self.bits, + "integer": self.integer, + "use_sigmoid": self.use_sigmoid, + "use_stochastic_rounding": self.use_stochastic_rounding } return config @@ -611,6 +1132,14 @@ def __init__(self, bits=8, integer=0, symmetric=0, u=255.0): self.symmetric = symmetric self.u = u + def __str__(self): + flags = [str(self.bits), str(self.integer)] + if self.symmetric or self.u != 255.0: + flags.append(str(int(self.symmetric))) + if self.u != 255.0: + flags.append(str(self.u)) + return "quantized_ulaw(" + ",".join(flags) + ")" + def __call__(self, x): non_sign_bits = self.bits - 1 m = pow(2, non_sign_bits) @@ -623,20 +1152,42 @@ def __call__(self, x): (1.0 * self.symmetric) / m, 1.0 - 1.0 / m) return xq + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_ulaw can represent.""" + unsigned_bits = self.bits - 1 + + if unsigned_bits > 0: + return ((1.0 - np.power(2.0, -unsigned_bits)) * + np.power(2.0, self.integer)) + else: + return 1.0 + + def min(self): + """Get the minimum value that quantized_ulaw can represent.""" + unsigned_bits = self.bits - 1 + + if unsigned_bits > 0: + if self.symmetric: + return -(1.0 - np.power(2.0, -unsigned_bits)) * np.power( + 2.0, self.integer) + else: + return -np.power(2.0, self.integer) + else: + return -1.0 + @classmethod def from_config(cls, config): return cls(**config) def get_config(self): config = { - "bits": - self.bits, - "integer": - self.integer, - "symmetric": - self.symmetric, - "u": - self.u + "bits": self.bits, + "integer": self.integer, + "symmetric": self.symmetric, + "u": self.u } return config @@ -666,6 +1217,14 @@ def __init__(self, bits=8, integer=0, symmetric=0, self.symmetric = symmetric self.use_stochastic_rounding = use_stochastic_rounding + def __str__(self): + flags = [str(self.bits), str(self.integer)] + if self.symmetric or self.use_stochastic_rounding: + flags.append(str(int(self.symmetric))) + if self.use_stochastic_rounding: + flags.append(str(int(self.use_stochastic_rounding))) + return "quantized_tanh(" + ",".join(flags) + ")" + def __call__(self, x): non_sign_bits = self.bits - 1 m = pow(2, non_sign_bits) @@ -677,6 +1236,33 @@ def __call__(self, x): (1.0 * self.symmetric) / m, 1.0 - 1.0 / m) return xq + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_tanh can represent.""" + unsigned_bits = self.bits - 1 + if unsigned_bits > 0: + return ((1.0 - np.power(2.0, -unsigned_bits)) * + np.power(2.0, self.integer)) + else: + return 1.0 + + def min(self): + """Get the minimum value that quantized_tanh can represent.""" + if not self.keep_negative: + return 0.0 + + unsigned_bits = self.bits - 1 + if unsigned_bits > 0: + if self.symmetric: + return -(1.0 - np.power(2.0, -unsigned_bits)) * np.power( + 2.0, self.integer) + else: + return -np.power(2.0, self.integer) + else: + return -1.0 + @classmethod def from_config(cls, config): return cls(**config) @@ -694,6 +1280,7 @@ def get_config(self): def _clip_power_of_two(x_abs, min_exp, max_exp, + max_value, quadratic_approximation=False, use_stochastic_rounding=False): """Clips a tensor using power-of-two quantizer. @@ -703,6 +1290,7 @@ def _clip_power_of_two(x_abs, x_abs: A tensor object. Its elements should be non-negative. min_exp: An integer representing the smallest exponent. max_exp: An integer representing the largest exponent. + max_value: A float or None. If it is None, we clip the value to max_value. quadratic_approximation: An boolean representing whether the quadratic approximation is applied. use_stochastic_rounding: An boolean representing whether the stochastic @@ -712,11 +1300,18 @@ def _clip_power_of_two(x_abs, A tensor object, the values are clipped by min_exp and max_exp. """ - eps = tf.keras.backend.epsilon() - log2 = np.log(2.0) # if quadratic_approximation is True, round to the exponent for sqrt(x), # so that the return value can be divided by two without remainder. - x_eps_pad = tf.where(x_abs < eps, eps, x_abs) + log2 = np.log(2.0) + + # When the elements of x_abs are small than the keras epsilon, + # we just overwrite x_abs with eps + eps = tf.keras.backend.epsilon() + x_filter = tf.where(x_abs < eps, eps, x_abs) + if max_value is not None: + # If the elements of x_filter has value larger than x_value, clip it. + x_filter = tf.where(x_filter >= max_value, + tf.ones_like(x_filter) * max_value, x_filter) def power_of_two_clip(x_abs, min_exp, max_exp, quadratic_approximation, use_stochastic_rounding): @@ -740,7 +1335,7 @@ def power_of_two_clip(x_abs, min_exp, max_exp, quadratic_approximation, x_clipped = tf.where( x_abs < eps, tf.ones_like(x_abs) * min_exp, - power_of_two_clip(x_eps_pad, min_exp, max_exp, quadratic_approximation, + power_of_two_clip(x_filter, min_exp, max_exp, quadratic_approximation, use_stochastic_rounding)) return x_clipped @@ -791,14 +1386,12 @@ def _get_min_max_exponents(non_sign_bits, need_exponent_sign_bit, """ effect_bits = non_sign_bits - need_exponent_sign_bit + min_exp = -2**(effect_bits) if quadratic_approximation: - if effect_bits % 2: - effect_bits = effect_bits - 1 - min_exp = -2**(effect_bits) max_exp = 2**(effect_bits) else: - min_exp = -2**(effect_bits) max_exp = 2**(effect_bits) - 1 + return min_exp, max_exp @@ -808,8 +1401,8 @@ class quantized_po2(object): # pylint: disable=invalid-name Attributes: bits: An integer, the bits allocated for the exponent, its sign and the sign of x. - max_value: default is None, or a non-negative value to put a constraint for - the max value. + max_value: An float or None. If None, no max_value is specified. + Otherwise, the maximum value of quantized_po2 <= max_value use_stochastic_rounding: A boolean, default is False, if True, it uses stochastic rounding and forces the mean of x to be x statstically. quadratic_approximation: A boolean, default is False if True, it forces the @@ -824,44 +1417,68 @@ def __init__(self, self.bits = bits self.max_value = max_value self.use_stochastic_rounding = use_stochastic_rounding + # if True, round to the exponent for sqrt(x), + # so that the return value can be divided by two without remainder. self.quadratic_approximation = quadratic_approximation - - def __call__(self, x): need_exponent_sign_bit = _need_exponent_sign_bit_check(self.max_value) non_sign_bits = self.bits - 1 - min_exp, max_exp = _get_min_max_exponents(non_sign_bits, - need_exponent_sign_bit, - self.quadratic_approximation) - eps = tf.keras.backend.epsilon() - if min_exp < np.log2(eps): - warnings.warn( - "QKeras: min_exp in po2 quantizer is smaller than tf.epsilon().") - if self.max_value: - max_exp = np.minimum(max_exp, np.round(np.log2(self.max_value + eps))) + self._min_exp, self._max_exp = _get_min_max_exponents( + non_sign_bits, need_exponent_sign_bit, self.quadratic_approximation) + def __str__(self): + flags = [str(self.bits)] + if self.max_value is not None or self.use_stochastic_rounding: + flags.append(str(int(self.max_value))) + if self.use_stochastic_rounding: + flags.append(str(int(self.use_stochastic_rounding))) + if self.quadratic_approximation: + flags.append( + "quadratic_approximation=" + str(int(self.quadratic_approximation))) + return "quantized_po2(" + ",".join(flags) + ")" + + def __call__(self, x): x_sign = tf.sign(x) x_sign += (1.0 - tf.abs(x_sign)) x_abs = tf.abs(x) - x_clipped = _clip_power_of_two(x_abs, min_exp, max_exp, + x_clipped = _clip_power_of_two(x_abs, self._min_exp, self._max_exp, + self.max_value, self.quadratic_approximation, self.use_stochastic_rounding) return x + tf.stop_gradient(-x + x_sign * pow(2.0, x_clipped)) + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_po2 can represent.""" + return self._max_exp + + def min(self): + """Get the minimum value that quantized_po2 can represent.""" + return self._min_exp + @classmethod def from_config(cls, config): return cls(**config) def get_config(self): - """Gets config.""" + """Gets configugration of the quantizer. + + Returns: + A dict mapping quantization configuration, including + bits: bitwidth for exponents. + max_value: the maximum value of this quantized_po2 can represent. + use_stochastic_rounding: + if True, stochastic rounding is used. + quadratic_approximation: + if True, the exponent is enforced to be even number, which is + the closest one to x. + """ config = { - "bits": - self.bits, - "max_value": - self.max_value, - "use_stochastic_rounding": - self.use_stochastic_rounding, - "quadratic_approximation": - self.quadratic_approximation + "bits": self.bits, + "max_value": self.max_value, + "use_stochastic_rounding": self.use_stochastic_rounding, + "quadratic_approximation": self.quadratic_approximation } return config @@ -890,23 +1507,40 @@ def __init__(self, # if True, round to the exponent for sqrt(x), # so that the return value can be divided by two without remainder. self.quadratic_approximation = quadratic_approximation + need_exponent_sign_bit = _need_exponent_sign_bit_check(self.max_value) + self._min_exp = -2**(self.bits - need_exponent_sign_bit) + self._max_exp = 2**(self.bits - need_exponent_sign_bit) - 1 + + def __str__(self): + flags = [str(self.bits)] + if self.max_value is not None or self.use_stochastic_rounding: + flags.append(str(int(self.max_value))) + if self.use_stochastic_rounding: + flags.append(str(int(self.use_stochastic_rounding))) + if self.quadratic_approximation: + flags.append( + "quadratic_approximation=" + str(int(self.quadratic_approximation))) + return "quantized_relu_po2(" + ",".join(flags) + ")" def __call__(self, x): - need_exponent_sign_bit = _need_exponent_sign_bit_check(self.max_value) - min_exp = -2**(self.bits - need_exponent_sign_bit) - max_exp = 2**(self.bits - need_exponent_sign_bit) - 1 - eps = tf.keras.backend.epsilon() - if min_exp < np.log2(eps): - warnings.warn("QKeras: min_exp in quantized_relu_po2 quantizer " - "is smaller than tf.epsilon().") - if self.max_value: - max_exp = np.minimum(max_exp, np.round(np.log2(self.max_value + eps))) x = tf.maximum(x, 0) - x_clipped = _clip_power_of_two(x, min_exp, max_exp, + x_clipped = _clip_power_of_two(x, self._min_exp, self._max_exp, + self.max_value, self.quadratic_approximation, self.use_stochastic_rounding) return x + tf.stop_gradient(-x + pow(2.0, x_clipped)) + def _set_trainable_parameter(self): + pass + + def max(self): + """Get the maximum value that quantized_relu_po2 can represent.""" + return self._max_exp + + def min(self): + """Get the minimum value that quantized_relu_po2 can represent.""" + return self._min_exp + @classmethod def from_config(cls, config): return cls(**config) diff --git a/qkeras/safe_eval.py b/qkeras/safe_eval.py index 9a964652..da4042cd 100644 --- a/qkeras/safe_eval.py +++ b/qkeras/safe_eval.py @@ -24,6 +24,9 @@ from pyparsing import Regex from pyparsing import Suppress +import logging +from tensorflow import keras + def Num(s): """Tries to convert string to either int or float.""" @@ -68,7 +71,7 @@ def safe_eval(eval_str, op_dict, *params, **kwparams): # pylint: disable=invali """Replaces eval by a safe eval mechanism.""" function_split = eval_str.split("(") - quantizer = op_dict[function_split[0]] + quantizer = op_dict.get(function_split[0], None) if len(function_split) == 2: args, kwargs = GetParams("(" + function_split[1]) @@ -80,6 +83,11 @@ def safe_eval(eval_str, op_dict, *params, **kwparams): # pylint: disable=invali for k in kwparams: kwargs[k] = kwparams[k] + # must be Keras activation object if None + if quantizer is None: + logging.info("keras dict %s", function_split[0]) + quantizer = keras.activations.get(function_split[0]) + if len(function_split) == 2 or args or kwargs: return quantizer(*args, **kwargs) else: diff --git a/qkeras/utils.py b/qkeras/utils.py index 3596c29d..20d7c9f6 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -13,14 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import copy import json import six +import types import tensorflow as tf import tensorflow.keras.backend as K from tensorflow.keras import initializers + +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Conv1D from tensorflow.keras.models import model_from_json from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper @@ -28,6 +35,8 @@ from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer import numpy as np +import matplotlib.pyplot as plt + from .qlayers import QActivation from .qlayers import QDense @@ -35,26 +44,25 @@ from .qconvolutional import QConv1D from .qconvolutional import QConv2D from .qconvolutional import QDepthwiseConv2D +from .qnormalization import QBatchNormalization from .qpooling import QAveragePooling2D -from .quantizers import quantized_bits -from .quantizers import bernoulli -from .quantizers import stochastic_ternary -from .quantizers import ternary -from .quantizers import stochastic_binary from .quantizers import binary +from .quantizers import bernoulli +from .quantizers import get_weight_scale +from .quantizers import quantized_bits from .quantizers import quantized_relu from .quantizers import quantized_ulaw from .quantizers import quantized_tanh from .quantizers import quantized_po2 from .quantizers import quantized_relu_po2 -from .qnormalization import QBatchNormalization - +from .quantizers import stochastic_binary +from .quantizers import stochastic_ternary +from .quantizers import ternary from .safe_eval import safe_eval # # Model utilities: before saving the weights, we want to apply the quantizers # - def model_save_quantized_weights(model, filename=None): """Quantizes model for inference and save it. @@ -144,25 +152,39 @@ def model_save_quantized_weights(model, filename=None): return saved_weights -def quantize_activation(layer_config, custom_objects, activation_bits): +def quantize_activation(layer_config, activation_bits): """Replaces activation by quantized activation functions.""" - str_act_bits = str(activation_bits) - # relu -> quantized_relu(bits) # tanh -> quantized_tanh(bits) # # more to come later - - if layer_config["activation"] == "relu": + if layer_config.get("activation", None) is None: + return + if isinstance(layer_config["activation"], six.string_types): + a_name = layer_config["activation"] + elif isinstance(layer_config["activation"], types.FunctionType): + a_name = layer_config["activation"].__name__ + else: + a_name = layer_config["activation"].__class__.__name__ + + if a_name == "linear": + return + if a_name == "relu": layer_config["activation"] = "quantized_relu(" + str_act_bits + ")" - custom_objects["quantized_relu(" + str_act_bits + ")"] = ( - quantized_relu(activation_bits)) - - elif layer_config["activation"] == "tanh": + elif a_name == "tanh": layer_config["activation"] = "quantized_tanh(" + str_act_bits + ")" - custom_objects["quantized_tanh(" + str_act_bits + ")"] = ( - quantized_tanh(activation_bits)) + + +def get_config(quantizer_config, layer, layer_class, parameter=None): + """Returns search of quantizer on quantizer_config.""" + quantizer = quantizer_config.get(layer["name"], + quantizer_config.get(layer_class, None)) + + if quantizer is not None and parameter is not None: + quantizer = quantizer.get(parameter, None) + + return quantizer def model_quantize(model, @@ -239,32 +261,9 @@ def model_quantize(model, # let's make a deep copy to make sure our objects are not shared elsewhere jm = copy.deepcopy(json.loads(model.to_json())) custom_objects = copy.deepcopy(custom_objects) - config = jm["config"] - layers = config["layers"] - custom_objects["QDense"] = QDense - custom_objects["QConv1D"] = QConv1D - custom_objects["QConv2D"] = QConv2D - custom_objects["QDepthwiseConv2D"] = QDepthwiseConv2D - custom_objects["QAveragePooling2D"] = QAveragePooling2D - custom_objects["QActivation"] = QActivation - - # just add all the objects from quantizer_config to - # custom_objects first. - - for layer_name in quantizer_config.keys(): - - if isinstance(quantizer_config[layer_name], six.string_types): - name = quantizer_config[layer_name] - custom_objects[name] = safe_eval(name, globals()) - - else: - for name in quantizer_config[layer_name].keys(): - custom_objects[quantizer_config[layer_name][name]] = ( - safe_eval(quantizer_config[layer_name][name], globals())) - for layer in layers: layer_config = layer["config"] @@ -273,92 +272,69 @@ def model_quantize(model, if layer["class_name"] == "Dense": layer["class_name"] = "QDense" - # needs to add kernel/bias quantizers - - kernel_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDense", - None))["kernel_quantizer"] - - bias_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDense", None))["bias_quantizer"] - + kernel_quantizer = get_config( + quantizer_config, layer, "QDense", "kernel_quantizer") + bias_quantizer = get_config( + quantizer_config, layer, "QDense", "bias_quantizer") layer_config["kernel_quantizer"] = kernel_quantizer layer_config["bias_quantizer"] = bias_quantizer # if activation is present, add activation here - - quantizer = quantizer_config.get(layer["name"], - quantizer_config.get( - "QDense", None)).get( - "activation_quantizer", None) + quantizer = get_config( + quantizer_config, layer, "QDense", "activation_quantizer") if quantizer: layer_config["activation"] = quantizer custom_objects[quantizer] = safe_eval(quantizer, globals()) else: - quantize_activation(layer_config, custom_objects, activation_bits) - - elif layer["class_name"] == "Conv2D": - layer["class_name"] = "QConv2D" + quantize_activation(layer_config, activation_bits) + elif layer["class_name"] in ["Conv1D", "Conv2D"]: + q_name = "Q" + layer["class_name"] + layer["class_name"] = q_name # needs to add kernel/bias quantizers + kernel_quantizer = get_config( + quantizer_config, layer, q_name, "kernel_quantizer") - kernel_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QConv2D", None)).get( - "kernel_quantizer", None) - - bias_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QConv2D", None)).get( - "bias_quantizer", None) - + bias_quantizer = get_config( + quantizer_config, layer, q_name, "bias_quantizer") layer_config["kernel_quantizer"] = kernel_quantizer layer_config["bias_quantizer"] = bias_quantizer # if activation is present, add activation here - - quantizer = quantizer_config.get(layer["name"], - quantizer_config.get( - "QConv2D", None)).get( - "activation_quantizer", None) + quantizer = get_config( + quantizer_config, layer, q_name, "activation_quantizer") if quantizer: layer_config["activation"] = quantizer - custom_objects[quantizer] = safe_eval(quantizer, globals()) else: - quantize_activation(layer_config, custom_objects, activation_bits) + quantize_activation(layer_config, activation_bits) elif layer["class_name"] == "DepthwiseConv2D": layer["class_name"] = "QDepthwiseConv2D" - # needs to add kernel/bias quantizers - - depthwise_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDepthwiseConv2D", None)).get( - "depthwise_quantizer", None) - - bias_quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDepthwiseConv2D", None)).get( - "bias_quantizer", None) + depthwise_quantizer = get_config(quantizer_config, layer, + "QDepthwiseConv2D", "depthwise_quantizer") + bias_quantizer = get_config(quantizer_config, layer, + "QDepthwiseConv2D", "bias_quantizer") layer_config["depthwise_quantizer"] = depthwise_quantizer layer_config["bias_quantizer"] = bias_quantizer - # if activation is present, add activation here - - quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QDepthwiseConv2D", None)).get( - "activation_quantizer", None) + quantizer = get_config(quantizer_config, layer, + "QDepthwiseConv2D", "activation_quantizer",) if quantizer: layer_config["activation"] = quantizer - custom_objects[quantizer] = safe_eval(quantizer, globals()) else: - quantize_activation(layer_config, custom_objects, activation_bits) + quantize_activation(layer_config, activation_bits) elif layer["class_name"] == "Activation": - quantizer = quantizer_config.get( - layer["name"], quantizer_config.get("QActivation", None)) + quantizer = get_config(quantizer_config, layer, "QActivation") + # this is to avoid softmax from quantizing in autoq + if quantizer is None: + continue # if quantizer exists in dictionary related to this name, # use it, otherwise, use normal transformations @@ -372,28 +348,34 @@ def model_quantize(model, quantizer = quantizer[layer_config["activation"]] if quantizer: layer_config["activation"] = quantizer - custom_objects[quantizer] = safe_eval(quantizer, globals()) else: - quantize_activation(layer_config, custom_objects, activation_bits) - - elif layer["class_name"] == "AveragePooling2D": - layer["class_name"] = "QAveragePooling2D" - - quantizer = quantizer_config.get(layer["name"], None) - - # if quantizer exists in dictionary related to this name, - # use it, otherwise, use normal transformations + quantize_activation(layer_config, activation_bits) - if quantizer: - layer_config["activation"] = quantizer - custom_objects[quantizer] = safe_eval(quantizer, globals()) - else: - quantize_activation(layer_config, custom_objects, activation_bits) + elif layer["class_name"] == "BatchNormalization": + layer["class_name"] = "QBatchNormalization" + # needs to add kernel/bias quantizers + gamma_quantizer = get_config( + quantizer_config, layer, "QBatchNormalization", + "gamma_quantizer") + beta_quantizer = get_config( + quantizer_config, layer, "QBatchNormalization", + "beta_quantizer") + mean_quantizer = get_config( + quantizer_config, layer, "QBatchNormalization", + "mean_quantizer") + variance_quantizer = get_config( + quantizer_config, layer, "QBatchNormalization", + "variance_quantizer") + + layer_config["gamma_quantizer"] = gamma_quantizer + layer_config["beta_quantizer"] = beta_quantizer + layer_config["mean_quantizer"] = mean_quantizer + layer_config["variance_quantizer"] = variance_quantizer # we need to keep a dictionary of custom objects as our quantized library # is not recognized by keras. - qmodel = model_from_json(json.dumps(jm), custom_objects=custom_objects) + qmodel = quantized_model_from_json(json.dumps(jm), custom_objects) # if transfer_weights is true, we load the weights from model to qmodel @@ -402,7 +384,8 @@ def model_quantize(model, if layer.get_weights(): qlayer.set_weights(copy.deepcopy(layer.get_weights())) - return qmodel, custom_objects + return qmodel + def _add_supported_quantized_objects(custom_objects): @@ -411,7 +394,6 @@ def _add_supported_quantized_objects(custom_objects): custom_objects["QConv1D"] = QConv1D custom_objects["QConv2D"] = QConv2D custom_objects["QDepthwiseConv2D"] = QDepthwiseConv2D - custom_objects["QAveragePooling2D"] = QAveragePooling2D custom_objects["QActivation"] = QActivation custom_objects["QBatchNormalization"] = QBatchNormalization custom_objects["Clip"] = Clip @@ -441,6 +423,7 @@ def quantized_model_from_json(json_string, custom_objects=None): return qmodel + def load_qmodel(filepath, custom_objects=None, compile=True): """ Load quantized model from Keras's model.save() h5 file. @@ -456,7 +439,7 @@ def load_qmodel(filepath, custom_objects=None, compile=True): considered during deserialization. compile: Boolean, whether to compile the model after loading. - + # Returns A Keras model instance. If an optimizer was found as part of the saved model, the model is already @@ -471,37 +454,90 @@ def load_qmodel(filepath, custom_objects=None, compile=True): # let's make a deep copy to make sure our objects are not shared elsewhere custom_objects = copy.deepcopy(custom_objects) - + _add_supported_quantized_objects(custom_objects) - - qmodel = tf.keras.models.load_model(filepath, custom_objects=custom_objects, compile=compile) - + + qmodel = tf.keras.models.load_model(filepath, custom_objects=custom_objects, + compile=compile) return qmodel def print_model_sparsity(model): - """Prints sparsity for the pruned layers in the model.""" - - def _get_sparsity(weights): - return 1.0 - np.count_nonzero(weights) / float(weights.size) - - print("Model Sparsity Summary ({})".format(model.name)) - print("--") - for layer in model.layers: - if isinstance(layer, pruning_wrapper.PruneLowMagnitude): - prunable_weights = layer.layer.get_prunable_weights() - elif isinstance(layer, prunable_layer.PrunableLayer): - prunable_weights = layer.get_prunable_weights() - elif prune_registry.PruneRegistry.supports(layer): - weight_names = prune_registry.PruneRegistry._weight_names(layer) - prunable_weights = [getattr(layer, weight) for weight in weight_names] - else: - prunable_weights = None - if prunable_weights: - print("{}: {}".format( - layer.name, ", ".join([ - "({}, {})".format(weight.name, - str(_get_sparsity(K.get_value(weight)))) - for weight in prunable_weights - ]))) - print("\n") \ No newline at end of file + """Prints sparsity for the pruned layers in the model.""" + + def _get_sparsity(weights): + return 1.0 - np.count_nonzero(weights) / float(weights.size) + + print("Model Sparsity Summary ({})".format(model.name)) + print("--") + for layer in model.layers: + if isinstance(layer, pruning_wrapper.PruneLowMagnitude): + prunable_weights = layer.layer.get_prunable_weights() + elif isinstance(layer, prunable_layer.PrunableLayer): + prunable_weights = layer.get_prunable_weights() + elif prune_registry.PruneRegistry.supports(layer): + weight_names = prune_registry.PruneRegistry._weight_names(layer) + prunable_weights = [getattr(layer, weight) for weight in weight_names] + else: + prunable_weights = None + if prunable_weights: + print("{}: {}".format( + layer.name, ", ".join([ + "({}, {})".format(weight.name, + str(_get_sparsity(K.get_value(weight)))) + for weight in prunable_weights + ]))) + print("\n") + + +def quantized_model_debug(model, X_test, plot=False): + """Debugs and plots model weights and activations.""" + outputs = [] + output_names = [] + + for layer in model.layers: + if layer.__class__.__name__ in [ + "QActivation", "QBatchNormalization", "Activation", "QDense", + "QConv2D", "QDepthwiseConv2D" + ]: + output_names.append(layer.name) + outputs.append(layer.output) + + model_debug = Model(inputs=model.inputs, outputs=outputs) + + y_pred = model_debug.predict(X_test) + + print("{:30} {: 8.4f} {: 8.4f}".format( + "input", np.min(X_test), np.max(X_test))) + + for n, p in zip(output_names, y_pred): + layer = model.get_layer(n) + print("{:30} {: 8.4f} {: 8.4f}".format(n, np.min(p), np.max(p)), end="") + if plot and layer.__class__.__name__ in ["QConv2D", "QDense", "QActivation"]: + plt.hist(p.flatten(), bins=25) + plt.title(layer.name + "(output)") + plt.show() + alpha = None + for i, weights in enumerate(layer.get_weights()): + if hasattr(layer, "get_quantizers") and layer.get_quantizers()[i]: + weights = K.eval(layer.get_quantizers()[i](K.constant(weights))) + if i == 0 and layer.__class__.__name__ in [ + "QConv1D", "QConv2D", "QDense"]: + alpha = get_weight_scale(layer.get_quantizers()[i], weights) + # if alpha is 0, let's remove all weights. + alpha_mask = (alpha == 0.0) + weights = np.where( + alpha_mask, + weights * alpha, + weights / alpha + ) + if plot: + plt.hist(weights.flatten(), bins=25) + plt.title(layer.name + "(weights)") + plt.show() + print(" ({: 8.4f} {: 8.4f})".format(np.min(weights), np.max(weights)), + end="") + if alpha is not None and isinstance(alpha, np.ndarray): + print(" a({: 10.6f} {: 10.6f})".format( + np.min(alpha), np.max(alpha)), end="") + print("") diff --git a/requirements.txt b/requirements.txt index 538b777c..01c4ceff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ pyasn1<0.5.0,>=0.4.6 requests<3,>=2.21.0 pyparsing pytest>=4.6.9 -tensorflow-model-optimization>=0.2.1 \ No newline at end of file +tensorflow-model-optimization>=0.2.1 +matplotlib>=3.1.2 diff --git a/setup.py b/setup.py index 593600d6..1bf91245 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ setuptools.setup( name="QKeras", - version="0.6.0", + version="0.7.0", author="Claudionor N. Coelho", author_email="nunescoelho@google.com", maintainer="Hao Zhuang", @@ -43,6 +43,7 @@ "scipy>=1.4.1", "pyparser", "setuptools>=41.0.0", + "tensorflow-model-optimization>=0.2.1", ], setup_requires=[ "pytest-runner", diff --git a/tests/automatic_conversion_test.py b/tests/automatic_conversion_test.py new file mode 100644 index 00000000..d1e1c710 --- /dev/null +++ b/tests/automatic_conversion_test.py @@ -0,0 +1,95 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import pytest +from tensorflow.keras.layers import * +from tensorflow.keras.models import * + +from qkeras import * +from qkeras.utils import model_quantize + + +def create_network(): + xi = Input((28,28,1)) + x = Conv2D(32, (3, 3))(xi) + x = Activation("relu")(x) + x = Conv2D(32, (3, 3), activation="relu")(x) + x = Activation("softmax")(x) + x = QConv2D(32, (3, 3), activation="quantized_relu(4)")(x) + return Model(inputs=xi, outputs=x) + + +def test_linear_activation(): + m = create_network() + + assert m.layers[1].activation.__name__ == "linear", "test failed" + + +def test_linear_activation_conversion(): + m = create_network() + + d = { + "QConv2D": { + "kernel_quantizer": "binary", + "bias_quantizer": "binary", + "activation_quantizer": "binary" + } + } + qq = model_quantize(m, d, 4) + + assert str(qq.layers[1].activation) == "binary()" + + +def test_no_activation_conversion_to_quantized(): + m = create_network() + d = {"QConv2D": {"kernel_quantizer": "binary", "bias_quantizer": "binary"}} + qq = model_quantize(m, d, 4) + assert qq.layers[2].__class__.__name__ == "Activation" + assert qq.layers[4].__class__.__name__ == "Activation" + + +def test_automatic_conversion_from_relu_to_qr(): + m = create_network() + d = { + "QConv2D": { + "kernel_quantizer": "binary", + "bias_quantizer": "binary" + }} + qq = model_quantize(m, d, 4) + assert str(qq.layers[3].activation) == "quantized_relu(4,0)" + + +def test_conversion_from_relu_activation_to_qr_qactivation(): + m = create_network() + d = { + "QConv2D": { + "kernel_quantizer": "binary", + "bias_quantizer": "binary" + }, + "QActivation": { + "relu": "ternary" + } + } + qq = model_quantize(m, d, 4) + assert qq.layers[2].__class__.__name__ == "QActivation" + assert str(qq.layers[2].quantizer) == "ternary()" + assert qq.layers[4].__class__.__name__ == "Activation" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/qactivation_test.py b/tests/qactivation_test.py index 72e51a8a..d2d82349 100644 --- a/tests/qactivation_test.py +++ b/tests/qactivation_test.py @@ -14,6 +14,9 @@ # limitations under the License. # ============================================================================== """Test activation from qlayers.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import numpy as np from numpy.testing import assert_allclose @@ -27,10 +30,12 @@ from qkeras import quantized_relu from qkeras import quantized_relu_po2 from qkeras import smooth_sigmoid +from qkeras import stochastic_binary +from qkeras import stochastic_ternary from qkeras import ternary @pytest.mark.parametrize( - 'bits, max_value, use_stochastic_rounding, quadratic_approximation, ' + + 'bits, max_value, use_stochastic_rounding, quadratic_approximation, ' 'test_values, expected_values', [ # bits=4 without max_value. Therefore the max exponent is 4 when # quadratic approximiation is enabled. The max and min values from this @@ -382,5 +387,66 @@ def test_stochastic_round_quantized_relu_po2(test_values, expected_values): assert_allclose(res, expected_values, rtol=1e-01, atol=1e-6) +def test_stochastic_binary(): + np.random.seed(42) + + x = np.random.uniform(-0.01, 0.01, size=10) + x = np.sort(x) + + s = stochastic_binary(alpha="auto_po2") + ty = np.zeros_like(s) + ts = 0.0 + + n = 1000 + + for _ in range(n): + y = K.eval(s(K.constant(x))) + scale = K.eval(s.scale)[0] + ts = ts + scale + ty = ty + (y / scale) + + result = (ty/n).astype(np.float32) + scale = np.array([ts/n]) + + expected = np.array( + [-1., -1., -1., -0.852, 0.782, 0.768, 0.97, 0.978, 1.0, 1.0] + ).astype(np.float32) + expected_scale = np.array([0.003906]) + + assert_allclose(result, expected, atol=0.1) + assert_allclose(scale, expected_scale, rtol=0.1) + + +def test_stochastic_ternary(): + np.random.seed(42) + + x = np.random.uniform(-0.01, 0.01, size=10) + x = np.sort(x) + + s = stochastic_ternary(alpha="auto_po2", temperature=8) + ty = np.zeros_like(s) + ts = 0.0 + + n = 1000 + + for _ in range(n): + y = K.eval(s(K.constant(x))) + scale = K.eval(s.scale)[0] + ts = ts + scale + ty = ty + (y / scale) + + result = (ty/n).astype(np.float32) + scale = np.array([ts/n]) + + expected = np.array( + [-0.998, -0.992, -0.992, -0.208, 0.048, 0.04 , 0.448, 0.606, + 0.987, 0.998] + ).astype(np.float32) + expected_scale = np.array([0.007812]) + + assert_allclose(result, expected, atol=0.1) + assert_allclose(scale, expected_scale, rtol=0.1) + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/qalpha_test.py b/tests/qalpha_test.py new file mode 100644 index 00000000..85ba0df3 --- /dev/null +++ b/tests/qalpha_test.py @@ -0,0 +1,127 @@ +# Copyright 2020 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test get_weight_scale function with auto and auto_po2 modes of quantizers.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import logging +from numpy.testing import assert_allclose +import pytest +from tensorflow.keras import backend as K +from qkeras import binary +from qkeras import get_weight_scale +from qkeras import ternary + + +# expected value if input is uniform distribution is: +# - alpha = m/2.0 for binary +# - alpha = (m+d)/2.0 for ternary + + +def test_binary_auto(): + """Test binary auto scale quantizer.""" + + np.random.seed(42) + N = 1000000 + m_list = [1.0, 0.1, 0.01, 0.001] + + for m in m_list: + x = np.random.uniform(-m, m, (N, 10)).astype(K.floatx()) + x = K.constant(x) + + quantizer = binary(alpha="auto") + q = K.eval(quantizer(x)) + + result = get_weight_scale(quantizer, q) + expected = m / 2.0 + logging.info("expect %s", expected) + logging.info("result %s", result) + assert_allclose(result, expected, rtol=0.02) + + +def test_binary_auto_po2(): + """Test binary auto_po2 scale quantizer.""" + + np.random.seed(42) + N = 1000000 + m_list = [1.0, 0.1, 0.01, 0.001] + + for m in m_list: + x = np.random.uniform(-m, m, (N, 10)).astype(K.floatx()) + x = K.constant(x) + + quantizer_ref = binary(alpha="auto") + quantizer = binary(alpha="auto_po2") + + q_ref = K.eval(quantizer_ref(x)) + q = K.eval(quantizer(x)) + + ref = get_weight_scale(quantizer_ref, q_ref) + + expected = np.power(2.0, np.round(np.log2(ref))) + result = get_weight_scale(quantizer, q) + + assert_allclose(result, expected, rtol=0.0001) + + +def test_ternary_auto(): + """Test ternary auto scale quantizer.""" + + np.random.seed(42) + N = 1000000 + m_list = [1.0, 0.1, 0.01, 0.001] + + for m in m_list: + x = np.random.uniform(-m, m, (N, 10)).astype(K.floatx()) + x = K.constant(x) + + quantizer = ternary(alpha="auto") + q = K.eval(quantizer(x)) + + d = m/3.0 + result = np.mean(get_weight_scale(quantizer, q)) + expected = (m + d) / 2.0 + assert_allclose(result, expected, rtol=0.02) + + +def test_ternary_auto_po2(): + """Test ternary auto_po2 scale quantizer.""" + + np.random.seed(42) + N = 1000000 + m_list = [1.0, 0.1, 0.01, 0.001] + + for m in m_list: + x = np.random.uniform(-m, m, (N, 10)).astype(K.floatx()) + x = K.constant(x) + + quantizer_ref = ternary(alpha="auto") + quantizer = ternary(alpha="auto_po2") + + q_ref = K.eval(quantizer_ref(x)) + q = K.eval(quantizer(x)) + + ref = get_weight_scale(quantizer_ref, q_ref) + + expected = np.power(2.0, np.round(np.log2(ref))) + result = get_weight_scale(quantizer, q) + + assert_allclose(result, expected, rtol=0.0001) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/qconvolutional_test.py b/tests/qconvolutional_test.py index 90d1a0a9..075fc7c7 100644 --- a/tests/qconvolutional_test.py +++ b/tests/qconvolutional_test.py @@ -14,12 +14,15 @@ # limitations under the License. # ============================================================================== """Test layers from qconvolutional.py.""" - +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import os -import tempfile import numpy as np from numpy.testing import assert_allclose import pytest +import tempfile + from tensorflow.keras import backend as K from tensorflow.keras.layers import Activation from tensorflow.keras.layers import Flatten @@ -27,6 +30,8 @@ from tensorflow.keras.models import Model from tensorflow.keras.backend import clear_session +from qkeras import binary +from qkeras import ternary from qkeras import QActivation from qkeras import QDense from qkeras import QConv1D @@ -41,7 +46,7 @@ from qkeras import extract_model_operations -# TODO: +# TODO(hzhuang): # qoctave_conv test # qbatchnorm test @@ -50,9 +55,9 @@ def test_qnetwork(): x = QSeparableConv2D( 32, (2, 2), strides=(2, 2), - depthwise_quantizer="binary", - pointwise_quantizer=quantized_bits(4, 0, 1), - depthwise_activation=quantized_bits(6, 2, 1), + depthwise_quantizer=binary(alpha=1.0), + pointwise_quantizer=quantized_bits(4, 0, 1, alpha=1.0), + depthwise_activation=quantized_bits(6, 2, 1, alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='conv2d_0_m')( x) @@ -60,7 +65,7 @@ def test_qnetwork(): x = QConv2D( 64, (3, 3), strides=(2, 2), - kernel_quantizer="ternary", + kernel_quantizer=ternary(alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='conv2d_1_m', activation=quantized_relu(6, 3, 1))( @@ -68,7 +73,7 @@ def test_qnetwork(): x = QConv2D( 64, (2, 2), strides=(2, 2), - kernel_quantizer=quantized_bits(6, 2, 1), + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='conv2d_2_m')( x) @@ -76,7 +81,7 @@ def test_qnetwork(): x = Flatten(name='flatten')(x) x = QDense( 10, - kernel_quantizer=quantized_bits(6, 2, 1), + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='dense')( x) @@ -90,7 +95,6 @@ def test_qnetwork(): model = quantized_model_from_json(json_string) # generate same output for weights - np.random.seed(42) for layer in model.layers: all_weights = [] @@ -126,27 +130,27 @@ def test_qnetwork(): assert np.all(all_weights == all_weights_signature) # test_qnetwork_forward: - expected_output = np.array([[0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 0.e+00, 0.e+00, 6.e-08, 1.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00, 0.e+00, 0.e+00, 5.e-07, 1.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00 ,1.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00, - 0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 1.e+00, - 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00], - [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, - 1.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00]]).astype(np.float16) - + expected_output = np.array( + [[0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 0.e+00, 0.e+00, 6.e-08, 1.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [ 0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 0.e+00, 0.e+00, 5.e-07, 1.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00 ,1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 1.e+00, + 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 1.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00]]).astype(np.float16) inputs = 2 * np.random.rand(10, 28, 28, 1) actual_output = model.predict(inputs).astype(np.float16) assert_allclose(actual_output, expected_output, rtol=1e-4) @@ -157,25 +161,25 @@ def test_qconv1d(): x = Input((4, 4,)) y = QConv1D( 2, 1, - kernel_quantizer=quantized_bits(6, 2, 1), + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), bias_quantizer=quantized_bits(4, 0, 1), name='qconv1d')( x) model = Model(inputs=x, outputs=y) - #Extract model operations + # Extract model operations model_ops = extract_model_operations(model) # Assertion about the number of operations for this Conv1D layer - assert model_ops['qconv1d']["number_of_operations"] == 32 + assert model_ops['qconv1d']['number_of_operations'] == 32 # Print qstats to make sure it works with Conv1D layer - print_qstats(model) + print_qstats(model) # reload the model to ensure saving/loading works - json_string = model.to_json() - clear_session() - model = quantized_model_from_json(json_string) + # json_string = model.to_json() + # clear_session() + # model = quantized_model_from_json(json_string) for layer in model.layers: all_weights = [] @@ -189,16 +193,15 @@ def test_qconv1d(): 10.0 * np.random.normal(0.0, np.sqrt(2.0 / input_size), shape)) if all_weights: layer.set_weights(all_weights) - # Save the model as an h5 file using Keras's model.save() fd, fname = tempfile.mkstemp('.h5') model.save(fname) del model # Delete the existing model - # Returns a compiled model identical to the previous one + # Return a compiled model identical to the previous one model = load_qmodel(fname) - #Clean the created h5 file after loading the model + # Clean the created h5 file after loading the model os.close(fd) os.remove(fname) @@ -207,12 +210,6 @@ def test_qconv1d(): inputs = np.random.rand(2, 4, 4) p = model.predict(inputs).astype(np.float16) - ''' - y = np.array([[[0.1309, -1.229], [-0.4165, -2.639], [-0.08105, -2.299], - [1.981, -2.195]], - [[-0.3174, -3.94], [-0.3352, -2.316], [0.105, -0.833], - [0.2115, -2.89]]]).astype(np.float16) - ''' y = np.array([[[-2.441, 3.816], [-3.807, -1.426], [-2.684, -1.317], [-1.659, 0.9834]], [[-4.99, 1.139], [-2.559, -1.216], [-2.285, 1.905], diff --git a/tests/qlayers_test.py b/tests/qlayers_test.py index 1fc60ff8..5e1f0f84 100644 --- a/tests/qlayers_test.py +++ b/tests/qlayers_test.py @@ -14,10 +14,13 @@ # limitations under the License. # ============================================================================== """Test layers from qlayers.py.""" - +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import numpy as np from numpy.testing import assert_allclose import pytest +import logging from tensorflow.keras import backend as K from tensorflow.keras.layers import Activation from tensorflow.keras.layers import Flatten @@ -51,27 +54,34 @@ def qdense_util(layer_cls, @pytest.mark.parametrize( - 'layer_kwargs, input_data, weight_data, bias_data, expected_output', [ - ({ - 'units': 2, - 'use_bias': True, - 'kernel_initializer': 'glorot_uniform', - 'bias_initializer': 'zeros' - }, np.array([[1, 1, 1, 1]], dtype=K.floatx()), - np.array([[10, 20], [10, 20], [10, 20], [10, 20]], - dtype=K.floatx()), np.array([0, 0], dtype=K.floatx()), - np.array([[40, 80]], dtype=K.floatx())), - ({ - 'units': 2, - 'use_bias': True, - 'kernel_initializer': 'glorot_uniform', - 'bias_initializer': 'zeros', - 'kernel_quantizer': 'quantized_bits(2,0)', - 'bias_quantizer': 'quantized_bits(2,0)', - }, np.array([[1, 1, 1, 1]], dtype=K.floatx()), - np.array([[10, 20], [10, 20], [10, 20], [10, 20]], dtype=K.floatx()), - np.array([0, 0], dtype=K.floatx()), np.array([[2, 2]], - dtype=K.floatx())), + 'layer_kwargs, input_data, weight_data, bias_data, expected_output', + [ + ( + { + 'units': 2, + 'use_bias': True, + 'kernel_initializer': 'glorot_uniform', + 'bias_initializer': 'zeros' + }, + np.array([[1, 1, 1, 1]], dtype=K.floatx()), + np.array([[10, 20], [10, 20], [10, 20], [10, 20]], + dtype=K.floatx()), # weight_data + np.array([0, 0], dtype=K.floatx()), # bias + np.array([[40, 80]], dtype=K.floatx())), # expected_output + ( + { + 'units': 2, + 'use_bias': True, + 'kernel_initializer': 'glorot_uniform', + 'bias_initializer': 'zeros', + 'kernel_quantizer': 'quantized_bits(2,0,alpha=1.0)', + 'bias_quantizer': 'quantized_bits(2,0)', + }, + np.array([[1, 1, 1, 1]], dtype=K.floatx()), + np.array([[10, 20], [10, 20], [10, 20], [10, 20]], + dtype=K.floatx()), # weight_data + np.array([0, 0], dtype=K.floatx()), # bias + np.array([[2, 2]], dtype=K.floatx())), #expected_output ]) def test_qdense(layer_kwargs, input_data, weight_data, bias_data, expected_output):