From 29f646d04ec9211c16452250848f9f9056593758 Mon Sep 17 00:00:00 2001 From: Ahmed Allam Date: Fri, 8 Sep 2023 12:24:32 +0300 Subject: [PATCH] Added an example for a vision transformer (vit) --- examples/vision_transformer.ipynb | 543 ++++++++++++++++++++++++++++++ 1 file changed, 543 insertions(+) create mode 100644 examples/vision_transformer.ipynb diff --git a/examples/vision_transformer.ipynb b/examples/vision_transformer.ipynb new file mode 100644 index 00000000..2b109dc3 --- /dev/null +++ b/examples/vision_transformer.ipynb @@ -0,0 +1,543 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "x96qLutZxINh" + }, + "source": [ + "# Vision Transformer (ViT)\n", + "\n", + "This example builds a vision transformer model using Equinox, an implementation based on the paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.\n", + "\n", + "\n", + "!!! cite \"Reference\"\n", + "\n", + " [arXiv link](https://arxiv.org/abs/2010.11929)\n", + "\n", + " ```bibtex\n", + " @article{dosovitskiy2020image,\n", + " title={An image is worth 16x16 words: Transformers for image recognition at scale},\n", + " author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},\n", + " journal={arXiv preprint arXiv:2010.11929},\n", + " year={2020}\n", + " }\n", + " ```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "3-NddIhhxINj" + }, + "outputs": [], + "source": [ + "from typing import List\n", + "import functools\n", + "\n", + "import numpy as np\n", + "\n", + "import einops\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import optax\n", + "\n", + "# We'll use PyTorch to load the dataset.\n", + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import equinox as eqx" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "id": "bYi-XlXRxINl" + }, + "outputs": [], + "source": [ + "# Hyperparameters\n", + "lr = 0.0002\n", + "dropout_rate = 0.2\n", + "beta1 = 0.9\n", + "beta2 = 0.999\n", + "batch_size = 32\n", + "patch_size = 4\n", + "num_patches = 64\n", + "num_steps = 500000\n", + "image_size = (64, 64, 1)\n", + "embedding_dim = 512\n", + "hidden_dim = 256\n", + "num_heads = 12\n", + "num_layers = 12\n", + "height, width, channels = image_size" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "Vcwi4un6CMu_" + }, + "outputs": [], + "source": [ + "# Load the MNIST dataset using torchvision\n", + "transform = transforms.Compose(\n", + " [\n", + " transforms.Resize((height, width)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,)),\n", + " ]\n", + ")\n", + "\n", + "train_dataset = torchvision.datasets.MNIST(\n", + " \"MNIST\",\n", + " train=True,\n", + " download=True,\n", + " transform=transform,\n", + ")\n", + "\n", + "test_dataset = torchvision.datasets.MNIST(\n", + " \"MNIST\",\n", + " train=False,\n", + " download=True,\n", + " transform=transform,\n", + ")\n", + "\n", + "trainloader = torch.utils.data.DataLoader(\n", + " train_dataset, batch_size=batch_size, shuffle=True\n", + ")\n", + "\n", + "testloader = torch.utils.data.DataLoader(\n", + " test_dataset, batch_size=batch_size, shuffle=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k3FO6HXqQEBq" + }, + "source": [ + "Let's load some example data, and see some sample MNIST digits." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 400 + }, + "id": "asQV91nfCMvA", + "outputId": "77f8436f-36f1-459e-95d7-3f73623df5d1" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAF/CAYAAADZ4XhyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9eXDl6VXfj7/vvu/7fqUrtaRuTfdMz+BtbDBrUVAOpgKVhCIxMaECpkgVCSEkFSqkAjHwtRP+CHYgIbGppBJCHKAMhYGA7SEwtmemu6dnWq1dV7r7vu/b74/5ndMfqaVudU93617peVWpPNbW9370fD7Pec55n/eRTSaTCQQCgUAgEAgEFwb5Wb8AgUAgEAgEAsGzRQSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFgaviRH/kRyGQyxOPxs34pAoFAcK4RAaBAMGO0Wi3823/7b3H9+nUYjUZoNBoEg0F86EMfwj//5/8cOzs7Z/0Sp5ZoNAqtVnvWL0MgEAjOHOVZvwCBQHB6Go0GPvjBD+L27dtYWFjAD//wD8PhcKBYLOIb3/gGfvmXfxmxWAyxWOysX6pAIBAIphgRAAoEM8Sv/dqv4fbt2/gH/+Af4Dd/8zchk8kOfX1vbw+9Xu+MXp1AIBAIZgVRAhYIZohXX30VAPCTP/mT9wV/ADA3N4fl5eVDn/vyl7+Mj3/841haWoLRaITRaMRLL72E3/zN3zz235DJZPjwhz+MVCqFH/qhH4LT6YTJZML3fu/3Ynd3FwBw9+5dfPSjH4XdbofJZMIP/MAPIJfLHfo98XgcMpkMP/IjP4I7d+7ge7/3e2G1WmE0GvFd3/VdeOONNx7pvb/yyiv4yEc+AqfTCY1Gg8XFRfzLf/kv0W63H+n3HOVzn/scZDIZPve5z+GLX/wi3vve90Kv1yMQCODnf/7nMR6PAQCf//znce3aNeh0OoTDYfx//9//d9/vSqfT+Ff/6l/hfe97H9xuNzQaDaLRKD7xiU8gn88f++/H43H8rb/1t2C322E0GvEt3/IteOWVV/ALv/ALkMlk+MpXvvKursUXvvAFfMu3fAvcbje0Wi38fj++4zu+A1/4whfe1XUTCASzjcgACgQzhMPhAABsbm7i+eefP9XP/Mqv/Aq2t7fxvve9D9///d+ParWKL33pS/iH//AfYmNjA5/+9Kfv+5lKpYIPfvCD8Hq9+NjHPobNzU384R/+IdbX1/EHf/AH+NCHPoQXX3wRH//4x/HGG2/gC1/4AsrlMv7iL/7ivt+1u7uLl19+GdevX8dP/MRPYH9/H7/7u7+Lb/7mb8Zf/MVf4L3vfe9D38NnP/tZ/ORP/iSsVis+8pGPwO124/XXX8cv/dIv4ctf/jK+/OUvQ61Wn+p6nMTv/d7v4U//9E/x0Y9+FC+//DL+6I/+CL/4i7+IyWQCi8WCX/zFX8T3fd/34cMf/jC+8IUv4Gd/9mfh8Xjw9/7e3+Pf8corr+DTn/40vv3bvx3vfe97oVKpcPPmTXz2s5/Fn/zJn+DGjRuwWCz8/alUCh/4wAeQyWTw3d/93XjhhRewsbGB7/zO78S3fdu3vetr8dnPfhaf+MQn4PP58P3f//1wOBzIZrP4xje+gd/7vd/D3/ybf/NdXTOBQDDDTAQCwczwB3/wBxMAE5PJNPkn/+SfTP7kT/5kUiwWH/gzu7u7931uMBhMvvM7v3OiUCgm+/v7h74GYAJg8tM//dOHPv8TP/ETEwATq9U6+bVf+zX+/Hg8nnzP93zPBMDkjTfe4M/v7e3x7/q5n/u5Q7/rS1/60gTA5Lnnnjv0+Y997GMTAJO9vT3+3J07dyZKpXJy7dq1+97rJz/5yQmAyac+9akHXgMiEolMNBrNoc/91//6XycAJiqVavKNb3yDP1+v1ydut3ui1+snXq93srOzw187ODiYqNXq+15/LpebNBqN+/7dz3/+8xMAk1/8xV889Pkf/uEfngCY/NIv/dKhz//Wb/0WX7svf/nLj30trl+/PlGr1ZNcLnffa3rYuhEIBOcbEQAKBDPGpz/96YnRaOQAAcAkFotNfvInf3Kyubl56t/zhS98YQJg8rnPfe7Q5wFMjEbjpNVqHfr8K6+8wv/WeDw+9LXf/u3fngCY/Jf/8l/4cxQAWq3WY4Oib//2b58AmLz++uv8ueMCwH/0j/7RBMDklVdeue93jEajicvlmrz44ounes8PCgD//t//+/d9/8c//vEJgMm//tf/+r6vfdu3fdtEoVBMBoPBQ//d8Xg8MZvNkw9/+MP8uW63O9FoNBO32z3pdrv3ff/S0tJ9AeCjXovr169PDAbDpFwuP/Q1CgSCi4UoAQsEM8Y//sf/GD/2Yz+GL33pS/jrv/5rvP766/j617+OX//1X8dv/dZv4Xd+53fwN/7G3+DvbzQa+NSnPoXf//3fx87ODlqt1qHfl06n7/s3FhcXodfrD33O5/MBAK5evXqf/pC+dtzveuGFF2A0Gu/7/Ic+9CH8+Z//OW7evIkXX3zxxPf7ta99DQDwJ3/yJ/jzP//z+76uUqmwvr5+4s+fluNK6vS+TvraaDRCLpdDIBDgz/+f//N/8Bu/8Ru4ceMGKpUKRqMRf016fTY2NtDr9fDSSy9Bo9Ec+t0ymQwf+MAHsLGxcejzj3ot/vbf/tv42Z/9WayuruKHfuiH8K3f+q344Ac/CLPZ/IArIRAILgIiABQIZhCTyYQf/MEfxA/+4A8CAGq1Gv7Fv/gX+MxnPoMf/dEfRSqVglqtRr/fx4c//GHcuHEDL7zwAv7u3/27cDgcUCqViMfj+PznP39s1/BxAYJSqXzo1waDwX1f83g8x74H+nytVnvgey2XywCAX/qlX3rg971bnsR7/vSnP42f+Zmfgcvlwnd913chGAxCp9MBeKeDW3qt6/U6AMDtdh/7eo67bo96LX7mZ34GDocDn/3sZ/HpT38an/rUp6BUKvG93/u9+Pf//t9jbm7uVL9HIBCcP0QAKBCcAywWC/7Df/gP+KM/+iPs7+/jrbfewosvvog/+IM/wI0bN/CjP/qj+M//+T8f+pn/+T//Jz7/+c8/9dd2tDv46OelTRHHQcFXvV6HyWR6si/uCTIcDvFv/s2/gc/nw61btw4FdpPJBL/6q7966PvpfZ3UHXzcdXvUayGTyfDxj38cH//4x1EqlfCXf/mX+B//43/gf/2v/4WtrS3cvn0bCoXi1O9RIBCcH4QNjEBwTpDJZDAYDIc+R1NBvu/7vu++7//Lv/zLZ/K6bt68iWazeeK//8ILLzzw56lLmMqf00qxWEStVsP73//++7J6r7/+OjqdzqHPLS0tQaPR4I033rgvCzuZTNjyR8q7uRYOhwMf/ehH8Tu/8zv4tm/7NqytrWF7e/uRf49AIDgfiABQIJghfuM3fgOvvfbasV/7/d//fdy9exdWqxWrq6sAgEgkAgD4f//v/x363q9+9av4T//pPz3dF/v/p1qt3leyJA3b6urqA/V/APCJT3wCSqUSP/VTP4WDg4Njf//Nmzef6Gt+HNxuN3Q6HW7cuHHIj69SqeCnfuqn7vt+jUbD/om/9mu/duhrv/3bv32srvFRr8VXvvIVTCaTQ98zGAy4lCzG4gkEFxdRAhYIZog//uM/xo//+I9jYWEBL7/8Mvx+P1qtFm7evIm//Mu/hFwux2c+8xluKvjIRz6CaDSKX/3VX8Xbb7+N1dVVbGxs4A//8A/x/d///fjf//t/P/XX/KEPfQif/exn8fWvfx3ve9/7EI/H8bu/+7vQ6XT3laWPY3V1FZ/5zGfwEz/xE1haWsL3fM/3IBaLodFoYHd3F1/96lfxIz/yI/iP//E/PvX38iDkcjk+8YlP4NOf/jSuXbuGj3zkI6jX6/jjP/5jRCIR+P3++37mk5/8JP7v//2/+Lmf+zl89atfZR/AP/zDP8R3f/d340tf+hLk8nvn9Ee9Fh/96EdhNpvxvve9D5FIBIPBAH/2Z3+GtbU1/MAP/AAfEAQCwcVDBIACwQzxK7/yK3j55ZfxZ3/2Z3jllVeQyWQAAIFAAB/72MfwUz/1U4cyakajEX/xF3+Bf/pP/yleeeUVfOUrX8GVK1fw3//7f4fH43kmAeD8/Dw++9nP4md/9mfx67/+6xiNRvjwhz+MX/7lX35o9o/4sR/7MTz//PP4d//u3+GVV17BF7/4RVgsFoTDYfz0T/80Pvaxjz3ld3E6PvnJT8Jut+Nzn/scPvOZz8Dj8eDv/J2/g1/4hV/grKyUUCiEV199Ff/sn/0z/Omf/im++tWv4sUXX8Sf/umf4nd/93cB3N+A8ijX4pOf/CS+9KUv4Rvf+Aa++MUvwmAwIBaL4bOf/Sx+9Ed/9OleDIFAMNXIJkfrAwKBQPAEiMfjmJubw8c+9jF87nOfO+uXM3N88IMfxKuvvoparXasjY5AIBC8G4QGUCAQCM4QyuJK+W//7b/hr/7qr/Ad3/EdIvgTCARPBVECFggEgjNkdXUVL7zwAi5fvgyFQoFbt27hK1/5CkwmEz71qU+d9csTCATnFBEACgQCwRny4z/+4/jiF7+I119/Ha1WCy6XCz/0Qz+En//5n8fy8vJZvzyBQHBOERpAgUAgEAgEgguG0AAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQwSAAoFAIBAIBBcMEQAKBAKBQCAQXDBEACgQCAQCgUBwwRABoEAgEAgEAsEFQ3nab5TJZE/zdcw8k8nk2M+L6/ZgTrpugLh2D0OsucdDrLnHR6y5x0Nct8dD3KuPz4OuHSEygAKBQCAQCAQXDBEACgQCgUAgEFwwRAAoEAgEAoFAcMEQAaBAIBAIBALBBUMEgAKBQCAQCAQXjFN3AU8LMpns0IdcLodcLj+2I+i475V+33g8xnA4xGg04o/TdM4IBAKBQHDeof3y6P46mUzEXnkOmJkAkAI4pVIJlUoFjUYDnU4HvV4PvV4PtVoNuVwOhULBP6NQKKBWq6HRaKBWq6HVaqFWq/nrrVYLxWIR5XIZhUIB5XIZvV7vLN6eQCAQCARnjjRholAooFQqDyVPxuMx+v0+RqMRxuOxCARnmJkKACnwMxgMMJvNsNvt8Hg8cLvdHAQqlUpeqBqNBiaTCWazGRaLBVarFXq9HjKZDJPJBNlsFuvr61hfX8edO3fQarVEACgQCASCC4tMJoNCoYBKpTqUQKF9td/vo9PpoNPpAIAIAmeYqQ4AFQoFZ/qMRiMHcfThcDjg9XrhcrlgMBigUqkOBYBarRZGoxFGoxF2ux1msxkGgwEymQzj8RiJRAKTyQStVgvpdBpK5VRfDsEZQlIDrVYLk8kEg8GAyWSCbreLbreLXq+HTqeD0Wh01i9VMIVIsyparZYrF0ajEVqtluUo/X4f/X4fvV7vgZvqZDLBYDBAr9fDcDjEYDDAeDzm56VKpbrvdw4Gg2f4jqcLafXIYrHAZrNBq9ViNBqh3++jXq+jVquh3W5f2IBGqVRCo9FAr9fzXms0GqHT6XjfBIB6vY5MJoNCoYBWq4VOp3Oh19YsM7URj1wuh1qthsvlgs/nQyAQQCgUQiQSgcPhgMlk4uBOWgKWy9/pa6FTjFqt5vKvRqMB8M7Dk25ySnPTA1ogOA6lUgmdTgeXy4WFhQXMzc1xFjmbzSKfzyOXy6HVap31SxVMITKZDEqlElqtFi6XC8FgEHNzc5ifn4fb7Ua/30etVkO9XkelUkGpVEK32z3x943HYzQaDZTLZdRqNTQaDQwGA9jtdvh8PlgsFj6UVKtV5PP5C71Jq1QqGAwG2Gw2LC8v44UXXoDb7Ua320WlUuEqUCKRwGAwwHA4POuX/MzRarVwOp3w+/2IxWKYm5uDy+XixAntralUCjdu3MDdu3eRzWZRKpUwHA4vZNA860xlAEgnEbPZjEAggEuXLmFxcZE/nE4nZ/wUCgXkcjkHdRTYUXBH2oXJZIJer4der4fRaIThcMgPTjrBjMfjs37rTwR63xTc0jWiDMRkMsFwOOQGmPF4LDJXD0GtVsNkMsHr9eLKlSu4fv06BoMBtra2oNFoMBwOUalUzm0AKD0cSf/76BqT3nO0YZzEZDLh5ivpvSf93HkQm9NhVq/Xw2w2IxwOY3l5GdeuXcO1a9cQDofR6XRYi5zJZJDJZNBut0/8naPRCKVSCblcDtlsFpVKBf1+H4FAAAsLC3A4HOh2u6jX68jn8xyASpve6L/POzKZDGq1GmazGV6vF5cvX8aHP/xhRKNRtFot5PN5qFQq5HI55HI5fh7O+ro7LXQPm81mBINBLC4u4tq1a7hy5Qr8fj/MZjOMRiOUSiUmkwl2d3cxGo3Q6XQwHA7RarXQarUuzPU6T0xlAGgwGBAIBHgxXrp0CfPz8/D7/fB6vTCbzdBoNFAoFLxBdLtdNJtNNJtNDIdDjMdj1g3KZDJ0Oh1OV9PJOJ1OY3NzE7u7u8jn8+j3+2f91t8V0qynwWCAyWQ6lCnV6XTQaDQYjUYoFAooFAqoVqtoNptotVoXYjN4HGQyGfR6Pfx+PxYWFjA/P39o085ms7wezxvSTnvpB6HVamG1WmGxWLikqdPpOON+UladDiGNRgPVapWzXXQwq9fraDabnI2Zxc2Frp1KpYLD4UAwGEQ4HMbc3BwuXbqEhYUFBAIB2O129Pv9Q/pmygqexGg0Qq1WQ7lcRrlcRr1ex3A45OyiyWRirValUkEul0OhUEClUjn0UavVzv19T/evx+PB3NwcwuEwPB4PXC4XTCYTZDIZbDYb9Hr9oYaHWVxzj4pCoWCJVTQa5cAvFoshHA7DbrdDr9dDq9Xyfe9wOBCJRFAul9HpdFAsFlGpVM7FYe2iMZUBoNFoRCwWw+rqKpaXl7GwsAC/3w+j0ciZP8pkURaBTnKFQoFPJnK5HDqdDgqFApVKhTt9aXPJZrNIJBLIZrOo1Woz3wBCOhedTgeHwwGfzwe/3w+32w2v18s6yF6vh7t372J9fR0HBwfIZDJCv/YAaAMJh8NYXFxEOByG1+tFrVaDyWSCWq3mg8Z5QyoIp45AaaBrs9kQiUQQDAbhdrvhcrl4nZnNZqhUqvt+J2VYer0eMpkM9vf3Ua1WAQC9Xg+pVAr7+/sYj8dot9szm405Wva9fPkyrl27xkGI2+3ma0SBtcFggN1uR7fbfWBFYjweo9vtot1uo9Pp8PcbjUaYzWbWtw0GA7TbbTSbTRQKBcTjcf7Y3d099wc/Wr8Gg4EPcMFgEBaLBVqtFkqlEt1uFwaDgQ9xF0kOpFAoYDQa4Xa7EYvF8NJLL+Hq1atwOByHrpH00KfX6xEMBnlN7ezsQC6XX1jt5CwzVQEgiXQdDgfm5+fx3HPPIRaLIRKJwGazsahZKnymsm6pVEIqlUIul+Msn1Kp5FNdqVRCJpNBsVhEo9FAvV5HqVRCPp9HvV4/F4tXpVLBarXC6XQiEokgGo0iHA7D7/fD7/fD6XTCarWi0+lAq9VCoVBAq9VCp9NBp9NxAEwb9GAwYAE5ZWFm/Ro9LlQCttlssFqtMJvNGAwG3Dh03q4LlYVIimEymaDVarkzkDZIu93O96jP54PX64XD4eBGLantEkH3cbfbRSqVgsPhQKlUAgB0u12YTCb+dwqFAorF4kwdzuggptVqueHg0qVLuHLlCq5evYpgMAiPx3NIV0U/A5xuLUllHFTKnUwm7JSgUChYEkPPynK5DIfDAbPZDLVajX6/j3a7zYff8xYISrNbfr8fkUgEsVgMfr+frz0966TyoYsAHeZI1jI3N4eVlRWsrKxgcXERGo3mvsCPUKvVrJ+Mx+PsrjHLUipaK2QVp1Qq+X6kZx2tjeM8EYnxeMwxibT5alrX1dQEgHT6tdlsiEajmJ+fx/z8PDweD3Q6HUajEZeLqtUqSqUS/zd9UEmDPIrkcjk0Gg3kcjmazSaXO6lzs9lsnpvMl0wm48xpLBbD/Pw8otEofD4fBy3UMKPT6TA3Nwe5XA6Px4NCoYBSqcQi8V6vh3q9jnq9zqWjWq12aKO5aJDWpdlscmmOdDDnsYSuUqmg1+vhcDgQi8UQjUZhs9n4IUkPQaPRCKfTCafTCYvFAovFckhucFxn/WQy4Uyiy+XCZDKB2+0GAAwGAzidTvh8Puzt7eGtt95Cu92eqQCQ9GYul4sbPWKxGC5duoRAIACbzcbPJSknme6ehDQ7S4cz0mPS1+hz9D2hUAgqlYozs3K5HIlEArlcDo1G48leiDNGq9Vy8+Dy8jKXNj0eD7RaLbrdLorFIr//er3OyYVZDGJOi1wuh16vP6Rpvnr1Kq5cuQKPx/PA4A+4dxh2Op3weDzw+XzIZrOoVquo1Wozee10Oh0ikQjC4TA/y4xGI99LwOkCwHa7fagxkKqO09pUNFUBIJ3UKACMRqMwmUxQKBTodrsolUpIJBJIJBLY399HMplEOp3mh5c02iYNIP0B6RQsFZifFxE0aY2sVisWFxfxTd/0TazhcDqd95Xv1Go1wuEwrFYrYrEYGo0GWq0W37jNZhO5XA7pdBobGxvcQNPtdrnkftEYDodot9u8zihD2mq1UKvVWHZwXlCr1VzefeGFF/Diiy9y5kSv17MEg4IL6Yc0CKG1ctyhQalUwmazwWAw8PeNRiMEg0HEYjEEAgF0u13s7e2hUqk80/f/uMhkMmi1WtjtdszPz+PFF1/ECy+8gEgkArvdDqvVemJg/KhQ2Ziu7WQyObZ8Sf8WZRlJ70ZZXjogn8cAMBQK4fnnn8eVK1ewvLyM+fl5zoCSNnJvbw/pdBq1Wg3dbneqMzZPAoVCAZPJxGXf559/Hi+//DIikQhLEh50CFGpVDCZTBgOh3C73fD7/chkMhiPx2i1WjPZbW4wGDA/P4/3vOc9iMVi8Pl8cDqdh6odpwkAq9Uq1tbWsLa2ho2NDYxGI9bnTiNTEwBSwwb5/VksFhgMBqjVahaL7+/v49atW9jb20MqlUI2m+USUafTOfGmPfoHPE+QPs1oNCIUCmFxcRErKysIBoNwuVzQaDRoNpsol8uYTCa8+QyHQ76RtVotbDYbl+B7vR58Ph/cbjeX8GQyGUqlEsrl8oUMAHU6HbxeL0KhEKxWK/uskdC+1+vN5Mn3OORyOSwWCy5duoSrV6/i2rVrWF5ehtvthkajgUql4hJuu91GtVpFr9fjwwUFgkqlkoNkKjGSVx0Fkjqdjs1mKWjU6XSwWCzo9/twOBzH6ginDemkIqvVyiW15eVlXLp0iSsZ0g3lSf7b0v897msAODjX6XQspQHeyVqkUim+t2f9/qYgmLSXc3NziEQi8Hg8h2QJvV4PhUIBiUQC+XyeGwjP4z4B3DsAUKJlZWUFV65cweXLlxGJROB0Og9lvE5Cqm01m81wOBxwOp2oVqsz66WrVqvh9XqxvLyMpaUluFyuY589dMg6+jmiVqvxdZbJZGzVRJKNabu3puqvRZkDqZnzcDhEp9NBqVTC+vo6XnnlFezt7aHdbrMA+mHlofN6QwPvXDOLxYJgMMiLd25uDhaLBUqlEq1WC/v7+9jZ2cF4PIbD4eCW/qN2MVS6UigUsNvtHEDS17a2ti7ktBS5XA6n04lr167hfe97H1wuF5v3kg613++fm3Umk8ngdDrx3ve+F+973/sQDAbh9Xqh1Wq5475er6NYLCKdTuPg4AClUgkajYYPcGSW3Wq1kEwm2dduOBzCZrMhFAohGAyy2NxkMrGBMXl6WiwWbuKadqSbotvt5oaP+fl5OBwO6PX6U22uTxsq61ksFkSjUahUKpRKJaytrSGdTnMzyayvZZlMBo1GA4vFAo/HA6fTyVlPotfroVgscvB73p9rpIm32+1YWlrCyy+/jKtXr8Ln88FsNj/UtukodI3NZjOsVushTeusQb0HpJnX6XScXZcmkI5mAelz9P91Oh0CgQB7DhcKBeRyOc4EigDwBKghgQT2Wq0WwDs3aaPRQD6fx/7+Pu7evYuDg4NDZY+LjEKhgNVqxfz8PJaWlhAKheB0OgG8I6jPZrPY2trCrVu3MBgM2CSWMjSUrdFoNHwCtFqtsNlsMBqNPJWAjGoPDg7O+B2fDQaDAeFwGLFYDGq1mjdJ8lM7D5umFMoK22w2AEClUuGAl6wfstksHy7y+fwhSxiLxQKz2YxGo4F4PI5cLod2u43hcAiPx8PWL/1+nzvXR6MRl59mrRtTqVTCYrHA6XRifn4ey8vLWFxchNfr5U7xJw2tt6Nrj67bSdePNGAkCaGmlGQyiWq1OvNZMFpPZrMZNpsNLpeL9avSAKXf76NarSKXy6FSqZz7AJCaBElicfnyZaysrPDB66izxtF1Jc1s0f9XqVTcSEi/Y9agw5vBYOAGNul9c/ReOM6flJ5ZSqWSrXNqtRqi0SgSiQTvo9NWHp+KAJAWEjWAUAYLeGfsTDqd5ixCv98/N6W2JwFl6xYXFxGLxeBwOAC8o0XIZDIspH/rrbe4a4vmIQP3rj219i8uLiIajSIQCHA2JxAIoFarYXd3dybKcU8aaXmPNFNH1+B56pCeTCaoVqvY2NgAcK9rl3S0dBgolUoolUooFotoNpuHNgNqFul2uzytggKLwWDAnXY6nQ5+vx9qtZqbFsins16vz0yTllarRTgcZkH9wsICB39PK4Mp7fClblapMTd1Mp4UBKpUKuh0OtjtdoRCIWQyGda6zuozliQEbrcboVAIPp8PdrsdJpPpPl/KwWDAbhD1ev3ca/8oO0VWODQOT7pGqFRJ1kLSgEWn0913mKHn3qxeN2ngJjWzP/o9wL1AkKo+o9GI9wZKYEndE6xWK8uyBoMBarUams3mM3+PD+LMA0DpBSPhdCQS4QCw0Wggm80imUyiUqlMrZjyrKDmD+qaputWrVaxv7+PtbU13L17FxsbG6jX61xSo82DTj86nQ6XLl1iR3eDwcAG0h6PB5VKhbVvF5WjWZVZfvA9iMlkgnq9js3NTZTLZW4QkE7RkRqrU1MMPUylsgLKGtLX6cNgMLA3m7QreDwes3kx/dvTHoyQDjcSieDFF1/Ec889h7m5OTidzvt8E58U0lnAlEkdjUbclEMjMAEcGwTS34o2Kr/fD5/Ph3K5jGKx+MRf77OC/hZer5elC9LypPTeHQwG7A5B2enzeD8TWq0WXq8X8/Pz8Pl8MJlMh7p9pfc2uUBQUEz7jEajuS8AlE6UmrXrR88puk+Pu08Ieq/dbpebAaWWT0qlkg+yAHiSGbmUxOPxZ/nWTsWZB4CkS7BarXC5XKzXICPTWq2GZDKJRCKBcrk8dSnUs4asbqjkRjqtRqOBdDqN/f19ZDIZlMtlNBoNNBoN7s6kYFqhUPAAcBpn1u/3eYQSTRWRusFfZGiAPE2pOI8l4FarhVQqhUqlwgEg6RypA/o03nEUaJA5udPphMvl4g3a5XJBp9MBADqdDk/oOTg4wMbGBjKZzANn4p4ldH+o1Wo2Xp+bmzuUXTktj5pJIU0R2UzUajW0220eo0nZe5vNxnpM0iVJXz8FjDqdjsvCswgdZNVqNZfhl5aW4PV6WfNM0NqlyVHSJqXzDGWpDAYDJwJkMhln9Wk9kcVavV7ngNpkMnEQSJBRO40wnEaN28OgphjKEB/d3+jw2W63Ua/XUa1WUS6XUSqV0Gq1+P5xu92IRqNwu918XfV6PVwuF8LhMA4ODmAwGKBUKqfKSu1M73YSkVqtVni9XrjdbjgcDl5sNKg7mUwiHo9zCVhwGDqFkAaDTrbZbBaZTAa1Wo3LRJT1kwYs0s6/o6cherA+rUzGrEGnQNJ0kD7yPGUPqAxbLBahUqnY0JQeXHQNTpOZozm4DocDzz//PK5evcpeeHa7HW63G1ar9VBAs7a2htu3b2N9fZ0bvqYR0tJZLBb4fD4+vJLNyKNw9Jo+bC0NBgMUCgUeZZlIJFAqldhJwW63w+/3IxwOIxgMcpn9uA5G6Rz1WUUul0Or1cJkMiEQCODy5ct47rnn4PP5DgW+NEGFNnLy9Tzv3n8A+O8snb1NXqb1eh17e3u4ffs2dnd3ORFA7hLhcJgzqwSNbUwmk9jb25vJBI1arYbVaoXD4Ti2iYUOvNVqFdvb29jb20MikUAqlUK1WuUEzMrKClQqFY+ppYMYzeR2u90wGo1swD4tlcypCQA9Hg/cbjd7VFGqtVqtIp1OI51Oo16vv/OilcpzMyj+SUAneSq/TSYTtNttFItFNjilVP5xJzRpJxOVj6TB3qwI8Z8FZP1CA9Cps/W8rUUyxAUOl7pP03xFBwfpJJFgMIhr167hwx/+MG8mUkNpMh0/ODjA2toaXnvtNWxubrJZ+zRCoybJENfj8fAB9kFSCWnARffkUZH4w9ZSr9fDwcEB1tfXcfv2bWxtbSGTybBJr9frxeLiIvr9PuurqaHn6Ovo9/tsjj8tG9OjQhux2WyG1+vFpUuXcOnSJVgslkPBOGWtisUi8vk8T0G5CAEgBb9kaN9sNqHRaPjgtbGxga997Wu4c+cO6vU6ut0ufD4fxuMxzGYzOp3OoWs0Ho/RaDSQyWSQSqVmzgOQMvgmkwlWq5WnY0nvPTrsVyoVbG1t4ebNm9jZ2eGEFB08ms0m5ufnEQgEuKOagkua0U0TU+ien4b94szz/dITK0XPwDsPOJph2Ww20ev1uFOVNmCpt9h524AfB9pMaRxNs9lkIf2DHm7SeZlOpxMOh4PLcsPh8NDf4Lw/JB8GnXqz2SyKxSJ3sT7IOX8WoYzUo0KHCPKm9Hg8CIfDWF5extWrVxEOh+FwOPigUqvVUCgUkEwmsb+/j729PayvryORSKBSqbB+cBqRehZSo8HDuiHpcNZsNtFut9lLkbR3pHk8TQYwlUrxZkTleqVSiWaziclkAofDgXa7fWzjHG1CdMhOJpNIpVJoNBozeY9LO1JNJhN3opNshYLdbrfL82t3dnZQKpV4Hznv+0e320Umk8HGxgY3JZjNZlQqFeTzeezs7GBjYwPZbBbAO3szzaamTL30YCOVg8xqFlXa3He0wiXVIycSCWxvb/OzqVAosJxKrVYjk8lgd3cXHo8HoVAIHo+H5SFGoxEOhwOBQIAHV1BF5ayZigCQZqwajUY2maWaO+nRaGi10WhEt9vl09tgMDi0SZz3m/i0kEExXb+TbkzqYqL0Nc1zNRgMkMlk6HQ6h8ols3aDP2kGgwFKpRIODg6QzWb5YHI0a3oRoeYPaujy+XxYWVnBe97zHqysrMDn88HhcEChUPAhbn9/H+vr61hbW8Pm5ibi8TgqlQrq9frUd/zT6Z/sRuj59aB1MBqN0Gw2kclkkM/nUalUUCwWsbu7y3Y6p5ETHA0kqYmB5ByUwaGJSEez+KPRiA/Z+XweW1tb2NnZmdmRhrQRk7k4jSIkKYt0/nQ+n8f6+jru3r2LQqGAbrd7IQLAZrOJ/f19VKtVJBIJrK+vQ6PRcABYLpe50dJms8HpdCIajWJhYQGLi4vw+XycGLgIjMdjvld3d3extbWFra0tNr6ne4vmbG9sbLD/LjXMkL6WdKm03qrV6lTcZ2cWAJK2jISSgUCAR68A9050TqcTCwsL3IhgtVrR6/XYu4keflTCoI9Wq3WudFmPCnn7kRXHcdYlND/YbrcfsgegDOB4PEa9XhcaTAlS2wCaldtoNNgc+aJBZV5aa+S/FgqFMDc3h9XVVVy/fh1zc3M8krFWq6FarbKG7a233sLdu3exs7PDViSzgHTs28NKv51Oh2eZ0/2UTqc58xePx7Gzs4NCofDYwYh0xqvH44HX671v+gVBh2yaq57P51EqlR7rOkwDJOZ3Op3swybNytNEmlKphFQqxdpJCnguwj7R7/dRLBZ5HaZSKSiVStTrdVQqFT5wUaIlEAggEomwnY7NZoNKpeKsWDKZRCaTQbPZPFcBNOkkaa7vxsYG1tbWsLe3h2KxeKgpjTLLNKwin8+j0Wgccj0gk+lYLMb+qclkcirK5WcSAEpd861WKwcf5MBNZWFq0fb5fGg2m1Cr1dwd3Gq1eOGSZUSxWES5XEYmk2FTU+DiZQXlcjkMBgPcbjd8Ph8AHCs8VSgUcLvdPO5rdXUVwWAQBoMBKpWKyyWbm5s8KeC8m6U+DJVKBa/XC41Gw92VJOzN5/O85i4KJMuw2+0sH/B6vYhGo4jFYmzFoVAoOPDL5/NIpVI4ODjgk3UikZi5QfLU3UxTc2h4/FEmkwnK5TK2trY48Njb20M+n2cdKXVdvpuNVKVSwe/3Y3FxEUtLS7h8+TJWV1d5hrPUAoW8HAuFAm/+swzpHClY0ev1/DWaXkPB3/7+Po9/k85AP+9QsNLr9XgNyGQy1p6S9IAyWGQNRNOjKKNaLBZx48YN3L59Gzdv3kQul+Ofn3VI+kKyjL29Pdy4cQNvvvkmCoXCYx1OFQoFzGYz5ufnkc/ncefOnamRC51JAEgPTrPZDLfbjWAwiHA4DLfbzX461KpusVgwPz/Pi0uaySIxL41bITGqXq/nCJx8yma9y+00SMfUkJ7P4/HwSV/qW0SZm0AggKtXr+L9738/otEoaxfIJDWbzWJnZwfb29soFAoXJgCUmj9TSYlsJij4UyqVbJ+QSCTOvUciXRO6LkqlkjfdQCDAG4bf70c0GkUkEoHBYABwz88zk8ng4OAAe3t7iMfjODg4YL2ftDtxFqCDFgW/JwWAVEra29vDzZs3OQgsFov8XKLS7UnBHx2GpUbPUl+78XgMnU6HcDiMa9eu4bnnnsP8/Dzm5uZgNpvvs3eR2n5QuX2WUSgU3PxC+4i07C19v4VCAdlsFqVS6dzZNz0I0uxJ/f6OOkIA9yaGkByImhfo+Vav17G1tYXXXnuNs6hSh4BZRHovkVSgXq8jk8lgc3MTOzs7nNk7DvI7JTna0QkqZrMZk8kEoVCIM9T0vWf5zDuTAFCtVnPgR6aUVqv10NxPaWfrSdCJhIS/drudu4n9fj/29/dxcHCATCaDdrt9qLPxPEH6FlpMUm9Au92OUqnEY5BMJhPMZjOMRiOsViuWlpZ4hJzdbodGo0Gr1UIikUA8HsedO3d4lFej0TiX1+8otO5oNFAgEMDy8jKsVitvwPSwII0kDZI/z5AnpNFoZI2Qz+fjLJ/D4YDdbofVaoXRaESv10OtVkM+n+cDWjabRS6X4w9qfJiGcsijQiVX8jU8LtAihsMha+4ajQZr906zYZJPJ5V3adwe6dyos1Uul2N5eRnLy8uYn5+Hx+PhbL4Usvkpl8vI5XIzad9xFOkoUdJiEtLxZrRJ0/PyonJcsEYHPHLmcLvdcDqdMBgMvA+TdrRer6NcLrPDxKwG0kcdNKQHKjqYHRfUSaH7qVKpoFAo8F4gTchQM4jT6UQgEEAwGOTy+1naXJ1JAKjRaOD1erG8vIxLly6xWeejzhKkxhC1Ws3ZxG63i1gshitXrmB9fR1f//rXMRwO+bR9nrQKAO57sFHmk7JUVquVS5UGg4FPdaQPmp+fx+Li4qESFk2BuHnzJu7cuYP9/X1UKhUWS59npIayFosFi4uLWF1dxXPPPQe73X6ojNbr9VAqlZDJZLgj+LxCa8pms8Hr9WJubg5LS0ucNSYfLcq8kN6KbF2o5Endc1Lt7qxeN7lcDqPRCLfbzY1Tp7F/of8+7XOI9G10GCFdFs1NH41GnJ2PRqNYXl4+JOU47nWQx2omkzkXJWBqyDEajYeyVYLTQ88+Sh6QLZtOp+NMIVkWUff6NHfpn4aTAkD6mvTrD7JD63a7rO+j5lSyViMdoEwmg81mQzAYxNzcHE+j6XQ6ZxaTnEkAKNUYeL1ebten0zNdFOnoqYdBukGr1Qqn04lgMAi9Xo9Go4FWq8W/m1Lgs7xopdDDvFgsolQqcSecwWCAy+VCMBhEp9PBYDCAUqlEKBRCMBhEIBCAz+djA1uTyQSZTMa6v62tLdy+fRvxeBz5fP7U9hSzjtTGxOFwIBKJYHV1FQsLCzCbzQDuTRJoNBp801OJfZahhxVlOancLZ124fV6WWe2urqKaDQKm83G2S8ycCdriEqlgnQ6jXg8zhpdsiahzMEslX2lSH3EyG/uuPuD1pTBYIDNZoPL5UK32z2kU3sQWq2WmzrC4TAWFhYQjUbhcDhgsVgwHA75HvV4PNzIddKmdTTLcR7uabrG1JA0LRqrs0I64ozu6eOQjnKT+ih6PB72rqNDHTXS1Ot1PsTNuoUOZTTpIEqxBiUB6J61WCwYDAbodrv3vVepprZUKnFT4HA4PDRjmLKA1PRKjXAymexiBYAADmlZjpZ5m80me1uR/9xxg7qlG5bD4UAwGGTndxoyf/XqVdY0KJVKtk84L9ms4XCIQqGAt99+m1P1tClHo1FotVpu51cqlTxthRY1lYPlcjmazSZKpRLi8TjrswqFAg+Hn9Wb/FGgTJfFYoHL5YLP52NfJ51Ox91h0tJmLpdDrVab2QBQetKliTImk4lHM5Kswm6385rxer0IBAKw2+3cCEMnaAqgzWYzB42NRoPvy2q1ilqtxpKCWdYOHZ0PfdL3GAwGhMNhDAYDeDweFIvFU5d+aKKAw+HgbKPD4YDZbOZufY1Gg263C7PZzM+Bk14TdS97PB5Eo1Hkcrn7xsTNKsK0/h0o2KAKGWWgjjIcDnlEqM1mw8LCAq5cuYKFhQWuClEHOU2XSqVSKBaLbDU0q/uoNHlC41I7nQ5P8iFbtIWFBR6FeZJPJmkqpeMFu90u+1DStSdpEVXlzjpTfWZdwBS40YWRltbq9TrW19fx5ptvctv6cZG3VKgfjUbxwgsvQKFQwOVywWw2w+PxQKlUwul0srat2WwCwMynronhcIhcLodbt25Bo9HA5/MhHA7DYrHw6J5YLIZms8liVOq0ps2e3M/J8+jg4AD7+/s8feW44Pu8IvV2owCQLIo0Gg36/T6azSZn/kjTNqsZQGnGj4x0dTodvF4vFhYWsLS0hOXlZSwuLvIhihqItFotryNqDgHuNXnZbDZ4PB5EIhFMJhNYrVbYbDbk83me5kPZwlkNBKWj1B702s1mM6LRKCwWC0+ROa3ujuQINLOUpqjQ32IymcBgMHAH50mbvRSj0Qi/34/BYIBkMgm9Xn+mmYh3g3QvedD7ftR5y7MI3cvUBOh0Otkb8Th9arvdRi6Xg0wmg9frxdWrV/HSSy8dkgXJ5XIOFCmbn8/nTzVkYNrpdDrI5/MwmUzcEOVwODhpFAgEsLKygmazyc1Dx923ZKpOwR8FgPRcpXuLmgppH6YM7VlxJgGgdJ6qtBOGxiFls1lsb2/j9u3bnIE6qfuUdAu1Wo3nGsZiMczPz8Nms8HhcECpVHLWYTKZYHd3l/9tej2zCnn1HRwcwOFwYHt7mzM2JpOJMwXU+k8ZVwr8pJ1PJKQmA8uLepqWBjg6nY4foFQGIeNe0rU1Go2Zm5JCf1ulUslGxhSgORwO+Hw+RCIRRCIRRKNRBINBaDQa1pqSBpLMUKUlJhKSU1fmcDiE0WhEvV7nTsxMJsONRc1m81BjxKxMFRiPx2i1WuxCYLFYDgnmCcoqS7v/HsU2gzwntVotG44f/TceZfawVM/Z7/fh9/vh8XiQy+V4wtIsPBPpfeh0Oq5qUAb06MZKDUmFQmGms/XHQfefXq9nqYDL5WJ7Jp1Od0hiJaXT6fBYPK/Xi+vXr2NpaYkbiBQKBVvHZLNZ3LlzB2tra8jlcjOd/SMoZiBz7Lt37wIA+2fa7XZEIhG2Q9PpdFwyP9rpS2VjlUrFB+OjmXhyOLFarZwBpGt8FvfcmQSAFOyREJyCQOpMSyaT7IxfrVYfmq2Ty+XodDqo1Wo4ODjA9evXOcihUtT8/Dx3wSqVSpRKpXMx2YI6AMfjMfb393Hnzh0YDAYsLi5ibm4OarX6kDaLOg+pq5A2EvIqksvlKJVKCAQCcLlcPO9x1m/0d4M0U93v95FMJnHnzh1sbm4in89P/cSKk6BsZyAQwKVLlxCLxRAOh3lcm9FohE6nY1sgCtRarRYbtVP3vdlsPvSgowOEy+WCXq9HKBRCv99nrUyxWGRNYD6f58xCOp1mfdG0b9Kj0QiVSgV7e3twOBzsQUfWN1Lo5K9Wqx/ZMoMOuXQfP4mDGQWV1NQTi8XYQ7VUKs3E/a5QKFiqMT8/j1AoBJfLdZ8hN8l+0uk0tre3kUqlDpn5zjp0L7pcLnzgAx/ABz7wAQQCAbZSoyDjOEcN0vW1Wi2uGEmNtMnepNlsIpFI4LXXXsONGzce2xNvWmm1Wtja2oJWq0Wr1cLzzz/PjgeBQAC5XI7lLpSokr5/up+sVit/HD2I0H1MAaDZbD50nc/iep5JADgajdBoNJDP55HJZNh+hJzKSf+XzWbZqPNhw+cpis/n81wG1ul0CAQC/JDTaDRQKpUoFovY29vjmv20bzQPgwLkQqGA3d1d7uYl/QHNwiQRK4m/x+Mxn1jIaoK8ASORCBKJBJe3zkvJ/N0yGAxQrVaRTqeRzWZn1hqHNH9arRZerxcrKyu4du0a5ubmEI1Godfr2RCVGjmq1SoqlQoajQYHj+QXZrPZOMMgLcVRVyEZEctkMi6j1+t1lEolFAoFxONxmEwmGI1GnvJD/nT9fn8qM1Kj0QjVahXxeJynbdAmcRS6D89a80OcZJ9Vq9XY123aUSqVsFgsCAaDiEaj8Pv9fHBRqVRcaer1eqhWq8hmszg4OEA+nz8XAaB0IpHFYsHCwgKuX7+OD33oQ2wzQrp36R4qlX0A71R/BoMBZ1SlQQvJM/r9Pj/3aCDALB56T6Lb7SKZTAJ4J5vu9XrZx9Rms7G0KhgMQi6Xo1wuo9Vq8TVVKpVssSaVaRyFOvopW017rnRPfpacSQBI6WQAvMHs7u4CeCclvbu7i1KpdOqgg250Mm+Mx+N47bXX2BOPhKxWqxVerxfz8/NYWVnBZDLB/v4+yuXyVG4wjwLpqYrFIg4ODrhzq1wu881MwmC1Ws02EPTwoOwfnaoXFxfRaDSgUCjQ6/XQarUAYCY2hieN9KakdUIP1FnVFEmbNVwuF+bm5hCLxeB0OrnMQTYhOzs72NnZYeE3ec5RAOh0OmE2mzm4od9N86Xp69L5rDRBhLJQNDt4YWEB+/v72NvbQzKZRKlUQrlcnkoD8tFohFqthmQyyZmoaDSK0Wj0xDJ1T4uj3d5kE6VWq6f6dUtRKBSwWq2IRCJYWFhgKx4KYIbDIVqtFqrVKvtQ5nI51jXPMqSzNZlM8Pv9uHTpElZXV7G8vAy73Q6tVssBXq/X40QH/c1J3kLWa3R4O84KhT6vVqvZFJ/cOWbx2Xccg8EAtVoNcrkcPp+PK4Qkh3K5XLh8+TI6nQ7u3LmD9fV1rqbR847+HkdNyKVIqyY2m42DRbqeFyIA7Pf77AdGZQeXy8U6FnroP0pmhcoqzWYTyWSSgz+/349wOMynQpfLhXA4jKWlJbRaLR6AfR4Wcr/fR6lUYr3HeDxGsVjkm5oelqTHarfbMBqNAMD2CTKZDCaTCXNzc5wBSqVSyOfz58o24kFIm5RI23b0Pc+yfQlwr2yk0+k4i0ejGIF7kzu2t7dx48YN3Lp1izPy/X6f9S5msxkWiwUmk+mQnECj0UCv18PtdmNubg4ej4c7zk0mE2f7yByVxseVy2Wsr6/DZDKx3o1kItMWbFMlI51Ow2w2IxaLsS3GwxoSnjTS7M5poTVOzT+zZp9CTW2UAaTDC+muqGuTKk25XI6bF2a96kONHna7HQsLC3jve9+La9euIRgMcgWIMn+0z3U6HQ7kqExOAf/RZkyCKgVSdwSLxcIBi1QLN0335qNC9zJZKpXLZTQaDb5Odrsdy8vLrAMnE/vhcMjXxmw2w2w28yGXgkOCNPhSaxmTyQSdTodut3smB68zKwGTzx+l6MvlMqefSZPwqBss1dHr9ToAwOl0Yn9/H+FwmJ36KeNB9gcbGxtP4y2eCaTVKBQK/EAfj8eclqYMTafTQTqdRjKZhFqtRj6f505Xu90OhUIBh8OBXq+HZDIJn8+HYrF4yLrjvCLNbNlsNs4oUPAhbV6ahUaFk5CW9svlMg4ODqDX6zkAqFarSCaT2Nvbw9raGnZ2djgTNxgMOHNUrVZZa0Q/S00gWq0W+XwetVoNbrebgz+aIkINSkajkX+H0Wjka6rVamEwGKDRaHjjbrfbx4qwz5KjXajP+kFOvpTkO0a639N0xZKxb6vVQqfTObOsjjQjSdm7h1UbpFlsMoA+Wr4kM/JEIoFischNLrN639LfVKfTceb+0qVLWFxcRDgchl6vZ319q9VCo9HgJqV2u81yjFAoxA00D1on9HfRarVwu91YWVlBv99HOp1GLpfjfZwat07jGEH6fGB63Djo2U69BKVSCblcjp9L5IxAso9qtQqtVsv3DCVXgsEguyUcBz03SYNP+/RZdQOfmQ+gdDB1rVbjlmm5XM4PpMe5ScmUkU5+8Xgcbrcbo9Ho0Lgg6nyjG+A8MBwO0Ww2IZPJYDQa+UTi9XoRCoVgs9mg0Wi45H7z5k2MRiO2OllaWsLq6ipcLhcbSdPXCoUCALDB5XmFZorS+yZtFz0gSEdJbf6zapFDwV+tVsPW1hYA4O7du3wv0CQP2jxqtRoHGXTvkjaPSm1Hxc7UbHX0QUpdxuSxGIlE4HA4oNVqodfr4ff7uURMXY3UFEYb2TR0IMrlcuh0Op4FTP5hz7L8S8+7SqXCQn466B433UD6czRBqNlsolwuc5boWa9nqT6SMlK9Xu9dB2r9fh/5fB5bW1ssK5p142Iq11JzwurqKlZWVhAIBGA0GnnvLJVKSKVSSKVSrNvrdrucce/1erDb7XA4HA/1sCSdYSQSwWg0gsPhwObmJvb29lAsFtlho9VqPXTIApmnk06WnqPT8vcgb8BCoYBEIgGj0chNbjTP98qVK1CpVIhGozwG1GAwIBAIcCb6OCmFVBdNEhmpBv/CBYBkJUEPIakX4Ek3qfSUfTS9Sh90qiiXy0gkEtwZRgPbKSiiRhG5XH4uSpvkRSSTydg6RyaTcZmEJqNQ5+KtW7fQarV4Yki/3+eNjLRZFATSxlsqlc76bT5VpOO9/H4/LBYLFAoFZ0roZE0Zk1kNAOk+IdP1crl8SLtCuqGjUzvovUon6pA+7zjtkNTsnSaK2O12nodZq9X4XqY5rmRJ4/F4YLfbYbfbOTNIJZdpyBzIZDLo9XrY7Xb2HqXyz9OAAm/6oCxZrVZDKpVCrVaDzWbDZDLhLOxJdk5Hs3/Ulf2sN2Op3IImechkMj5gPAgKHE8KckkSE4/HkUgkUKlUHsl+Z9qgjKdWq4XNZkM4HMbly5exuLgIt9sNlUrFjZDxeBybm5vY3t5GOp1GJpPBaDSCy+VCKBSC3++/bxY1rQn6t6T6P61Wy89D2k+NRiNSqRSXQyuVClQqFQfuUskQ/a9ULweA96hpeYZOJhP2Btzb24PNZkMgEOA9kQ54VqsVc3NzKBQKKJfLUKvVh0za1Wo1j84jORodyCgAlE5aunAZQEK66B4EaRGoxEEXjn4HleYoS0EbE6XBO50Oi7PJo0zasTPLmzlxdNzM3NwcFhcXOZPV6XSQSqXYCqFSqaDZbHI50+l0skYwEAiwjouaREhfc16hNUYnVKPRyJ3UJFPIZrN805OWaBbXjPSQRdN2pGULkmY8zKD5YZupVGxPWUHq7KUAk/SGZKNC86tJ4E4/R53qGxsbJ04HepbQ63E6nbwpnlT6eTfQ34o0ueSnSJY8hUKB7XNCoRAWFxe5LEXNXcf9TgoiqRGPKjHPMkCS/v0oYKAD+UlI71PaC44ro5HUqF6vcxPEWR8a3g3U9OLxeLCwsICFhQVEIhHYbDaMRiOWNG1sbCAejyOZTCKTybBZs81mw9zcHJaXlxEOh3n8J9HtdnmOLV1X6uKn/1WpVByc03SbSqXCWUDplB86JHc6HbYgI8ssem5OWzZ2Mpmg1WohkUhw5tPlcnHHPDWTSudz1+t1yOVy1keaTCbI5XI+WDUaDRiNRvZWBO5lQmlsKwWWz5ozDwBPi1SsbDKZYLVaD83SpNp9s9lkXQJlOWizINNa6mYyGo0c3NBinOUHBM1DdjgciMViuH79Oq5cucLWCMViETs7O7h9+zYymQzPP6zX69wRvbW1xWaqbrebgyGTycSn8/OINBMh7XajDEq73WYbiUwmwx2xs64DpPV+dN1LT+1P6gFNp2Eyzq5UKsjlctjb28Pc3BxWV1extLSESCSCUCjEOky9Xs/3PG12+XyeZSJndf2pBEylapqH/KQhqQzZZCUSCSQSCZ5Ck8lkkE6n0e/3ceXKFZa7yGQyntZy3O88GgCS5+qz3pCl01Qow/ugwEC6F1A3K73Po1Whfr/PAcisS1doqtXy8jJ3/IZCIWi1WtRqNXa/+Ou//mukUilu2KJO/OXlZXzgAx/A888/z35/0ia3druNTCbDzQ+kzwVwSE5AljN+vx/dbhe9Xg+dTgeNRoPNtpPJJLLZLJeIS6US68hHoxHfu2d9iDvKZDJBo9FAPB7n8ZUUuPn9fjidTm6yNBgMcDgcbKFDWT3y66xWq9ja2kIqlYLf72fdH+011DhitVpRKpXOJLlypgGg9KJRTVxaDyexKGX76Ia3WCzst0XpY5pjW6lUWM9C2RtpCYv+XWk2kVrhZz24oTR9LBbD0tISFhYWEAwGMZm8M14vmUxiY2MDW1tbyOVyfCKmh20ul8Pu7i6cTieCwSCcTieftOlvM+vX6CSkJzKj0Qir1crj9ORyOQeAiUTi0PSP88CzOvhIZR9UYq7X67xJkI0TeQ1SyYlOz1RerVQqqFQq0Ol0nG04i9IeGct7PB74fD5YLJYTA0A6iFIgJx1A/7BNUDqwvlAocFMDjSKkLINSqYTf7+fSKR1ojkJ/AzLdz2azqNVq6HQ6T+bCPAbS4A/AQwNA0ouSNovsN6SbqLQyNMsHNYLM+v1+PyKRCLxeL6xWKwaDAer1OlKpFOLxOLa3t9niy2w2IxQKIRaL4dq1a7h27RpWVlbYo67T6XBWOZPJYHt7G7VaDRaLhZu1XC4XbDYbHyRIJmO1WgHcO0iSr26xWITf70c+n+fScC6XY29faaJmmoI/gu7TbrfLNlcymYxfMzWmUVZUo9HwPUpl9cFggIODA9y9excHBwcYjUZYXFzkf4O0lRTTnJU/6JkFgNKZhXa7ncsV1BZNJzwK9ij9qlKp2LOKNEs0DaNcLrPJ8+7uLlqtFrRa7SGPHTpp0uek3ZzTuBgfBZPJxGagS0tLcLvdUCgUyGazSCaTePvtt7Gzs8PlIumJeDweo9lsIpVKIZPJ3DclZZbsIR4HOtHRQ8/tdsPtdvM4pE6nw6W2Wq0289mEaUDaeVcoFDCZTJDP53FwcICdnR3Mzc1xqYs0R8FgEM8//zxUKhXW19dx9+5d7O3tsZThWW7y9PxyuVzwer0nBoCkKyoUCjz9pFgsolAooFQqoVqtPvB1U3BETV5kyE16VOr+peaaUCjEso/jNpbJZIJarYa9vb1D1++skGaZpRrTk57HVHqnIMVut8NisRzqRD+PSLueKQihhhky8G6325DJZLDb7QiHw+yRSJN+AoEA69MGgwFnu3Z2dhCPx3FwcIB6vc77LjU20ChIAFyFI7sZ0rXRbFuqRIVCIS4BU1PFwcEB/3u5XG4qy8DEcDhEJpPB66+/jkqlgmw2i/n5edYmU0xC5V5yPKADXjqdxtbWForFItxuN+sdgXsVp+PGsj5LziQAlNpE0Aglr9cLp9PJ5oh0wiNhJdXdpRpAWnzj8ZhLwIVCgTfufD7P5TxKywJg8TOVhmdZGEzI5XIYDAZEIhFcuXIFkUgEFosF4/EYpVIJGxsbWF9fx8HBAfsXSd8z+UXl8/lD5bWLAj206MTncDhgsVj4YdntdpHNZkUA+IShsnC1WmVPPcoW5HI5jEYjWCwWtq+gjn7SZzabTZ7s8Kz/JlRitVgssFqtXHE4DtoEqSHh4OAABwcHSKVSfD8+6N8hqHOXtJk0A5bE6sFgEH6/n3VLR8tKpL+iCSZra2tIJBJot9tP5qI8Bkcbix4G3avUIERZYqpQHJ1edFZTFp40VDakRkYKeCn71u12eba32WzGysoKLl++zDYxHo+Hu1Pb7TaazSYymQzW19dx69Yt7O3t8eGfSpOFQoF/N2UNDQYD79VkY0KendTgYbfb+ZpTQyY5bxgMBs6wkdZ4GgPA0WiEUqnEZuLULU9NNGS4Tc03+/v7KBQKrHUsFotIp9OsQz16j9P1oiDywmgA1Wo1XC4X3G43p6fD4TBsNhvPH6ULQwOZKRNDmcOjFgc0fJkaQzQaDarVKjumRyIRmEymQ8EipaJpIU7jIjwN0mtFAbPRaOQOwUQigfX1dezu7nIn3NH3SnqZVquFZrN54QIctVrNm6jD4bjPU4w2cOq2nPVJAtOEVH9Lm8V4PIZCoYDb7YbP52MBOnXdAkAul4PFYjkzCYd0AhHZWB1n/0Al11KphEQigd3dXR53WSqVUKvVHhicSCc30MGZ9IfkbkCZv+XlZTidTj4gS18LZRBJDhKPx7G3t4dcLjdTo9G0Wi08Hg+i0SgHFNL3KrVCoeYWqZn4rEKHfJfLdUhzSnsc6fzItiQWi2F+fh4+nw8OhwMajQaDwQCdTgfJZJIn7mxtbWF7exuZTAbVapVLmeS72Wq12DMWAPR6PRwOB9sfkX8s6ekpo0VINZ4KhYJ15ySBOFptmhaolEsDFdRqNTqdDrLZLBwOB3vrKpVK9t8lHS3NWG42m1AqlXwQkULPM/IDvDABoEajgdfrxaVLl7C0tITl5WXEYjE+xVEZhQI6+pDav9CCkqZSKXDUaDSccqWA0WQywWKxsNiVpo3QQpzVLCCVBWjANN2MWq0WvV6Psw53795lYetxD0FpOa7Vas38w/JRUavVcDqdPE9UOseRSni5XA7pdPpCBsjPCspEU1nz4OAAPp+PMxJUepLL5XA4HDAYDGcaANIGRmMT6WB69Pu63S5r7vb39xGPxzm78rBnD2VZqMxJ5TaaoBIIBLjsS+J+acWDIN/HTCbDpbj9/X0Ui8WZ0rPqdDrWOnu9XhbWA/eC7Uqlgnw+zyV26gKexWc8QZIDOhTR2qfJHoFAAG63G+PxGCaTCS6XCy6Xi02ypdrPzc1NvPbaa1hbW0OhUGBdM+nCKWtVrVaRyWQOOUDQ9CCHw4FoNIpYLMYmyLQHS6FMucPhgFKpPKT/LZVKyOfzU/s8pcMpTRij3gI6kEqNy6XNpxSbyOVyPrDS75M2G0rL+RcmAFSr1fB4PFhaWsKVK1ewuLiIaDQKrVbLehcSS1O5o9Pp8AMWAGfs6I9AgSNZl7hcrvs8jqgZgoTUqVSKrTxmtfuXSkC0GdDJZDQaoV6vI5/P89SPQqHwwBtN6sV1nrU0x6FSqWCxWOB0OvlkJ91U6NRKGVTB04EOImTHQMFVu93mh6p0dBkdFs9CP0PVhEqlgmKxyKLu4zZAKfS8otP/ww5aFOzRZk4TUtxu96GyLxnbn3T/DgYDVCoVJJNJJBIJZDIZFAqFY8tT08jRKRiUrT86e5WMsQuFAhtkk73NrB9qpW4F9N80kpGaYyhbTObrJGGhhqtkMok7d+7gzTffxPr6Ot9f0ioYXc96vX6f5y6VoW02G2dXyZhdpVLBarVy4oUylBQEAmDrNen9O61Q3EDdztVq9Vi/06MaVuCeVvK4Q4dUzyn1F3zWnFkA6Ha7sbi4iPn5eS5Z0AmFTqmlUokf+nTqJQNJ6lijtLPBYIBerz9WCEy/lwwe6fTz9ttvc7lpViFvqFgshlgsxqesWq3GXaulUolH1pz0XqnDLBAIIBAIsJfRRUFq/0J6FilSo3HB00Nq02S1WvnDaDTySD4qUdHGc1aj4Wh+aCaTQTKZ5CydNHtM78lgMCAYDKLb7cLpdGJubg6tVutU/w5lT0gKQ44IpD0k42xauyfdt0dHo5XLZb5+s/AMlM6vNpvNh0Y1SiGNY6lUOpT5m/V7l8aQHRwccBaQ9k6yS6JGScoAU9CSz+eRSCQ4+7y1tcUNH1K/T+K4gOYok8kE8XicD8fkTUkzmUmiQPcuALTbbRSLRezv7yOZTKJer8/E2gPwyM//B32/1MuSsvoXpglEpVLB5XKx9o+CNvJsKhQKbGY5HA75ZENGiwDQbDahUCjg8Xjg8XhYX3CcEz+d1KvVKtLpNNbX1/GNb3wDOzs77IE3qygUCtjtdi6nOxwOyOVyNJtNpNNpFqY+bKwSGVlGIhFEo1GegAEcnkBwXqEAUKvV3qfjknriCZ4sx3XFkT8ndWJLs7LSeZ2UfaAM/rNen5RlT6fTLIoPBALHfq/JZGKT9WazyR5tp4ECSyoBH7XNos7Lk+b/SrMYxWKRTYIrlQpv/tMOrQ1pE4TFYuEAkN4zZbvK5TLPj55Vec9RSB+7vb0No9GIyWTCHq1Wq5XLizQphyb9pFIpbGxsYG1tDXfu3MHOzg7r32lfeJRnGwWLg8EA/X4f5XIZuVwO5XIZlUoFHo+Htfsul4v1hwB4X9rZ2cHe3t5DO+DPC0enlx01Mz8ri7Uzy7/STNVer3fIY47KIwaDAWazmQNAmt5BxpRSGxlpPV6KdNRRJpPhG2Fzc5OHg8/6pi6Xy3nUWzAY5MwdGbzSRvOwErdCoYDFYmHrAIvFwg8S6e85zzfrUUuC0WjEYuhqtTqzUz+mBZJwSB+GNJCePshclUbxLS4u8roeDodcwjo4OMCdO3d40sFZNHFRNrJer3MwSuvnaJlIq9XCbrdDp9Pxxnla2YnU+oOMZh9mGiudHtJqtVCv17G/v4/d3V0cHBwgl8vNnJaVJD7UrW+1WlkHR93R7XabA4zt7W0UCoVTB9rTDq3/zc1NAOA1RKV/mtIhlf/kcjn+u9NM5HQ6/a47b6VNW6R9o+eltCHE4XDAarVyANjpdLC2tobd3V0Ui0WeUnLRoC512lvf7dzrx+VMAkDpkG6lUolgMMhRsNFohM/ng1qtRjAYPFQCphIvOcXTBkK6m6PZv8FggGq1ikKhgPX1ddy5cwfr6+vY2dnhZohZ39DJlJMyoSSQf5zfY7VaMT8/j/n5ec600pQAKqfMqlbyYZDOT2rMS6fbdDqNVCp1plYZsw5lWKX3KOmFaO1Sxs/n88Hr9cLtdvNoJZlMxh14r7/+Om7cuIGdnR0kk8kzDWSOjqCkh/jRIJCeX2Qa+yhZF/JZo7Leae5vqnrQ4XdnZ4ezQMlkEtVqdaY2X6lmirJ/0hKwNOjZ3d3Fm2++ibfeeouznOeBwWDADRM0do1sScj5gbrSk8kk7t69i+3tbeRyOW6GIc37k4I0+5RprNVqvJeTSbJUUjMcDrkkfZGaDaUSImpipTGOuVzuzHS4ZxYAZrNZrK+vcwcNANYKUJewy+UCcP/kDjJ/lg5EpwcezfylB0Iul2PR6+3bt7G9vX2uTh4UHJMGUtpBTRkV6ZgkWoBST0WyQAkGgwiHw3C5XFCpVGi1WtxNVywWZy5j8ChQxpQ6BqmEUiwWsb29jf39/RM7qAWHOXq/ks6F5iuTTk2lUnHndSAQ4MCP9G46nY4zO7VaDblcDjs7O7h58yZef/11zv6d1emZyo21Wo0dBSgTcjRDRwHwk4Q2FGk1hdYn6bLK5TL7/W1vb/MkG3qts7Sej8oFjrpCUDBOgeBZ+UM+LUhz2u122b6M5mgHg0FYLBa2ednf38fbb7+NjY0NnvH8NLLkUgsn0l5K/06EtFlC6Knf4Ti5xrPmTALAbreLRCIBhUKBRqOBcrmMQqHAegGTycSpfSoZ0clXLpezJot0faTzAO49lGkkDc3JJM+rYrHINhPnAXrwtdttdDod1gjRyKBKpYJMJgODwcCnDGr4sFgsh8pt165dg9/vh1wu587Gg4MDHhc1K92Cj0O322WhNIn1u90uMpnMId3MRX9onQaSZpBHHWXzKLCj8XpKpZJnjpLGTa/XswFru91Go9HgUVKJRALJZBJ7e3tsuHqWDQy0ISeTSeh0OsRiMdTrdbZhepodjnTfD4dDlMtlJBKJQ13+JLbPZrP8DKQxXFK7j1lZz1LLHbIOyefzAMD6UJoqRfuI2WwGcC9LdR6gCkyr1UI2m+Vq2u7uLk+9om5vMq2nDt2n/bc+TePIReWoBpCaNyORCOsin/QB8TScWQCYzWY5w0SpbBoC7/f7AYDLkMDhRgSau1ev13keJnUFUzq6WCwin88jmUwik8nwTELKYp2XUibpkBqNBprNJoxGI+tkvF4vj/oxm82o1Wro9XpQKpWw2+3w+XyIRCJsGErGqgBQLpexv7+P/f19pFIp9iw7L9ftKL1ej/VllCHWarXI5XLY3NzE7u7uhcoAHn1gHf3ag6CmBbfbjWg0iqWlJczPz7NPGRn3Hv2gUh7d0zQ7lLps9/f3kc/nD3X/nuXfYzweo16vo9PpQK1WI5vNolKp8AjGJ+nuL91c6VlIzga5XA7r6+vY2tri8l673UYqlWLvslarxVrJWRx9Sc+50WjEz/ZsNsuyAtKSAoDdbofNZoPZbEa/358pk+uHQRk38o+sVquHun6BewEvGS3P2t/6vHFcxpMCQPKYXV9fvzgB4Gg04qYCujCdToe7icjMlErCRzciSvU3Gg3k83kUCgW+yWmmbbVa5Xmb5XIZzWaTOwbPE6PRCJVKBXt7e3C5XFxu02q13JW4vLyMRqOBQCCAfr8PpVLJbvLBYJADb5odWq1WkUwmsbW1xc0y5KR/Hsrmx0FrkkrANPSbSmuzYpXxpKBxT3q9njcQqXTgOGgD0mq1PE+ZZpEGg0E4nU5eY7SRUVan3+9zJpumVCSTSeRyOd7sM5kMarXas7wMD4TKr4PBgLNwGxsbAACPx8PWHMeNZHsUSFxP5e7BYMAHYDqo3b17F/v7+/x8o8Awn8+zdnfWD280xo7mvhcKBbbBofVE14b8Y2cpy3lapGXX86JvPI9IG7HoOTEajdjDkZxN6DlxYXwA6cIAYBsWslPY3NzkcXAnGZpSJrDX66HVah0q6dJJsdvtsleY1CrivDEcDpHP53Hr1q1D4+BUKhW7w8tkMrjdbjQaDe6qNplMPB2FJizI5XLWve3u7mJtbQ0HBwf8c+c5AKLNg6yIaLOl62g0GtkU9bxtKEehKRtXr16F3+/nAxeV2Y7z2pRChw+bzXZobBWNgaJsFDUYFQoFNqNtNpsc7BWLRTQaDdTrdbasmFY6nQ7i8Ti+/vWvo1wuY2FhAeFwmLXM7yYAJP++fD6ParXKXb3pdBqJRAKpVIoH0dNzkLp/KWt/Hu5dqb6xVquhWCzC4/FwsNdqtfjQXywWudHlPD73BdOPdKgFJRLoOUrVAZrkctzknmfBmdnAUFaBsi2lUomtXejinFSCkpZDKBiUbsr0/+n0R/97HjduKomMx2NotVrMz88jHA7DYrHwYHC9Xo9wOHzo9EFlA9JYTiYTNBoN1Go1JJNJbG9vY2trC9lsljMI5/H6EdSWT8LqUqnEY4to8sS72cRnCZlMBovFwmMaKQCk0i7NIAXuHcakPpFqtRpGo/GQ4z8ADvoajQaq1Sqq1SpyuRxSqRQHe9RQUS6X0Wg0DmVypnkjp85Lalhpt9ucOSYj6Mel2WyyiS/NGy2Xy2ztQdo/qc6Nno3n8b7t9/solUrsv+hyuSCTyVAqldjuiyacnOeqhWC6oT2FpBqNRgOtVout68jy7qh597Ncr2c6h0WqZxHjtR4PyniWy2UcHBzg5s2bAMADwWkOJFnoSDOrMpmMS0lkkr23t4e1tTVsbW1xe/pF8L+jNViv17G7u4tXX30VVquVGw9oSPp5vw4E2ZbY7Xb22dRqtbye6LRKpsw00J3E+o1GgzVnFPiRDIPGu1EDGFlUkG3JUW+sWdjAh8MhTzWgUmWhUMD29jbcbvd900EeBbK/yeVyqNVqaLVa3P1JHb0XCfL6UyqVbD9iNpu5JE7WJ+fdt1Qw3VAGkBJc+/v7PMLQ6XRyI5xarWbfY6PRyPrOZ8F0D+ITPBTSIY3HYySTSbz66qtIJpNwuVzsq0Y+ax6Phy04pF5ENDFkfX0db775Jra3t5HNZlEul89t6fwoJCmoVqu4c+cOUqkUW+GQjGCaS5BPGgrmaGwj6QGlZs4A2COyXC6zSJ/K6VS6pMYaKoNQhzWJ1ClTdvRjlsTr5EpAWeR8Po+1tTWezfpuyjv0u7vdLmdESeR/Eb0pW60WZ0Pv3r3LWnGpNrxSqYjgT3CmUKc+jaBdX1/nWd52u537G6hpk0yzn2VCTASA5wDKOlSrVXa/JzsEl8sFr9fLcyM9Hg9MJhP/XLlc5g7LjY0N3L17l/3VLopJJ0HBC23gF5XJZIJ2u41EIgG73c4ZO8piSddEpVLB7u4ustks201ItZSFQgH7+/solUqHgjsSRJ+X9SW1Zel2u6hUKmf9ks4t1P1arVbP+qUIBCci9ekslUrY2dmBzWaD1+vloQ1kdUdd7MfNoX+aiADwHEFdlNQ1SJ2JRqPxUClY2m7e6XR4lBWV456Vb5RgOplMJigWi3jjjTd4/ZB1y1G63S53TlMALe3IpHIlGa9LP8T6EggE553xeMx+oU6nE+FwGG63Gy6XCyaT6Uz7E0QAeI6g7BX5AkqbaaRG2tLGGumGTBu42JwvNpPJhLVm29vbx64b4qSAjv77aIOIWFcCgeAiMR6P0W63uRoSj8dht9vR7XbhcDjYw5c0z8/yGSkCwHOG9DRxEbR7gqfDtHfdCgQCwSwgnWKTyWRw+/ZtNBoNbtTsdru4c+cOMpkMjyF9Vsgmpww3z3pm3bRz0mUU1+3BPGj5iWv3YMSaezzEmnt8xJp7PMR1ezzOy71KHb96vR5WqxVms5mn2AyHQ1QqFR5YQfrod8tpQjsRAD4hxA3+eJyXG/wsEGvu8RBr7vERa+7xENft8Thv96p0xKb09ZMd3pMs/57md4kSsEAgEAgEAsFTZtoGUjz72SMCgUAgEAgEgjNFBIACgUAgEAgEF4xTawAFAoFAIBAIBOcDkQEUCAQCgUAguGCIAFAgEAgEAoHggiECQIFAIBAIBIILhggABQKBQCAQCC4YIgAUCAQCgUAguGCIAFAgEAgEAoHggnHqSSCzOHblWSJG/Twe523Uz7NErLnHQ6y5x0esucdDXLfHQ9yrj89pHP5EBlAgEAgEAoHggiFmAQsEJyCTySCXy/mkScO6hXe64ElCg+Glaw24NzdU+iEQCARPChEACgRHUCqVUCqV0Ov1sNlssNlsGA6HqFaraDQa6HQ66Ha7GI/HZ/1SBTOMQqGAWq2GXq+HxWKB1WqFXq+HQqGATCZDp9NBrVZDo9FAs9lEq9XCcDg865ctEAjOCSIAFAgkyOVyqNVqGAwGeDweXL58GVeuXEG328X6+jr29vaQz+dRKpXQ6/XO+uUKZhi1Wg2LxQK3243FxUUsLy/D7/dDpVJBLpejUChge3sbu7u7SCQS6Pf7IgAUCARPDBEACgQS5HI5dDodLBYLgsEgrl27hm/+5m9Gq9WCQqFAv99Ht9tFrVYTAaDgXaFWq2G32xEOh3Ht2jW8//3vx+LiIrRaLeRyORKJBBwOB9RqNXq9HgqFAjqdzlm/bIFAcE4QAaBAgHs6LJVKBZvNhkgkgoWFBYTDYfh8PjQaDXg8HrhcLuRyOSgUirN+yYIZhfR+JpMJkUgEzz33HJaWlhAKheDxeDgDOJlM0G630ev1kM/nsb29fdYvXSAQnCNEACgQ4N6mrFar4XQ6sbCwgJWVFYRCIRiNRozHYzidTjidThiNRhEACh4bmUwGhUIBq9WK5eVlvPTSS5ifn4fD4YBKpeK1ZTabsbCwgOFwiI2NDWg0mjN+5QKB4Dwx8wHg0Q66o12bo9EIo9HojF+lYNpRKBTQarWw2WwIBoO4dOkSLl26BJ/PB71ej263C71eD71eD7VaDblcOCg9CpRdVavVUCqVUKlUUKlUAE72q5LL5VAoFJDL5dBoNNDpdFCr1RiPxxiPx+j1emi1Wuh2u+h0OjPRJCGTyaDRaGA0GuH1ejE3N4eFhQV4vV4YDIZDBwsqETscjvu+JhAIBO+WmQ8AaYOgDYXKJwBYr9Xr9UTHpuCBUOk3EAggGo1icXER8/PzsFqtUKvVZ/3yZh6VSgWz2QyLxQKTyQSLxcIdrycFgEqlElqtFjqdDk6nEz6fD1arFcPhEL1eD6VSCQcHB0ilUkilUkgmk6jVas/4nT0acrkcRqMRHo8HwWAQXq8Xdrsder2eA2JiMplgOBxiOBxiNBoJGxiBQPBEmckAkDJ9SqUSGo2GswN6vR5arZY3lXa7jVqthnq9jn6/j9FoJAJBwbFoNBrY7XYEAgGEQiGEQiH4fD6oVCpMJhPOJAtPtsdDoVDAbDbD4/HA7XbD7XbD4XBAqVQ+MAA0GAwwmUwIBAJYWFiA0+nEYDBAp9NBKpXCxsYG1tfXoVar0el0MBgMprpbVi6Xw2Aw8Dpzu92wWq0wGAz3TTYYDofodrtotVoYDAZizQkEgifKTAWAFPipVCpotVoYjUY4nU44HA7Y7XbY7XZYLBYulVQqFezv7yORSKBcLqNaraLdbp/xuxBME7Sm9Ho93G43b8omkwlKpZI3YfJh63a7GA6HYjN+RJRKJUwmEzweDyKRCMLhMFuenHQtqSyv0+ngcDjgdrthNpsxGo3Q6/U4YNJqtTAYDNBoNNjZ2UE6nUapVJpK6Qdp/+bm5jA/Pw+73Q6VSnVf8DeZTNDtdlEoFJDL5VCv16c2qBUIBLPJzASAtFHTpmCxWOByuRCLxTA/P49gMAifzweXy8Ulu0wmg1u3bsFoNGJ3dxf9fh+dTkds3gKGMsk6nQ4ejwehUAgulwsajQaTyQS9Xg+NRgP1ep0DQJGNeXSOBoBLS0uIRqPQarUPnJOqUCgOlYJJ3qHVauFwOKDRaOBwOGC1WmE2m6HT6TAcDlGr1aY2ALTZbIjFYojFYrDZbCdq+9rtNnK5HFKpFKrVKgaDwTN+tQKB4DwzMwGgQqHgUq/dboff70c4HMbi4iIWFxcRCoXg9/vhdrs5AEylUjy+azgcotlsotPpcInoPG7icrmcRfNarRZqtRoymQyTyQT9fh+tVgv9fv+RS+EymYzF+xqNhjMuAPh3dzodLsPNQrmdGhMow+T3+xEKheB0OjmQqFarSKfTODg4QCaTQbFYRLPZnMrg4llCzVcUoFEQM5lMMB6PWbtG9xj5K5rNZjidTni9XoRCIeh0uofeh5Qdo/+lNW4wGKDX62E2mzEej9Hv91EsFrG9vT11TTp0nUwmE2eaA4EATCbTfa91PB5jMBigXq8jmUwiHo+jXC6LDKBAIHiiTH0ASBuNWq2GzWaD0+lELBbDysoKFhYWWFNE5V+NRsObkcViwfz8PABgNBqh2+1iNBqhUqmgXq+fuxM12ZjodDr4fD4Eg0HY7XbWWeVyOWxtbSGbzbKw/LQolUqYzWZYrVbukvV4PADeubaFQgF7e3vY399HpVJBs9mceqNkKv26XC6Ew2FEo1FEIhHOLHW7Xezv7+PWrVvY2tpCIpFAKpVCqVQ6d2vnUZGOyzOZTDAYDADeWQv9fh/1eh21Wu3QdaJ7GbgXxJHf3eMiDTClv3/a0Gq1sFqt8Pl8fFCl55X0NUu7m3O5HHZ2drCzs4NisYh+v3+G70AgEJw3pjoAlJZ9qRMwFovh2rVr+MAHPoClpSXOclEHsLScYjAYMDc3B5PJxCdqctKnTNV5ggJACnxfeOEFRKNR7i68e/cums0myuUyZ2pOs/lKA/BQKITnn38e3/qt34qlpSXIZDIMh0Nsbm7ir/7qrzjzR93X0wx1ZHq9XkQiEUSjUYTDYZjNZiiVSlSrVezt7eHVV1/F+vo6SqUS6vU6er3ehd6MKRus0+lgs9ng9XrhcDgAvNN53263kc1m79NLkn0LrT3678cNAEejEXcE93o9DAYD/vemKbtP2U+Xy4VIJMIyA9KZSgPA0WiETqeDSqWCTCaDra0tbG5uolQqXdg1J70+R4P8o9nhkzi6Jo6uj2lbM0+bo9fxuPd+2r1B+r+P8vWjDXXTfP2PO1wefW/H6XiP+9+TPncWTHUASBs0zctcWlrC5cuXsbq6isXFRQQCgQfe+CqVCkqlEnK5HJFIBOVyGYPBADKZDM1mk7Ng56WcR12Tdrv9kJedWq3mzMydO3ewv78PAA997zKZjLVXdrsdi4uLuHz5Mp5//nlcvnwZ0WgUg8EArVYLmUyGmyaGw+HUl3+Bd9aX2WxGMBhEJBKBy+WC0WgEALRaLZRKJWQyGRwcHCCZTKLZbKLb7Z7xqz57pAFNKBTC4uIigsEg5HI5l83j8Ti0Wi3y+TxqtRp/PpPJwG63w+l0wmq1Qq/XP9ZDkJokms0mqtUq9vf3sb29jXg8jmq1OjXrj8rkRqOR78lwOAyLxXJf8wdpTrPZLHZ3d7G+vo79/X0UCgW02+1z85x6ENKJPPTspkSATqeDwWCAVquFRqNhX0nylqSfB+5trCR/6fV6aLfb6HQ6hyx16CDSbDbRaDS4SjQt6+dJIjW7N5lMMJvNUKvVrHEGwDpbOsDTs5yuh0Kh4N9Dv0ur1R6yX6MkDGl3NRrNsZIMqhS0Wi3+20zTIUcqe6JGM51Ox2tMpVJBr9fDYDDw+6f1BoAPpt1ul2Vnw+EQ/X7/0NcGg8GpkzFPmqkOABUKBRwOB+bm5rC4uIjnnnsOV69eRSgUgt1uP1W5h7JXXq8Xq6urUKlUGAwGKBQK6Ha75+rBqlar2WqDNJLhcBhyuRz9fh8ejwcOhwMWi4W7Wx/03sm6w+VycUbxpZdeQiwWg9PpZLF9oVBgH7ZMJoNKpTITeiW5XA6LxYJoNIpoNAqLxQLgneCPMjDFYhH1ev2h1+oiIQ2cV1dXcf36dVy6dAkqlQqj0QjlchkbGxtwOBxYX1/H5uYmarUaMpkMH7o0Gg0fWB5ns51MJqjVashms0in04jH49ja2kIymUS1Wp2a9adQKKBSqWC1WjE/P4+rV68iGo0ea/symUzQarUQj8fxjW98A7du3UIqlWIbmPMYlEiRyn0MBgNvrPQ58k50Op2H/CTJU/Ik3ed4PGYtbz6fZ19YCgTb7TYSiQT29vZQKBQwGAzOZaMXBWxmsxlzc3O4dOkS9Ho99vf3sbe3h8lkApfLBZVKhUqlgnK5jFarxRUPWssU7FBVyG63w2QycSWOvHmpUctutx/rpVqpVBCPx5FIJJDJZJDNZqfmulOQq9VqYTKZ4HK5WLpBhw2ycwoGg5zNpzU4Ho9Rr9eRz+dRKBS4/6DRaKBWq6FaraJYLHJViaoZz5qpDwDNZjN8Ph/m5+exuLiIpaUlOByO+252aVmJTt30gFUqlbBarfwz+XweBwcHLOafplPH46BQKPhaBYNBxGIxRCIReDweWCwWnpRAJVm6Ng8LoMkceW5uDisrK1hdXcXq6ircbjdUKhU6nQ7y+Tz29vawu7uLTCaDcrmMdrs9E5sVZRUcDgccDgdno1qtForFInK5HMrlMusZZ+E9PU2kpV+Hw4FoNMpZeco0j8djlEolAO/ILKrVKg4ODlAqlVCpVNDr9Q4Ffo+bARyPx6hUKmwCnUgkcHBwgFqtNjXlPGpcMxqN3PgxPz/P02WkUEDSbreRTqextraGra0tlEol9Hq9qXg/T4Ojnd7k+2i1WmGxWHjqjkajgd/vx9zcHPx+PywWC6xWKwcgRqPx0DNNmgGkNUmm4d1ul6/3eDxGq9WC3W7nxjnaoM/bPU8Bm9/vx/LyMp5//nlYLBbYbDbo9XqMx2MEAgFoNBoUCgXk83nU63W02210u91DGT3y4HU4HHC5XPy3omys1FqLmuqOksvlYLVauRGsXq9zJvIsofem1+vhcDj44BGNRuHxeKBQKDAej2E0GhGJRDA3N8cZfXrvdEBNp9Ns49RqtVCv11EqlVAoFJDNZjnJ0Gq1eJ95lhW0qQ4AgcPTAKjB47iTHmmBBoMBP3ilEblGo+EOxFAohFgsxoPWm83mzD5gpaOlwuEwrl69yhkZm82Gfr+PZDKJra0t3Lp1izsKH5b5pGvodrtx6dIlXLlyBcFgkOfgUvZvd3cXt27dwp07d5DNZvmhOSvX87jxgd1uF5VKBYVC4ZCJ+Ky8p6cFZWZIYrCwsIC5uTmelkKeftJNgK4vdbaStcnbb7+NfD4PtVr92AEgBZj00el0pmrDViqVbFfl9/s5A280GjmLQIxGIza4rlaryOVyKJVK6Ha753rdUXncbDbD6/UiGo0iEAjAbrfDarVCq9XywcNiscDhcLDdD41m1Ov1rKU8GgDSv2EymTjwlpY1qTzscrkQCASwu7uLra0tbGxssO5yWrLJ7xaTyYTFxUUsLy9jZWUFKysrsNls/IwHwPpnKom3Wi3OAlLJU61Wc9BuMBhgNBqh0+nYDYDue7VaDaPRCJPJdKzVkVqt5sw2dbzL5fIzrbTIZDIYDAauokWjUVy6dAmBQAA2mw0Gg4F9SBUKBa+nVqvF2VHqS6CJP3q9nnsO2u022u02Go0GB9mZTAaJRAL7+/uHKk7PgqkOAOlmpgwXTWQYDof33ezk19bpdPjiS7UJtCmRhczCwgKq1Sqy2eyhuv2sQVksm82GcDiMa9eu4T3veQ+cTif0ej3q9ToSiQRu3LiBN998E/v7+2wpcdKNJvXGc7vdWFhYwNLSErxeL2cuBoMBa73efPNNbG1tIZ/PT00K/zTQGpKupfF4jHa7jXK5jGKxiEaj8Vi2OecNOmiYzWbOZsViMYRCIZjN5gdm5KVWTNQxXq/XsbOz8666gOlZMBwOuQFkmqDOea/Xyx6lpHuUboh0baj7l8pD1Wr13MsOpNY4S0tLeOmll3Dp0iU4nU7YbDa2sQLAGywdLijYkFoQnbSWNBoNX/+jAvzhcAi/34/5+XmEQiEolUou251Vae5JI5PJYDKZMD8/j+vXr2NxcRHRaJQrPFQFo0Ca3jfN2O71eofsv6Q2UPRxnF3T0a9J/z5KpZJ1cPv7+9Bqtc/4qhyG3g9NHlpeXuaEis/nAwDO0pMsQy6Xc3mXkiYmk4mDXpPJxFUO6TOx1+uhWq2iXC7j4OAAb7/9NttnkT7wWeyjUx0A0macz+cRj8cBvGOO6nA4YDKZeMFQ+phS1gaDAS6XC06nk8sDJFqV+gi63e4zX3SPy1F7nEgkglgshnA4DI/Hwyn3TqeDQqHAWotqtYp+v//AxUWdxGRZQdkLOh32ej3WyFH5LZ/Po9VqzcSGJQ1wqeSk1+shk8n4xsxkMkin09w4dJGh62WxWBCJRLC4uMinYrvdzsJoGsPWaDRQqVQ4iKFDAX10u90L0UxD2S2aVmQ0GrkyIWU0GnHWb3t7G+l0Go1G41yvOwoMTCYT/H4/VxmWl5cRi8VgtVpZV3Uc0oMFra+jhzSpNRBNkCLv0qO/i6ooo9EIm5ubnNGaVluh00INNVSuDYfDPDjBbrfDbDY/8D32+31uVCDP1KMzqwnp34B+JzU9UMmdPqgsTxres5Y6SMdORqNRLC8v47nnnsPS0hJCoRD0ej1r9iqVCiqVClqt1iEbKgoAKXtNukeSMNCoWmqcoWCRGnGoYa7ZbHKF7mknHqY6ABwOhygWi5DL5ahUKtja2oLdbofP50M4HGbrifF4jGKxyB1zdrudG0cWFhYOTRCgh7LL5WLdxyze5FQW0Wq1cLvdWF5eZn0kleKopESNGpVK5aHBH/COZ5nP52MtIY3gom6udrvN3bHZbBbVapUNpmch+0f+dVRSIg0LvbdCoYCDgwP2NJx1jei7gdYZbSDLy8u4fv06FhYW4PV6Wag/mUxY5JzL5ZBMJrG3t4dMJnNhxy9Sichms8FisUCr1R4rXxkMBsjlcnjrrbdw+/ZtHBwcsF3VeYSCEmoSoAaz5eVlhEIhWK3WEztHgXsZU+nHcRINWrtUrqSGkuPsPGizpnImaQ9ncW+QQk4OVqsVXq+XD/Q2m+1Ue59UTkVB+3FI/yYA2HuWSp6kQaeO2MFggFQqhdu3b3O3e7VaPbP9Q61Ws6sB6SMvX74Mn88HjUaDarWKzc1NbG5uHtL1UYBGASDpBp1OJx82qPJIdllmsxkGgwEKhQIGgwFOp5NLzuVymX2K6To9zWsy1QEgnYzb7TZSqRSfZAKBAFZWVuD3+zldnc/nsbW1hXQ6Db/fj2azyRo2St8C9+aLWiwWGI3GE08z0w4Ff5SpW1pawsLCAhs/08mtXq9zOZOyfw+CrHcCgQAWFhYQDofhdDp5wY5GIzSbTaRSKezu7iKbzaJWq/GIvVkLAG02G2eUaVJKsVhk0+dZ6Wh+WkgbGbxeLxYXF3Ht2jUEAgFYLBY+XFG5qFwuH9K05HK5c69jOwnyLzWZTJz9O27DHQ6HKBQKuHv3Lt566y1uVDivHJUTzM/PY3V1FbFYDHa7HQaD4ZClxlE5wWAw4M5UyiYfN9mJsn6UudLr9ZzZkwZ3FChShYg0brMe/AGHze79fj9LEahphhpiHqTbpr8BZfOOg/4u3W6XJVej0YgN4ZvNJme26G8Wj8dx48YN3L1799Ae8qyhINnr9XJT2+XLl7G0tASNRoN+v49cLof19XW88cYb2N/fRzqdRqVSuS8AlFYfKQBUqVTcyBoIBLgiYLVa+cPtdiMYDLL1WKlU4pLwhQ0AKYs1HA75hqUOHOCdbl4KAKnjplarwWg08onjaEmS6u+NRgPtdnsmyywkVPX5fAiFQrh06RLm5ubg9XphMpkAAM1mE7lcDvF4nE8r1GF0EtJMTywWw+XLl9kYWSaT8bg3miiytbWFXC43deL7B0GnfWoIslqtvOF0u13uAKxWq/xQmoWy9pOGJBN6vR52ux1utxvhcPi+KRaUpaHDGvnx0fzaWb3HnhRUPn9QQCHVnRYKBTQajXO95sixgLp6Q6EQvF4vbDYbdDrdoSwTGYuT7oqaE6SlspOqDwqFgsttVDlyu91cXpaWg6lEKdWUzsoz7Tgo06nRaHiAQiwWg8vlglqtRqvVQq1WQ71e52v6bu5TGsXY6/X4GUv2O1QhosCPpCIk7SqXy2dms0XWNiSjosDP4/FAo9GgVqthf38fa2trWFtbw/b2NrLZLEql0qESMACuIPV6PdTrdU4uKZVKlEolFItFuN3uQ403S0tLnJ0lDXA6nUaxWOSg+mmuw6kOAIF7otHRaMTBHi2Yg4MD/j4KTuhnjnZ3EsPhEI1GA9lsFsVicSazEzKZDGazGfPz8+yNODc3B5fLBY1Gg+FwiHK5jN3dXU5ZN5vNB6aTpV1bbrcbi4uLuHr1KjweD0wmE5f4qtUqUqkU7t69i7fffpu1G7OCNAB0uVyw2WzQarX8/iqVCqrVKh8QLmoDCJ3ijUYjb9Tz8/Pwer0wm833lTOHwyFKpRJ2dnZw9+5dtmQ5b1YaTwM66Lbb7QthOUR60uM6LKVrivSi5XIZuVyO/UZzuRxyuRw3aZGB89FnG1VJ9Ho9QqEQVldXcenSJUQiEajV6kOzzMkOTGqAPMvz4qUSIZfLheXlZSwvL7NsqlQqIR6PszVOOp1+V1INun40aIE8QekgTckYuqb0/dS4eVbOESqVCgaDgW2tVldXEY1GYTKZMBwOkUwm8frrr+PmzZtYW1tj+7jjDhzS5o5Wq8VrWSaTIZPJYHd3l+eX+3w+tFotOJ1OuFwubrDsdDrY2dlBOp3mitTTZOoDQOD+sSmdTuc+jQydePR6PQuvyX+ISglkHUFNEclk8qlf4CcJla8NBsMhrcLCwgLcbjd0Oh37WmUyGayvr+POnTtIJpMnevNJ3eEpyyOdtELdTHSzkr7r4OAAmUwGnU5n5kqkGo3m0AgzagSiQ0Sn02GtytFTKT3cjn48qFxEtgFkUzQLkwbUajWsViv8fj8WFxexsrJyqBP8QVogMoqmw8NRyPKESnnS6QuzuuG+G6gEZbFYYLfbD5XkKNCmkudx64ZKRbNyDek5TbPdqTOaGj6GwyEHw/l8Hvv7+/y8IcPgfD7PHp3UpXoUygCSG4JWq+V55jTHHLiXgaVGHJplTvfqrEH3n9PpRDAYxLVr17C6uopwOAylUskGzLdv38bGxgYHgO9mL5RW66SVOir9TtMeQdIMqm54PB4sLS1heXkZ0WgUDoeDbc6oQ/f27dtcmn3QmqAml5MkHEqlEmq1GuVyGXNzc2i1Wiyz0Wq18Hg8h+6Hk3SwT4qZCABPg9Sk1uVyYXFxEbFYDDabjb2FhsMhms0mz9jc39+fKQ9AsjLw+/2HdArkNUTBX7lc5rT122+/jUKhcGKWjh6SJpMJCwsLePnll/Hiiy8iGo0e8teiGyKZTCKdTrNR6iyekqlTy+fzcSPQaSHbCukUAgqST7oO3W6XzT9pHvU0N5bQBh0MBrG0tISrV6/i2rVrPIHnuM5MEjrPz89DoVAgFAqhXC7ft+4o00qldmnZUyoiv0goFAq4XC6srKxgNBohHo8jlUpBLpfD4/HwhkG+gEehgIkOY7N4T0rpdrt8yKQqxu7uLkqlEgcUdFB7kE8fldDIyoRKyEfLjUflC1QxOU5CNO2Q5MDr9eI973kPrl+/jlgshmg0CqPRiGaziWKxiDt37uD111/HxsYGl4HfTQn46CAGSrpMYwVFo9HA4/GwlRXp5+fn52GxWHg9ZLNZJBIJxONxZDIZNBqNd/1ejmpYz9pf9twEgBRFUzqXhq5L7T263S7/YemP2m63Z+ZhqVKpuDy7vLyMhYUFRCIRmEwm1rBJT7LxeBzxeJwfZCRUlYqf6YRMOpGXXnqJHeKlovXBYIB6vc6lF3qIzsq1A+4dEqhTi3RHZEZMhwTp+yItnNT41O12s9CXMhgPMjSmJia9Xo98Pn9fhmHariFlpBwOB0+WWVpagsvlOuS7JoV0XXTP0WZ7dAOdTCZoNBooFotc0qNNizKvR+ePnndUKhUcDgcWFhY4+Ca7iEgkAp1Oh3Q6fWKZrtPpoF6vczlNmr2e1gCGspb9fp8/6G9P5ck7d+5gfX0d6+vr2N3dRbPZPDTBQ2otdBzUuEA+sVSiPHpNaE71wcEBa7weJpmZVqjxxel0YnV1FR/84AfZALvZbKJcLmNnZwfr6+u4e/cutre3n+gsWpJfSf//tF1DlUoFl8uFhYUFXLt2DdevX0c0GuVOcBobGI/Hsb+/j1QqxZq8d/tepE1N0mtD9wNVRugZ+LSv3bkIAEnsSmJKqeEqWVRQd2cymWQR56x5bWk0Gvh8PiwvL2NxcZF9DGlDptQzCW7p4Seda0gG2cC9DjEqidCYKgr+6HdKSzKUwaJyNJU1p+0mPwo9GHU6Hb9fmiuq0Wg4YKHgjBqPKECmTi2PxwOv1wuv18ticuqQlgbW0odgr9fD3NwcBzz0UCmXyyiXy1Npn0NrhkoWGo3mgR3z1JhEsgsqeR8N4mjSCnUHZjIZBINBZDIZ5PN5dsKnTM1FCASlYxxpbnAgEOCNSqvVIhwOn6i3pYMtrSfyK2s0GqjX61P3jKOSa7FYRDqdht1uh91uR6VSQalUQjqd5sCPOi5rtdq7uk+USiWMRiPPrSWbEtJtlUol7O/vIx6Pc8Vk2u7JhyH1oqNxejabDSqVCs1mE4lEgmVBlFGle+xJMu3XjQzayeeWGmMKhQJyuRwymQxnoDc2NlCpVJ7YQUrqQGE0GrlZZjAYoNFocLNIqVR66LSuJ/J6nupvf4ZIy3oul4tH1sjlcr64mUwG+/v77F03ax2eFACurKwcsnyRNsr0ej0eO0MGlPQ1GnFjtVo5SLFYLHC73fD7/WzDQCP3gHsj9jqdDpdf+v0+i2fpd0/7daRrYTKZeJ0Eg0E4HA6o1WoWK1MA0uv1OECmjPLly5cRi8UQDAbZA0/qdC8dfSYNAunv0mq1sL+/j9u3b3NHGTnKA9P34KTgi7KUR0/3UshCg0rkJ+nRjgruSeB/cHCAra0txONxpNNpAOAT8SxP6jkNZIhMwV8oFEKn04FMJuOuWKlW8igUAJZKJZ7Lvb+/j2Qyyc+CaYKspPL5PIxGI6xWK8xmM8bjMXZ3d/ljb2+Puy3f7SGTdL8ej+e+rH+v1+OJDHt7eygWi1Mt0TgJqb8ijU9Vq9Xo9XpsMk5ek6lUCo1G41zfVydBlRyq4hgMBgyHQ+zv7+PWrVvY3d1lqVOpVHqi84lVKhXvuWazmdchuQDk83lks1kUCgURAJ4GKuvRHzQYDMLtdvPMWuDeAyebzSKVSnF0PW0PxoehUChYmE9NH1KRqNTx3mazIRgM8kY7Ho95c7Hb7Ryk0GL0er081osyPUfT1dQVSv55tEDPqoX/UZAaGkvfg9FoBHCvxE1rg/wQPR4PwuEwlpaW8Pzzz2NxcZHHVFFmudPp8OmbZuKqVKpDwSFpKaXXnv5NAFyimoYHMpXOyCqCOi0po37chASSE0izhA97L+PxGI1Gg5txqIHLZDJBp9OxZpLG8U1jOelJQIEeNYIA967d0RFax73/brfL2i6r1XpoNis5ApD9xjRcP2kGkDLsSqUS3W4XW1tb2N3dRTqdfiISHcri63Q6mM1m9vykyhAdcEneksvlWI86a9CznwJAytpTc1uz2WTrl9FoxN6U0oOetPw4Cw1FjwM9q2iqCXmYUsPH1tYWMpkMisXiE+3Ip/vc7XYjEAhwdp+cSaTd7fV6/ZmswZkOAKlEReXfQCDAVhXSMXGj0QitVgulUonHxc3qDU5B3nHdQdS5STez0WjE5cuXWeNBE1BMJhNvLPRgpKkY0oYI6uaiU0skEuFGG9JVUnfwLATTR0u0dA1JI0RZlH6/D71eD5PJhKWlJTz33HOYn59HOBzmLk3SsFHGisoK9EGbMHVtkxkw2feQbxadPKmrcVrWZbvdRjqd5qyxy+XCeDzmqRYndQFLA8MHdUbT9ZfOyVQqlXA6nYhGo8hms0gmk9jd3cX29jaXyqf9oPFuOG5KxWmgZx39LGWtbTYbrFYrdnd3kUgkUC6Xp2JTl1ouUdCRz+cxGo14AzzNyMqHIdWFk+Gu1Lyc9gYyMO50OjNt/SQNbKQ2TbQveL1eRKNRziZLg75+v88awUajwc92khGd9Zp5klBzCtmayeVy1Ot1Nv/P5/NPfAa8dJxmNBrF5cuXEQwGYTQa+ZBGziSkdX0WzHQASAueGj+oq8fr9UKn03H2ioKUQqEw0wHgUeHz0bIczQU2GAyw2WyIxWJc5qZ5lwaD4dD8Yzo10qgkEsISpPWTjp3TarVcLi2XyycGA9OKNJtCD0EqRxaLRQwGA7ZRuHbtGj7wgQ8gGAyyXyAdJNbX1/Haa69hbW2NRwl5vV643W52ejebzWw5Q9eSOqzppK1QKDAcDqdmXVJmk4J6GptIBxDKMBz9Gfrfh20WR+ez0lxv8uKiWcLxeBx//dd/zYbS9LeaJR7WqPAkoEMZZdO8Xi/LG8jnst1u8/o662tIOtDBYIBms4lSqcSlMGng8W7vBbL7oGw/aXal2VFqRCGZyzRlSh8VaZVDqVSytoxGwQUCATQaDZ7SQc9tah4kyx0KgJrNJjtAnPWaeZLQ+qvVavxeK5UKyyYoCfAkM3/0d7HZbJifn8eVK1cQDodhMBgwGAxQLBaxt7eHRCLxTBtTZzIAlI5doYfd5cuXcenSJR5dplar0el0WFS8ubmJvb09pFIp1Gq1qdhoH5XxeIxarYZsNss2JKQjoM1ZrVaz+73VauX3OZlMOJt31MbjaNeW1BWf9H2kx+r1eryZ1Gq1pyIiflZIDXjr9ToHPVarFS6XC5FIBCsrK3yjNptNFAoF1ihtbGzgrbfews7ODlQqFR8waBg4ZQJJQxiNRuHz+eB0OuF2uxGNRnliTblcRjabfaIdee8G2hyLxSK2t7cPTQ+gcjBlGOh7aSrDgwT0VAahrDNlRsmXi3zaHA4HlEol4vE4z9GdJcNx4J70hMzFn6bpPMkNSKZBATod6GhqQalUQq1WO/OMPT1XaNLH04C6q8PhMEKhEMsMyLez3++jUqkgk8mgXC6j3W7PtIUONbS0Wi3kcjlsbGywrECtVmMwGMBut7OPp0qlYo9J2itDoRCy2Szr0EqlEhvjn5cgsNfroVAoYHNzE8ViEUqlEo1GA3t7e9wb8CSfwdQg53K5MDc3h7m5OfYbVKlUqFarKBQKSCQSyGazz3QO+MwGgDSy7LnnnsMLL7yApaUlRKNR+P1+7q7J5/PshffWW2/h7t27SKVSPDpo1uj1esjn89jc3AQANvYEcF9gJx1BBdxrzz9Nto4yYlJTZGkX8FtvvYU7d+5ge3ubU+WziHQEV7FY5BOv3+/Hiy++iKWlJbjdbjaS3d7extbWFttT0DxIqadTo9FAOp3mUziVX+bn57G8vIzLly/j2rVrcDqd8Hq9/BqSyST29/cP/a5poN1uY2dnB+VymYXRyWQSJpOJO9hotCLNAK5WqyceCsjbjkZTzc/PY25ujgMVurcVCgUfcAwGw6HGpFlhMBiwvYjL5Xpmmyi5IpBpNzWKabVabGxsYGtr68wDwGeBWq1mz9RLly5xYoAOLt1uF+l0micvPK1A9FlBBw4KYjudDnZ3d9m2ihwMgsEgdDodj3KUWvKQW8bW1ha2t7exvb2NnZ2dmdB5n5ZOp4N4PI5qtcpWZ8PhEJVKhe/RJ3kIUCqVPF51ZWWFZWpkUUelaNIdPsuZyM8kAKSSDwlUSZxKgnFpxol8cMgbSnohKPNHY7xisRiuXr2Kl156CdFoFHa7HWazGQB4Zi0NcN7Z2WEdzKwu5F6vh2w2i7t373LmqtVqcQlNKs4/bgGRPoSu/9H5pCT8pQ49Coroo9FooFqtYmNjA7u7u8jlcjN7LYF7+lAS+pIfYjQa5U5rCurIl+zOnTtYW1vD5uYm6vX6ofIeSQ2Aw1lVi8XCaX2z2cwyBavVCrlcjlwux93XdFKfFmhmJ9mKtNttlEolzjJRB1u9XsfBwQHi8Th7Zh2HXC6H3+9HLpdDuVzm0y51qVN2gu5zmk5DD2YSqU+jf+JRqKxP4u56vY5ut8udvafV+5FmSeoRRqMx6VrRAZAOfSqVijPQlOmWviaphnfar+PjQpKYSCSCQCAAi8Vy6BnZ6XSQzWb5WfYsN96nAZU2yZqr2WwilUrB7XbD5/NhYWGBdbcUBNO6Ic2zy+WCy+XiLD05HYzHY86SUiZ+Vq/VYDBAqVRCuVw+9PmnIdWgJiS73Y6FhQVcunQJPp+PnRLICiudTiOfz7P29Vnx1ANAekjRpAC/389lMNoAKcNEtXjy46lUKqzHoMyAXq/H3Nwcrl+/ziPLIpEIDxIHgFqtxinejY0N7OzsIJPJoNVqzXTA0uv1kEqlMBqNuBRJljfkB3hSACiXy1ljRV5/5IdF30/dUOl0Grdv38bm5iYqlQparRba7TZnBWmhzmrpl5B2kLvdbjY7npubg9vthlwuRyqVwvb2NjY3N7G2toadnR1O0x/3/o/r1qTTNV1D2rxJk0nBFGW5nvb4n0eFAuVGo8HjE+kQQQe3breLSqWCer3+wAeYTCZDtVrF3t4e+wEWCgUsLy/j0qVLCIVCnJnQ6/WIxWLo9/uwWCx8YKRy/bTLOGj0ZLVaRaVSQblcRrVaZa3eg3wVpfT7fZ4EUq1WWRNJGjeTycTNOQaDAQaDAWq1mnWWJDynkutkMsHe3h4ymQwKhcJUHTieJLSGSGZA1wS4N2c4n8/j4OAAhUJh5gNAgg4M1JwntQmKx+NwOByHZBc02choNHK2XavVYmFhgZsDg8Egtra2sLm5iXw+P/O6wGfR2EIHMb1ez/tKJBLhTn9qIlxbW0M8Hudu/Wd5XZ9qACidM0vtzysrK7h06RKi0ShCoRDkcjnrinK5HFKpFHZ2dvjhKZPJWChP3VyxWAwvv/wyXnrpJdhsNi5HkecfCcg3Nzexs7OD/f39h25MswBlAKvVKhKJBN+0Wq32PnuOo4tboVBgbm4O/X6fMwXS+ZvSh0YqlcLNmzfxta99jQNAmiYi9XCb9YclaSbJ7gV4JytANi/D4ZCvxZ07d7C1tYVcLsdC8dNCukqpxQJ1clPgRxYqj5IZepZIG0OKxSJrToF7mWPKUD3s95ApLY3Ho8kLRqMRbrebNVo6nQ5zc3Ow2WzQ6/VsYUGb97QHgKPRiE3TyeS4VCpxw8ZpAkBqjCgWi2wgTjNJKcvndrsRDAb5YE0NXVR5MRgMCAQCUKvVfM2USiV6vR4fss8bdLjTarVs1n5cAEjm7GSKPOvPNII0uWQRRBozMnXX6/VsSEwTjajZy+v1IhwOIxKJwO/389g0nU7HBxCpPZjgfqTrjyqWlImm7F+hUMD6+jreeustHBwcPPHO49PwVAJA2lhp8oTT6YTH40E0GsXy8vIhM125XM7lRafTCYfDAYvFwl1LNNWCSkJOpxNXr17F0tISQqEQ1Go1FAoFL3QadUOlulQqxaWXR1msRzfhaVjolOKnaR+1Wo2Dh5MGR0u7LanlnL5f+jAcjUY8Q3h9fZ1nJZN4/TxvEmQjRJuiRqPBZDLhhpu9vT3E43HuID/tWqANmIJMCizJe6vX6/HaJ8PeaS5tUkDzbkXK1DRCDUTD4RA2m41nB5NNBxkkq9VqhMNhxGIx9uaiWdTTDJVeyfIknU7j4OCArYGoYnES9LO5XA47OztYW1tDIpFALpdDtVplM3aXy4VcLodAIICFhQXuCgbuaYENBgPG4zFCoRA/P+jAPe3r7lGhjZe6f0kaRHovadNDoVBAsVicGSur00LPdJJNSKFqGgXHxWKRp7HYbDaUSiUA79gJ0XQfanLb2dlBqVRCpVKZ6YaZp4nU8sXr9SISiSAWi8Hv97NErdFosPxge3ubJQjPuqr2VAJAMtG12WwIBAJYWVnB5cuXEQqF4Ha7+YYkPzqycqHTbCgUwuXLl3k6Aw1httlsbPfi9/s560ezHMkFf2NjA3fv3sXe3h6XOB41+KNgiqYQTNtphzJxZHMjzchIkY5Ak8vlsNlscDqdh4yyKcuQzWZx69Yt3LhxA4lEgv2gZjnV/zCk01KUSiVrrVqtFvL5PI8GqlQqj6yXIvmDXq+Hx+PB4uIiQqEQjEYjNwhIzckpwzbrpfXTQll+2ojpBDwajfiAQuvXYrFgbm6Oy6g0LWSakRoNUzMIjah0u92w2WwP/HnKUMXjcbz99tt44403kEwmObND1ZVEIsFZ7Ha7zV3UdACRarAdDgcGgwEflI1GI4/tOy/3OTVeuVwuboCwWq0cFFOjg3TSw6xNhXo3SDWhtH+QJYpOp0M2m+XnII3LpL08Go2yX2mr1bowz6rTQveZRqOB2+3GlStXcPXqVVy+fBlut5tnnpfLZWQyGcTjcSQSiTPLxD+1AJAMZOfn53H9+nW8//3v56CN2s8pANHpdBxgHbUcIe3CaDTi0xx5OSkUCtZWZbNZbG5u4s0338TGxgZrFchT6jSvmTJCR0solK2YJn8oacntQQaylA2gUqPVaj3UDUcNOP1+H7lcDm+//TZu376NbDbLOqtpec9PGqmIXqFQ8IORyt7SDMGjinNpc6ZrTqP2AoEAz8otFArY29vDwcEBTyCZFhNa6Ui7o4PLnxTSUi7p2o5mo+hvZDQa4ff7USgUsL29fZ+V0TRCzzPgHV1yMpmETqeDy+XC4uIihsPhiSV/OpRR5pDGeFHj1Xg85p+jbnMKeK5evYrhcHhfiZkmBCmVSuRyObZFoXnB5yEAIp9Kh8PB5UuHwwGz2czJAvK8o1J6o9GYaTurR4XWJQV/rVaL92OFQoFiscjfS81ZNpsNLpcL4XCYTesLhcLUyzCeJRT80aSpaDSK5557Di+99BJCoRBPj6pWq0ilUkgmk0gkEshkMmd2AHtqASAJHwOBADweD58iTgpUjlo80OlZp9PBaDTyZAsaXyOTyXhaQTKZxNbWFt5++22sr69zt+9pRb0kpibTUDLxJTF6pVLB7u4uMpnMVC34B42HIug0HAwG4ff7eQwSZTipC5bK59SN1Gw2z1VZ6GHQ9aCAmPy0aNTbaa1ZpGVfsl5YXFzE0tISgsEgzGYz+v0+6vU6tra2cOvWLayvryOTyRwy7T5raGQejeiSeqQ9yTVB5RLSsB6XyZZmsaT/f1agTGehUIBWq8XBwQHS6fQhY+IHzVim33FUY0lrhabQeL1e2O12zvYfvU5kRD6ZTA5ph6ex8ehxkclkMJlMiEQiWFxchMfj4c5y4J1nZb1eRyaTQSaTYU/EabjnHhVppYpe/2nvzZPuY/qd8Xicv4fkCkqlEn6/H5FIhC2hqMx8UfaJk6DnmM1m4+c92X6Fw2Gu+pRKJbz99tt4++23sba2hmKxeKaHj6cSANKJndrPKZh6lIe2tEOT/MbIOkbaxr+zs4MbN25gfX2dRx6RQfFpFyU9QO12O6LRKHcWG41GqFQq7O7uYjQaoVgsTlUAeBrUajV3uM7Pz8NkMgG496Cgzt9arcZ+eJSOnsWH4ruFTsfU8UwZ5NNeC8pY6XQ6BAIBXL16Faurq1hZWUEgEOApKolEAnfu3MHXvva1Q1rLacjCUAOGx+OBTqdjWxfSqDyp1yiVf5hMJvZOPC4IpHU6TVnS00KZTtJW+Xw+7O3t8YH4JI9DMnamySt0SKXfSZl7s9mMhYUFrK6uYmFhAXa7/ZDfHUHPVOm8arKNmaWA+kFQALi4uIjLly/D6/Uemm5EpctUKoV0Ov1Mx249SShQk66b0yQEHgbZOpGf52g0Yj9OvV4Pv9+Per2Ozc1N6HQ6tmubtX3xSUNjBz0eD9773vfi27/92zE3N8dWOjRqb29vD6+++ipeffVVzkCf5fp7agGg1WpFKBTC3NwcXC7XqS0PpFBK9UE/S6dcarkmcStNvqDylbSkRTcOdSdbrVZ4PB74/X5Eo1FcunQJc3Nz7JekVCpx8+bNmTOipXKI1+vl9ySdA0wzktPpNHZ3d3FwcIBKpfJMncifJUeHnlNWT/p3pbVEG660O/dBJ11pZzVpWVdXV/HCCy9gZWWFDcrJm2t9fZ0tinK53FTYKtBDTK/XIxgMYmlpCTabDblcDvl8no1SSRv6uIPS6RrT4SQYDGJxcREul4uDwKPTaWhaSqFQ4GH2s4R0wgsNnaesCmXjpAGbtJxEWWSaKgLcMzFvt9twOp14/vnnsbq6ilgsxrOaT2oKo4wrfc95yf4RRqMRwWAQ8/PzPG1BGjCXy2UcHBxgf3+f9eWzBpXzSUdK1QpqEnw374mavShb2mw22YqJvALJVuck7flFg56dZrMZkUgEq6urCAaDLCsiCcj6+jr7yE7DdJWnEgBSKpSyaWRy+6TR6/VYXFyE0WhEOBzG5uYmtre3kU6nkcvlUKvVDonKpTNvaZKIz+eD3+9HIBCAz+fjwfc0+5TK0LN2SpbObXW73YjFYgiFQjAYDADA3ZelUgkbGxu4desWNjY2UK/Xz/iVPx2kVizSwe96vZ7/vsC97LXL5UKtVoPb7YbZbEapVDqxS4uyOCaTiQeuz83NYWVlBSsrKwgGg9Dr9RiPxygUClhbW8Obb76J/f19du6fhk2ImgR8Ph8uX76M97znPfD7/ezNmUwmsbe3h/39fc4UP85hgTqvKXP1Td/0Tbh+/Tqi0SgMBsN9He1kQUPTUmbRukRqOp5KpfDGG28AAB8+yStVehih2aFyuRw6nQ5LS0uHjLNrtRpKpRK0Wi1ba9ntdhiNxnMX1J0WqQbQ7XbDZDLx4a3X66FeryOfzyMejyOZTKJer0/FvfeoGAwGLC4uYnV1FTKZDIlEguU7hULhXXXI06GEdH/UtElBHzlvCO4hzdZLrdWompTL5bC2toa7d++yldg0lM6fahew1+uF1+s9tME+DNId0E35IO2PVqtFMBhk93Kz2cylXKPRiEKhgG63i36/z38cjUYDjUYDnU4Hv9+PhYUFRCIRDv6oFKVQKA4NJn8WxpFPCukp32Qywel0IhQKsR5GOgOXRPW3b9/G3t7ezI9DOgm6GSkAbDQaqNVqMBgMnA0B3rl2lPl1uVxwOBxsjUCm5YS0S5U6L+fn53H58mUsLy8jEokgGAzCYrFgMBigXq/zKfDu3btsKH3Wp0CCrAsCgQCWlpbwwgsvYG5ujg3ad3Z2YDAYuDGk3W4/UgBI2SbS2pI04Zu+6Zvw3HPPwW63s6E5QUF7vV5HNptFMplEtVqduQAQuNd9Sc8lrVaLxcVFRKNR9usj3TEAvn/JqFx6UBiNRiiVSshms5DL5QgEAnC5XLw5P+iwOkvPskeBMssajYYtYOh60AxzGhKQSqWQz+enRnP7qGg0GgQCATz//PMsYaIkC92T0uYt+ji6hx79b6piuFwuHhvndrthsVhYW0rr87yuo0dB6nVsNBp5Khc1j5IRfDqdxsbGBra3t1Eul6dGYvVUAkBpEEcfp0kVkz8YdQBLM3fSUUfSgJA+T1MuqKwUiUQOmRhLdTSUAbTZbPD7/XC73bBarbBarVAoFKwTyeVyXK4pFApTs1E/DLlcztMBIpEIPB4PzGYzb67UdVkoFHBwcICDgwOkUqmZ8FZ7XKSWHBRMHBwc8Fqgh6d0EzEajdxNSI0y0k5gOlRQkO3z+RCJRBCNRtnwczweczdxOp3mGcrZbBaNRmOqtDNkiUPNVkajEWazmbuZ6dBAzQg0JP40FkmU9SND+HA4jLm5OVy9ehXhcBg2m+2+Mih1bJJFBZWh2+32VDw8HwfqvCQfsLW1NRiNRiwuLiIcDrORM10H2nCPdj3TM5XkMVar9b7rdxxSSyupZmzWN3LaC2i6BXmd0mGCHCVoBCGt41l931I9p9lsRjAY5OCOGhc7nQ56vR7vqSSLogPC0T2Z5FAulwuhUAjhcBjz8/OIRqOwWCyQyd6ZW0u+pbRHz+q9+G6hNadWq+FwODA3N4dYLMayg2aziWw2i0QigfX1dZ45TQ2W08BTCwDJxLTT6fBD7EFZQGmJpNVqodlssikvuZYfHXdG0OnZ7/fDarVibm4OnU6Hfw+dtqVD5WkKA5UAabQVtcWXSiVsbm7i5s2buH379tR1AD8IyuSQBtPj8Rxywu92uzxZIB6P82mYOuLOI5RJAsAnsr29PTY6JYNO4J53Is0GDofDkMvlhxoQyOrIbrfD4XCwe77P5+PJFWRRVCwWsb29zfZEOzs73P01LQ8C4LB7Pd0n1NCiUChYNwm8Y2S6t7cHpVL5QP2i1KqEJvlEIhE8//zzuHbtGqLRKPx+P0+lkd7bNLc2m80inU6jWCxyg9c0XbdHhbRoxWIRb7/9NmemST96tOv5OKhRhwK+0zbZUQBIG/e0+Zs+LpSFMRgMx85F7/V6KJVK7P3X6/Vm+n1LjZ4pG0jdzmTaXK1W2eOQ5CtSTbN0LyWP2HA4zDr4hYUF9qvU6/VoNBpotVqoVqu8P09LJussoDVH+8TCwgKWlpZ4mtHRGfI0c3qaJs48lQBwPB6j0WggnU5jf3+fvftoU5GaK9PDjwLGVqvF5blOp8OzLk0mE5d46QaXerjRRAyr1conWgrmKJCkQE/6AKTXS106pBHJZDK4c+cO3nzzTWxubqJcLs/EpkOnErqZaa6tdLPo9/uoVCpIJBI4ODjg4O+8Nn8QpLmo1WpIpVKc9Q2FQjwontYnaYlcLhdisRgb5lIAJJPJeLoH2b0Eg0G2LKKuz1wuh0Qiwebk8Xics3/T9uCUZhCkDVOUeafxeJPJhCdP1Go1Niam9yP9GdpsKJvqdruxvLyMq1ev4sqVK6zTkmqEKaAkuw4a51goFHjjmZYH6ONA2eharYZ4PI7hcAiTycTXgmb6Su2ajkId1JQBfFjwR4dyKudns9lDGulZvp7APd0aTZM6qjmnrGu5XD4XUz+kU4zMZvMhj0etVssa3UqlgmaziXa7jdFoxNY/0iCQfp/T6cTc3Bzm5uZw6dIldo2Qy+U8KYpGGlI1ZBb2xKcFJVo8Hg9isRhfM2rMIbeHu3fvcrPfo0ySehY8lQBwMBgglUrhtddeQ71eZ30d6Qik2YRms8kzPmmhUiBGAaDRaGSxtNfrhdVqhcVi4UHWJBw/Sd9AnYVkB0CZGdrQW60W3yyFQoGnP1B5lEqj0/SHOw7Ksup0OjgcDi5HOp3O+zZYCgAzmcxMdlU+LjSLNplM8nhBms5Ba0m6brxeL0ajEUKhEEsUgHuNH/RzNJe53W5jf3+fr20ul+MMViqVQqVSmVrTWTqU0fuUmoBTedhqtWI8HiMWiyGXy0Emk3FWhdYQlaXI546MZD0eD3w+H4LBIKLRKFwu16F51MC97lYK/tbW1vDWW29hfX0d+Xyey07Tfi+eBuoOVCqV2NnZgcViwWg0QjgcRigUus+z8zhOk/WjTF+tVsP6+jrW19d5tvU02Q+9G1QqFWdhIpEIBy7SZAM996f1/nsUlEolTCYTu1eQtIemzDSbTbRaLbTbbe4QHo/H0Gq1h8rj0olX0gOt2+3mjKJ0Wkg6nebZ3bN+Dd8tOp0OkUiEJ30sLy/D5/NBq9Wi2+2yr24ikUA+nz+1L/Gz5KkEgP1+H4lEAt1uF4lEAqFQCPPz86y102g0HOhRWp4WVafT4fmo0tItBTRzc3NsLu1yuTAej9m24yikeZA+COihS/qlVquFcrmMRCLBnVSZTIbnvlJr/SxsOlIdDF0vqRUCcLgcQu93GtrRnxXUUZpKpdDtdlkyYLfbOatCAYlKpYLL5YLFYjnUCCQNiijTRb+7Wq3irbfewo0bN5BMJnluK1l2TLNnFj3oyUqCSkxSobPFYoFKpUIsFkO73eYshFwu56yKTqfjQI9GP5Ldi8/n48MbadakpTpp2XdnZwdvvfUWXnvtNfZKpIPYtN+Lp2EwGPBGurOzw01Gg8GAM4B0T5+2vHvcf5O0plgs4s0338RXvvIVbG9vo1AooFKp/P/Ye8/mtrJsS3ABILz3AGHpjWxmZb169Wa6O/pn98eZ6I4pm0aiJHoShPfeu/mQsbYumJRNpQhQd0UwyqTIFC7PPWefvZf5bDufZQI92Ciq4Sid4Nj9oYwtlV67oVBICjm/3y8ddKZq0fGAI2DucbfV9hwfs7NMegfXaaVSEQPobzkGjo0mq9WKzc1N/PWvf8XBwQHC4TBsNpvUMOR9p9NpyTBfNvwhBSDHNyTct1ot1Ot1ieQxGAyyKOv1uhBzmYVKzsJ4PIZer5eWNscXmUwGfr9fbit3tfzfBZKB6WnG7l+hUEA+n0elUkG1WhWl4SpxZBj75nA4ZNxGfptOp0O320Wr1cL19TVSqRQymQzK5fI39zIr48aurq7w008/YTweI5lMIhQKCVWBXWN2jkmoJrguZrOZrPHz83P861//wsuXLyVCjuOXZS+yWXwxXSebzQp9w2q1ymWKMY+j0Ui6gslkUgpAjs6ZAMT1yPEcVXLA2+4UD+dWq4WrqyucnZ3J2Jzj32Wxy/lS4IWUgjMAMpXo9/viTuDxeGCxWBYUwkpQ2c5YPeVllReeZrMpcXLHx8dyAVqFycb7oFSWe71exONxcZ5QXizYwapUKhJ9t8pgp5zZ2LxQKadhwNs1xq4n9zRlY4RQuiRwzEs6VrVaxfn5OV6/fi1hCw/pXfxYsPB2OBzY3NzE9vY2tre3ZYpEdf719TVOTk4klYyxqsuGP1QEwnHSYDBAqVQSoQU5BeRK3e448OZC4r7y5xQKBVitVvlyOBwyKvkYTCaThfY4/920tOCmuIq3RJ1OJ0R7j8cDj8ezEGnXbrdxcnKCFy9eyOKsVqsrl6zwJcB0hlQqheFwiJubm4UXmeIjp9MpNi7M672tnhwOh8hkMri4uBCOH30DPzVJ5D4xHo/lpmoymcQ2KBaLLXh/6XQ64Uz6fD7s7e0tKHMZBcmihZc48o+UYNePYpl8Po8XL17g3//+N05OTlAsFh98Ms10OpXuZqPRQKlUwunpKfb39/H8+XPs7e2996I7HA4ln5XUFq65+XyOarWKVCqFy8tLyUjnpWSViz/grWDLZDIJzcDn8y1cMgCIkXg2m13absynYDgcIpfL4ejoCJPJBOFwGF6vFzabTS5rwKKxPYAFfu/trjJ5klwvNzc3SKfT0vVjsdlsNpdKyfo1sba2Jpy/R48eYX9/H6FQCDabDTqdDo1GA2dnZ/jnP/8pxTLftWXEH5aozkKOvAvGIAGQcSzxvggb5c/pdruyoJXGi5/iM8hxCG/aym7Ql4jSuU8YDAb4fD7xBFNav7BLlUqlcHJygpubG7nlfYtg4cFuL0nTuVwONpsNZrNZqAdutxuj0UhSKJTrYzabYTAYyI2vUChgPB4v8OdWBewAMgoqlUrJ5sZDFXiruqf67V2fU9mFUP5vdqk4niIVJJvN4ubmBi9evMCLFy9wfX0tPp4PtfgDfn0+vHwqpxyNRkMuGLSruovqQn9JdvM5RuaeViqVcHFxIRe+Vqu1YGe0qlAKIcgT9/v9culVpn9wjZGSseodwOFwiHw+j6OjI/T7fVSrVfHrU8Yqkj5w+x0k2Fzh+dpoNER4dXZ2JokpuVxugX6xanvb54KFMlPD3G43tra28OzZMxweHiKRSMDpdGI2m6HRaODq6gpv3rzBzz//jKurK1GcL+v+9YcVgEp8qQWjXHzKTuFwOPzolA5+L0nuD+EWTFitVsRiMezv70sniy//dDpFr9dDqVRCoVBAq9Vaypb01wbV4O12Gzc3N2i1WtK1MhqNUuhMp1OJQLt9eRmPx6KQ48u+ymuKRS0Vf71e78618iGrEuDt+6a8XNFTsdVqoVgsIpvNysg5n88jk8kIafohvZ8fgjI5gJFxAJBKpURQc9ekYzAYCKeZl1oeOPP5HO12Wy447/pdriKUdk10iFCKAhlp1mq1UK1WUa/XRbW+6s9gMBggn88DAPL5PDwej4gs+TzcbrfQf94FOkLw2TAXnokitVoNtVpNxFffGthoUmoQdnd3sb+/LwLLtbU1lMtlvHjxAq9evcLLly9l9LvsnqVfpQD80lAWguQPfe73P5TDhV2ZWCyGvb09RCIRiX3jwdLtdlGpVFAsFpfOhPi+wDVA/ikTP5SZ0TqdTjqGd73MLAIfQvFHkDNVq9UW4uruKvre1TlXensqzYc5ZioWi7i4uMCbN29weXmJbDYrBeftLta3AtJeptMpcrkc6vU63rx5s5BWcxtcm8qxr/KfsQumFAc8FNCHjUp8JjFotdqFkSY7/TQxXvVnMBqNpJtJo3aKN6iIjkajCIVC76VH9Xo9ZDIZpNNp4UaSj6tMwlrmIuaPgvKCEQwG8fjxYzx9+hRbW1viYrC2tobZbIZisYgff/wR//znP8UBotvtLr1R9koWgAQ3t28Z9KuzWCwIh8OIxWKIxWILIzuSe/v9Pmq1GqrVqixOFb+C3QIVb0eSpVJJVPzRaFS4REqCPdNB+v3+QoHByxmFXkrrjU6ng3q9jkqlgpubG5yfnyOTyQhBf5k3zK8BCmM4IldxN2jmzxhGpdcsALnQKfOruU5X/VLBLv1gMFjw7+SFlSLHSqXy3gKw3++LTRXFgA/JIPxzwOfIzl8gEMDBwcGC1Qv5uL1eD41GAzc3N7i4uMDFxYV4L64CzWClC0AVvy5Wh8OBcDiMzc1NSaNwOp2iXlUWgFRALzMxVcX9Yj6fo9frIZfLSbSiz+cDAAQCgQWrHAadl8tl6dzRIJaikpubG9RqNTlUuBa5eVarVbF9+lYPHRWfDqp/fT4fAoEAHA7HgqfkYDCQ+MWTkxOUy2Xhkz6kdaZ0I+DFi1xl+ky+CzQHZ2Gs/DnfKnjRtdvtiMViODg4wKNHj7C3t4doNAq73S6uGqSwXF5eIp/Po9ForFSHWS0AVxy05QiFQojFYgiHw/D7/WJozNQTknybzaYkAHzLL7mK94MKZo1Gg2AwKD6J3NhIsq/VamJ8Td9MHiaj0Qi5XA5v3rxBoVBY4KWxy6BU/H9LfD8Vvx9arVYESkxRURY7o9EIpVIJ5+fnuLq6Qq1WW+n83/dBSZWYzWYiBrpt9XLX9ym79t8y2Ek1Go2w2+3w+/1IJpN4/PgxHj9+LGNf0s7q9Tqy2SzOzs5weXkp6Tqr1GFWC8AVhzK/VZlpTN7abDYTYm+r1RKvsFVZoCruB8r0iKurK+h0OhQKBQQCAXi9XjloGflI0QZtlDjCrFaryOfzqNVq6ppT8cVANwibzSYJMy6X6zepMsr0j4dsJaSEkh+v4uNAU3+OfaPRKOLxOPb397G7uyudv9lshnK5vCBco2UOhZWrtM+pBeADgNIa5zZJn+pfejg9BPsHFV8PvV4PV1dXqFQqcsEwGAxiSEwPP6XIQCn4GA6HSxmBpGJ1wT2OcWihUAiRSOQ3BaAKFR8Ldv44Tdvd3cXh4SH29vawtbUFv98PjUaDTqcj4QEnJyeoVquo1WpiwbRq+5z6tjwQKAu/2WwmKq56vY7r62tpUy9bGLWK5QZzo+v1+n3/VVSoAPBW/OF0OuH3+xEKhRAKhSSmUIWKTwWjBKPRKHZ2dsTkmYJKjUaDcrmMTCYjUZ/n5+fodDoS8zkYDO77Y3wy1ALwAUIZE5ROp/HLL7/gp59+wvHxMSqViloAqlChYmVhNBrh8XgQiUQQDocRCoXg9XphsVg+OhBAhQolrFYr9vb28Oc//xk7OzvCp7fZbNBqtZKjfXR0hOPjY5yenqJQKIjvJs3/Vw1qAfgAQK81OrmXy2W0223k83mcnZ3JjSWbzS5EmalQoULFqmFtbQ1WqxVOpxMulwtOp1NiB5XG98qDWRU5qHgfTCYT4vE4vvvuO+zu7sLtdsNqtYqgJpPJ4OjoCP/4xz+QTqdF8EFR3Kr6laoF4IpjNpuh0+kgl8vBZDJhOBzi7OxMOoCFQgHn5+di07GKtxQVKlSoIJiWwtEb84+1Wi0MBoMYQJdKJTGA7na7qvOBineC1i9Op1Oy4AGgXC7j8vISL1++xPHxMdLpNCqVigQHrDrUAnDFMZ1O5SZSr9dxdnYGq9UqEXm9Xg/NZlOSP9QNUIUKFauMyWQi0456vS4JHzTwHQ6HKJfLyOVyKBaLkn+8zJmsKu4fer1e0lQ0Go3YWP373//Gjz/+iLOzMxQKBblMPASoBeCKg6kNw+EQ1WpVHOGV8VzfurGnChUqHg6Y2tNsNlEoFHB9fQ2LxSLdm06ng0wmg9PTU2SzWTQaDZl+qPugirswnU7RarVQKBSg1+uxtraG0WiEN2/e4MWLF3jz5o1E7z2kKZpaAD4QKG+2q+JCrkKFChWfCvr7tVotnJ+fQ6PR4PLyEmazGSaTCYPBAPV6HcViEdfX1zL9ULt/Kt6FTqeDo6MjAIDb7YZWq8V0OkUqlcLp6SkqlcqDTM/SzD/ySvQ+N3EVeOfNUn1u78f7lp/67N4Pdc19HtQ19/lYljVH71Oz2QyLxQKTyQSdTgetVovJZCJJNP1+X6K57rP7tyzPbdXwtd5VvV4Pm80Gm80mI2CluHIwGEhiyqrgY9a7WgB+Iagv+OdBPYw/H+qa+zyoa+7zoa65z4P63D4P6rv6+fiiBaAKFSpUqFChQoWKhwHtff8FVKhQoUKFChUqVHxdqAWgChUqVKhQoULFNwa1AFShQoUKFSpUqPjGoBaAKlSoUKFChQoV3xjUAlCFChUqVKhQoeIbg1oAqlChQoUKFSpUfGNQC0AVKlSoUKFChYpvDB8dBaeaLr4fqtHn50E1+vx8qGvu86Cuuc+HuuY+D+pz+zyo7+rn42MsntUOoAoVKlSoUKFCxTeGj+4AqlChQsUfjdu3eo1GI1/Kf87b7Ww2w3w+v9ecVxUqVKhYRagFoAoVKu4da2trMBgM0Ov10Gq10Gg0WFtbg9FohNlshtlshtVqhcFgwHw+x3Q6RafTQbVaRaPRwHA4xGQywWw2u++PokKFChUrAbUAVKFCxb1Co9HAYDDA6XRKkbe2tgaz2QyPxwOfz4dgMIhwOAyn04nZbIbxeIxUKoUXL17g+PgY9XodvV5PLQBVqFCh4iOhFoAqVKi4N+h0OqytrcFmsyEQCMDn88FkMkGv18NutyMcDiMSiSCZTGJzcxM+nw/j8Rj9fh9HR0fodrsolUoYDAYYDocYj8f3/ZFUqFChYiWgFoAqVKi4F2i1WthsNrjdbsTjcezu7iIej8NsNksH0O12w+fzwe/3S3E4HA7R6/UwHA6h1WphNpthMpmg0+nu+yOpUKFCxcpALQBVqFBxL9DpdHC5XNjY2MCzZ8/wH//xH9jf34fZbBYOoJIXCAD1eh3ZbBa5XA7pdBr9fh8GgwFGo1H+jAoVKlSo+DBWpgDU6XQwGo0wmUy/IYsDwGAwQKfTUXlAHwEernq9HjabDS6XCw6HA2tra5jP5xgOh/Ise70e+v0+RqPRg1ZaKp+JXq8XHppOp4NWq4XBYIDNZoPRaHzvz5lOp5hMJphMJhgMBhiNRhgOhxgOhxiNRhiNRpjNZt/0GjUYDDCZTHA4HNjY2MDTp0/x3Xff4cmTJ9je3oZOp8N4PMZwOES320Wz2US73Uaj0UC1WkUul0M+n0epVEKlUpFu4CqvT51OJ3uczWaDw+GQTijw6/7Gr36/j36/j/F4jMlkstKfW8XyQK/Xw2g0yhf3QK1Wi9lshm63i3q9jsFgcN9/1U+GTqcTIZnZbIbFYoHFYoFOp8N8PsdoNJJ3azqdAnjLTTYajdDr9VJr8M+ORiNMp1NMp1OMRiM5J4GP8+BbBqxMAajX6+Hz+RAIBOB2u2Gz2RZ+geVyGefn50in0xgOh/f9111q6HQ6WCwWOBwObG5u4vnz59J5mc1mqNVquLi4wPX1NXK5HHK5HBqNBqbT6YMtXNbW1mC1WqUg9ng8sNvtUqz4/X7E43F4vd73GpCyaGm326hUKiiXy6hUKqhUKqhWq2i1WgsbxbcGjUYDi8WCUCiEaDSKx48f409/+hMODw/h8/mg0WjQ6/XQbDZRqVSQyWSQSqWk4KtWq2i32+h0Ouh2u1IMdbvdlX3vNRoN9Ho9TCYTfD4fdnZ2sL+/j3A4DIfDgfl8jlKphEKhgEKhIAVwq9VCt9tVeY8qvghMJhMCgYDQLTweD6xWK4xGI2azGc7OzvDzzz8jn8/f91/1k6DVamE0GhEKhRCLxRCPxxGNRhGJRGA2mzGZTNDtdmWy0Ov1APx6Jrjdbvj9fjgcDuj1esxmM1SrVRQKBdTrdQyHQwwGA9RqNeRyOdTrdcznc7GnWnasVAHo8XiwtbWFcDgMr9cLu90uvJ9UKoXhcIhWq4VWqyXFyqr8Ir4meBtyOp1IJpP4z//8T/xf/9f/BYfDgdlshnw+j59//hkulwtra2vodDpot9sPsvjTarUiQqDiNBQKIRwOw+PxwGQywWq1IhaL4fDwEOvr6+/9eYPBAI1GA7VaDdlsFul0Gul0GtlsVsQN1WpV1ue3tDY1Gg10Oh3sdjui0Sj29/fx+PFjPH78GPF4HAaDAePxGPV6HcViETc3Nzg+PsarV69wfX2NfD6Per2+8OyU/7mq65MXMpfLhVgshmfPnuGvf/0rNjY24Ha7AQDpdBrX19e4uLiQw6hQKAAAut3ug7qcsRuv0+mg0WhktM+Jz11fSijXw2w2WzgLlOtFxdt30mAwwOv1Ih6PI5lMIhKJYH19HU6nEyaTCZPJBGazGcViEa1WS6Ycy76HsavudDoRjUbx6NEjHB4eYnd3F3t7e7BarZhMJqjX6zg9PcXV1RXa7TaAX2uOQCCAWCwmZ8FsNkM2m8X19TWKxSJ6vR663S7y+bys29FohMlkIuuOXcJl3KNWpgA0GAzw+XzY3NxEMpmE3++XAnA+n8Nms2E8HkOn06FQKKDRaKDdbkurdpkX6X1A+eKzG+h0OsVjbWdnB7PZDO12G+l0GjqdDpPJBBqN5sE8S61WC7vdDpfLhXA4jI2NDcTjcfh8Pni9XjgcDhkBsPvsdDrf+zPNZjMMBoMIE5xOJwKBACKRCKLRKC4vL3FycoJsNiuq1YfyPN8HjtHZZdja2sKjR4+wubkJr9eLtbU1tNttWW+pVArX19c4Pz/HxcWF3Li73e59f5QvBhY3RqMRgUAAGxsb2Nvbw97eHuLxOEKhEGw2G4BfDa9J2fB4PAiFQri6usLp6Sny+bxQNh7CWuJl3+VywWKxCC2DI0pSgEjZWFtbk2dJi6But4ter4dOpyPdYtJZlOO7bx181jxb9/b2sLGxgUAgAK/XK1O20WiEVquFSqUCk8mEQqGAWq2GwWCA8Xi8tM/SbDYjGAwiEong4OAAjx49wt7eHmKxGLxeL0wmE6bTKfR6vRS5/X4fwFuOMjuApEhRoBYKhaQDWK/XsbW1hXK5jEajId150qg6nQ5arRYGg8FSFYIrUwByBMxF6vf7YbPZ5KW32WwyXrq6usLV1RXy+bxsCA9hY/yjwCQFPiObzYZEIiFdhqOjI+HCLeuL/jnQarVwOByIxWI4ODjA8+fPsbu7KwePyWSCVquFTqeTTuB8Pn9vEby2tiaHlsVigd/vRywWQ71eR6VSgcfjQa/XQ6vVAgC5HT50sNBxOBxSAD5+/BiRSAQOh0Nu4dlsFicnJzg7O8Pl5SXy+TxyuRxardaDG5vzEmY2mxEOh/H48WM8efIEW1tbCAaDsNls0Ov1AACPxwOz2Qyv14tYLIbt7W0cHR1Bo9HIgTIYDB7EWjIajQgGg9IBJV+LFA2bzSbm4CaTSTrrRKfTQaVSkbE5uaLVahWVSkX2sYfwrH4vOBrd39/H4eEhnjx5gmQyCbvdDrPZDK1WK3zmra0tOWuPjo4wm83QaDSky7WMMJvNiEajODg4kGlDMpmULjrfQYvFgvX1dTgcDvksvLSazeYFIdra2hocDgeGw6GsIzoT1Ot15HI5ZLNZVCoV1Go1VCoVFItFzGYz6ZwCWDhz7wtLXwCyxc82bigUQiQSgc/nk9vxdDqFTqfDbDaDyWSCzWYTAis3xmWpuJcBPIztdrs8K+UYhUWMw+GAxWKRG/ZDA5+Dy+VCKBRCPB7H1tYWHA6HdBpuj5rYIb09TuJ4ir52LADn8znG4zH8fj+CwSCGw6EUNrPZTDaRhwqO2K1WKwKBAEKhEHZ2drC1tYVEIgGHwwGtVot6vY50Oo2TkxOcnp7i7OwMNzc3qNfrkvTx0EDqgc/nQywWw87ODvb29hAKhWT0RlitVlitVulIBINBaDQadLtd6WiRB7nsY7l3ge+P1WrF+vq6PAubzQar1Sr7ld1ul6LQYrFI1x349VBtt9tS/GWzWeTzeSkE3W43yuUySqUSms0mxuPxN9kgWFtbw9raGlwul9AxDg8Psb29jfX1dRm/s7jT6XQIh8PSieX4l7SMdru9lKIkg8EAj8cjE5hIJIJQKCSfj1hbW5P1xb//XTQDCkPsdrt8r7KQ63Q6SKfTCIVCKBaLwgP3eDyw2WziWcpONKdA9/XOLnUBqFRm3r7tabXaBdsHm80mC9Rut8Pv98PtdmMymcgtZVU3xi8Ng8EgXZh4PA6n07ngoTYYDFAul3Fzc4NSqYROpyNt/of0/Fi80ViYSlOuLSUPiRsALxRs/XPT4yiKCjNl8UiCPwAEAgEkEgmUSiVMJhO02+0H19lSgl2/UCiEg4MDGXFubGzAarUKzSCVSuGXX37Bzz//jGw2i1KphFqthl6vJzfmhwZ2/jY2NrC5uYlYLIZAIAC73S4drduHES9j8/kcsVhMCuN+vy+jJ3KQVg18T7xeLxKJhHBuTSYTzGazuEDQCYJffPcIq9UKr9crI/NQKIRGo4Fms4larYabmxtcXFzg5uZGRnbLVrj8kWCR7XA45NK7u7uLSCQCg8EgF4nBYCA+m7zQBoNB2R+9Xq9M23K5HJrNJlqt1lKJkrhv8+/ML+V+/r4ISV5KlN9zF++U4NRnbW0NHo9H+PPkhZfLZSkMeTFpNBryzn7tRtXSF4DcFKxWKywWC4xGoxzMhFarlbEb+VxbW1uw2+3I5XK4uroSHuC38pK/CxqNBiaTCcFgELu7uzJmuV0AFgoF6VS12+0Hm7M6Ho8xGAzQarVQLpeFB8kXnDdeXjZokdNut9HtdmWzNBqN0jUF3tp6kJtEo2K/349EIoFKpYJGoyFE/ocIHh5+vx+bm5v44Ycf8Je//AWxWAwOhwMmk0l4RZeXl/jpp5/wt7/9Dc1mE6PRSC4dD7FDqtVqYbVapfuytbWFaDQKn88Hg8Eg1BZeupQHF9cTR1Zra2vIZrO4ubmR93RVC0AmwiSTSRwcHCAajUqBp+za3D6Up9Op/HeO6KxWK3w+HyaTiVgxtdttnJ2dLag6yZ18aBfcd4FCrPX1dWxsbGB7exu7u7vw+XwYjUaoVqvSeSf/2ev1wmw2i0WRy+VCPB6H3++XvU2r1WIwGCxVMc2/x11n13w+l/H2XV3g2xOd200nZRGovOyz20chCG1ixuMxKpUKrq6ucHl5iTdv3gCAqJDvY59b6gKQhZ3P58P6+roof81m829c/3ngmkwmOYhrtRoSiQSi0Sh0Oh2azaZIvL9FGAwGWK1WGXdubGwgEonA6XQu3KDH4zHa7TZqtRpardZKj5Xeh9lshn6/j2q1ilQqJRY4wWAQfr8fTqfzN0bEo9EIzWZTvOk4gqN/m9frRTgcFsEI7YrIJ+Gtm50IJXfpoYCjdYvFIvzKJ0+e4PHjx9ja2oLL5QLw67PM5/M4OjrCzz//jPPzcxQKhZX0GftYkM6iLP7of+j3+2E2mwH8WtA0m00UCgV0Oh0RI9lsNslM5vgzHo9je3sb+Xwea2tryOfzS03MvwsajUaeydbWFmKxGILBINxuN7RarVAplD5tLNpYcCh9O2+L3DjOHAwGUkgaDAY5N6rVKprNpggAHiJYsBgMBoRCIRweHgonzm63YzAYiBUKu1Rra2vwer2/yeOmTyV/F5x8aLVaVCqVpbFlmk6n6HQ6IgodDoeYTCZywSqVSri6ukK5XP7NGUdNAb0DjUbjQt3B4pAUF9Ym7EzfBXJ5Sa9it5rPvd1uf1VF/1IXgFThJBIJbG9vIxaLwe1231kA3v4+k8kk33twcCBk1n6//+AKmY8BN9j19XVsbm6KmpocG+XNhjcWjjqX6Ub3JTGbzdDpdJDNZtFoNHB9fQ273Q6v1wuPxwOn0yndBz6f8XiMVqsll4lut4vRaLRQXG9tbclYb3NzU5ItgF+pCvF4HO12G0dHRw+yAGQnPhAI4NGjR/hv/+2/4enTpwiFQrDb7dBoNBgOh2g0Gjg7O8P/8//8P/jpp5+QzWaXanz0R4DK83A4jL29PTx//hzff/+9uBoAkI5VKpXC//f//X/IZDKiVmeBFI/HpbCx2+3Y3d1Fr9eDTqcTOsMq2Z1oNBo4HA7s7e3h6dOniEaj8vkmk4lcvGq1Gvr9vnyufr8v5v8sQniZNRqNoiamoIudU6vVCqfTKYfw+fk5rq6uMBgMHuReR1DwkEwm8V//9V949uwZTCYTRqMRcrkcjo6OcHJyglwuh0KhILx6p9OJ7e1tPHnyBDs7OwiFQnC5XFI88n8bjUacn58jm80uhfvGcDhEuVxGOp1GLBZDp9PBaDQSZfP19TX+1//6X3jx4oUow/l31ul0cLvdCAQC4olIYQz/Oc2l19fXsbW1BZPJ9F6+PEUpDocDXq8XyWQSl5eX+Pe//y10InJ6vwaWsgDUarXC5QuFQtjc3MTu7q6MPPR6PabTqZAnZ7OZVOI8rLnRsttChc63xvdQgh0ZZq4Gg0G4XC4ZOSnBmzU32of4vObzuRS5jUZDOFa0xLFarbKm+FKzXU8TYo4PSFUIhUJoNpvodDrQ6XQLhSTH7x6PR6wFHlJ+LTswdrtdVJyHh4d4+vQpDg4OZJ11u12Uy2VkMhm8fv0aL168wMnJyYMXxABvR79+vx/RaBSJRAKJREI8ENvtNnq9HtrtNi4vL/Hjjz/i9PQUTqcTPp8P1WoVs9kMOp0OPp8PdrsdVqsV8Xgc4/EYnU4H5XJZbChW5ZlqNBpxH9jZ2UEgEIDRaJTP1Gw2kc/nkclkxKdtOp2i3++j0+lgOp2KSpjvKw3HQ6EQ3G63jIVdLpd092ltNZvNpKvPg/ihQDme5EQtFosJ9YCepRcXFzg6OsLr169RKBRQLpelWDKbzahWqzJKpy2RyWRCLBaD3W6Xjux0OkW73Uar1br3ydFwOBR3AXqyUkmu1+vR7XaRy+VwcnIiYhb+7snjC4VC8Pl80nVXqoHtdjscDgdarZZ0/Zgcwm60kr7BiRJH6TyDe72eTD+q1eq3XQAaDAa43W4Eg0Hs7OzgyZMnODw8RDgcXtgUmKgwmUxEzWm32+UXxMqcPCOKGr4l/zVCo9HA6XRiY2MD+/v7iEajknTxviLkW+BN8vNxzNTtdiX66C7SMCPduPZYKE+nU9RqNaTTaWg0GrHsIH+GMXLkcb2PTLxqIF2DAqy9vT0RfZDXRr5VJpORg+bo6AjFYvGb8WXjJYMKVq6JbreLRqOBRqOBer2OarWKo6MjpFIp5HI5VKtVlEoltFot9Ho9NBoNbG9vY3t7W6gHk8lEVIeDwQClUgnlcnklxpq8HHHUaLPZoNPp0G63cXNzg+vra+FO1Wo1AG8vqYwB5OiN+xmtPch129nZQTweh9FolIIzFotJwVIqlaRw6ff7D6YIJP3EbreLCC0cDsNms2E6naJQKODVq1d4+fIlLi4uJGVGaWg8n8+Ry+Wg1Wql69psNhGNRqXATiaTEiWazWZRLBbvVeEKvJ3YFAoFHB8fY21tDY1GAzs7O9jY2IDf78fjx4/R6/VwfHyMi4sLoaDMZjOx6yINQ+mYwffYbrcLfSCfzy90DOkmoYzVIz2B3cJgMIh4PI6DgwPpdne73a/yzJayAGTrPhaLYXNzEwcHB9jd3RXLkn6/Lzws3nLZTbFarVIAUuwwGo2QzWbFU0rpzv2tgB3RZDKJ3d1dhEIh6XDdVYTc9gb8VsANjJ0BpfKL//x2CgXJxORt8ZnFYjFUq1V0Oh0hEj9UsABkmgDHRVtbW8Lj6vf7aLVaSKVS+PHHH/H3v/8d2WwW1Wr1m7mQsQvA0aNOp5MCpFgsiu9hoVDA6empKKKBX0frjUZDOmLT6RQejwcOhwNut1tSZljI0AFhFQpA4Nf92u12S2dTq9Wi0+kgk8ng1atXePPmDd68eYNyuQzg7R7F9/D2Zc1sNov9R7vdhtVqRTAYlEOc4ziDwYBKpSLpDqTAPIQCkBcO8o9JIQiFQrBYLBgOh+L1+vr1a6RSKZTL5QUOKS/G5XJZCmRShHjRdbvdiEajMJvNqNfr+Omnn+TSd580BDotcD+nB6ter0ckEoHH48Hh4SFGo5HQgTqdDoC3HPHxeCzuEMozgZ1Ri8Ui3eNyuYyNjQ25+LPbSGEhaxPuA+wyxuNxUfGTg/lNFoBs3YfDYWxubiKRSEiblGM05vZls1lx1w4GgzCbzfD5fMIB4TiPKQ6hUAiFQkF4Jd9CAcjWP6PfqOhSjkpugx3WRqPxoG043oXP2bR4GFFpWK/XFy4oD73A4eXL7/cjEokgFoshGo2KSnA0GqFYLCKVSuHNmzc4Pz9HKpWSPM2H/GyUYKeLI0mOHsvl8sKYKpPJ4OrqCvV6XXjLtH/hyNLv90tmMAtKjqTsdrt0ulYFjO0i2V6j0cheREPdUqkkBfGHYDQapUPIcbCyw8iRKMd8iUQCtVpNuMHLIGL4vVB66K6vr2N7exv7+/sIBAKYzWao1+vil0hLkrsuDNzbhsMhdDodbDYbjEYj/H6/JIdYrVYpCHm5Ufqn3gd4Oe/1eqjVaiIWIjVAmfxx17vyvkYRVc/koCrPzVKpJHnKHo8HwWAQoVBooQnAC4vSQULJYf0aWKoCkBUy1WC7u7syqmT7lG3Zi4sLvHnzRjbIra0taW/Tc40bCsdS7Mg8pBf8Q1DmjHo8HlEqKZWtStCcuFarLQRefysH9O+F0leQ/EJ2E5T2Mg8NPBQCgQDW19cRDAblIFhbW0O9Xhey8+vXr5FOp4WQvSpChS8BdkoDgQBcLhfG4zGKxSJyuRwymQxSqRRubm6QSqUka1RJUeC7ycjLWq0mHWYetrczc1cdyjzfT9mH2FnVarW4ubnB6ekprFYrhsOhCArZiWHKFCP1Hoo9E62YfD6feCs+evRIngML6kql8tFpO/w+KtB5BlOpTq9ecuCWAVSAk0t6fHwsnpLsutXr9U8SobEzOp/P5XubzSYymQxcLpdcNhhDRxPzZcLSFIB8Eek7ROUvuQpcSDSOvbq6wsuXL2XxAcCjR4/kQFFW1+TIrK+vo1wuo9lsIpfL3efH/Wpgm9rlcsHtdkuMErupSrB46fV6qFarcsB8C4XylwJvnEqO4NeU9X9tKJM+2EmJRqPStTeZTBiPx6jX67i4uMBPP/2Es7MzVCqVb7K7zNEjs6Y5esrn87i5uZHiL5VKodVq/ab7oBxNlstlIa5brVZ5p5XmtasOJQ3lU336ptOprDHugUajUSIgrVarkPLpa9dut5HNZh8MXUO53qLRqPBGO50OCoUCbm5ukMvlUKvVRADzIQyHQ1SrVaytraFYLKJer6PX6wnNRRnWsCwXEF6e6BdpMplEOc4Rb71e/6T9SEn9YfevUqmI/RDFNu12G06nE7u7u3f+HF5w7uOsWIoCkMUf2/DPnj3D3t4eotEoPB6PEKVp2ks7AKrdqOKs1+uo1+uS38dbMQ8on88nGcIPSX35PiiLYEYomUym34x/5/O5cLQKhYL4QDEFRO0Afhz4vBndxc3wocbpkVcVi8Xw5MkTPHr0SLJsjUYjer2e5Pty7Fsul9Hr9b4JCsZtcD+iXxiNxSuVCrLZrDyfbrd752FMpaXygvEtcHU/5/PxgGaH5urqStSrFGZRDcw9kgblD+V84ATI4/EIt1Kn06FWqy0IPxqNhgg2PgRm33LcyZxl2rYpDcuXqQvNBke320WpVMJsNlsQp31qB5BQqp+n0+lCN56ef5yE3P77cOJGz89qtfpVfVDvvQDkQjEajUgkEvhv/+2/4U9/+pN0/5jwwQq+1WrJF9v1tOao1+uoVCpS4PCBsw3udrsXxlLfAuh9xZxDKpNuZyFyLF4oFJDJZJDP56XA/ta6NL8HvGzQaoI2HSSeP7SD2uFwYH9/H8+ePcPh4aEozBmHVy6X8ebNG/z44484OzsTJetDTZb5GLBDoNfrhSZAY/BUKiXuBh/CQ+jw3YW7xtj8/z71M/NwbrVauLq6QrPZhNFolLxlJv0o+YcP5bLGZ2Y2m+HxeOB2u+XScXNzg3/84x/417/+JSbYH1v8cPSpTFDiSJg+n/z3L9Nz5N47HA5RqVTQbreh0+lkjTDF5PeA3UVlxvfGxgaCweBCtrfSxJwFaSqVQj6fX6B9/NG49yqIt2EqYZ48eYInT57A6/XC6XTKC9/v91Eul5FKpSQxgAKQ+Xwu1gjlcllucnTVVwohaL3wUDfP21DeADmSu2s8xAKQfCSqVx9yTu0fAYoh3G632BKxGNJqtSvd9VKOF2l/sLOzg2fPnuH7779HMpkUk9PRaCQRby9fvsQvv/yCVCqFZrP5TVIKeBgbDIaFbNtOpyMG9bSBub1GlEH17BzSmNfn8y1MO1YZPIR7vZ4UZVSv+v1+GZvTfP1TCpbBYCAGu8ViEY1GQ34OlcRKS6dVh06nEy9dv9+PeDyOUCgErVaLRqOBXC6Hy8tLXF5efpZX5Hw+x2g0QqPRQCaTEWGh3+8XFw+v1yu2Wcuy77Hw6vV6XyQVTOmxqNQbUMS6v7+P7e1thEKhhQKQXUdyBtPpNG5ublAsFr9qWMW9F4D071tfX0ckEpHIN2V6AgD0ej1cXl7ib3/7G168eIGbm5sFjgx9tCqVCvx+/4NPFPhYkJzv9/vhdrtlnH4blLzzNkcvPBWfBnpDMS6JUXAPpQDkhS0Wi4lH57Nnz7C7uytK/U6ng1KphEwmgxcvXuDo6AgXFxdf1eB02aCMJVPmmt/Fxb0Ni8WywIn2er1yqG9tbcHj8cBkMq18p34wGKDRaKBWq8mBarVaEYlERIFKoQIL5c/Zo26PKDmGa7Va4nywyu8p8KtlEAtnpmHFYjGMx2OUSiVUq1WZnn3uPj+ZTFCtVnF1dQWLxSKZ33a7XQy9AUih/RBBDjSTPWiwzRzrYDAIn88Hp9O5UABOJhPp+tN/MJPJiNn218K9FoBKhRKLPyYkKAmk7PDRQPbNmzcolUrSKmWsS6fTEQf8Vd8MvxQ4kmS82V1RNeTKsANYKpW+2ci83wsWgLTi4EH/ECgHLP4YAfX999/j+fPn2NraQiQSEXulVquFdDotRs+np6fIZDJL1Qn42uCz47rgxUAZNXiblsFuq8PhwMbGBr7//ntsbW1hfX0doVBIbCM40VA+22Ubv30I7NKVy2Xk83kZo5lMJgQCAQC/FhKcTDAF5GP5j3y+yu4rqTDsxlSrVeE9r/L5QesXJk1Eo1Ekk0n4fD4ZM1KE9anCGiU4Ncrn83A6nSKioP/u+vo6SqXSg9j7CCUVYW1tDUajUUQ2sVgMe3t7+P777/Hdd9/J2JfrjC4mTPwpFAo4Pz/HycmJeFC2Wq2veu7e22+GD5C5mJubm5LRSAKuck7OeCSGOivtI9jCJwlT7Vy9BQ+QcDgMn8+3cAsB3iqQGEGTTqeRzWbRbrfV5/gZYMHNKLm7YvZWEUr7kkgkgp2dHezu7iKRSMDr9cJgMAgNI5vN4uzsDK9evcLZ2Zncar/V4o8jIrvdLlQMi8UCs9ksRrHKjiD3M4oUaKy9t7eHZDK5QOdQpgsAb/fVZSPgfwjz+RytVgvn5+fCzXO5XMJb9vv9SCaTUpxRdcki5kN7lcVigdvtRjgcRjKZRDgclmdIL8abmxtcXl5KBNoqg2NY7vs2mw2TyQT5fB5nZ2fIZrNfJG2CBY0yR5fnNs+Wh9JI4LvFFB/mxnPkG4lEkEwmsb29jUAgIOEUyoudkmd/cnKC4+NjnJ6eCvfva5+591IA8mbLNjXl6evr68LbA952phg90+/3xVdNLU4+DmxPB4NBMeVVHgpUFHa7XWlJZzIZyXFU8WlYW1sTDupDEhtpNBqJVtzb28Pe3h62t7cRiUSErsEuSjqdxvn5OV6+fIlUKqVeJgDh7fn9fng8HlHjK9cKFeM8NM1ms2Qqb21tyR7JeCml2pDg3soicFUwn8/RaDRwfHwMrVYLj8eDZDIJq9UqtlXxeBwajQbT6RSVSgWFQkGMnt+3vugtG4lEsLu7i52dHcRiMZmIDAYD1Ot1XF9fi1BplQtAdgA9Hg/W19fh8XgkQSudTuP4+BjpdPoPiRvjmU0u58cqi5cdfK/ISfX5fHKZ2NraQjQaRSQSgd/vF/qP0paJ6HQ6OD8/l8nI+fk5MpkM6vX6vay5ezmdSIYmWX59fR2JREIKFN4iRqMRms0mqtUqMpkMyuWyCBO+1W7Cx4KeTDabTeKVuCgJEnQZj1Or1VAsFiWa6yG8uF8LPHTZ5QkGg2IxQfB5s5Pd7/eXfh1zbMYDdGtrCwcHB9ja2kI4HBbBR7PZlKSPy8tLXF1dSfb2XYeMUlDyLs86pb3CfYfK/x4oL7ws8vhFvigTVJgpOh6P4ff7sbW1hd3dXSSTSeFHvwv8XTkcDrhcrgXF/7I/v/l8jm63i0wmA6PRiK2tLezs7AixnjnBa2trQgeiaIZK6vF4/BsPTlKE7HY7otEodnZ2kEgkEAgEYDabxVeWNjy5XA7NZnNlR8Bca1arFYFAANFoFF6vVxK0isUi0uk0qtXq744IZGfbYrFIMohWq5URZ7PZxGAwWIlzhGNd7uNra2uyN/GL728oFEI4HEYsFpP3c319HYFAQCzm7vLY5Vpj94+2WLxw3MdZcC8FoFKZyoiUUCgEl8slvjw8KHkrY6u0UqlgMBgs/cF5n+CG53Q6EYvFZATAw0A5Mifxmqayg8FgIQdSxYdB6T9H7YlEQi407GhTZUhD5OPjYxSLxaXvNJhMJkSjURlDHh4eCsHZZrNhMBiIcvzi4gLn5+dyq32XnYGyS8VRyl3Gu+wk8NK3qgpNpc2EcpIxn89hMpng9Xqxu7uL8XiMSCSCXq+H0WgEt9stNhKRSGShe3/Xc6UPWyKRwGAwQKlUwtXVFTqdzhexuPijQV5ftVrF9fU1Xrx4gcFggHg8jnA4LNOMRCKBv/zlLwgEAuIE0e12RcRRqVTE8H82m8lz4bvJCDC+j+l0Gvl8XlJVVnXCRJGWMvptY2MDPp9PRuzNZnMh4vP3XAp4zqyvryMajcLtdmNtbQ2DwUA6tMtuI8aLKL2DuR8pOdxKegY1C6FQSAQerF2oxr8N2uZQ6U+uabPZXOiS3scF7d4KQLvdLg+PD5PcD3oMNRoNXF5e4h//+Adevnwp9iT9fn8lX9CvBa1WKy9mPB6H3+8XPprS94immEz9oDLsIfE2vgaUN+54PI7NzU1sbm4iEAjIoc3nXavVcHJyghcvXiCfzy+1JQqNTBOJBL7//nvs7+9jd3cXGxsbQtCv1WoyVjo+PhZCMy8T/Dn8z9vdMLfbLTmitzEajVCtVqHVaheKmFU0Plb6prXbbeH7GAwGeL1e7O3twe12S6rCYDCQaD2OlW7zd5XPgM+VXm8AkE6n4fP5UKvVxOJkmcECn/u+Xq/HaDQSJwN2AmOxGBwOBx49eiTFNCO+stksLi8vcXp6Ku+dTqeD1+tFJBKR7p/FYsFwOES5XJY0DD77VZ1+3C4AyUlzOp2YTqdoNptoNpvodDpfZPqg0+nk35NIJODxeMTbkgVgu91e6maC0ofY6XRKWAQ7erTzYtKOcs9yOp1Cx6CY6y5wmtnpdOR30Gq15HdwO7bxwYtADAaDcDzYoaJVBvDrxl+v1yUe6fz8XEw81eLvw+CInaMgBp8r3e05+i2Xyzg7O8Pr169xeXmJdru9cofr+6DkQ/Hrc4jxyjgqjpr4881mM8LhMJ4+fYrnz59jc3NTEmcMBoN0f1qtFvL5PE5PT3F2draUHcDbCjfGMrL4YzoP8Ks3Z6FQwMXFhVi9XF9fo1QqYTweL4TGm81myd4ksd9ms0lCwV0FIBN+qtXqQjRht9tFt9td6oNFCXYA2u02qtUqSqUS8vk8AoEAPB6PqPTNZjMCgYDwnNlV5t7IUS7HnCwqNRqN/BnyqinWicVi0oVZhUxvKnJzuZxYtDAEIBwOy3tFnzmlIMTv9yMQCMDr9cLn86FQKMga2djYEBsdNhqazSZKpZL4rzH1aJXPF767NP93uVxYW1tDs9lEvV6X4u9LTHk4+WDOPMef7OTS0HyZnyfV+X6/H9vb29ja2kIoFJJ15HA4xFN4bW0NBoNBOoRKvcL7wOKO/PBIJIL9/X14PB40Gg00m020220RNXECAPzxxeC9FYDBYBA7OztIJpNwuVy/KU4qlQpSqRTS6TRKpRLq9fo3rST8VCgl6kq/MVp10PcqnU7jp59+wt/+9jdks1k0m837/qt/UbBAI5GcPI9PAUfmAOQmRxNRFtrb29v47//9v+P7779HIBCA3W6XQ3s0GoljfiaTwcXFhRTby+RXqezOUZyQTCaxubkpRGeHwwHgV1/OWq2GVCqFN2/e4OjoSMZo/X5f+Gjs9Hu9XhEi+f1++f/olajkShJU/3c6HVxfXwtx+ubmRkYqq4LxeIxWqwWNRoObmxtcXFzIyIwCEI7ESc/gRY6XN767/X5fRFvskIbDYRE18GeGw2Hs7++L4Wy9Xr/vx/BRYOeX43+us/39fTx+/BjxeBxOp1NsdIxGo/BUKZp59uyZeAZOp1O4XC5Eo9EF42yaQt/c3KBcLn/VCK4/AspOu5LPxqxb5Yj7SxR/FEWYzeZ3JkwtO4xGo1At/vznP+M///M/sb6+LpdVFn3KJgL5ux8LFso6nQ7JZBIGgwGbm5uo1WoL58L19TXy+TwKhQKazeZXqXXupQBUKlPZ/VMeyuSvpFKphdb8+0YY5At1u93fcASVi5W+PNxUl50c/Tlg/JvD4ZCNUpkUoOQkVKtVpFIpnJ6eol6vP4hnwc2PI4pgMCiWD59ry3LbLqfZbGI+n8NiscDn82F3dxdPnz7F7u7uQreGY/ZyuYzLy0tcXFyIOGLZRnIcZSvd7Hd3d8WiiRc1ZcrH69evcXp6Kikfo9EIa2trcLvd8Hg8CAQCQvOIRCIIh8MIBALS/eLvQ3loKbu0fK/D4bAIKIBfu4PM7lyFrGrlCLZYLOL8/FxGS7SB4aWNkYF3+XUyNqpYLKJYLCKfz0Ov10tXx+12w+FwwGAwwOVyIRQKif3TqkQRTqdTtNttdLtd9Pt9tNttEQDOZjMMBoOFz8UDWWl9QmsSdkr1er0oioFf32e+lzTgXfUCEMDCOJHFGDukFJ79Xi6tssvI2Et2xFZJeU7uH8fYOzs7ePr0KUKhkDwf5XPk/+Y7xDNhNpvJM3lX8UthCS9n0WgUnU5HBEhXV1dwu93Ch+Y7PRwO5ff1R3RS79WjQunGrgSJ5VTIfMgfhyMCvszRaHSBW8UK3Ol0wuv1yiapFD08JGi1WthsNqyvryMWi0mn4TbIpXoofk1cS+ykWK1WJBIJPH78GIlEQqw3Pifonc+IZtnlchkAYLVa4fP58OjRI/h8PlHC0di21+uhWCzi+PgYv/zyC16+fIlqtbqUnWyj0bjAY9zd3cXW1pbwiDQajXRAz8/P8be//Q0//fQTrq6uhDpAjszOzg4ODw8RiUQW3jmaILNDyk70bDaT3w95NSziDQYD/H4/9vf3pauh0+nk4KYB7SqA6Qmnp6cwGAzSNQkEAlIQAvjNnsjRb6VSwcuXL/H69WvJ67ZYLKjX62g2m0gmk0gkEsL15R67imCxx9/vZDIRfiDpQ+wm03eT3T2uE1Iw2NlWTkCazSYqlYpEw33ITmYVobRSY/H3e8FRptvtFsNnOgKsra2t1DnCApAjbCUNbTKZSKGrPD95FvByOhgMhHrxrqQtgpMRdhJ5AbTb7WLaHQqFcHV1hUKhIJnFjDH80s926UzKlI7wmUxGVL/v++Dz+Rz9fh+VSgXpdBrb29u/KQBJ8mQottPphNlslg7CQwGl+U6nU8wp3W73J7WsVxUs/tj99Pl82NnZwX/+53/i0aNHkqWqLADvEhTcZaDLl77dbiOdTqNQKAAAHA4H3G63KMF42DJar16vI5vN4vj4GP/+979xfn4u3cNlAYsEpi7s7Ozg+fPnePr0KTY2NoRfpuwan5+f4+9//ztevHghxHkaG4dCITx58gT/9//9fyORSIhg5LZZMYnpuVwOo9FIspPJEWQxzU7u1tYWLBYLptOpvN8cra5KATidTlGv13F+fg6tViscXT5/i8Xym++haIuX3NevX+N//+//LZQNh8MhB8Ta2ho8Ho+M6lcdHFdSQcnx+c7ODra3t7GxsYGNjQ2sr69jPp/LxUE5ruO7xovZeDzGcDiU+FCORh/SOQC87VaNRiMpAL9Et5wFoM/nQyAQQDgcRjAYhNFolInaMu1v74LSoJ0Waaw/OElkQacUYLFmGAwGwt0jxehjzlkWgXxejOyj2t/pdMLpdOLs7GyhQcai80vi3oygb3f+2MXjiI1jDvI4PrSgSLKu1WpoNpsLLzNbvVarVUxXyY37nG7QsoLFj8vlkpeTpOmHYkj8LlDJRXshKuA4lmUHkEUF8HYMRC4evRMZF6XMaWULvtfryTMGILc3Fpc8YLrdLnK5nIzXyV1btlETu+MWiwWhUAgbGxs4ODhYSPmgUXi1WkU+n0cqlcLr169xfX2NVqslRGkq6JLJJB49eoSNjQ24XC4MBgM5ZLvdrrybtOEoFAoYj8ew2+1iX8Sfyc2QfMLJZIJQKIT19XXUajUUCoWV6nDxgJnNZhK7WK1W4fP55HC+fflgTCMvEzc3N0ilUigUCkIa58XG6/VK0f4QuvrsYLHbwvNAo9FI97her6NSqSCZTEKj0YidmDJRSgkWh2wKcI3zgF3VLiA5oxRcsRs3HA5Rq9VQq9U+2Ez5GDClxe/3C4/XYDBgMpmIHY9ydLmsYHHcbrdRKpVwfX0t3D+KZbjvsADkehyNRrIeO50OnE4nNjY24Pf7pQP9Pi4kaxLyVy0Wi/xeqHxnpGgmk5GmA4v4L/Vcl6YqYLu/UqnIaKPRaHzUzYyHLsnOt/mC7AwBEEEAx0yrdHh8CAaDQYy1qZijP9FDKnRvQ2l+ykzGx48f47vvvpNbFTlWwNuuHzsq9Eoj/43xPna7fSFSi4cGyfbAr8+cX2tra7IJVqtVXFxc4Oeff8arV69wc3Mjbu/LtCkyKpA30N3dXRweHkrnbjKZyKgslUqJ1cvx8THq9TrMZrMIRRKJBOLx+MJIqN1uiyl0Pp9HsVgU8Qa7pMx2pWjJ6XTKeIkCFP4+jEYjbDYbXC6XFIarRDpnN48HCPPL32XLwXVKo2IKFnjI8hnm83msra0hkUig1WrB4/E8SH4z8JYixPPi+voasVhMOjHsUPFCdhtUfnq9XmxtbaFarUoCyyoULu+CkvtNigXwNn6sWCyi2+3+7s9GF49IJAKPxwOj0SjTkVKp9JvAhmVdg/P5HJ1OBzc3N5hMJmi327i5uRGzcXYAOY3g91B5zi5yt9tFKBRCs9lEPB6Xou59tQX3XbfbvUD7sFgsiEajCylAZ2dn4ltZrVbR7Xa/mHvEVy0AyekhGZ+HJvkq3W5XCsBKpSKGlR+zYJW5wbcXHQtAZetVqex5KDAajfB6vcIjUKZ/PKTPqQS5PdzQGSv4/Plz/OUvf0EoFJLPzhEmu3n9fh83Nzc4OjpCvV6Hy+WSn8GiT3lR4K2NpGdgkXSt1WoxHA4lEo1mtq9evZLLzB/B4/g9oCcn87gZORYKhaDT6eSzUO37008/4dWrVxImzyLtu+++w87ODjY2NhAIBORSVi6X8ebNG7x8+VIEMN1uF8DbQlzJVWOB7fV6hcvm8XjgdDoxn88XyOfscqxSAQhA9jNlaoUy7YTrk4Ui85XPz88lNL7T6YitCwtEACiXy9JlXUae6ZcAP2+j0UA+n4fRaESpVILL5cLW1pYoffmeKdfZfD6XTpnb7cbGxoYYlrPrw2SRVXt+Wq0WZrP5NwUgaRuVSuWL2KhxD6QXHr17G40GcrmcGEAvewEI/Pps+D5R2KbRaESAwUspnyXfTfJIaXUTiUSEI84J4/uaLnq9Hn6/X9551kScUrGR0263Ybfbpbjm2ly5ApBZoh6PR7gbjDaiMovj32q1uuDJ9KEFxMo5GAyK0/tt01T+OR40n+sHt4zg57JYLKJmikajoih6yN0/2l1Q8JJIJLC5uYmdnR2JjmJCQL1eR71eF35Hp9NBOp2WDqDNZhNFWDabRTweF+Wq8pbGi8xdmE6nEq3H0RTNy5dJrap0wHe5XGLmGgwGYbfbodPp5H28vLzE0dERjo+PcXl5KQIYWjkdHBxgf38fwWAQVqtVOLz0PHzx4gVOTk6Qy+VQKpUwHA7lPeTYnTdtKpCpGmYXmzdqUj0qlYpYzqxat0Zp1svPyU4KJyFUbVLpe319jevra1xdXaFcLv/mIqEUdC3LGvujwC4MCzQaPvMCoRR68bmwYzOZTITmQU82FoV6vR5utxv5fB65XE66ZavyPNn5ZN40M7rZHFHaWf0eUJiljH5rNpu4ubnBixcvcHp6ikajsfTFHwDp+vLCxcspi1edTrdAGwIWRSD0hGURV61WpZj7UAHIMysSiYhTAidVyn2Re0Q0GpWJCW3Ifi++agHocDiwu7uLZ8+e4eDgAJubmwgGg6LU6vf7qNVqos792JePP3tzcxOPHz/G5uYm7Hb7V/hUywEepna7HfF4HPv7+xLV9VCK3HfB4XDg8PAQT58+FcXq+vq62BJ0Oh3hrV1cXODi4kJiopjK0Ol0RPFlNpuFY7a7u4sffvhBFIYfA3YW+bNpZbFsBrNcMywAY7EY4vE4vF6vJDDQuPrNmzf4+9//jsvLS9RqNfR6PTFOff78OR4/fozd3V1otVopGk9OTvDq1Ssxcc/n8/IclMUf+S9OpxOBQADBYBDxeFzGybFYDOvr67DZbAB+vbGXy2WkUilks9mlTxq4DRYpVP1tbGxgZ2cH4XAYFotFRu63DbZzuZyoVZU8Sv5MHshGo1EmG3eJmR4SbmfRsoNvt9uF/6aM4Wu32+j3+7BYLNIto7clbckikQh++ukndLtdyRde9iKGoKAhFApJ+hO5aO/L3P5UKH8eLyzdbhfn5+f45z//iePjY1Sr1ZUpnpV2XeRn8+9+2wZG+c+UFy6KVc/PzxdESO+CXq+XSzftZ5RxrVzbACTGMJlMotVqoVgsfrHP/lVHwFarFfF4XDpUzKgFIDFPNDn9GK6UkvgfiUTEi40kaEJ5A6TzNqOlVunweBd4oFqtVqyvr2Nzc1NuE3cdALyxFQoFGRktU3HyKbBYLEgkEnj69Ck2NzcRiUTgcDgwGo3QbDaRTqfx5s0bHB8fiwEzC0Bl8aB0z7fb7aKq/NS8TOV4gB5ky8gp4gHodDrh9/ulAPR4PNDpdMKNefPmDU5OTnB6eirRdVSv7u/vy3P3+/3SYb26usKbN2/w888/C2et0+ksWB8wpYYiD/IuGcMUiUQQDAZFsT+bzdBqtVAoFJBOp8U0td1uL92zfR8ouvF6vfI5WeBSFMMoszdv3khCD/etu/jQ7FxxosIuwkPu/AO/rmGLxQKv1ytG2Iy8pP8ds285AqVoKRwOi32Mz+cTQ3eLxYJOp4NsNivdlmUSbX0I7IDyIkAxgdfrhdfrRa1Wg1ar/eRzj9MCk8kEv9+PYDCIQCAAs9ksljqZTAaXl5dIp9MfTd1aFsxms981ViX3+2NhMBik+TCbzeD3+7GzswO32y0FIItAnktsRHxJQedX7QAyciUUCsHpdP7uDYpkVHYj9vf3sb+/L7dpggTVarWKm5sbZLNZuU0vWxTX54C3FPKn6ClGU9nbGAwGyGQyePHiBY6Pj1Eul1fGRuM2mKXKw89kMmE8HqNUKiGbzeL169f417/+hZOTE+kuk+Cr3KBMJpO02nd2dsRmIplM3mnNseqgbxULrmg0uuDl1Wg05NldXFyg0WjIJc1iscDv9+Pw8BCPHj1CIBCQvN5cLofr62uk02lJBqGqjXwaRiFtbW0hGAyKktBqtYp/I/83xy9MrTk9PcXFxQVSqRTK5fIXyTT9mqBHZygUWkjv4Gg7k8ngp59+wj/+8Q9cXl6iVCqh2Wy+97JK4vju7i5isZjwsh56B5D2QNFoVC4vLP6AX9NqLi8vcXJyspD3y4436QuxWExi+MbjMRKJBDY2NjAYDJDNZlciQg942+ggT2wymUCr1YpClQKNfD7/yT9bq9VKl3Rvbw+Hh4fY29uTRJVqtSrJFl/Kb/Ahg53T2zZEX9tC56t2AJmjRw++25WssrX8ro1LmRjAjTSZTIppbSKRkPxDQumgz1EKH/oqHR7vgnKc53A4xNTydguaLetut4tsNouTkxNcXFyslJHubfAzs5tM4YIyqeKXX37B2dnZQkePtyu26j0eD2KxGHZ3d3FwcICDgwMkEgm43e4Fc08lSV/JJyWUcWpKhfCXiF/6kmDcGy9QtAzS6/ULeayZTEZ8C/lZOGZigWw0GjEajRZsOWitw840HfAdDgd2dnbw/fff4/nz56JWV47YlYR9dlMLhQIuLy9xfHwsQohGo7EyYyaCXSvaNNFCg7m0l5eXePnyJf71r38hl8uJBcq7/CpZUEajUezs7GB9fR1Wq1U879iJXqa19yVAYZ/D4UAkEpECkOt3Npuh3W7j+voaP/30E05OTpDNZlGtViX6i7xfirocDgcmk4lwidnRbjQa9/1xPwrsYvGCO51OodFoYLPZEAwGhcv8qWNgPmun0ykjS6YDtVotVCoVFAoFVKtVtNvtlSmYvyaU6VQ0KL9dA72r+FOeOV96v/uqBaDSOfs2J+p2gsNt5SoPVv4Zh8OBUCgkIfV7e3sySlH6twFvq+12uy2B2KugUPpU3I7Tug1a5VSrVRSLRWSzWZRKpZUeAQNvfb3II+XvuVwuL/BJ+fJwDEm1FikE+/v72NnZQSKRQDQaFSI1VepKuyHm3dJugs+bnTUapEajURGhLJNhMTOSlZ02pSrebDYjHA5ja2tLCj6aL1utVmxvb0v3ilwgo9Eo7+V8PofD4UC/35fLidPpXOi+kHNIzhaVd81mU0b0VNrlcjnp/uVyOfR6vZUsajhKo/ms0WjEfD5Ho9HA9fU1zs7OkMvl0Gw20e/337k/UaFOu4hkMomtrS0EAgGsra2JFZFSePNQ9jqeAzQjjkaj8r6ura3JRSSbzSKVSuHq6grpdBrVahWtVkvOHqvVikAgINQZv98vncDt7W1Uq1Wk0+mVidBT8vF6vZ4Ihfj+fYw/3W3wPKFPKBstVqsVvV4P2WwWb968watXr5DJZN67Zr9V8NJnt9vhdrvh8/nErJ1rL5lMSgPjtj8yqXHtdvuL73tftQCk6WKr1frNCI7cGJrqcnypzMWkeo7dmq2tLTx9+hSPHj2SDUA5AiBms5lsCtxYH1rx9zGgpUSpVEKhUEChUJCR6Ko/CxZn5P3Q+LTT6QCAHLTAr8WNz+eDz+eD1+uFx+PBxsYGHj16hK2tLbjd7oXUEK5BriN2BrhGlSox2kuMx2NEo1EkEgm0222xFlimApACjNvxeBqNBk6nE9vb2zAYDGKuTrqEyWTC7u6uxMMpO1GBQAAA5HI2m80W6Alut1u4fyyA6J/I5JRMJoObmxsUi0UpBIvF4sL/p0z6WSWwALTZbBI9NZ1OUavVZFxZKBQ+WLBxv3S73QiHw0gkErJ2tVotms0misWicCW/lGpwGUDut81mg9frFYUk+VP9fh+lUgnpdBrpdHohUpSKz/F4LPQRs9ksUYSkRezs7KBUKuHFixcrM0ZnB5Disy8hYGHxaLVaEQqFcHh4iM3NTdkX6HV6dHQk3owqFkEHlHA4LF6ryWQSLpcLFotFMrs5ubzLCJ6pNd1u94ueIV+tAKTrNjszoVBoYYFyQ3O5XNKuLpfLsnHdXojMKmVqAdv/yoOMPlo0smXR02g0Hlzsz4egTKggMZredMsoUvhUUO3HTh8LHJLjjUajeDZxbMTxo9/vRzwex97eHmKxmFiO8GeyM8WXMJ/Po1QqwWq1yi1bqeCiITVvdsxyrNVqS5M3yoKWXXlGRSk/QyQSgdlslgOFG49er0c4HF6gGShHm8pim/+M3lbk/vLf3Wq1hKNbq9XE2Pfy8lKizljQl0olEX2s6oVF2QGkVxspKtVqVQRK79rk2f1ibF8sFpNRfDAYBABRvmezWWSz2Y+K01wl8N2maXsgEEAgEBDbq3q9jlwuh8vLS4kTpS2JMsPVYrEgm83KHjAYDGSNhsNhETmsCkjdYJOFSmalXc7Hgpc6i8WycMmg4KtWqyGbzeLq6kr8KRkJ+RDAiy3tWJR1BRX27I4y1pKCP4LfS2eJjY0N7O3tiVMJnSoYTKFsJPDc6fV6kopGA/gv+Yy/agHIlvHV1RVCoRC63a5I1ckVCofDMiL2+/2iBmYXwWQySdufCjpupLe9/ZQ3wVevXoktRblcXtkOwu8BCxouVCW/aNXBIoMd4Ol0KqNdpspQ2cuRJjlY7CTQi43riMaelN7TeoQcUqfTKWNmFpRMD+E6JZm81WrJ6HIZLGFopsvRTiwWE2d6vk+M1HK73QsHCV3smcKh9KHk2JxeWcqCXKPRCFeQFxB+UaBTr9elQ83xPZ0BVk3wcReUueROp1O6yFxz7/OYBLAQ9bi5uYnDw0M8fvwYsVhMDJGvrq5wcnKC8/NzKZofysEMQKgX7CYzQ5pc3WazifPzc7x+/Vr8/G5PfJT8Ul7wZrPZb+LUVklJPZlMhP9cqVRQr9fFkJmfTemB+z56AXlqzAZn4RIIBDCbzZDL5fDmzRtcX1+LwfQy7GtfAnxGSk49PRUByLNhF5rep+RAElyjPp9PBIbsVgcCAZhMJskPvk3b6vf7MgmhE0M6nRZR3ZfCVx0B03jXZrNhY2MDnU5HDhlyOljk0TOQG5cyO8/hcMDr9YpBLE0Xb99wer0e0uk0Xr9+jaOjI7x8+RLX19dfNEpllaCMsWHxt+oHqhJ6vV44ok6nE/F4XPh35JNRxerz+UTgQeNNXiKU495msyn8sx9//BFnZ2diJ+H1elGtVsWegw78ykSLzc1NTCYT5HI52O121Ot1KYzuE4PBAKVSCdPpFD6fT+xzyOVjh8VisdzZcaPQRUlcVlppMOqMY3nyk1j8ZTIZZDIZZLNZZDIZFItFEY+QY6lMyOB6XXVwH1OOwNlJoMiG/GflfsbikIT+jY0NPH36FP/xH/+BnZ0dGcdXq1UcHx/jl19+wcXFBSqVihRADwV8x1n8kUtFAUi73RYBGP0nb39+Xlq4H7JAVDoq3PV7WGbwwjqfz1Eul1GtVtFoNKDRaOQSxoKD0Xfch5STOK5F8tOePXuGH374AclkEl6vF5VKRc7Vy8tL1Ot1uSAuazOBv0Plf971e1WK+MivpbAUeNto4P5IGlGn0xGPWfIumWYUi8UQCATEdogdP2Uxrvz7UKh5dXWFf//733j9+jXOzs5wc3Mj+oUvha/aARwOhxIoXyqVUKlUYLPZFl5gVsOMSlG+mMqWKqOg7gJf6kqlgrOzM0kiyGazonh9SBvix4K3lGKxiHa7/SAOVKXyzWq1ygGrJNq73W4pRqbT6UJxw82RP4Nt98FgIF6J2WwWp6eneP36Na6vr9FqtWTUwtG6xWJBMpkUSxRGrAGQrEiPx4N6vS4FzX2CI4Z6vS5eifP5HNVqVawyOJYgT1Cr1Upnj7YlLKr5TtFZn6HwStEX8Os+wOdaLBZRKBSQz+dFnEWS8zIfJr8XSr7qbDYTUZLf70c4HJaOJxWCNPhlZF8ymcTGxga2t7eRSCRgtVpFjXl8fIzXr1/j5OQEpVJJKB4PCcp3nGcGBUw0Yq9UKrLPKbuf7MCSShSLxRCLxeDz+aSjreyUrRK4j2k0GjSbTaEUsEiOx+N49OgRZrOZCGI4Jh6NRuINSgqL3+/H/v4+njx5InzgYrGIs7MznJ6eik3RMq8xZbeOa8ZsNgsHl2bLhFJsSts6WisBb6lo5PGyC93r9eB2u1Gr1aQAJP2ItnfsUr/PnkkZSffixQu8ePECl5eXKBQK7/QB/V3P54v+tA+ABNVarSYRR1QAKr3WKO9X5jkqq2QWgndBqYTK5XJ4/fo1fvrpJ+RyOeH+3Xf35T5AnlEqlcLl5SUqlcqDGAuRINtqtcRuheMbdvd4WWAniy8xlcNcl/TJKhQKKJVK8sX/TQEC7WRarRaurq7Q6/UQCATwww8/iPUCx8A6nU68L0OhEGq1mow177vAYZczl8thPp+jWCwiHA7LgchDIBgMSqwex2Y01200GgtFM9X25PPRF4yUC3KBWeTwWSgL9FXm+H0I5GlxTMcphsvlQjKZlJi8yWQi9iM2mw3JZBLJZBLxeFwiCtltrdfrknJzdHSE169f4+bm5kFxst6H21m/vNywuFHu9/QODIfDODg4wPfff48nT54gEonIhe12V2yVwMtZt9tFpVJBqVSCzWYTn1S73Y6dnR2JFszlcjLFMBqNMg0grzSZTCKRSMDr9SKTyeDo6Ag///wzXr58iVQqhXa7vdTTNL1eD5fLtUAX8Pv9SCQSory9DSUVgzZBymYTx7W0+jIajZhMJggGgwvUMl6e2fGjwONdxd9sNkOlUsGPP/6IX375BS9fvsTJyYloIf6IptVXLQB5s2fOJTkKTqdTcvfY7TOZTHIjex84rlOSe9nlOj4+FusIWnA89OLvfYcnTULp0s4FzN/LKmIwGCCfz+Ps7Ay1Wk1ioWw2G+x2O8xm80JXmYcrDwdygCqVipDmOZoslUpicMpEAKVwiYXLZDJZeEmVUT7s7rCDw5SG+wY/w3g8lqK0VquhUCggl8sJMT4UCglnxWg0yjPjJY6dOwpdmINcq9WEtKwsAJX/buX4eBUP288BJyHkPjJZhYbG5GoxwxqAeCcyQYkJKVT1ZzIZvHnzBi9fvpRR0SpFcX1psPvv9/tht9sX3llmsHKEfnBwgO3tbTidTphMpgXbDQq8VukZsgDudDooFAq4ubmRfPNgMAiz2Yz19XUxYPf5fDKNo/VTMpkUI3yqpKfTKUqlEo6OjvDixQtcX1+jWq0u/QVDKZYKhULw+XwinOLv/S4om02kqL3rzxE+n++dP4dgg+o2D38ymWA0GuHs7Aw//vgjfv75Z7Ev6na7f9g6vJeTiAXbaDSSbgoXknID/BABV6nqorq1Uqng/PxcskjZoXnoI6XbN2CGoysXHz2uaILaarWg0WikE7OKxXGj0cCLFy8k4onFlsfjkfgjv98vI01G8HAE2Wq1ZHyWz+elqKFCWumpdZdgRnkBeReUqrFlg9LbkCMkRq6xaCXJnupdjsuVxZ0y+m44HMqzo/BmVS8YXxosADudjvB5NBqNCJNMJpPYEvX7fQCQUZTf74fT6RReJkdFr169km5BsViUeKmHut+9D1qtFn6/H//xH/8Bu92+YP0CvFWwKzuqHM3pdDpMJhN0Oh2Uy2W5GK0ayINMpVJyllJ9T2EH4wM3NzdlvzMYDGJQTq86AHIJvri4EFuhVQlRsFqtSCaT0uVlhJ3f74fb7b6TRqbkCZLO8zE80I/5M8ruf6vVknOY3rWZTAYXFxe4ubn5Kq4R91IAKrN52UUhiXQ8HovM/y5PP36/sohUcgtvbm7w888/429/+xsuLy9Xurj5VNDfiiHmSuNPRvGFQiEkEgk0m80FPuSqqqKbzSZevXqF6+tr4a/Y7XaxEorFYpKNTFpBuVzG8fExzs/PxSKCnCG+dBxH8utjDtRVHB3xXeS7RDsSrh0KZDi+UF4y2FHn5YrP6PZz+xbevY8FnzcLQL53FotFFNfRaHTBUkJJ3idntd/vo1Ao4Pj4GC9fvhR+6pfyf1t2vOtd02g08Pl8+P777xGJROQCp7QwCgaDiMfj8Pl8v/HB5AWoWCyiWq2upH0O6T6ZTAaDwUC6YOSQut1uscIiXarT6WBtbU2mJnweVOVfX1/j/Pwc6XRaIhhX4bmYTCZEo1E8evQI8Xhc/PbYDLjrUn57KqEsdDmhJD4mblHZvePv5fr6Wjqv5XJZKHFKIRwv1H8k7qUAHI1GKJfLuLi4AABJp6BVjNFolJxMAFKkKAu/6XQqfCMSXhn1Rnl6qVS6j4/31aFs+6fTaZycnMDv94vFCXkK7CSQo0X1Fp/jKtzobmM8HkvKBsPKbTabjCF5uw0Gg5J4obzNkoZAb0T6zH0KyGuj/xZ5rauiHlQWbcs+0ll1sHjjhbVYLEoUHrsNZrMZLpdLDqfxeIxWq4VqtSpCmXq9jrOzMxwdHeHi4kJI4stKxv+S4EWXzQM2EKistFgsWF9fh9PpXIhFA34tphl/yOQFAKJar1aruL6+lshBpa3HKoH70Ww2QyqVwqtXryS3dzabwel0wmq1wul0yhnACxtDExqNBgqFAtLpNK6vr0XlyoJ6FQpAJecYgDhEvKvw4+en7RS/OK1U2ulQkHpXrK2yVlE2u8rlMq6ursSfslaroVqtSiHIZ/u1Ovj3UgD2+32k02mMx2Pk83kx5CUPyel04vnz59jc3BTSNDuEytFxq9VCPp8Xc2eKASqViiRAfAugqrRWq+Ho6AgajQaJRAKRSETUp4ztoreiMimjXq8v5XjyU0DuDru97LLUajWk02kp/pgVTM855QHyubd9rtFyuYxyuSwd7FXyEFPxdUCLB/KzGF9JRaLZbF4Y1wG/8lzT6TQuLy+lS10ul5HL5XBzc4N8Pv9ZF5dVxXQ6lb1eqWSluItxjFT7K4sVirOoaicnmJfim5sbvHr1Cj/++CNOT0/RaDRWotC5DSWXMZfL4ccffxTngm63Kz66DodDBJWtVgulUgn5fB6pVArX19eS8ctuFS8Zq7LWaAXn9XrFQukurp7SI5dm//l8Xsa1tVpN6hCKUMkzDQaDd46SlWI3vrO052m32+h0OlJgsuD8msUfcE8F4HA4RD6fR7VaRTabRSAQgNPpFCWm3++XQ3w2m8lLTqIkOUiNRgOpVArpdFo6WCyGvqVOBhdvo9HA8fGxjPFuE/C5SCmDJ2fuXaP2VQI3PL7ELPKo5lLaOnCN3B5dfu6LR85NNpvFzc2NjAWUGyu9J9WR6LcNTjmozvT7/eKfSGsJmu3yfa1WqxK5lUqlJBWA3W1ynL+VdcXpD4WE5KjZ7XZYrVahLLDYU47tlBYvyiScSqWCTCaD09NTvHr1Ci9evEA2m12ZUedt8EwYDodih8PuEr+UosvZbIZ6vY6bmxucnp7i6OhogVOq5EGv0qSIBTBN6iORiFywlFDS0orFIi4uLnB5eYl0Oo1MJoN8Pi/2X3QtcbvdiMViSCQSd6qJB4MBOp2OeMlSnHUXvUiZVf81cW8cQGVLlcULfcmGwyGOjo6EzKyMtOHDI5GaVh2rymH7khgMBiiXywtq6FQqtZByAfzqNUSX8Ww2+6C6B9z4viYmkwlKpRJevnyJbrcrKk1uMs1mEycnJ0ilUuLkvoqHiorfD+59VGm+fv0azWZT1OEc/7rdbukANptNOZCKxaJQHjje/BbGvkow8aJer+Py8hIOhwPtdlsEDHyWSpNdvV4vvC/grRqT5PvLy0tcXl7i4uICp6enKJVKD8JAm6ItmtEbDAYMBgMUCgWcnp7C7XYLt7fRaCCXywlHLZPJiO3Vqu5ZTDxiR30wGODNmzd3TrzYOKrX69IBLJfLKJVKaDabwtmlMpgNlkajcWdkIBtVpLiVy2Xxjl2WZ6mZf+Tf5I/qEPEXQwUWxQsulwtWqxUApC2q7NSwcOQI774LmPfF6nwtKE0qTSbTgh+ekvDKEQr92+5TKPO+5bcqXUnGGJJzSU4InzcPfGUn8Et0qJdhza0i7nvNkevHKCmLxbLQmVKmUAAQZSqtiLh+7sMwexnWHH3aGLe4vr6OSCSCeDyOZDIJl8sl7yAnAOR+sVPDUV82mxXe9MnJCW5ubkSVSX+7L/F87/O5sQhmh9lqtcrZwOkPLya0cuK5wCncfZ2vv/ddVSZ3MKmJAQC3wXeJEyRSgygK5LOgOpiXinfRfZT2dEqe5dd6Xz/m33PvBeBDwTJsjKuI+z6MVxnqmvs8qGvu87FMa45+sWazWdTTsVhMIh7p/8kC0OVySeAAjbbp/Xl1dYWrqyuUy+U/pNhZpue2SlDf1c/Hx5R29+9Iq0KFChUqVHwi2K1R/vdqtSodGaUF1vtGwO12+7MdAFSoWGWoHcAvBPWG93lQb3ifD3XNfR7UNff5WMY1xyLvtmnvXX5tt1MZlF9/5HhuGZ/bKkB9Vz8fagdQhQoVKlQ8aJBUT1WvChUqPg6rbf6mQoUKFSpUqFCh4pOhFoAqVKhQoUKFChXfGNQCUIUKFSpUqFCh4huDWgCqUKFChQoVKlR8Y1ALQBUqVKhQoUKFim8MH20Do0KFChUqVKhQoeJhQO0AqlChQoUKFSpUfGNQC0AVKlSoUKFChYpvDGoBqEKFChUqVKhQ8Y1BLQBVqFChQoUKFSq+MagFoAoVKlSoUKFCxTcGtQBUoUKFChUqVKj4xqAWgCpUqFChQoUKFd8Y1j72D2o0mj/y77HyeJedovrc3o/32VCqz+79UNfc50Fdc58Pdc19HtTn9nlQ39XPx8dYPKsdQBUqVKhQoUKFim8MH90BVKFChYqPgUajgVarlRv6bDbDfD7/qBupChUqVKj4OlALQBUqVHwRrK2tYW1tDRaLBW63G263G5PJBI1GA+12G/1+H4PBALPZ7L7/qipUqFDxzUMtAFWoUPG7odVqYTAYYLVaEQwGcXh4iEePHmEwGOD4+BhXV1colUqoVqsYDof3/ddVoUKFim8eagGoQsXvgEajkS+dTgetVitft0eh8/kc0+kU4/EYk8kE8/n8QXTDNBoN1tbWYLVa4fF4EIvF8OzZM/z3//7f0e12YTAY5HO3Wi21AFShQoWKJYBaAKpQ8RlgwafT6aDX62G1WmG32+FwOGCxWGCz2WA2m2EymWAwGAD8WgBWq1Wk02kUi0V0u130+31Mp9N7/jSfD51OB6PRCKvVivX1dcTjcezu7mJrawuRSAT9fh/VahXdbhftdhuFQgHdblflA6pQoULFPUMtAFWo+ESws7e2tgaTyQSLxYJAIIB4PI5IJIJQKIRAIACv1wuv1wur1QoAmE6nODk5wf/7//6/+Pnnn1EqlTAej1e6AGTnz+v1YmNjA48ePcL+/j6SySQcDgfMZjO2t7cxHA6Rz+dxdnYGnU6H2Wz2ILqfKlSoULGqUAtAFSo+ATqdDhaLBRaLBS6XC263G4FAAOFwGNFoFNFoFOFwWApAj8cDq9UKjUaDyWQCu92OTqeD8XgMnU6HXq+H8Xi8sipZnU4Hs9kMl8uFSCSC3d1d7O3tIRgMwmw2w2g0IhwOo9FowOVywWAwQKvVruRnXXasra3BbrfD6XTCarVibW1N1liz2USr1cJ4PJb1pmJ1Qc6t0WiE2WyGxWKB2WyGXq/H2toatNr3O7zdpqVMJhMMBgP0ej0Mh0OMRiOhqkwmk5W+pH4MjEaj7NdGo1EmPMR0OsVsNpPnMhwOMRgM5D/H4/FKXmrVAlCFik+AwWCA3+/H+vo6NjY2sLOzg42NDXi9XhkBW61WWK1WmM1mGAwGaDQazOdzaDQaeL1ePH36FAAwmUxEFMHNdtXAAtDpdMLn82F9fR3hcBh2ux16vR7T6VQOKBZ/Kr48tFotzGYzNjY28PjxYyQSCdjtdphMJuRyOfz00084OztDo9FAq9VaybWm4i3W1tbgcrnknYvH4wiHw3C5XPLuvQ8sACeTCfr95nxJpgAASqNJREFUPjqdDkqlEnK5HAqFglwY2u02er0e+v3+yhU3nwKHw4HvvvsOf/rTn+D3+2E0GhcK6dFohF6vh3a7jVqthkqlglKphGKxiFKpJC4Hq/aM1AJQhYqPgFarhV6vh9PpxPr6OnZ2dvD48WM8f/4cBwcHsNlsdwpBlN0ujUYDh8OBra0tTKdTpFIpvHr1CtVqFdPpVArFVQKfi8VigdPphNfrhc/ng16vh0ajwXQ6hVarFb6kUhSj4suAPEyXy4XNzU38+c9/xtOnT+F2u2G1WnF8fIxOp4NqtYrRaIRut7vUBSApFvzS6XR3rhmlAEv5Pe/DfD7HZDKRTuh0Ol2pd46f02KxwOfzYWNjA7u7uzg8PMTm5iZ8Ph98Ph+MRuMHf45Go8F4PEa73Uaj0UA6ncbp6Smur69RKpVQqVRQLpdRrVal+0VPz4cGi8WCnZ0d/I//8T+wsbEBk8kEk8kkz2k4HKLRaKDRaCCXyyGTyeDm5gZXV1eyr7GTukpQC0AVKj4CNpsNfr8f0WgU+/v72Nvbw87ODtbX1+Fyue7ccO86tFhEulwuWCwWGdOxMFq1zfX2Ya38PPP5HMPhEJVKBcViEc1mE6PRaOUO3WWHxWLB+vo6kskk9vf3sbm5iWg0CrvdDoPBAIfDAZvNBpPJ9FHjwfuERqMROyGr1SqjTYvFcuffW6/Xw2AwwGAwyJ9Tju6UmM1mGI/HKJfLyGazKJVK6HQ66Pf7K7EeNRoNTCYTzGYzAoEAdnZ28PTpU2xvbyORSGB9fR1Op/OjOoD8edPpFEajESaTCTqdDgaDAT6fD7VaDbVaDdlsFldXV0ilUtINfIiFoEajgdFohM1mg8PhkNE69+XxeAyDwQCLxQKj0QiHwwGXyyVf5+fnODk5WZm1RKgFoAoVH4BGo4HdbsfGxgYODg6wv7+P3d1drK+vw+PxvPPAuQvs1nA8bDAYpGB6iOj1eigUCkin06jVahgMBmoB+IVhs9mwsbGBZ8+e4eDgALFYDF6vVy4lXGccay1rB5aXCb1eD7fbDb/fD7/fD6/XC7fb/Zv3jPQDi8UCu90u5uN3XcbIse12uzg+PsY//vEPsWEaDocrwXHjmN/tdiMSieDw8BA//PADkskkXC4XbDab7Ccfer+4BrRarXDeeDmNx+PodDpotVq4vLyE2WzGfD5HLpcTzuCqcpbfB36m259rPp9Dq9XKBcpkMsHr9SIQCMiXXq9HoVBAsVhcqeey9AUgRwB6vR5msxlms3nhFqvsmpCEqSSzkrypQsWngv52vBVvb2/j0aNH2NrawsbGBtxut9ycPxbslBkMBuj1+gXPwIeG+XyO0WiEZrOJWq0m4pdV2iBXAWazGdFoFLu7u0gmk/D7/bBarZhOp+j3+/Kl9J9cRtBOyev1IpFIIB6PIxQKSRG4trZ4XFGQxa4NVfcmk+k3P5vFXrfbhcViQbfblTOi0+msREdL2R31eDwIhULY3NxEPB6XFJ679hFlYaOkoyj9SznydDqdmE6nGA6H6Pf7sFgs0rU3mUwwGo2oVCrSOX0oZ+tsNkO73Ua5XJaO+Wg0klqDz9ZoNMJoNMoUh13qRqOBly9f4ubmZqWmHEtdAPKBm81meDweJJNJbGxswOVyyS+GC346ncqGl06ncXZ2hnw+Ly/5KvwyVCwXqKp0uVzY2NjA/v4+9vf3EQwG4XQ65Ub4ucXbQyz6boMH7yocsKsKg8EAj8eDcDgMj8cDk8kkEXzlchlXV1dIp9MoFArodDpLyVMiPzaZTGJzcxNbW1vY2tqCz+eD3W6HzWb7TZf8thKW3pu3C0UAC8VPJBLBd999h+l0il6vh3q9jtlshslkshIFjbJwU3KO37WfUMnLz8eOFr9Xr9cLz5JfwK/PNxQK4cmTJ3A6nYhEIri8vJSRcKlUejDjYNYNP/74I2q1GjweD7xeL2w2mxR5RqMRer1enpder4fD4cB0OkUoFEIwGITH40Gr1UK3212JrvLSFoDKdAGXy4VEIoE///nP+Otf/4poNCpyd2UByG7Djz/+KK19ACJjX/VF+ilQplDcdQNU8WHodDrY7XaEw+GF8S8VrR8q/m7fuB8qbh8c7+pCKP/zW4byWf3e50JeGAtAjkAnkwlqtRqur69xcXGBq6sr5PN5DAaDpSwAtVotXC4X9vb28P3332N/fx/b29twuVzQ6XTvLOruEovMZrN3ikbW1tYQDodhNpsxm81wfX2Ny8tLDIfDlbLx4CiWX8qJmBI8B/v9PobDoRSBPD9ZQBuNxoVJBovDQCAAh8OBRCKBWCyGQCAAs9mMXq+HRqMhHeVVf6/7/T6ur68BALlcDuvr6wiFQiKqcbvdUgwqn5fVaoVOp0M4HBb7r1WiFSxlAWgymYRkGQqFxHLju+++w8HBAdxut4yS+NJzE6ClRqfTgdFoRDqdRj6fR7vdRrfb/SZiqPR6PXw+H4LBICwWC1qtFprNJjqdjnREVXwYvB2bzWbpBN7uPt/e/NjxGo/H6Pf76PV6mM1msNlssNvtdx5kqwyDwQCXywW/3w+n0ympJwTHSXwvV/2g+FwouzbkrOl0OvFc49fHPiOOA00mE9xuNzweD5xOpwiLRqORKH8rlQqazSa63e7SFTjkoFksFvGRPDw8RDKZxPr6upioKzGZTGQ9jcfjO38mi0Zld4v/zGAwSFeRNk2rAuXeUq/XkclkMB6P39kBnM1maLVaaDQa6HQ6GI1GmEwmImjgCJ3Pgr8Pu90Ou92+wLGkErjRaOD6+lr+3cu2pj4Ho9EI5XIZk8kE9XodxWIRPp8PHo9H3i3u4fR9dbvd0Ov14iG4sbGBYrGIyWSyMpGXS3kaWSwWbGxsYHt7W77IbXG73RiNRigWi2i32/KSs1NosViwtbUFk8mESCSCly9f4uTkBOl0GrlcDqPR6MEfQkajEbu7u/iv//ovhEIhXF1d4eTkBKlUCvl8/pt4Bn8EbttO8P9TggVPt9tFoVBAPp/HZDJBLBbDxsYG7Hb7nd+3qjAajQiFQtja2kI4HIbFYln4bNPpVGLg2GX5FsHuk8lkQigUQiKRgMlkEmsJfs1ms4+aVmi1WlitVvj9fkQiEXg8HthsNulMkH/Ji+8ydv2AX2kWDocDfr8fm5ub2Nvbw/b29jvFHABkvN1qtYTjrey06/V6SejhF7tbfC6dTgftdlvMfFdlQjSZTKQIu7q6EkXquzrJ9BrN5XKoVqui4mV0pcvlkiKH6mG3243NzU2ZcgBvBSgU5pAC0+/3H4RR9GQyQbPZxGAwkMKaUZ4c/5Ij+fTpU/zpT3/Czs6ORH+63W7s7u7KXpfNZtHpdO77Y30QS1UAsu3MBfjs2TNRXAYCAcznc4zHY5RKJZycnKBcLosNgNfrRSQSka7Xzs6O2CBwU2RHhnyIhwqDwYBwOIznz5+LQgwAxuMxms0mGo3Gvf79VgXK2za9sur1Okwm04LPndJbrNfrScc1k8kgnU7LobS+vi4F4EOB0Wj8zfiRoA1Mo9FArVaTbui3iLW1NUmPicfjePToEVwuFyqVCiqVCtLpNLRa7ULawLsKEhaTbrcbGxsbwpVTrkty2ljgLOtz1+v1cLlciEajiMfjiMViCAaD0Ov1mM/nkrLALxZvtVoNrVYLg8FAulqE3W4XQQhH5Mrx5mQyQa/XQ7fblZ+/zOIYgr/Tfr+PRqOBTCaDyWQCo9H4TrHjZDJBpVJBNptFpVJBt9vFeDyG1WoVIQMTMOx2O4xGI/x+P0ajEQCIoIiiNRZEtzurqw6KR4fDIVqt1kLHnhMfnU4nIhm73Q6LxYL5fC4d7HA4jHK5LJ3BVcDSFIBKhVM4HMbOzg4ODg4Qj8fhdDoxn89Rq9VQKpVwenqKn3/+GdlsVjqAwWAQiURC/JDC4TD8fj92d3eh1+ulFd7r9UTB9NCKQC5YElVp0Ovz+ZBIJIQTtIp+c/eB6XSKdruNfD6Py8tLvH79Wg5eu90OrVYrsUDNZlPGbaVSCbVaDc1mE81mEy6XC+FweCVGAp8KKgitVqsctMpDod/vo1gsIpPJoNFo3Dmye+jQaDQwm83w+XyIRCLY39/Hd999h2AwiG63i1qthlevXknRwwSGu/Yn0l0Ysff8+XM8e/YMoVDoN+PMVXjH2XEix8zhcECn00lXqdVqoVqtolqtolaroVqtotvtAnjbtanVauh2u5K2Ew6HhUPIcSYPZF7WBoOBHPhKccQyQ9m95EWhVqthbW3tnR04FrukAFGhCrwtelqtFsrlshR5LpdLuoZbW1vY3t5GJBKR0TuLZ1KqVr37p8TtTio78nzvdDqdZJqbTCbhrvJSRmeHVcFSFYB0sw+Hw9jc3MT+/r7kh/Z6PZTLZZydneGXX37B3//+d9zc3AinIxQKoVgsyhjF7/fD4/Fgc3MTdrsd7XYbmUwGlUpFXp6HUgDytkIrAKPRKH5QBoMBbrcb0WgUhUIBNpttYcP4FALvsm+QXxqTyURGl3a7HW/evIFWq8X6+jr8fr9wTjudDgqFAlKpFG5ubpDJZFAul+VGnkgk8OjRo4Uuw0N4lkoLCY5JlJ2W+XyOfr+PUqmEfD6PZrP5YN65T4GyAEwkEtjd3cXjx48Rj8cxnU7RaDSg1+ulw8xD+0MFYCgUwtOnT/HkyZM7bVKUCuxl7gByf2KEoFarleitSqWC6+trXF1d4ebmBtfX1xgMBnIulMtlZDIZ1Go1zOdzrK2tSeFHbqTysysLwH6/L+fAKhQxLADZBWw2myJEe9d+osz6VY7L+XMo5lCqia1WKyqVinDnaYLP59btdsVaiCrghwo+P057+v2+rElSMDY3N6UjumppR0tTAGq1WtjtdkQiESQSCYRCIZFYl8tlFItFnJyc4NWrVzg5OUEmk0GxWJTChy3r+XwOl8uFra0taV/P53OEQiHEYjFUKhVMp1N0Op2V7sjc9kekVJ2LMh6Pw+VyiYDB6/VifX0d29vbC5YAo9EI/X4fg8HgvS8yDxNuIt+CvyIpB5PJBOVyGefn55hOp8jn8/B6vULi7/V6KJVKyGazyOVychFhceT1epd6DPc5oIKQ+ccOh+NOC47pdCq+nN8a91TpI0kng729PSSTSQSDQXi9XgCQYoXq8g+N1uhLZrPZEAgE4Pf7pRtBKC1OWq3W0gpwePGnunI2my1QLjKZjKQs3NzcIJ1OYzKZwOfzwWw2o1wuI5fLSTFksVgWUhtuj+Jms5k0EzgSXYXij1AW8zzzfs/PUV4y2DDgSFkJ5TiUnUIWPQ9BBfwhKC/uPHfveleV/PBVwNIUgDqdTrh/W1tb8Hg8AIBqtSpWBicnJ3j9+jXS6TRarZbI/SeTCTqdDnK5HDQaDaLRKBqNBobDoXTFPB4PYrEY6vW6HNgcJawi6IJvtVoRCASwvr6OYDAIt9uNQCCAvb09BINBMYQdDoeIRqP47rvv4Pf7hU/TbDZRLBZRq9Xeu6GQD0dF50O/+Skxn8/RbrdxdXWFRqMhxTbXHlXndM+n4fFDHbXz0HY4HPD5fMIhstlswkFT8VbharPZxLT38PAQiURC1K3KLh0vVe97r5S5t1Qg3pXwMR6P0Wg0cHNzg0KhgH6//4d/3t8Ldra0Wi2q1SpKpRJSqRTOzs5wcnKCQqGAVqsFAMLp5rum0+lgs9kkJSMWi0k0mnL8Ox6P0Wq1kM/nhaj/rexjHwJjAwOBAJLJJHZ2doRDvra2hvl8DpvNBpfLBafTCYfDIbznb+UZ0hosGo0imUzC5/PBYDCs7OdfigKQJHmSmhOJhDhxs/P38uVLnJ2d4fz8XMYkvHnQ4Z23lnK5LHwHFoDsjAUCAaTT6U9Kb1gWKO1uLBYLnE4nPB4PNjY25GWlSisYDMLv98NsNssIIBKJQK/XIxaLSeevXC4jlUohl8u9tyPK9jf5H59qsH3biHTVQL5LoVBYCJ1XrkGlmSzXtNJw9aFAaRXhcrngdrsXLEg4NuFFYRUI9n8EGB/ldDoRDAaxsbGB3d1dhEIhmEymhfEcv8bj8Xu760ouEukejPJSgkrZQqGAcrm8tBmlLMrY9atWq2i1WkIbuLy8xPn5OS4uLtBsNjEejxdSGdhZpx3ObTEJKS98R2nRwW59u91e2cP790LpaKDVamXUS8eCnZ0dxGIxOJ1OrK2tYTabiek2vwwGg1jyPFQou5+0gYnFYojFYhJRyIbKqtld3XsByJus1WqFz+dDPB6Hz+dDv99HtVrF8fExXr9+jdPTU2SzWTSbzfcWKkpTUBZ5tKIol8solUpotVorxUXi5k6RjN1ulxc1FovJphcKheByueBwOMTDiSNitq6tVqvEQrFLEIlEUK1W3/sSky/CThd5IB8DeiwVi0Xhga3apstx+ceOXThOoXfUbX7cqoPcW46MlKpAktOLxSKurq5Qr9eFbL8qG+OXAMe0vJDRUJajSY7Hu92ueLU1m81Pith6F9+IxSUPpWV933jJPz4+RrPZhNPphEajQb1eR6VSQS6XQzabFcUvL5DsOpE65HK5RDj45MkTsSRiR1rJ/avVajJO5iTpW4NSKEg/SYpn2ExIJBIIBoOw2+3QaDQYDodot9siKOl2uw9+EsRxLz2JE4kEnj17hq2tLXEcmU6nqNfrYir+obN0mXDvBeDa2prw1AKBABKJBNxutzzMly9f4ujoCJeXl8IjugtKMrrFYpHih0rNarWKy8tLnJ6eolgs/i7+xNcGb2lGoxE+nw/r6+vY2dnBo0ePsL+/j0AgAKfTuSDXVyalMLLMYrHA4/HIQcyihmPd973IHLPTP4tK6o9Bt9vF0dERfvzxR+lEPHTeiLL7Qxf5213AVf78yi7UbfLzYDBAJpPBq1ev8PLlS5RKpQW15bcCHhxMFXC73dI1Ial8MBiIkrVcLqNWq2EwGPxuXpqSs7vMHneM4Op0OrBardKt6/V66Pf7C2pTrh0mLWg0Gin+4vE4vv/+e/z1r3/F5uam2OLwnWMx3O12hcSfyWQeVJ7tx0LpuEEjaKfTif39ffz1r3/FwcEBPB6PcMi5XtmUyefzYi7+LrHSQwCnODQp/+GHH/D06VPJXyYPfDweo1wu4/Xr13jx4sUHp2nLhHstANl2ZtTW5uYmgsEgjEYjms0mrq+vcX19jVwuJ+rdd4GFJF3NlYR0epHlcjlkMhnU6/WVKgB1Op34rdEsdWtrS8jk9COaTqeyYJVdAR7QdxmrKo1n2eq+3VXgz+aG3Ov1PilRhF2Ny8vLB+Ud9S4oKQ3xeBzJZBIejwcGg2HBNkbpQ7bMh/RtkIJASgUPbv5eR6OReHVeXl7K+/atHLR8fywWi3D/YrEYPB4PjEajFCWDwQD5fB4XFxdIp9Ni1Pu+A5XvscVi+Y2/Hcepo9FI8kgZ/basa2s8HqNer6PT6UinjqPa239vXjrMZjPcbjdcLhcCgQCCwSA2Nzfx/PlzHB4eIhAIyN7Pn0Ork1QqhfPzc+Tz+Qfth6q8oJlMJmmIABArK6/XK5dTu92OnZ0dPHv2DJubm+IpybVMylCtVkOlUkGr1ZJJ0kOFsgD0+XzY2dnB06dPEQ6H4XK5YDQaf5PKcn19jVqttjLP5d4KQCrZPB4PHj9+jO+++w6PHz+G0+kUM0b6O33oYVJtR6NVxtqwOlfK3SkCWaXDiCrCaDSKR48e4U9/+hPcbjcmkwmur6+lqPB4PNjb20MsFvtoI0oSr0ejkSjn7uKrkffFjqLdbv/oLgWd0nkjZ0dzWQ+l3wOuRbPZjEAggP39fTx+/BjRaBRGo1F88VKpFIrFomykq1QA6nQ6eDwe7OzsYGdnBx6PZ2HNkH+WzWZRKBTQ6/UefMdXCb4rLpcLsVgM+/v72NraEr4Q0Wq1cHJygn/96184PT1Fs9n8oACEAfRer1e4WQTdDXgYVSoV4UIv637H0exdcYrK/4+f3Wg0Yn19HY8fP8b29raI30KhEEKhEJxOp1wyeWltt9tIp9M4OjrCq1ev8OrVK9Tr9fv4uF8NPA+pFKcgBvi1Mx0Oh5FIJERBrtfr4fF4EAwGZWJ0uwnAZBk6aDz091mZKsNJjtIwW6fTiZ8kVeuNRkP281XAvRSAylikQCCA3d1dfP/994jH4zCbzcKFISH4Y0iV5Ft5vV44HA4pVjguUPLXVomkqfRHDAaDWF9fRywWAwCJt2MRQX+69fX1jy4A2S1otVqiWHxXBNPtQ1x56CtD2W93EEej0cKm8pA7gCz+nE4nIpEItre3sbOzg0AgAIPBIKrrdDqNcrksPJpVWY/Ar79rh8MhHFSa9xKkCyhHmh8qQNh9BhaFNasIcpqZTrSxsYFoNPqb59Tr9XB9fY1Xr14hlUqh0+l8cB3o9XrYbLaF6C5iPp8L15lWROwALjPuUj4rs3vZyeKokrnwz58/x/r6OgKBABwOh1z6lWNfpmZks1mcnJzgxYsXYh9GA9937UdKYdeqpIUAi2dGIBAQ8ZHSdiiRSGBrawuBQEDWpDL14vYzUWacP5T834+B0uCZnVTSrIBFFT+AhT/Li9cy+3DeSwHI/EeXy4VIJCIRbmazWTawQqGAQqEgdi7vA8ctHCWHQiGYzWZMJhN0u10Zt5FHsgovsRLsAFosFlQqFfz8888YDofIZrMolUqYTqci8PiUmxm5NtlsFplMBsPh8LOECsquIA2BKQoAgEajgWq1KnybVXv+nwKSqZPJJDY2NhCJRODz+YQDSOU1LS1W6TJCsFhTGo7fPkg5elPSC/i9d/13ipX0ej1Go5GMxpWb5yo8J5o+BwIBRKNRhEKhha6B8tLETNdyuYxms/lRFwEWl1Rd3x4BK4n6zABehed2G1xj5IczICAej2N3dxePHj0SwSA9KG+DF5F6vY5utwu9Xg+/3w+HwwHg7WF9u9sFvDUxb7fbqNVqoqZeBW6XRqOB0+nE9va2+E4mEgkpePkc+Cw+NI1Rmo/b7XaJWH3IF3ngbdeTaT21Wg2NRkNslyiAY3rZ06dPodVqxWOy0WgIh1VZDC4T7qUA1Ov10iFh8ed2u6HRaNBut1EoFJDL5YSn8aGihgVgJBLBzs4OIpEIzGYzxuMx2u22cExWcTMkYZc5q+l0GsfHx6Ia7Ha7oqCORCKfdDubz+fodDpIp9N4+fKl2Od86ian1WrhdDoRCoXg9/sXFMgA0Ol0kM/n0el0pCBYtd/Dx8JisSAWi+Hg4ABbW1uizGZHejAYoFKpoFAovDPuaxWgtCP5WIsbfo/SzoiHCHldNpsNnU4HjUZjgQ+3SgkqFosF6+vroqLk5e02/5UFCt/jT6G6KG13CHZoGEj/JcQk9wGlz6Hdbkc4HEY0GsXBwYEoML1eL1wuF0wmEwwGw29+htIerFarYTgcwuFwYGtrC2azWbzsaEB9ew3PZjPU63Vks1lcXFzg6OhIUoGWHexuUtQRi8Xg9/ths9kA/Lpfm81moft86Gzl99Bqx+PxwGQyPfgCkFF5FGgVi0WUSiWYTCaYTCaxYHI6nUgmkwCAQCCA6+trpFIpZLNZ5PN5eb7LuHa+agHIA4Mv9fb2NjY2NsTVnbmrNC8lB/BDhyQXtM/nQzQahc/ng16vlxiqbDaLarWKwWDwlT7plwU3s36/j0wmg0wms2AXoZTq34YyXB54a7tDB3eOgIvFoghuaLb6sdBqtXC73QiHw1hfX4fdbhdFMvBW6ceCZ9luQb8XHLmw+7e1tYW9vT0kEgl4vV7ZaEkYJpGaBfEqQukfdhdY3FGdT8oHu4bsvPD7SbRmAVir1SS7m7FT73MBWBawE88CkJ9JWaTQ+6/f74sCnzFm/Od3dcppvcOOGD3uCHYAW62WdABXaX3xfCAVhQbasVgMm5ubwhWPRCK/Gfe+D2wQ6PV6uZAFg0ExLzebzXcWgFQLG41GtNttlMtlURMvc+eeExnaqlGApOySfmpcGQtAj8eDUCiEaDSKUqkkezrH5BS0rRKn+X3gZ6rVashkMri8vATw6zvq9/tlL3O73dLYokk2z0GTyYRqtSr6g2VqgHzVApA8vXA4jL29PXz//ffY2dmB3+/HbDZDqVTC69ev8fr1a7GOeN9C4gFEQ1Sr1Sr8v/l8jmaziaurK7x8+RLX19dot9tL8+A/FuT1ZDIZCURXdkfoN8eFdzsSajQaoVKpoFwuA4C08Fmg8YDmYUJbik+BRqMR9WG9XofZbF7wvRuPxyiVSmg0GittBv0u6HQ6yXk9PDzE48ePsbe3h/X1ddhsNiGk3x4n9Hq9lTqgPxbK7iAtmcxmM4LBoFgW2Wy2BcER9waa+5K6wXF5NptFKpVCtVpd2rXDvcjpdIqZLpMCCOVosdlsYj6fS1rPdDrFeDyWnFpefJX7nMlkgsvlgt/vlzxcgpy3er0uiuJVWl9KP9hEIoFEIoFIJIJQKCRdQHafPlTAcNTJRCklD9ViscDhcMBms4kA4vbPms1mMnbmpa3T6cBsNqNUKqFWqy1t957CmuFwKAK/+Xy+cC7w837oXSJ9g9QGOlH0+334/X5RBDebTdTrdZm2PTSFcLvdxvn5OTQaDUqlEmKxGMLhMPx+vzSwqGngRIyc/UgkgouLC5ydnS2Yvi8DvloBqLxBxGIxHB4e4ocffhDBwnA4RC6Xw9HREY6OjlAqlT6oXmOHgeowytlNJpOYM56dneGXX36RDtSyHh7vAse0qVRKsmdHo5EEVPOQZSQXb7PsGvZ6PeTzeZyenkKn08mYFnjLgbHb7aLSpWL6UzEYDNBqtaSro9yglX6DyzKGV7rg838Td93Q3ndr0+v1CAaDePr0Kb7//nvs7+8jmUzC4XBI4gOV2p1OB9VqFdVqdeU6NLehFGvcfj7s5DidTpjNZvj9fuzs7GB3dxfhcBgej0cUm8CigTutiZrNpnh3/vLLL+KXtwzr5zb4Lipzf1kAKoUaHCtVq1U0m005LPhnGM/INBke3Mrn6fP5ZLR8WwSiLABXaX0phQv09Pvuu+8kbYFjdLPZ/NE8Zb1eLxxB5aWTtlj8Ur77/O/s5AYCAcxmM6EQ6fV6sZVZlkP8LrDDzC76h/bdj3mneH6vra3B5XJhd3cXhUIB+XweuVwO6XQaqVQKWq1WutzL+K5+DtrtNi4uLkRhH41G5R3f3NwUaxiPxwOHw4FIJCIWYOl0GhaLRcSWvIwswyTsDy8AlZm17Pzt7+9jY2MDTqdTPMOKxSJevXqF09PTj47o4eiXghJyz3Q6nUjWlSPgVZWuM7+SqmYAIrrweDwLZF9K/ZlLWygU8Pr1axwdHUlsD4tGdlx4qHCM/DkbG2+cywylHyIPE4vFsqCm5CFK0RAxGAzkZqtUfLEjE41Gsb29ja2tLayvry94vo1GI3Q6HZRKJRQKBTn8H6oJrcFggN/vl9gzZRQarSecTifsdvt7OUjdbleiHFutFq6vr1EoFITbtkzvMoVtpEIw+cNqtf6mYCGh3uv14uDgABaLRUbbw+EQ9XpdeGvswPCSm0wmsb+/j3g8LmvsNpRRhavC09JoNHA4HOK19vTpUxwcHEie+V1j2ndBqbzkRVl5KVaaS/NnUgHLVAxy3NiRJMeLF3I6VfAyvox4n3qZz4cXDF42budRs/Os9BQkD9Xn80nKDdd8MBjE1dUVTk9PUSgURAi26uB0i3QUXkar1SoqlQo2NjawsbEhYlYKZVj7kP5jsVjkDCCl5T47pX94AajX6+H1ehEOh/H48WP85S9/we7uLiwWCwaDAdLpNM7Pz3F6eorj42OkUik0Go2PEjOQT8iFxxEoR5K9Xm9BLLHMN7YPQdlhoQN+MpnE1tYWnj9/jj//+c/Y3t6Gw+HAfD5HpVIR09OffvoJL168EBsK2il4vV7pWFCIcxeh+iGA/CkSmQOBgHBZYrGYKOSGwyFKpRJyuRxqtRqAXzfLarWKVColY7rpdAqTySS2O4lEAvF4HOFwGHa7fUFZyCQaGpuTO/Mx9iirCIvFgq2tLbmweDyeBcNZWil86EBnB2c+n6NUKuHo6EhUi8s23uQ+t7m5iY2NDfj9flgslt98Tprf82IcDocXOnU8aGq1miTt8BA2GAzwer2IRqMIh8O/4RaS++VwOEQksioZ1BqNBj6fD3/+85/xl7/8RUZspKZ8yueg9RfV5LwwcBJRLpclVhRYTJFyOp2ikOXzdTgciMfjsFqtmM/nIgqYzWZoNptLtQ4BLEw13mVzw+JwPp8LH5fJNMzvns1m4g3LsTt/F6RdsbufSCSwvb0tLhXsPjL5adX3OfLlealgcymVSuHo6EgobUzm4tlKSppOpxNqyMuXL3F+fo5SqXTvsXFfpQB0u91IJpPY3d3F4eEhotGo5Dyen5/j5cuXePPmDfL5PEqlEgaDgdySCeWNTv7ya2uwWq1ywFAVx9sMCyZlaPq7wD+vDGdfFtxFBqftDcdqm5ubWF9fh0ajQbfbRbVaxcXFhTzb8/NzeL1exONxtFqtBeI5OTG3VYWrCCX/TPl7Z0Qgs1mpQGcR7fF4xPGefDOmz8znc+Tzefk55GfR6DeZTGJnZ0c4SuxCcz11u10Zw5+fn6NSqaycGfmnwGg0ihUTUwesVuuCAbjysCE1QDnq5KZJ+xzaqdhsNgyHw6VTuNKgORwOIxQKyfj/duFCVT8Ph/X19YX3m6IspS0M1b/swDgcjjs7i3xeoVAIkUhEFIsfM025b7Dbtr6+jmQyKW4Cer1e1M18FsqYO+BtmhH3bnbr2enr9Xry/cPhEPl8HtfX12IGrdVqYbVaxUc2kUhgOp3KQc51zBx1eixyv1jGyQepBuxW2mw2OdP47lGQQH87TppoWzKZTMTEn9QqAJIfzD0V+LWgDIfDMho/PT1FPp9HvV5/EMI/1gfMg6dQrVgsioCVDYRYLIb19XW4XC4RgYRCIdjtdthstoUOK0Vg9zXR+CoFoNPpRDgchs/ng1arFW7emzdvcHl5iVQqhVwuh0ajgdFoJC8bFxxJrSRH88VnkchMQ962lQT0YDCIeDz+m9xaJQmWB9J0OkWv10Oz2VzaTVOprKRzO/l7Ss5do9HA9fU1zs/PZWz2LvHFQzFn5u+eggKHwyGHIQ9limWcTidcLhe8Xi98Pp/E55Hn43A40O12ZX1UKhVEo1Hk83m02230+324XC4kk0kkk0npWHAdKnl/5XIZFxcX+Omnn3BycoJGo7FU48svDVIz+B7ejiwbjUbo9XrodrsolUrI5/Po9XoiAuDvjCKH2WwmI3uDwXCnb9t9g903ils+5iJ113vHjFZ2ZFj0cG1zFHxXR2xtbQ1+vx97e3uYz+eoVqs4Pz8X8dUyFcy3wT2YIzYWIADkogBAuso8iHkZYDxlp9OREIFms4lOpyMFIPCWTkMRFvB2OsCxfLPZlKKFRQ4pIqFQCE+ePJE/X6/Xly5SjsVfpVLBzc2N7EHKYq3dbkuxZrfb4XA4MBwOUSwWxTFjPB6LXRudJkh5IQeV7zUv23a7HX6/H8lkEoVCAQAWBE0PBUpj7Pl8jkKhgF9++QXValVSaUKhkCTVUHC0vr6O3d1d6UgznILr/WufC1+FA8hFYbPZ5AVkBBJHvsobL28dSgUlv49VOADpyFDRRXNK3ph5G97a2vqNsIEbKIs8yr2r1apYriyj5YSyuOW4QkkiB95GcaXTaVxdXaFWq733syyTLP33gAekxWKBx+ORDim9w9gpJo9Fr9fLF8ngVGUGg0E5MOfzORqNBuLxOMrlMur1OtrtNlwuFzY3N5FIJGRz5IiEt2wmYlxeXuLly5e4vLzEYDB4EM/7XWDRp3wf+Z7x3Wo2m6hUKri4uMDx8bGoxz0eDw4ODoTXy4NFyc1axgKQa0/5uZUXyLvERneBo2E6Gdz+/ruSdgh2W81mM2azGU5PT2Gz2aSLtswFIIAF31Yl74/dvPl8LvGetAVqt9uiGCcni8rxarUqBaDST5LjTz4PpT+l1+vFdDoVbnk4HAYA2R8o6nG5XGi1Wvj555/v63G9F/Sv4/mqLGLH4zEajQYqlQrG4zE8Hg88Hg96vR5ubm5QLBbR6/UwGAyEY848c3atmXikPHeUOc3RaBTxeFzsc243YB4ClJzJcrkskYNerxeBQACJRAK7u7vo9/uIxWLSYU4mk5jP52J7R2eD+3hH//ACkIqpfD4PjUaDYrGITqeDN2/e4OrqSrohTJFwOBzweDzwer0iSmDRxxxC3uZsNhtisRgSicRCJBX5Mm63G5ubm9Bqtb9p0ysLQHZrRqMR8vk81tbW5JbE8dQygIUtFc8ej0dsNfR6vViNKGX5tIy5a2FxM6RbPpVuqwhyLNxuN3w+H2KxGKLRKHZ2dnBwcIB4PC5CoXd1Z3iosoOnBIsPl8sl69But8vtmBshVdTKRJuTkxNcXV2hWCyi3W7/sQ9iSUDBDbm4PLC73S7q9boYq7JL3e124XQ60ev1EAwGhaKgHOffpdpcFnC/IH+UdhC8EDDphN27D3Xd7xodv+t7eIEbj8cy8mO6yKrEdvFAvLq6gtlsRj6fl+iybrcrXDKdTrcwDWLXT5nFSo4ehVafwrXVaDRoNBrS5VcWiVR5UzjG0fCygUK2SqUi9i1KLirPZI5nedkaDAZSkNCjkl3CTqcj3pMsrm9ubmTfM5lMYvNEFWwymUSlUkE+n5cx56qeL+8C3z2OctmJbrfbIvJgB5sqd3IBaQc2nU6Rz+dRLBYfXgHY7/eRSqUwGo1EGTMej5HL5VCpVKDT6cRdnLLqcDgsXT0lZ4gtf77M9L9TjvK40I1GI3w+Hw4ODhAOh3/zYJUjYLZyR6MRrq6uxLS3UCjIS7IMUHa4HA4HfD4fAoGAJE2wS1qpVKSg+1BreTgcolqtIpvNolKpLGXX82PAeKOdnR1sbW3JaDYUCkmRfBcn6y7c9axol8Px8Hg8Fh6X8mdSxVksFnF5eYmzszOcnp7i6urqQd6CPwSOojjqZcRjsViUcVOtVhP+jMlkWhjZrQpGoxGKxaKoTTUaDVqtlhQNfF89Ho/sUR9byLL4UBaOt22LePFIpVI4PT3FmzdvcHZ2tjDOXGZwZP3LL78gn88LrYc8O+V64Oel8wD3bo6CB4MBut2u5CB/yW77XfYqy4b5fI7BYCBZ4+SPKv85PQJnsxna7TYqlYoUenz/aCNWLBbR7/fFM5FG5Ha7XUbAgUAAf/nLX/Cf//mfsNlsiEQiaDabyOVySKVS8rtZFvuTPwqkMSh5zjSANhgM2NzclAnT7u6uTJxevnwpFjFfE394AUh/v3K5vDBmIyeF4g0mgzx9+lRGakpRhxKsulkQkRfELgGVTVTgrK+vS6dPKX9XZoxyM5nNZsjn88hms3KzXBYoLUycTic8Ho/Yt6ytrYkHGG9oSjLvXZsgNwJ+z8fkLi8jSKwPBALY39/H06dPpQB0Op3SheHv+fY6UKqrb/9crivyUummf5ubxZ/B8Xsul8PJyQl+/vlnnJ2drXQSzeeCBxEV1BTBZDIZ5PN5VCoV6Qwwz3XZuWrvAg3X2W3X6XRotVqyP/l8PnkfKRL62AKQ4zZmkN7+Xh40jUYDqVQKP/30E16+fImbmxvp4Cw75WA+/9W4v9fr4fLycsHK5i6KinLvVnpRKi1gPsdwXikiu12kcxJFDuEyRzmSw0exyl1G10qrl9viGj43/pxarbaQ7sMvIplMwmaz4fDwULxmo9GoNCjo7LGqDYaPBfmXLLBJddFoNCKG29jYkO620WiEXq9HvV7H6enpV//7/uEFIEcTd93oNRqNbFz5fF6EDK1WC16vV4pDEm6pngN+5TGw00deEP99SjUYb4Jsw3JkoMwGVnYYs9ksLi8vpYBaphec6l+PxyMKIxa/wK/dViqrs9msbP40lb0LHNEpw+NXCUqDXAp+ksnkgoiAv3t2BLgG2KJXdkiUIiJ6gAUCAbHUeNeFhOAYjnFvxWJRbuKr9my/BCaTCbrdrqj+r66uJMqw0+nI78/r9WJ7exuHh4dIJBLS/V8VUJgwm81QLBZxenqKer0u42t2AL1e78IY+EOgPx49FHmpvd11LpfLYvt0fn4uKtcPmekvE9jVu49OOQs/juqZlqQstufzX1OZcrkczs7OkMlklvZSx8vu78XHduysVqt0+em2wfP6rkvLQ4ZS4UuOYDqdxuvXr2E2m6V5Y7FY4PV6RR9xH3SCeyUwzOdz9Ho9FAoFNJtN5PN5vH79GqFQCPF4HNFoFG63eyFbz263i5mnTqfD+vq6RJoBbwPR+/2+cGEqlYrwOqrVKnK5HPL5PLrdrvw9WATy8G40GksnBGFRogxC580N+JUrk0ql8Pr1a1xfX6PVakkr/67RJzcJGh8vm7fax0Cv1wu5NhQKCR+UvEje2NvttnD3yEspFotiCMvNUhm5FQgEcHh4KD5ZH+NdxwKz2WwKJ4m/h1V7tr8XyvVFflwmk0G5XBb1HHmVGxsb+NOf/oQffvhB3vuPjataFnDvqdVqGAwGuL6+XigseIn9mCQLZSJFOBzG//yf/3MhYlF5Een1ekin0zg+Pha6QT6fl+6Nig/jdqSoy+WCw+H4TdReu93G9fU1jo6OcHNzg06nc49/6+XBZDJBvV5HOp0WNTDP1VV5f780lPsfNRAGgwGxWAyxWEymSpx+3Idn570zWNkdZOG1traGXC6HYrGIXC4Ht9sNh8OxUACSL0ROltfrlfi34XCISqWCTCYjfCPm0LZaLSkAc7ncgjJYWQTSPX0VoBSwkJhL65dOp/PB2xsPLRZBq9It4FjDYrFICgcjedxutwh/yuWyHIjsOrEbVSqVFrhCPARI8r4tSPgYcP3QnoLjgFV5rl8aXF/sxrfbbSFEk/oRj8fx/PlzPHnyBLu7u9LlYleNwhES85d1pMk9hETw3wsWJIPBAM+ePcN0Or2zczgYDJDJZHB6eorr62tRJH6ra+5TQVUvPQC3t7fF05NqbL7TtVpNwguy2ezSdgD/aNzmpPJio2xIfOtQirNoxeT1epHL5VCtVsXU3mAw3JlH/TVw7wWgErPZTBRKANBoNER9ydszTWKNRiNisRi8Xi82NjZgsVhEtXR2doa///3vMoZhN4+H8m01MaHkkywjaCpcrVblM5F0St4Zu5v0VHyoLyPHtLQNef78Ofb29uD1eqX4a7fbODs7w//5P/8Hx8fHUux1u12xllAW++zyaTQaMSCmXcwyWo+sIqhK9Pv9CIVC2N3dxdOnT3F4eIhkMgm73S5drtFohFqtJsa9+Xxe/NuW9R39kqCHKu2LlMpT5Vrs9/solUq4vr4Wwv5Dfe+/NOhZmUgk8OzZMxweHmJnZwfb29vw+/0wm82yx1Jgx/zber2+kpzp3wvaPNE+S6vViu0W05AALJ0/4n2BRWC325X9LJvNYjwew2633+vfbakKQHbger2eqFOVqjeOU2w2m0REsdBh96/VaiGVSuFf//oXfvnll4XR5m2S8F3//mUGR9/kl9H3ij5ENzc3oq6+i6DMookvLbBIel4lngaLM3rxPX36FPF4HE6nc8Fj6eTkBP/85z/xyy+/yEjsNgFap9OJHyDzG+12O5xOp1goKMd2d/mzEVR98pJCDswqdVc/BKVoipcQEu5vP4/bvCqXywUASCQS2NzcxJMnT/DDDz9gY2MDNpsNZrNZfgb3gOvra6RSKZRKJUmxeSjP8l1gZ8DhcMDr9UpSj9Fo/M0zVhbKLEqWfS9bBnAs73a7sb29jR9++AHPnj1DKBSCz+cTwRd56rT4KRQK32SXldQDvsv0FlxbW0M0GkUkEpEc9Ha7LVx8hiys4ppUekQqR9qf+lmoMaBROX1P35dO9jWwVAUg8T4CKzdGAEIu1Wq1ciDRFJT2EsyEXMXFdxssctvtNjKZDH788UeR8jcaDVxeXiKbzcqYTLk58SBmHquSS6TkISyjr9VtKHlVJNIyC1qv10vG9C+//IJffvkFmUxmIV+aSjYqxV0ul2QD+/1+6biEQiHEYjGxo+BNTqlCv60WNJvNCAQC2NjYEM5bo9FAp9NZ8BVbZTAD9fr6Wix2mKV8+3JBLuXm5qbE562trSEcDiMajSKZTIrtE8cgyotgoVDA+fk5rq6uRNTwkIrp94GJNi6XSw7buy5oHJXTEuUhrLE/GmwkRCIRbGxs4ODgQCzIaBlFl4RKpYKzszOcnZ3h5cuXyOVyDzbH+12gzypFcX6/X2LyjEaj2Jo4HA4AWJi0Ka1lVukc1mg0IjLV6XTSOKAt0afQxJT+vVarVRwl7vu8Xf7T/g6wkHE6nXIjpvya471Op/OH+EDdJ8in6na7uLm5wXg8xtnZGYC3zu+FQkGUf0p7ExZMDodD0gJYBFmtVhkx3feC/FgoO0vk79hsNuh0OnS7XVxeXuKf//wnjo6ORHSg7PqxeAwGg9jY2MDu7i4ODg6QTCZhtVoXUmaYY8sbLa0S2O1TEsXNZjMikQjW1tYkiqlSqQDAgzmcp9MparUazs/P4XK5EAqFEA6HRdGrLAAtFgtCoRAMBgP8fr+Iatxut5i9O53OBYENb8vdbheFQgFnZ2e4vr5eKWPjLwGlz+n7SOKcbLDb8lD2uz8K3A+dTqckBR0eHiIWiwnvT6fTSZJPsVjE69ev8fe//x0XFxcoFAorwxH/UtBqtXC73dja2sLu7i42NjYQiUQk+tBut8Pr9cJiscg5TBFco9EQF4RVenfZbGJjQXn+8n372J/Dc9bpdMLr9cp5pUxSuQ+sxmmvAB8mOzcsZCgkqdfraDabC/mRDwXsjM5mMxnzsvhQig/uKnp5m+GBYrFYpFujTFpYpRGw0pmf+an0Ncxmszg5OcH19bWEnvNzsrBjLM/+/j4ePXqEJ0+eIJlMSga1EtzUSCkYjUaiymZuNT2dSO4Nh8NYX1+XjgG5rauO2WyGVquFdDoNt9uNjY0NhEIhzOdzSePhOqL7vcVigc/nk6xvjpA4Wlf6gzKnNJ1OI5VKCf+v2+0+iAL6Y0APSrPZLBeS2wWgUrTGL7UAvBukjFDhb7FYkEwmcXh4iMePH2NzcxN+v38hL7fT6aBUKomw7s2bN8hkMiuRbav8vGazeSEKjvgUlb3BYMDGxgaePHmC/f19bG5uIhKJyHusFIHQWYFpLRTCrRpYANrtdoma5dlbqVQWouBu04KUX7zw0qaMOcqcLAH3Rz9buQIQgPA2WEVrtVr0+33UajXkcjmxYXiIUGYX05gUeHsYvOsAUN5m6EHE0fkq+wDyJaP6kqbWlUpFLgLT6VS6hexIJRIJJBIJxONx+e8+n088JTnmpTdZvV5HKpUSVTkV64wtDAaDiEQi0kHgiIlih1arhVKpdN+P64uA3XbaPrx+/RpGoxE7OzvY3Nxc4PGxy8wDiZ1TmhoruTUk22ezWVxcXODk5ARv3rxBPp9Hu93+5ixNeHjcdTkjHUHpdUoutFoA/ha0JnE4HAiFQohEItjd3V0o/sj5G41GYt1xcnKC169fI5VKSaLDKowy+Xld/397Z/rTxtWF8QfbIQYveMNgYwMuCWTpKlWp2n7o31+1UhsS433B28x4bDMe8NjGOLwfXp2Ta15CCEleGPv+JFQpUASXmXvPPcvzBALY3t7GkydPWHyYuG3wQZeRcDiMra0t7pEkKTLqc6Z3Wfy+dpaBoTOT1Ari8ThWV1dRq9WQy+W43YraAcT2Iko2UOYvlUrh5cuX+OGHH7C7u8sX5fu+tNk6ACQBRZr67HQ6qNfraLfbc5f9E7nqZEJ86GWjkofb7eYAkErnVGoT0/R2gH5PMQtiWRba7TZb+NGLSVNrgUCAzc1//PFHvHz5EvF4HOvr6zNWcbS+VIocjUYs7pvNZlEul1Gv1+FyubCxsYFEIoFnz54BAPdp0SRxNBpFPB6Hqqr3nu7/UlBv1MXFBRqNBrxeL2/8wWCQ+/wI2hDF3190eaCSOvl9Hx8f46+//mI3C/IPX5TsH/De/o0GFa5O/tLlj6R1hsMhHyYPucx2Gx9kcR+7+t/bQt+f1tHj8WB9fR3xeBz7+/t4+fIlDg4O2K2CLn7Ue2oYBl9uDg8PcXx8DMMwMBqNHnxQQ9WeUCiEZDKJn3/+GX/88Qd2d3dnvuZqoPah7wW8b5uh4TYamhN7fkUVjbs6sTwkxCz85uYmDwgdHR3xudDv9wGA267o0kaDgNSidHBwgN9++w3ffvst902Tf/DH7Fq/JrYLAOmPQpqAlG0RpRDa7fbc+67edhMSp3yXl5e5VOlyufgFJR1BO7kGUD8GDcRQLyRZ95EPdDweZwcV8ove2dnBwcEBnj59ilAoBL/fzwELZbfEPpZer4dSqYTXr18jn89zUOJ0OtHtdnFycgIAPKnpcDi4NByLxdiia14CQACcbTYMA7VajZ/F8XiMZrPJTgqrq6twu93cLL68vIzLy0sOWGgTHAwGME0Tuq7j33//RTqdRqFQQK/XY9eeRUDsbQ2Hw4jFYjwkc7U/l/qRrrraPNRD1+v1YmNjA8FgkAMucbqeBvnIsUnU0RQHXK7boyjQo6yLOHlOmfh4PI54PI5vvvkG+/v7SCQSWFtbg8fj4ffeNE0oioJGo4FMJoOjoyOUy2V0Oh3bDH4sLS3x3vPs2TO8ePEC+/v72NnZmfkakZuemasXj+ugvdg0TTQaDZTLZZRKJVv7ywPvS+k+nw8bGxvY29vj9zAcDqPdbrNr2GQyYeceKo3TENeLFy/w/PlzbG1tsWYx6RVns1koinIvZXLbBYAA2AKOeg+WlpYwHA6hqioqlQpUVbVlz8HXQMwk0I2N/v1zxtrvEyrNjkYjqKqKN2/e4PLycmao4LvvvsPOzg6XHMlCj2z01tbWeIhIvMGen5/DMAx2C6nX66hWq6hUKjMC25ZlcQl9NBphZWUFoVCItdro8InH47AsC5lMxjYDNp/C+fk5dF1nz990Oo1oNIpEIoFkMomNjQ1Eo1FEIhEuG5FfsmEYODk5Qa/X48ytoiioVCqoVCrQNI1L+IuC0+mE1+tFMBjE9vY2Dg4OsL+/j/X19ZkLBD2r5FduhwAwFArh1atX+P7773kQTZTBIO9Ueh7IxanT6fAhS6oHV6EKBx26Pp8PkUiEPcEp00/etBT4Ub8vtcEoioJsNou3b98in8+jXC5DURS+sNiBpaUl+Hw+7O3tsS+6x+O5ti8cuF0G8GNfR2LkNC2dy+VQLpcftF3ebaCMJj1foVAIXq8XsVgMP/30E1qtFltbjkYjrgpRlo8SAcFgENFolHUlTdNEsVjE33//jcPDQ5TL5S8iHv+p2PJEEqcvqTQyHo/ZeLnf79v61vEloexXIBBg55QPpewf6sFxHTSObxgGCoUCptMp9vb28OTJE5ZzoWk+KgFT0Cf2XliWhYuLC86AWpYFTdOgqioajQYHI/V6HaqqwjTNmY2Qmp3ppkeuNeTtGAgEEIlE7s3r8WszmUxgmiY3zJNcxPb2NlKpFJLJJGKxGDY3NzkQnEwm6Ha76Ha7fMArisIaa6RxSdN2i8Ty8jIikQhSqRQODg6wt7eHRCLxPxODZFtJbkfksPKQ18vv92N/fx+//vorotEovF4vt6JQD2+v14Ou61AUBYqiQFVVaJqGdrvNbjDXXe7Jb5mGA9fW1hCLxbC/v897AtmK0h5Ik5xUPaLLx5s3b7js2+l0YJrmg17Xq5D8UiQSQTKZRDQa/Sy9uav6nmK/OWVndV1HNpvF69evkcvleFqaBkLsCJ2NNGBF60qi16lUCq1WC4lEgo0ZHj16xPsd6ciKSQbLstDpdNBsNpHJZLjaQe0F/29sfyKJhzEFMXbLaH1NlpeX2eIomUzC6/UCmO29uri4YF0ju6wb/fwkFUKyDjTFu7a2xocmldWo7DQejzn7JEoVGIbBGQgS2aYghW541/VdTiYTdDodZDIZHrTxer0zXqJ2ma7+VOjvIEqRmKaJVquF8/NzqKqKtbU1LsV5vV7+u4mOLKZpsl/3YDB48MHM12J1dRVPnjzB77//zllsCpToEKHLT7vdRjqdRjqdRqfTefCZ0uXlZZ6GDIfD3Ioift7lcsHr9SIUCiGRSMA0TfT7fRbPpYvBVajHNxKJcBbe7/ezZh3pnIpyQ2SfqWkaSqUSCoUCyuUyarUaD33Ypex7HWKf33X7z232egr+xK+lwcHT01NomgZFUVCr1VAoFJDP59FsNjlja2cNXjHLTlaylmWxddvjx48RDofhcrkQjUYxmUzgdDq5BLyyssJVN+pT1zQNb9++RTqdRiaTQbVaxcnJyb1pxNo6ALTrg/X/hLySU6kUEokElwLEwI+yX3YLAC8vLzEajaDrOqbTKZLJJM7PzznYEzMM1AtJMiOqqqLZbLItD2Wfut0uv+jD4ZBFP2/S8Hv37h1PCJLIcTgcxmQygd/vnxv9vw8hXrioHYMyNrQBkiSF0+mcySDQB01bi84Bi8jKygp2d3fxyy+/4NmzZ/D7/TycBLzvvaTBpHQ6jWw2i16v9+DXjKoRdEESBwgA8Ht7cXGBjY2Nmf2JdE7Jv/sqbrcb6+vrLOVCzxpVikSnI7qwULa/WCzi8PAQh4eHKJVK6Pf7nGl86Gt6E3TB/5Dz1U3clEQhx612u418Po+joyMUi0XUajXU63WYpskT6XarLInQsBtdTikIpP5SuqysrKzw70k9g+LUPr2vlmVBURS8fv0af/75JxqNBrsb3dfwlq0DQMnHIaPzRCKBzc1N7kEgzcRWqwVd1zmAsVuzPd1GnU4nms0mCoUCZwNFdwng/aZmGAYURUG73ebeM03T0Ol0Znyjb7sWdFOkjFcul+Pbod/vR7/fZy3AeUcM7hbh9/0aiIcmTc3S80UHSafTwfHxMRqNBjRNs0W5nILW4+NjWJbFPXgOh2PmgwI2h8Mxk1kOBAIIBALXDviRMkQgEJixyxMlhuiSS5PT5DJDPWulUgmtVosHTh76en4IGrJSVRWlUmmm+nHV1vLq/0fZZcrMX7cHkji2oii8dtVqFbquo9vt2rbkexUxA6iqKrLZLLueBAIBTjCQfSjNJFDgbVkWP2uUvab+0nK5jJOTE5ydnd1rb6kMAOcc6kOLx+PchErlo2q1ilwux/p2dHOzE7RpjUYj1Go1TKdTFAoFnna+utlR1pCa58UyJOmpfc7mb5omcrkcDMPgDYGs6c7Ozr7EryyZY6gsWSqVWNzc6XSyhyj5gLdaLeRyOZbIsUOp7fT0FEdHR3C73djd3WXdTBpQE/24qcxGmXsScF9aWro2KHE6nZyVEYM/yn6RbzV5hFerVVSrVZRKJRwfH0PTNA5e7FQJuY7Ly0v0+31WRphOp2w9Bvw3W3qTpSCJsLdarWuzrTTpS3275D8/b0Lt1N4zGAxQr9cBAPV6HfF4HFtbW9xa4Pf7sb6+zg5QpM9JU76NRgO1Wg2NRgPVahX1ev1ey74itgsARYVt0YPVzi/s50JrAby/8ZL0i8fj4aZVEjqeTCbQdR2lUgmlUgn1ep17a+77gbwLlKqnbKbL5eJswnXQoSCWR66WSu76PJENXb1enzERJ69WieQm6HKWz+exurqK1dVVXFxc8HAEtS7UajWW2fjQZOxD4/T0FPl8HpPJhOV9YrEYiwm73W6WDvL5fLyHidaLZIlHH2LPN5V3x+Px/7S5UJ8pSTpR6Zwuv6KOot3PEnLqqVQq3LO2sbHBnu8ej2emRCmK3g+HQ3S7XZ7mJVktgoLLWq0GVVXR7/dhmiYHM3Z4Dj8FeiYUReF1SSaTMxcY6rF3uVwsZTQYDFjiJZ/Pc6aZNIrJ0eu+18tWAaB4G3S73WzSDNhfdfyu0GQSuS+Mx2O8e/eOfQf39vZ4Eoy0xAaDAZd/NU3jCUI7Bn+EWIa9T2T5U/I5jMdjqKqKdDqNwWAARVHg9/txenrKmpTUC0fZv/s+RG4LlYCn0ykGgwE6nQ7C4TBn/shPdnNzE9vb29ja2uLsp+gl63a7eaiDBOBJLogCPsr4UaafhknIJrJaraLZbPKUr91aXz4G6fJdXl4ik8ng8ePHqNVqWF1dvdZXmgJAkmhqNBpoNpvXZgDpb0ftMtQnPa9QoEbPFfWS93o9zgCWy2UUCgX4fD6WBtN1nTOA4ln7kLBdAEiNvXRTJNuueVEf/1QcDgdWVlYQiUTgcDhgmiYmkwlisRgLHu/t7WF9fR0ejwcOhwPn5+fo9/vQNA26rsOyLFsHfxLJvECZbMMwUCwW4fF48OjRoxmpIsomD4fDe9EOuyvj8Rjdbpclg/L5PA8JOZ1OhMNhJJNJdtUJhUJwOBxsv0iZTr/fj52dHd7vut0uT+mPx2OcnZ1x4z719vZ6PZ7yp94s6vWdx72PguWLiwvkcjmoqjrTUnAdojEADcFdFxhTlouUI+xyAfkSkPj9aDSCoih8eSHXD5pqp/fVsixey4doTmGbAFAs/QLg297Z2RlPiIlp6HkPAkV7t2g0iqdPn8Lr9cIwDEynU2xtbSGVSiGVSmF3dxdra2t49+4d98DQ5Gun04FlWXO/XhKJHSAZHdM07/tH+eKQ1/NwOIRhGDOfIxtBXdc5S0LZFkVRoGkaBx2BQACqqiIcDuP09HRGponKb/1+H4ZhsJi0YRgs4L4IAcvVYSxd1+/7R5oLqId8Xio8tgkAxRIveYbS1NF0OkWlUoGu63xrmfeAhjKhPp8POzs7ePXqFba2tjgVHwwG2X0hGAxieXkZg8EA1WoVmUwGxWIRjUYD3W534dwWJBLJw0KUZaEewX/++QdLS0vsdEISQaTxt7KywkLs4vDW1Ylpy7JsP9krkXwNbBMAEhSBd7tdnvocjUaoVCoczMxbP8d1kCey2+1GMBhEMpnEwcEBT9CRBQ0pwJ+fn0PTNBQKBRag1DQN/X5/IXsnJRLJw4KCtH6/j3q9zqVKUeAfeF/9EGVixM+LyQJRpFzucRLJLLYKACmtfXJygmw2i36/zwEgeUfOczOqiDi1pes6isUinE4nNjc3EQ6HuQdhPB7j5OQE3W4X5XKZp98URVmYcohEInn4iGXL+x7mkkgWAdsFgDR98+bNG5RKJd40SONpUQJA8mFcWlrC8fExXC4XTk9P8fz5cy6TU4mkUqmgWCzyODqZV8tNViKRSCSSxcRWASDw3yCQZAGu6hgtUoqfAl8SnCStLDI59/l8cLvdGA6HXPYtl8vc9ydLIhKJRCKRLC62CwCBm30KFw0SPjUMg3siFUWB2+3Go0eP2IBaUZS5VGuXSCQSiUTy6dgyAJS8h1wwSJup2+2ytha5UIzHY26wlmVfiUQikUgkS5e3TKVd5x0oec+HllGu283c9PjJtbsZ+czdDfnM3R35zN0NuW53Q76rd+c2od2tA0CJRCKRSCQSyXzg+PiXSCQSiUQikUjmCRkASiQSiUQikSwYMgCUSCQSiUQiWTBkACiRSCQSiUSyYMgAUCKRSCQSiWTBkAGgRCKRSCQSyYIhA0CJRCKRSCSSBUMGgBKJRCKRSCQLhgwAJRKJRCKRSBaM/wCaH9nOQ/ButwAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "images, labels = next(iter(trainloader))\n", + "\n", + "plot_sample = images.permute(0, 2, 3, 1)\n", + "\n", + "fig, axes = plt.subplots(nrows=4, ncols=8, figsize=(8, 4))\n", + "fig.suptitle(\"Sample Images\", y=1.02, fontsize=14)\n", + "\n", + "for ax, image in zip(sum(axes.tolist(), []), plot_sample):\n", + " ax.imshow(image, cmap=\"gray\")\n", + " ax.set_axis_off()\n", + "\n", + "plt.subplots_adjust(wspace=0.1, hspace=0.1)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h-Q4A5H8OQLs" + }, + "source": [ + "Now Let's start by making the patch embeddings layer that will turn images into embedded patches to be processed then by the attention layers." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "SFo1GzZvxINl" + }, + "outputs": [], + "source": [ + "class PatchEmbedding(eqx.Module):\n", + " linear: eqx.nn.Embedding\n", + " patch_size: int\n", + "\n", + " def __init__(\n", + " self,\n", + " input_channels: int,\n", + " output_shape: int,\n", + " patch_size: int,\n", + " key: jr.PRNGKey,\n", + " ):\n", + " self.patch_size = patch_size\n", + "\n", + " self.linear = eqx.nn.Linear(\n", + " self.patch_size**2 * input_channels,\n", + " output_shape,\n", + " key=key,\n", + " )\n", + "\n", + " def __call__(self, x):\n", + " x = einops.rearrange(\n", + " x,\n", + " \"c (h ph) (w pw) -> (h w) (c ph pw)\",\n", + " ph=self.patch_size,\n", + " pw=self.patch_size,\n", + " )\n", + " x = jax.vmap(self.linear)(x) # [H'*W', D]\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dOJp3mZNOjJW" + }, + "source": [ + "After that, we implement the attention block which is the core of the transformer architecture." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "mDg-L_9ixINm" + }, + "outputs": [], + "source": [ + "class AttentionBlock(eqx.Module):\n", + " layer_norm: eqx.nn.LayerNorm\n", + " attention: eqx.nn.MultiheadAttention\n", + " linear1: eqx.nn.Sequential\n", + " linear2: eqx.nn.Sequential\n", + " dropout: eqx.nn.Dropout\n", + "\n", + " def __init__(\n", + " self,\n", + " input_shape: int,\n", + " hidden_dim: int,\n", + " num_heads: int,\n", + " dropout_rate: float,\n", + " key: jr.PRNGKey,\n", + " ):\n", + " keys = jr.split(key, 3)\n", + "\n", + " self.layer_norm = eqx.nn.LayerNorm(input_shape)\n", + " self.attention = eqx.nn.MultiheadAttention(num_heads, input_shape, key=keys[0])\n", + "\n", + " self.linear1 = eqx.nn.Linear(input_shape, hidden_dim, key=keys[1])\n", + " self.dropout = eqx.nn.Dropout(dropout_rate)\n", + " self.linear2 = eqx.nn.Linear(hidden_dim, input_shape, key=keys[2])\n", + "\n", + " def __call__(self, x, enable_dropout, key):\n", + " input_x = self.layer_norm(x)\n", + " x = x + self.attention(input_x, input_x, input_x)\n", + "\n", + " input_x = self.layer_norm(x)\n", + " input_x = jax.vmap(self.linear1)(input_x)\n", + " input_x = jax.nn.gelu(input_x)\n", + "\n", + " keys = jr.split(key, num=2)\n", + "\n", + " input_x = self.dropout(input_x, inference=not enable_dropout, key=keys[0])\n", + " input_x = jax.vmap(self.linear2)(input_x)\n", + " input_x = self.dropout(input_x, inference=not enable_dropout, key=keys[1])\n", + "\n", + " x = x + input_x\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t_RB1ip0PEk4" + }, + "source": [ + "Lastly, we build the full Vision Transformer model, which is composed of embeddings layers, a series of transformer blocks, and a classification head." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "nG6fLPhyQEBx" + }, + "outputs": [], + "source": [ + "class VisionTransformer(eqx.Module):\n", + " patch_embedding: PatchEmbedding\n", + " positional_embedding: jax.Array\n", + " cls_token: jax.Array\n", + " attention_blocks: List[AttentionBlock]\n", + " dropout: eqx.nn.Dropout\n", + " mlp: eqx.nn.Sequential\n", + "\n", + " def __init__(\n", + " self,\n", + " embedding_dim: int,\n", + " hidden_dim: int,\n", + " num_heads: int,\n", + " num_layers: int,\n", + " dropout_rate: float,\n", + " patch_size: int,\n", + " num_patches: int,\n", + " num_classes: int,\n", + " key: jr.PRNGKey,\n", + " ):\n", + " keys = jr.split(key, 5)\n", + "\n", + " self.patch_embedding = PatchEmbedding(\n", + " channels, embedding_dim, patch_size, keys[0]\n", + " )\n", + "\n", + " self.positional_embedding = jr.normal(keys[1], (num_patches + 1, embedding_dim))\n", + "\n", + " self.cls_token = jr.normal(keys[2], (1, embedding_dim))\n", + "\n", + " self.attention_blocks = [\n", + " AttentionBlock(embedding_dim, hidden_dim, num_heads, dropout_rate, keys[3])\n", + " for _ in range(num_layers)\n", + " ]\n", + "\n", + " self.dropout = eqx.nn.Dropout(dropout_rate)\n", + "\n", + " self.mlp = eqx.nn.Sequential(\n", + " [\n", + " eqx.nn.LayerNorm(embedding_dim),\n", + " eqx.nn.Linear(embedding_dim, num_classes, key=keys[4]),\n", + " ]\n", + " )\n", + "\n", + " def __call__(self, x, enable_dropout, key):\n", + " x = self.patch_embedding(x)\n", + "\n", + " x = jnp.concatenate((self.cls_token, x), axis=0)\n", + "\n", + " x = x + self.positional_embedding[x.shape[0]]\n", + "\n", + " key, subkey = jr.split(key)\n", + "\n", + " x = self.dropout(x, inference=not enable_dropout, key=subkey)\n", + "\n", + " for block in self.attention_blocks:\n", + " key, subkey = jr.split(key)\n", + " x = block(x, enable_dropout, key=subkey)\n", + "\n", + " x = x[0]\n", + " x = self.mlp(x)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "agBSRsXVxINn" + }, + "outputs": [], + "source": [ + "@eqx.filter_value_and_grad\n", + "def compute_grads(\n", + " model: VisionTransformer, images: jnp.ndarray, labels: jnp.ndarray, key\n", + "):\n", + " logits = jax.vmap(model, in_axes=(0, None, 0))(images, True, key)\n", + " loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)\n", + "\n", + " return jnp.mean(loss)\n", + "\n", + "\n", + "@eqx.filter_jit\n", + "def step_model(\n", + " model: VisionTransformer,\n", + " optimizer: optax.GradientTransformation,\n", + " state: optax.OptState,\n", + " images: jnp.ndarray,\n", + " labels: jnp.ndarray,\n", + " key,\n", + "):\n", + " loss, grads = compute_grads(model, images, labels, key)\n", + " updates, new_state = optimizer.update(grads, state, model)\n", + "\n", + " model = eqx.apply_updates(model, updates)\n", + "\n", + " return model, new_state, loss\n", + "\n", + "\n", + "def train(\n", + " model: VisionTransformer,\n", + " optimizer: optax.GradientTransformation,\n", + " state: optax.OptState,\n", + " data_loader: torch.utils.data.DataLoader,\n", + " num_steps: int,\n", + " print_every: int = 1000,\n", + " key=None,\n", + "):\n", + " losses = []\n", + "\n", + " def infinite_trainloader():\n", + " while True:\n", + " yield from data_loader\n", + "\n", + " for step, batch in zip(range(num_steps), infinite_trainloader()):\n", + " images, labels = batch\n", + "\n", + " images = images.numpy()\n", + " labels = labels.numpy()\n", + "\n", + " key, *subkeys = jr.split(key, num=batch_size + 1)\n", + " subkeys = jnp.array(subkeys)\n", + "\n", + " (model, state, loss) = step_model(\n", + " model, optimizer, state, images, labels, subkeys\n", + " )\n", + "\n", + " losses.append(loss)\n", + "\n", + " if (step % print_every) == 0 or step == num_steps - 1:\n", + " print(f\"Step: {step}/{num_steps}, Loss: {loss}.\")\n", + "\n", + " return model, state, losses" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "y3Bm_Xln-rSp", + "outputId": "867bf4d3-326c-422a-89e8-c2ff63f893b1" + }, + "outputs": [], + "source": [ + "key = jr.PRNGKey(2003)\n", + "\n", + "model = VisionTransformer(\n", + " embedding_dim=embedding_dim,\n", + " hidden_dim=hidden_dim,\n", + " num_heads=num_heads,\n", + " num_layers=num_layers,\n", + " dropout_rate=dropout_rate,\n", + " patch_size=patch_size,\n", + " num_patches=num_patches,\n", + " num_classes=10,\n", + " key=key,\n", + ")\n", + "\n", + "optimizer = optax.adamw(\n", + " learning_rate=lr,\n", + " b1=beta1,\n", + " b2=beta2,\n", + ")\n", + "\n", + "state = optimizer.init(eqx.filter(model, eqx.is_array))\n", + "\n", + "model, state, losses = train(model, optimizer, state, trainloader, num_steps, key=key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X4GPbpuMQEB1" + }, + "source": [ + "And now let's see how the vision transformer performs on the MNIST dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pZu5pMRZW3tF", + "outputId": "d74d06b7-0340-4e4f-b723-8e5d532aecb5", + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: 97.5%\n" + ] + } + ], + "source": [ + "accuracies = []\n", + "\n", + "for batch in range(len(test_dataset) // batch_size):\n", + " images, labels = next(iter(testloader))\n", + "\n", + " logits = jax.vmap(functools.partial(model, enable_dropout=False))(\n", + " images.numpy(), key=jax.random.split(key, num=batch_size)\n", + " )\n", + "\n", + " predictions = jnp.argmax(logits, axis=-1)\n", + "\n", + " accuracy = jnp.mean(predictions == labels.numpy())\n", + "\n", + " accuracies.append(accuracy)\n", + "\n", + "print(f\"Accuracy: {np.sum(accuracies) / len(accuracies) * 100}%\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "vit", + "language": "python", + "name": "vit" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}