From 5c5a3db94ec00f4c2a952b5e34f04184c6a2dfb0 Mon Sep 17 00:00:00 2001 From: Bernard Tan <30761156+thkkk@users.noreply.github.com> Date: Mon, 22 Nov 2021 22:21:02 +0800 Subject: [PATCH] Implement BCQPolicy and offline_bcq example (#480) This PR implements BCQPolicy, which could be used to train an offline agent in the environment of continuous action space. An experimental result 'halfcheetah-expert-v1' is provided, which is a d4rl environment (for Offline Reinforcement Learning). Example usage is in the examples/offline/offline_bcq.py. --- README.md | 1 + docs/api/tianshou.policy.rst | 5 + docs/index.rst | 1 + examples/offline/README.md | 28 ++ examples/offline/offline_bcq.py | 241 ++++++++++++++++++ .../bcq/halfcheetah-expert-v1_reward.png | Bin 0 -> 56131 bytes .../bcq/halfcheetah-expert-v1_reward.svg | 1 + test/base/test_env.py | 2 +- test/offline/__init__.py | 0 test/offline/gather_pendulum_data.py | 170 ++++++++++++ test/offline/test_bcq.py | 221 ++++++++++++++++ tianshou/policy/__init__.py | 2 + tianshou/policy/imitation/bcq.py | 213 ++++++++++++++++ tianshou/utils/net/continuous.py | 119 +++++++++ 14 files changed, 1003 insertions(+), 1 deletion(-) create mode 100644 examples/offline/README.md create mode 100644 examples/offline/offline_bcq.py create mode 100644 examples/offline/results/bcq/halfcheetah-expert-v1_reward.png create mode 100644 examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg create mode 100644 test/offline/__init__.py create mode 100644 test/offline/gather_pendulum_data.py create mode 100644 test/offline/test_bcq.py create mode 100644 tianshou/policy/imitation/bcq.py diff --git a/README.md b/README.md index 512cd7697..13cfc191f 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning +- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf) - [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf) - [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index b05f5be42..7292afdcc 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -109,6 +109,11 @@ Imitation :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.BCQPolicy + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: tianshou.policy.DiscreteBCQPolicy :members: :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index b56bce367..a7fa0da26 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -27,6 +27,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning +* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ diff --git a/examples/offline/README.md b/examples/offline/README.md new file mode 100644 index 000000000..8995ee6e2 --- /dev/null +++ b/examples/offline/README.md @@ -0,0 +1,28 @@ +# Offline + +In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore. + +Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. + +## Train + +Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset. + +To train an agent with BCQ algorithm: + +```bash +python offline_bcq.py --task halfcheetah-expert-v1 +``` + +After 1M steps: + +![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png) + +`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment. + +## Results + +| Environment | BCQ | +| --------------------- | --------------- | +| halfcheetah-expert-v1 | 10624.0 ± 181.4 | + diff --git a/examples/offline/offline_bcq.py b/examples/offline/offline_bcq.py new file mode 100644 index 000000000..e488489e2 --- /dev/null +++ b/examples/offline/offline_bcq.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +import argparse +import datetime +import os +import pprint + +import d4rl +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import BCQPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import BasicLogger +from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.continuous import VAE, Critic, Perturbation + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='halfcheetah-expert-v1') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=1000000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[400, 300]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument('--epoch', type=int, default=200) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=1 / 35) + + parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[750, 750]) + # default to 2 * action_dim + parser.add_argument('--latent-dim', type=int) + parser.add_argument("--gamma", default=0.99) + parser.add_argument("--tau", default=0.005) + # Weighting for Clipped Double Q-learning in BCQ + parser.add_argument("--lmbda", default=0.75) + # Max perturbation hyper-parameter for BCQ + parser.add_argument("--phi", default=0.05) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only', + ) + return parser.parse_args() + + +def test_bcq(): + args = get_args() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + print("device:", args.device) + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + print("Max_action", args.max_action) + + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + # model + # perturbation network + net_a = MLP( + input_dim=args.state_dim + args.action_dim, + output_dim=args.action_dim, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = Perturbation( + net_a, max_action=args.max_action, device=args.device, phi=args.phi + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + # vae + # output_dim = 0, so the last Module in the encoder is ReLU + vae_encoder = MLP( + input_dim=args.state_dim + args.action_dim, + hidden_sizes=args.vae_hidden_sizes, + device=args.device, + ) + if not args.latent_dim: + args.latent_dim = args.action_dim * 2 + vae_decoder = MLP( + input_dim=args.state_dim + args.latent_dim, + output_dim=args.action_dim, + hidden_sizes=args.vae_hidden_sizes, + device=args.device, + ) + vae = VAE( + vae_encoder, + vae_decoder, + hidden_dim=args.vae_hidden_sizes[-1], + latent_dim=args.latent_dim, + max_action=args.max_action, + device=args.device, + ).to(args.device) + vae_optim = torch.optim.Adam(vae.parameters()) + + policy = BCQPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + vae, + vae_optim, + device=args.device, + gamma=args.gamma, + tau=args.tau, + lmbda=args.lmbda, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.start_timesteps, random=True) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' + log_path = os.path.join(args.logdir, args.task, 'bcq', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def watch(): + if args.resume_path is None: + args.resume_path = os.path.join(log_path, 'policy.pth') + + policy.load_state_dict( + torch.load(args.resume_path, map_location=torch.device('cpu')) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + if not args.watch: + dataset = d4rl.qlearning_dataset(env) + dataset_size = dataset['rewards'].size + + print("dataset_size", dataset_size) + replay_buffer = ReplayBuffer(dataset_size) + + for i in range(dataset_size): + replay_buffer.add( + Batch( + obs=dataset['observations'][i], + act=dataset['actions'][i], + rew=dataset['rewards'][i], + done=dataset['terminals'][i], + obs_next=dataset['next_observations'][i], + ) + ) + print("dataset loaded") + # trainer + result = offline_trainer( + policy, + replay_buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + ) + pprint.pprint(result) + else: + watch() + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == '__main__': + test_bcq() diff --git a/examples/offline/results/bcq/halfcheetah-expert-v1_reward.png b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.png new file mode 100644 index 0000000000000000000000000000000000000000..5afa6a3adc97e5b3db6c1b5253b7d4c0d25178e1 GIT binary patch literal 56131 zcmb@uWl$W^7PdRM2DjiA+#v*a34}m!g1fsrA;AX;79h9;5AN;+cXyZIHaK6;x%E|@ zoL{$YU1}(psh;V**Iw&=pS8Oo{F9>WD^y}s5D4^2?!A->2!tR90>Nq`BLc6KSnU8j2yl?F zm712Tq_dfktCfR2rJ9wU8So+q#KpnIxzkv~$+?>r#>KIl^ow#gIdeBHb0_#WCGcWW z9Vh=zSO^#IUO`+qF*+OwLlx|D|7yRJwXQ`^k;giRzP5Qu+YN5s!1iYjF2cjO-C zfLvYqJwJQiA9xq%7dJH(_+acl__yx<8g|tbR==LU^vxW*6lOVa6?pZP68XO#VVn_e z4J2Y%@Z$gV;7t4eUjDzHi(~a9he)aX*V}-P=;NV{!}<4Z`Bbq9|MTgu!t-*!<1=ey z6%|=Er2h9w#IXwC-{8x6oGgdSl6%2-8amELw>xdg4F7zK_Vo1h#tQD=qjOW^zU!Y7JqLvEba`<=Vi~+PbUs zbG74wS&~tA$g2nw_;2FN0WVKl6^1ao!ztg!{BNq~Ypmjnwp{c*HxLpM6E||eY=h+0 zm6h1cnq@sv?^eT7xGjuFe!k6LbUm(}+vLu?i`={weYqQ(n4b?D9UVP7UG;3V;8W4hRq}YrWA9lAzM{Ij$`%E9+aV zwQ<`1gRStB$EqSG{GS{_2K{hFS;3(IVgwzK=jF&dPWxp(PzMr)?-%Cs798lJMc^}o zKt*b(nNLR+v5EziR#S9W5Xe?(^=Q|}cPl~7r)?x4EO2jyq3)+?Q-c7OCY`Glmo>CHh1@1L&`q=D#+*M^ZRN;MX&7$8z)Y{sb z{LX>AxTM5slz(r1LY4vnwCsC1Vrga7*%w8EjT#Kiyp2KT@*Xmex;hs}W1~&2}A6J{JS@mB5%lVn9fKuzZFM?T&pC`8NrIxwrY{-Z`wI zgL}V%*_x2a$jJ8h_r)|d2}ea=e2*5Ykpjg)?#*P11-QUaC6tsfK^+^8>xr8Borg0e zAJZ*Xn%pccEjwDgE)gG(frt#E=9w?K0`tsiX(1K#IDrSX+#i-OFf(_HzFdlK&1f1( zQ^$%4r1vd3v`3+_w_t(R&ppq82{FK{fcT!SrY0sQxw7;!AGhi8U;K9xH96c4r%T09 zK_Z!$1)2U$MLz$El$5Ps7Kl%H#f$Ud2$Ue&ks&x!jSc4m$6=8 ze2(Y6geG?vJ3vP8x@EZio!tN=K_x8vn^pYW-`>Z&%WCR_%$G+C7Oe_UlD;>iKah2s z$MtJ6Zf5)7H8qAn{hPBT{gYl{ zu&0FjB~j+ z_D-n(;j}&G8xp25kkZn2Q^AD#VGk$Gn7~q$w_Ko}H2r1c;o&jti+t?~h!dyTARZrS zbScD|E2Zap&LAdNHfED#?G|x&G#xE-EL6Qj!)WLSYyRSnbEAUZ|EOs&=Qg#KpP#Tq zHWpOA78kWg#;&ffv#YCd#kWqIQX>_<|Jq7~tei{SMe)VCr_e)lFPIPxMDDeX-D|Ptx%H}4 zubBXJmel@;kfiI3!t1b__Vm!;w8i0dVJbz9-MvL0FK0Vn(OE-Sw7?6~F3XKzAHV(O>GV9Y+!vvK)tkAvxcCf^6^}L?Zgo8-3nPvwaZ+w@9zpm zt8`csfsV=2-o9^7^uZMvRum_ss5#k1V?$V{~;vt?IFe;xqs0LK9>Fe zB}o1ENa4jJt=`j;(XKt;)So-N+&FZvM7DSyRrUip`M6<^_|MWtBFi(y>3giW-Hv*+ zsSc{58{S_XXf?arO+-UxbsVw*M*02wH{3G;IwEv=eK*hy7#8pTnax)s?LYCZSS@E= zXvv%=h~~q|oa+vNH6FNn6KVgl{U3GdOOT>{I_t)SlFgh!02!dhq{Jj76B83KAkV#Y zkN;o_PzTT$pRb(y6K**2J*)`(+^oA~K9HJ^riBBQzcoOgDP>|pgFzvJ4$^9Leq-OT z5dyOZt(zzSP4ZM2m;E1(wOTx#ZXB9DPWu3hiQ*pRgM*eil%Sa5hv_qYaR~KXvCYRa zqK4SpNp~fH@3cjRq&m*TvqtrIYWOZdS_3nm?0Dn9V;Mky+2rW&cFtbuuJt0 z{WLK`u%NJBD z@(wV28!`x5Qe|!T)@d%I{GJ=ZgpU1W9_!hmc;L8bG|COZQ1!z`Oz+->4uvg5$D%eqz~j+QT75%i^hs}mr% zD1Mh+Cuv*Zb6M1}1F*U?HdAwE7nkkFiIiu7&Qza6s6 z$}r++JMHVc?d#z{q!>uHym#)xKyLz#C7|;7#|>ALADjT@H+^~9&yCsaBXJl;ZLVK- zM$`8_mj)ag1Dbmq?#<`BsScY{na?~>b%x?C5Kc*e#{fWt+7SS~SRf9M6WvmSc5*0e zgA4(Xx&>w%=p#qIVtNF{<2q7>7E7|A2OBiBZ*69i zeSiLwyH3`sII(J&0E~DFMW;Hnz4n2eNOKTLu@FcCg$-I?@BCI+UOtAJ2dbhr%RX9d zB?)*~S_8--?dTLI9RBd|5Cj4`gMOf`s5I<^g_1JVPzl>_drSWmOAC?r1BegBC+Ci} zp3~kfCn&fgVUqU%j9T3uBa@pf?8AHHF8Cj&ov(xgAgljxd3!XJnE_(zhuZJzs{Zus ztof1%KqE)MF+g}RqsfIM00J=N2)i5=RJwcwIj(oWJg%%gk+BebxbeuD|L{+DY?4l; zK7>Ig5Ck;3>o-dd{bK>oF1NlS&<1gVG@Sta=I<+Q1t=jv6)R%RDZ8Srja;YBwv)u6 zC6G+eeSPiaY3&?}M++92BdA5z&}AANln3}}?5yR$l~|($-(%IQfLfO6ODU2yAQ&0G z7ji8ZgUnDdCMG6!9C(vtD!G~Z^Jf+89yBK16wg@DX4&!iZm%22Spc+$y`E&gvj1xg zZt~hI`d9k!TK*q_MgSNMHh|k8EOf-`fJ|OHVr^~hD(MW+ZWMpA%V;tu9EJZa6`fK> zWRYrt@!@n4uy*~o=UXr!o7vJ%Kz^Vn+Zq)-?E$KKd}Sq_)jLmwyhGx+_v=S*K2i+8 zl<^-Rr(+pD(DxPKi}}c4r9tR2f~zzK4ghhREZY zq|o#2mM1Vo>?j!+5SWl%24D^Q)@u!D5zaTd&_ltWDZ0_`!5sk2*|Riq|3MR=UjqUG1kmaJpa2d8G;}Z(T_0Il zSuO4CdV%<67s@;IylcPvDD?7p`JV#o4#nK6ozw3ICSKM1l^YOg@KivqS zz9*G^^D!@g(=e4<=m25Hh(7bRJ*=Gm2mAmpuQVONfflR=^OVaApn)4_8{JLRJ9hOe zAxPw&AgCb&UEUW))%^AP_9s?y-!pL*y(T>J7^DwBSUQ1|!+r@*9x%m$5-~uo99EOG zrMit=e>)cdW&mt-OIGwA`)6vZvbuUObm{3W*`W68$QKFS_kh|Fka?BHy{OVK#rgSk zz;+CD;|Mg=?G3bhxxat2{&hkF0ZzdU$^FQ_m>@oxfqY;`*YyWq|z7qY_nxtX55g?B0I&d|AX>~u=_60<$4@d%8DfBhP**dAoU9ZU7Y}OWvokZ3oaFz5tME>!T$}&9FBg`bmp60#&#Hf=lo9y< z%$RTf;Xq>rNovM_Y7WbU7)S2k(-BSm|G9toZ%m=s%L)^UNzu2_9U9vF`dj|*->;y_ z3ABSk0E0W9ZG=+fl$6*2yN%oHxiKOV78ZMs@4o%%Ty6Dny4W7ru$MVNxev_s32nx2 zM7aX7j)(~r-(CZa;l%7LuR;`YZ(iJyKlut;4Y}zH8x;};X-7CNy}Wm{>Q~Q%*7#s7 zl*&CXS9UB6Jl8O2?E+~SoSsR<6-~BH#>i@#H>XmGbsXMVvgW9@CiewF_;SDjtw}uvsg^o!`UYR*q?PvAVl0uZoFKf`WqO zXoY-_|40il`3S?(CSVXKFTb7x%^x4VC?#?VVY5FaceG|3$8H9eQ9dmWQVN$4!mbZ$U<Lz*9(3eAMLxw~7N+yL?tOfdNG^Gfee@a&pWuBP6Te!^6H(rK(AnKpdn!0~ zo?c!~h)LAL0b%{7>aTafwY-A>^)~^j~qYcN{l6u_2LdIiVwg~wNYCDcFV;jDX zvi>8Nb(b*O0klQnb2RV^>?Ea`P-t-*`<=;&$zN4FRTG_ZP-r>}D=Al?T)|W_e0n$V zSdEI5%#l33kPNx~GU&88mkbxE&<M~Ntu-y(O{2+&Rjvn?O4~Vuo!{)-{q)?LFVYTCJTB@J9t)A=hzyGC z*E`u^lCz~ccH|t3H^t!%AS_ixTA#42ub;flIP*FDO(cj$$EgtCuH#PsNiXS~c4V8L zOun;?2>wZdXT3|cGU$V_6#GvtT1c~r%rH)Hv1weCjd`ETgkX7!! z=#*#mTk=HdBrSM8uOPqGW3uH&FU2`sa+QIgi?vE*8;=oq6C=(<(^G>o$D2>04@ITw z(MEyfymR19YfWauT8!t#nMu2p_DxSW~wezs#<6_g#0+vy=5N z$hZ{QQLc*K3KC-s1~6ZnTPW@cQHz8!I8J_0e>DzXEk<+k{RF8yiFg_;)J=-=j>tagtLkR2t3lK@ zsR_TYbwwM78I7K6`LBKKNzAs^`z^+mp`%4Aw;x=5(HpvA)dJc!HsDM4sl>wI=pI!S z;ezM9HeP0M+_q(G(Zg=cAe(~#9m`_?WD}4n1xcU15AwF>e+>nqM_7q?D zSwo8P5~IQGMc4fSJZ99Z$~NWCNiXw><^3~`I4A>X-VVVRpM6T^wR#YKiYYlE)TtGF z_Tz*HB`3w+`!USifp|}fTR~B)=~8Le>IFJ4+{T5%1?~O10eG$NhXxlOa86ad4P#uT z#NQBb)^$%)n{c7H)}Y&0(~u^d25?!mETm%XS~*DfrSf#Ce3||@OOCmz{G&esO(w3V zk3{aReK6cb+U(MR6q!Nd5q~Px&W%KbnQiC^@6+W^aT`f09FU<3Exw)D$-pvT4^i`U z6W|2M-^Fl+?Om!lF{2w3KFTHZ*5jj$A{rMQCwIL^{l()iLqof|1xvc^*sARfLc^;# z{TaHc>Am^oXbdGpDK?SM6z_*5yRcH?Mg+YD>+w#~cxo68K$#3+_g3#R-@B5O$Z)9bS^B z_;Q45&P#&VxM5>*X&{<=*$NZK>Bj>{k}Q@(z(--7!ZQfnwnYHieU!8c8&(f>ADYgD zj(MJ^L7{Z4qgX&1vYC`hb9#7iOL?A(vwu&9yUVKRW=>tczJ6LCwdBJI;#1XnHa#yB z!7m9ulh6E@d^T1SqHPnfDtPEHUa{mC+}XXQ>wu3EY~>pHo}QY&B+|M^Sr_+LOMm7X zXz{pjs_6~!4G&sB5<=B!K`cBYnC(}Qs|!yE<2IruP$x)*nQx|ASg2+-GeY_8BtSae z&Xxf;YbmsyDjnBz@{W+mIs9xnp4@N^ZvGZ)-(-w4De30DSVVboV~P`p7flJ zTJ5PI#VZ(UAH}Fw0$phtDAwL*jRem_n5<35gsVx!#KrT8cyfKM4oc zg2CmF{E5z63N~B0^(Lm?&zF!MdkWnLAVRKUk9v;YsO{}JG!lHtnYiHg4~hdu^_Wpf zBK0vLiy1MpiAFb@p4J|)1LH;Mmp>`Fpe3ENfqk*Zo<+NPVc$UTv_cb*t%AG>8wGZ4 zByisV;kv(nPqr_eGS<~rp!eKT9^qx@TE`*j+IuiP_p)E zszK`5?^RWm`kn=Ryv=HJW**eqg0C?$F_S5;uHT2THo-h0n8skq(1kF#IOGj3E|V`5 z(fWk4&UW=fvI5RFs?5jYuc>F`Bn=b4{0uLRv*RklBk_Dz;r<%@Q$e&|>T4D9=FOmJ zBUWYsL0nH+QO*xN`T$M&HyZ9?r8)oiUVt#a$*uU-w3>B4ctbzkG0C`qzT?Ep6WGQ} zU$At8Z;Ettvx3*M(0%?}(l0Z-K|^0X)d>DbT^(^aeJfy@qVHESl3*n47u43ubo8F@ zlV?+w)>A;{z<9HR4SM9)n6A2__Z8YYP0`3KEF7>iXc9V-O_#{Y2@&2JZccUf3H&)| zdx-5?yOTtc*!_?(7600V+c-1j`$N!oy3iO*{I|BzJJM-`gSD;i`P~atbZ1aa%8TYQ zmwG)SMY%-~En8EUzwwZ(?Ds5Ba7>5WN z9PM{X^M~53=k(AYP5xCsR|gNvsOQfF;PwV@P-n}Rces!gW1nU(GStf}z$b9YSbC&7 zQKwXI)C`KEDoblC(C z;?5?ei*f(??ONGv;OcE}2>S(SR>>EUoP!N>oFhFPXER4|yuYUaHwCrFPF4#;tUY|N z`_O;1^;tms%MbId=b!R}U+cnF*nDti0~GYbA~uG+Q-6GW(|oInBBH$&L)7=mGDZ;a z81}s0668yvHSMs?@BQ@cfL3Q94d0BsN8abdNNvYeG3Xy39!@M#;#BlB-uNd?QrxK3*KUV| zl}~9Zqq#XFYW#{BkY3#>8ieQJV~nk_l)$}_-Tq?{{#z5+`?@zBt283A7UjKcYnQqv zw3s98d8jSkCxCH7EGuews_&L8r?5YbE?9T2vCSKKZ+!^Bbi81A2;XUV&bZmsB@s^m zUC_P>`Qa!w`buj}>;4kqsSlr0nkF{hp^R8$!nr$`olRA4-IM#_rPDLMFBi_GG|}55 zTIsAZ%)d<1f5q3Rqt*Hfl`mWxsY&}Y(X{2jo+fGy0aarQ)~0gZ87R4u&U@hVn_ams zVYEgh>9S7+Kua0Z*cI9(euFM7vc(&L*sOn#z^OJ{a_-R#)Z)1(UHj*mj?)rhsYvIu zBEscBg|hmMqkzYB0%jGP9bUX{s@guR)O4@ct~J{?3^;_TzD5je?>`*J#tqiQ7_d`t zUv&vy6Z~o{B&7!{3EN%h#>)?$#;mY`31$_P861LHmbZT*-L~?mHiYKxB*r<+XTJ8t zUo*_S(8YhdyBk%d0F&UEpjRM(NgLqN-MO5&&5ia@GXfK7V>lK}!Jfi%gMRooZplD7 z-si{(Ie=N+DBN>*@tDm_`tN(ob-ECNNS&cnHFVA0HFjoAdmgK6No7)%mlZuapT5_- zbHjPk%A#wX+o3$)eu!xU7Lyng8mPbXgePs}(A!~CTxe#Z&G#q6kB1gQ#ztA(FYQLe z7hrHj@Ur>U-~&TRYN_AvxZfN;BTPFPNhOEdf6z|kN@PkbF|vLVPANzg6~Z*p>`I(S zFC_(x{yZPuyHxfN{A)rb>f`a#gE3~pSuEfiH#QtjN*EZWfHOv-p3pLxE}*_7v3ANB zVX$Q9uQh+ig$m8cGo8N65X_U!PcCq4?yISJ;S%z(HW%0OmTdf2F#d@ZQ%>lG>F5fn zXuiHMOSVk-8aDY3YE_(E1h1bZM5#CXQnb(5@Z9_-C0co;Jt;K1$pmh1P3x=AvrE&S zGJMX6N0xkBahWM|#3Q*jw~Rc1_689Wd3U`%xG*xtZWZ~NcSHNCOf)lUt2Amuvn&(E z7+#E;QCV9>$d!)c(=>zI1e!>S`rHA@&m5<|?&U7=OFJ-fNP}i#>L}9y73UpW(3Ud4 zuQr-+ZoSqSLb86DBV3V@k{Tw z^UL~1u5p=G<%qG!q_4D!6U^(FAp{CfQJ8rN{iqejfH6yQtZsjp+03ZbOvf+b)ter- zbYD_VQ>B%zHmMHoGP*O0c=Y+BS#RP{=yavJFD?gaDL0Lhc$-O$4fvyl(Vl&4mc-En^4^ul^Nh<(&!60V5>utL(Mn*G-2*RMc- zB3-npnb+Wt@uhU-BqmicNC*Sbu;m-I%uM%gn-li%C5wJbC~DM>&_;B5n`EoUA^5M7 zvt9I4!#%ui4;8w8Pi>wyrLi35l3b%RJpXRR&Y6l|Oy>`7Pfy6rR>i9`c}2L2{mkwi z7`pVbM?D_&8j6vYZtb;?a8sh3XLGNnzT`e^^B>*%ADBGnzK~OXs)5hfGKa*OI>$F2>-PV=AKTtHffQCnx5Mk`d z=3vn#m-p5^Wg!f2ERpxKR4ifprFzQR@S89kU?)&n;9kH?Tmo!a$OHSCdQ>9nPAY3B z{D<8)thP8MHjZebvWPZqfp&jRk>3%F>Ehg(ACI~VA?KBBhQ9=vySiehEd9^Efp<>KYFl!D*@7tW5A5j*u zST+_+y+NG*euouO(dt}tFNhyJQmdK-2wSBvU9pwFV+HG0t4$B0{b^%!kGb1qWU4C9A0%ti195RYN`;*6_R;me(orOKfnjwiDR3<@^-nAhC-|KO70i z=u|n(Ox$mBp9t>dr^$>8S0VU;C7jpbczQngwUFU3yiQy?>gbr>JD&L09=ER_>(C=@D%Cb8g5)T0z-QLB`UKB&qqYt<=ei|FSBaW`L|~t zIeU5gZx<$Q)1v!8$H6)n;#@ey`@KaEdTfr`HSM`RX+n$LMaA5j>y=xc_D-^r=XJX-C!-T`5@a?OzjZRfyu~%0sHhii593)4NV9imvG>pl zf}|CIVdwesGm?%KdCM<^&a^M<8UmHyH?-~_3oF&idES}UZ*rSLFDkb`S+KVvs~>JlJ$W}{Xv zmj2lL+@d6E1_R;K2pv6V8f|3C0#nTjv7+=J;|q-kbN>GBn`d1dgpimVI#kG@TIMCj zJo8O|gItU;s1U8bp{2CTS%m6A=FIU^D43P*!@;jH7cRQ^dzW97O!;=ijR(_-U^~T| zjN(0XZb&ei+Vjb9vi8?IS}GRL>WukKSV5C1+8|i(W!~{c4t!Os!ZkV8N)c=e@=)W< z$sx|6cly((u354t<%f9Ue}oE}oOTI!jAtw`D17oFEf@&H_+_4^%#->rzh?@+s72^i zw=-`YzviK+OMlY+DUHD=JdDPdn9zz?dpl41qdG41qep$E*V%ZeLLTZB0sFh0*YKR2 z33b%^9>h#cOwPV2&dP0;VYNw+CDX4cSZ?w2Ek=p+wRJqz9x;&0o7_Q1w=aHuI=#EJ z1Bk;SOSj9oB%l4&4YS0O!H#09ro}}TkPM#%?T$(P_mk`0OJ3fkcbXL*#l?yQRbJ0a zrYfQl%=;Fn5$)h#a{^aXI6lv)Isx^LtSSB_^jxPtNb^l0E4eN|KrUlRW3`a-Qp_7S zZW8gUeYBg(%1&bh^XGKi{*Qt-j*z@hx8^pL635a}jbYjf1Kyj?u%xQo-Xt>WOx6L)h8^_X|<_A0U(L<=O17XugPpZ_j8 z?@*GRd1bhle60ANFW$FDfIQ8_naDf4@z6U{_3w? z{me~aHr3Y$hrxPYQ7w*yE+FrcdDFq}^(O60Oh59GGifFRL&23kpave>lCO)mfcA9c zzA1Tw5KBnq*uC!0I#^q{TkdLEfj)87%borqWAVe7@r+`-;kNKE8Rp<`)S(9E=Q(>L z_%VZsnf=UbNxsfr=e?K(-_ICnN;P)4UtPoDGsRACaa*8Y%2)X5Ia7f97d(e1E`+ww z%1l+7G0&XThAtVWG5MDJz3>Wi8Fh?*z3r87d}whr{H^Q)IZQt~v_ul>dcEbaA;(q9_iN$fhrNUS6u12s z`;y)z#`yPF-osa_oHIch?aJV5>j#;^ZU1Q!LfwBM3K1Oprh6N0go=PLj%WEb z0<*Vs=4i(}P)u7$)Sq0M+B#;Owvkx~(3lhjN0z;oqAR>U#F>hRzl?6E*yCk*!T8+J zV^jsWbbgYG#3{)v7M;zDV=4Z?tXcAT8e1MWzru%R_mMq>th9DvFcY{ips6{R8grYk z?DdMoVK)HrcYib${*_4tD-2KjkksJ`?|O$1Tsyn>r+QM*?DN;B?V-OU->V%oVvyLh zG1C03tY)!%tJk!17RT9QtnHks17yKf&b)1J;wxk%zc9z|N)8A9wwQ8k#x<;qwNwhr z;mdMR-9aOi_ad(QB)Vi+`qPk&*jfxmnygaUb$Wbp9zNlwmN<72B=ervYwOeTK0L@> zP$e3cIs(}k3rKr(VYV19p*o^?o_Dt(b7T!A!~Pv|bZyx0i>Wd7BZY>M$b``~Le)|| zz0r%QJc*`#`3<+{w~-#mEBbdooxPd~VB6WqxsntI^)rnX#Vd2lW@|&>Y^itS^%CVt zcfkl5)#}WU1IDm7IaD|iePJ5pQWHc279JU3^G%Ez&$K4Fqu1xXO+P--phm4a?;a;! za}DC5;7FR>apSGM2`He9+hvfV&C#diUy(Nog!r)nE?`v=u3uuz#vV}M&sOf~x*qLL zl>wvOzcWwjtiGWob#sf?{o~>_&9UO{>%1gP?bb>dhSSU@&xCz^i+b*I3k<>&mmMyA z7u}-MYW28&o|8t#Xp(Q(axlE3X9(@IuYF>juT(+{FPkD5scKLm`#WjdVX#Yvk|(t) zS(Z+QNiAA?66$5fbTXv#2#ENmbIy&TNOgbFyzABwW;f3ib8K96vBbL_NIp zWl*&vshU^4;s&dawl?muy0Z%^RgBxbeKoc4bftlT-Z(qBW<}|a@jCIVh2xl9mURo) zl~xoVg_X3${^3oslAB>+qdyxjkK2_tgVD8INu|a`zJBHG{I;% zq%|aw#et)Z7AzIo?VN`^maYsJM#r1D%^TI+Hg%T4!JY3A2{KkL_s};uLA{5ddoK5g zZZjc5@HHK*-_cW*eCg?#KjW@&l?j~cS z!#YEeQ>Y?4QuV$lY`xW|S+7U`giJ2JJ*d0LiKw+(GdMuMOZ?EKngOe&Wvzp%?@ zK;<%vX7=`r2q{|in$g$#(gq=5NA8LuGf;QJZH*CQspK#mBC2}S!-~TGjSBfTFjj~$ z1Cu~yqA|eNuhKq<78_c6Dg8mScJ?Byktrj)7N)D_ZtlfD#K6+6{h5l zAs>9rt(88XXT+qIKKg&CZlvaKB>1RH`w{hc8{IwIta?aaU{(Ic;iR|l4|XBy+(+aC z_C6(Zw2D@Tq3obQ{W|`R5^Einj0*`_X{(9Ol_*WbPi?AsAvF@ivGD^V62~Zu!vK21q~D*s5v+q1 z=cb2TR;{=fh=9w+PD;~F1Gy!r)kC=~3f9=XZ(9YZ9#)+j2mSaugd*MiokZ|?^^IkW|+1lQB*_EyxFH;-nsFS5q43#W2{0fn6 zaV53bQuvH;sjnWd01|u$_pCs%c8BvY;gdPq0@AwWznsfm1+l$RcdMdN;N#;!NC?vq zu`=@8-~S5|N~z51>v4~5z}Qj_H7@i&mRgwN*VI=Ot`TzjonnIIy*+Kyn$AL&qCz#QkK{iQ9@(m*-M zgagYU+15)fzWis-do-s&@6Q>;APZZ2qq0obZlKQiDv!J$nx+CHC%+lI_ay78szTnanSF|cVJqN5=OEHz1vy!&@Hmh%9cFMR6m!&8ZW#+AAjV}~{n+i9QrxN=W ze$D4oLDn}n)v1h=`;mODv6pVwj~5%I?Nb?cFN~RVlXv%2HFKk77n&!$Ncoy93^UAE zkwA3*_uLt${NUwn9FE?05_R@DCOTX=oZ?sL19YK7B`y7nEx5q#{C!bVDdpAuQPn1b z%Y(tbRAYntq>o-sbi4%gwuzyK)?Nk+>iSng+uFWuxpI^{VT}!ULgs3G%6RCtlXzhV z_VQ5z5^}?6dc%%Rsd{yKlEm1IjXA`_+dl{=a1tBr#yQBe=<*ELy|%A+KiQ8(ruLL& zw>f2bSTYSvvZs4dkr1w}V6-Hj*UW)wEmsj*+Tj{@6hKyMf8~j)qWl#!)dd)ZQZhR6Q5_;OPx`w9pwL8fniAWVn4w%LmeVw~DoP$1y z4Dq&?kKbgdi-?|!k_wdZIEL%eH5Krq)T&o=t)~%0BLeryN=9}5^bW_%YCX81bPEZg za*U=TG6=VNdO9**(C$dG48pfk*LTM$bi`@w{MmN=CT|n^ZnU>eoMqIrFj^yEri9i( zO|R9R8dtVzwJ32hd)inl&ayD&8Ve82%dzG@Q`;18?AIP^?&w;|d1NglL+U4*Knra3 zZAN&A5tXHS0RCxtGRcrm&qpkpnv2Z{gwH$(2o`jqsevB9srwT|tUY^Solg~k_9Ysf~CEw%YEJ4KjpK3{LvRwGyn z4#Lv|y&)ovG%_=;3*|-es4fNWU+c-O69cV8dWF|Z0Et-U_`m}cqQ;+YU$ z_7)e3UMiei4>0L5=^?Yz(scF#ND`sy#5Zt3KV3GP1}fDHaO*a z4(O-$CSw&|Ni6;ZG?PY`@wSoeyQ7c18#>&KiAU~rDOs`~s-5*QkW5pjtj)T-HJJs? z`8udS3W*pFvh4I_VrwhQc+aOhGrbZ!m?@H3p}(G#Ilm5FF)YRJ|5({-+@r&5Z2Y~; ziVq}odD`(H_?;E^@zAC8<+&Szrj#y;MMoK%KtwJi@>xoCuVaW%*>PMT`g;~hcH~9} zl1q%k`W}qgNkz2D;P(9C;i5TqtK41VWsxESwf9^&LLs6g+o@KnUKvx7`lEDEsBELq zF1(?1@A=IJp&YI1u`WaiXe%Aca>sU3jVcJdSfFZNIJ&99c?qXU@OP{DP zFk2-<60x+FicBPrPFR?#237kzZ*msn%#Dq}$4hTCw2XFcY>NydAtlCw9GR1L-sYoY ztzt5l-{1YY4}?&~ujUIgI?CdGJ|Orkz&2JR>>rrE+?Xq$WIInxv zTr`r+za8hMP5BOPfnkvno~94&`zbw?8{OS2chYN#Hw#>7>L8?iT!8DyI~aUk^8S8J zQu#hx>&!;?+j%k_!H%-ekP{}g7+Ve9E+`tFo9YzQmzj?)%l=lTF(Ftattu7oM?=3) zwsZlXH@V)jpv}U&WDm6u)hum)QKnX2v+wvy%=)PxrRbC{Ha@G!Wwj^SC5(u?^AKgh zbE1)t(;_Pa>~nna^M^c~WY4fA@hnvkZfi z)5IU#x=)e3Ixco9CH>vHQV6Mi-DpS$zA+HdT!?05H1=f&oNXj-FqglcJ^N*c8tzr~ zX-gAYm+5!HD*wnd&Yvmol1#ewKj(N77V`cHzgK&pu@K zb=l>?KjsRXIO6WBRid2~nJ?ZizsI@!acD$0cZ&O>)I9!q2!?Js;Pha_fE}lR1+eh$ zZ$L*Xkel5Z90@iKBVPt>WJEMd(Q~$h()aK7;^65QdWDcAs&kn1+e%9a_vra6FO#Ah zi{rEge;Mj7PTAGAJbx%4G%pX?YuvtLeh|2puFs|AcXWDz=hKTv3WD`Z#zGgXS^z1f6kj<1w98V9Hqj?~HKm0e>DzkIz**Gh&_ex|o`!@_24JThw9pr`~t-`v7K-7&P}Y<8!_(_8?CmS2xjN_H0>pquxIJL!gfAUv)LLm#i|Y4 z{ZN`W{CCG1P_eP4{z^ndRK@Yj`cdzg2z7Q#)wtqk=0h>i)lo!766|Qt4i6SF_%iXoflF_$-n+{~^tya{5fd-Q3v=xFLuJA!92| zHR6@jwK+4rrih9d|LE&*Nm9BC6do z;r;KGhCkGK3)xJ?3foR;(zr>n)_dsl2FTl~xyMiC;=y`N!)BTQH;M3v5#kCteF?9j zkE!@*?9Wh54g8}8$5{z?At(O`Y3>P*m7~=~QfJfLMZ}$_uaN5#-q|v9@gw`>QIU6* zvDZ_K)eCz9(Qx!Jm5}v6p+}@2?8l?Ilyl9dPM337zdLt{Yxo{5I_!yvh~sq>;>a~O zHB~xoeb8#R$IZB3DPKQND05zIfBx&CtfTYI=cKXKbdTtgPq=Dt(eM-Ip>|kCl5Y|g z2XEE%K;?taJS;B>Vo9aesPyMX|2N0J*cb1$ubJ?FN-I~$(etF3xMwGpCzaOAV{l_! zRax-r>a({fT|@q?5=wUvzk(qTt&{yq)JIRuj#}P@5CP{tEqL{*=auxcCy`*gLP0{APAHk3E$ zjVh28-6HZ^L^#6ckt?^~Oc3B|UO1E7h3{){g|y+Zl-Wd=?&{MoF{+5-T}mTWTD|x7 zO`VkX$-Akeq3okGU|0P-g_3#9tYHgF87-!Aqf@1#b5HJIU77FmiI{B)8-KEE5Ioh8 zAgfc%99+beW)cyXjBGqASW~bt_g1cme&j~gi6*;^;G2|NJ=RU-U)3QaC?m#vUnX(K z|Gb9(WY6#*m~uOKnzAx#oA=(*Fk1fYCcjcD%JLmw#ofD(9|hZME)SBpGrnWTA*F=^ ztquH<>Gr?{1z3t>_*vmT|N0xk`(x#-OCR#Y&ZVN!+*cJ&UuE5=l@LUDHSZu#r(Eab zYw6)ymOtK2H7UU^y@WoYM-10qyASqnrS`#15}AtV=)X#PMlswKx)bnGgr#^4UKm}W zPtht##L4j@6da8eK3WsxwX7NFD!aps0iD{yu{}h1ZKM{Mc)F;bjQLm4KrE1R#AKr0 zHyuvQctTb%&r-%%#3zR_Rw(Cd&ffrku$xmiOZg5Zmc(A`Fx=X{{HQS>V`eB+H(oev z*yhhe%xiX52Yj7v_xj4yh*ig6HhHkw5nA~5ug=|aTr^&M8&35C_96K(V>Ejm>xl?he5%xMh&w4DLFd<|Xg< z{q8;Yk9*JU1#8Xf?yB0gtGf16dso#HBq73rljhW>-tqah_BSQsd`NxBNOo|@oWL z$9-|!YK4BfHHbVGD3#T9;9Pl?KMv(XLyj^2hcRYsEcC8s@qmzA+-b(1W_mJ1?J)2*-k9;1PNH2Q)t^`IqVV33Z@EtUp)A&|QG?p{eb05< zdoB+;n`bI5goV!QM=asZm6e2kH!F`nt}#p-_k?pl6T%>StuBoHvKf+<+3}nt8+D4@ zSYM1IedB z7`xv&Aj`Rv|5x_OL=boODnc8KOsViKZ@zI@k_H-I?94fO`en_-F|wq(7So_C#*lo^Jz`dYkCO` zh1O6oFi54s+>|2n(D0;FNNAg`-gz5VEh-E<%$Ejq98bj4Yj9`S*(9Y#464^*ploSQ z_(<=@rI+^9mxjzz_dx_AR26z5yqHrWUWAffaU++G!%D#}0^kwI_!`lIYfMl{#9@5| z{~)KVMOx`l8bghd+Q^WQAR9d_o}Uf22eX3|87oQ94apjc3TV7mGuYlqhCq zB)@%om=Xo*-ySBF3^VU6P8WHH0sFWf)!xE{vlkajuko`azb#vlkzwfMzliycbm`sW zS;$V}NFyPyPv5eVD&77H1(9{@HbuJsQmc#TTjv^>vcuUNVK6y%KNlT4D{{EG^D&`Y z`ZtUB;Hr5;Mip?Z}(g zz8Cd>pVtbBuCSvSewk%QX6$ye{A~dOQq-fFS2l0jnGB%ICa7n4|Bfu$WYy8Y_QFfH zN#R2&`$kpL9CmWoyf9@9Zjm9 z(ky=5q44~)wiXzj#hn^%o+?EP#tJn`*TmOFerVebgFmv&C*ix79}gjoCF>275u&q?7RGilO?+!Si*psR=jU+nnnPs?dllQR6>V?#J={tpxk88Sq*KDOz(W~Rf%T`i>`~`1!T{n&_zpeOaFETO-khjamZRt9U zc7%c1z8Cah&S_G463J3K<%N?6#%I!suDd)_67UQw{Y+E;o@G^ER*pOz`ukZm9w8I3 zTwqr%#3L&sce+K?nd*2&YOI&=$M0xM9>2@TwkbSMtU^_r7Jo4xYma?_ECJ{nL^y?v~Kk8MXT-A_jm zEmXPgDR7x=&}MGJa(1#s+KF#5`15+-=bM7>_6l#yP>87D`HuF@k;a|UcJ3?kkk7c( z&YoLFqbM+3x- zmcOq0uxPr6Z(hsrZ?aD9M_vdZx<*W^_?zeH6e$%J~+)~ zIUkXPQBBX~MiYs~);SsWwp7CbO)>EKcRiv5x#%eIl0&6kA63a=ffn9Xhdxh->ZL3@ ziGhPfhU)3m0SW& z$A3~p;3Mv0`J(jwi6}N5B}=oKv0K&L!03+|5*m#bWkHV+ggM#Bm+4xJ#=qU4fr&6T zsESY6tv_iF9OEdKa1W+&<0z5zSc<((Rj8(@9A*vSs$VlQ9QxAsq?Ex=xI794LSY4! zPH{?kMQ%ll)m?7FbkRM%uX-c&;;AwGrt+}knNEeL?3YikMKf99XM5Kme=D%cq5ZiZ zhA{llZ^;0@mcHbaMEW~>lor^k`uILwMr|!_@VoGzP>PfVed--xZ!0G^ERXhAXOh(G z)Xgo4iXx3`v3#tQ)E(!(5+TZulb)$!sOOzii_byd0>*O-Hs`JzqK<;1cFAhrY5nQ9ohiXDD@1@bY}}^>0?B;G?;nyA23G$Yn#g zX0>B)U$lQuq@vB+ZD)q<x>w@tV+fkgi-pR6 zb_CkH^N00+4tN{E_5-oZ_vJaMmuJdN9O`wS1_YL9hFvgbAqxaC-Zuo}h<;WPV&8(2 z=0P_ljeLqfr4*y^BA&#KaEj*Z@G29J#Wx^JJF^u#MlMu_i3swGIvfnDsEVYIdmZ0% z-{8HUQR2#&Da$TIjyFXFJ~l@qeXMDY(Qj0+%btzVSZHvau))a1r(HL1N;dqxrxPqId7mCZvyxF7@J)rN>H zyS{FMW6^x4X1)9?LT~2a3SY{`n+BM9mms?lzYRFbxOFH8A#v;}XeoRb(&sLo+!;?)dtNU4TPcwtdC)p-0a zGeeVipNRNc+aCKu2oLguRNV%ZE{L#3)~||3VW6fH8-q;AXu*2`XD^^@7m=OteW)G@eGM3`P-TK2+?hrKjnI8f6UCP zmSpTSxt3Ds&x`!o7R*w9ITgw(M7-ezv+CR^w>s24h7Ywmcf4_N?IdP_nVm))MDv{U z@!@Z!*(jh~h#OBsp{RvJ)?Nb0{e~DE5y_D8%sSGVNFNBrwF}wGw%z6mCTanA^Bac+ul)hRp zzxOHY?vUi&`Q8(g<*Uh5A(G%!=bX+R7d`!P4O`MP(4nYS$0)RtmZou-Kr+P35Uj;n zO7-Gtn{{>kj!*O}?6SaXgu<_N?Yu?tl#oHessxQyv(~wf*kzPilCwpUA-7B~2Zjx# z43L%g*v*77_i*=`u>>_b7sS2EJfLl(FU+udat)BgFPij`FUVISDhdaNxp8QOv?Okm#)y zy|a*3y1`<0#qHOTI3dV&97dd{d0WDz@#esi94F0M^#vIw4>cFp78!kPK^!Bgg+iJ( zghbpA;)qTDvY#$)eRsFdh-AW`qzO-c3^@O`QYd2^SVp7{Yhz=Gdt{7xm34?u7eIXfd<_r!HC98w(WmX51lZo_u zIDiWo1Q<*6h|5f`_8@c=EYtcggT=FfBgL{8+?Y@gfBg0X;aCfc?#VFXeCaaI+)%ad zAJ7!K{J7x}yjBj@xmHvclD|wr6cnN7NSjbl3ux^zBW!MvMvr4T6WFfmjq81Lt%&#S z8DbDVKg}z|X#D;0%p%*7k)eH^)bQObi)kduym8=6uN?36kIWhTG0DEV#lCa#Li}xa zoOB~DGD$+*Z6#q}SKqwX6Qd>$|21N5y;d#Ta?AM2*oHlvua?$6M*S}MtG_BWJw*ik z720E~--OGW#e!UkZ!tBul*T_IO@ zxB`lRkmsjKHdY@5U6Na5^tCplP_8}f!h~Weql5n3e;L~4G!}BM4BlLJYBhfg|1ZeoVgI z6{jqh?1F>5S2;NBMDkO(PgJnXAzl?a&uFdln0X%4?#l^XR*Un0D6PMl$$OR6Q)IkD zAdbnAuIr%vPE^FNR-h&Pr!nQbEya1sS8Ei+&UFJP$8H2seN5J){k!;mt3DI`T1+7p zAJLdN6>?sRci|xL()gY&QtzZhMF!2GbsD)RF)dKywDjWwp+Rid^l>ld5Te{9IMW$N zWkCOgo3Rk`@Ka+fhyHJn0bxHXU|kygaNRGU(@26ouKRS_&&pqEomL*(aSx;dvuI^KNWLjc=2nxO6<`UXuMrei>Ld>3HJQLC7TUd&*|7Q z8sSPZy=6|9d$OP)aU@D^FoE0yT5IHte#aJr9NT-$_6;kAKurtIV@x z5vLLhyH1&t|M@LEQ3K|=v&3d5_%Qfk=6Uy)1eSW&=d3@KnCZ{_YCh^EGCMF0MIfP7 zc~{u|(k#XH68elMDxM>G6||8)adW#7?}wAF1E0-GS-E*O+;xeYq7!QZmiqZ=HjG;e zy7~#z@#3!1eSg_;!aZ$Y*hj@+qJ~i`8XT~XiZ*B{O8phoKFew$n&#{;I5X=Q&5MrI z2b`xE!-$elc{uzP9D^3q25hEy6ayklq#<)dpOW7h+(RR4=Rm5&+hL>@M}MJJW2*<& zhDKjy9ZTd-OctvH$MF^@w;GJBMxly-b+t@E>d(2$>F;+hn`F3Rq8bj-2r>36mKpD z`UTF)R%#a-z`3tB=Q1@f?7quollm=`JvTvEdsM74WS5NZ(`w$(YJM9QX5(W|l7Mr8 zrnU_y(={U@&S{@SZ2wJa-N!RNVK4iX-Kww|mHSV%OL~%}fOy{H{vLe}ZAocyhQ4=$ zmT!2qxR2N<2%f!6D2`&jOgxcD7-Hm1*|Dvgwx~h{r?SDBbOjt=ydX?1X2Zgt_dIw` zgT;6A`DDsQoQIDp-rr@n!IpK7^DyiqEeC9P>0oIhP(QMQA2K%BpvoI}n!ymqxj)=u#LgKBgf;*ZSGiBV^1F-%4fgj~@KG5b@;eB#lkR z52?}a7}YBO`i;hyW_Q-;rqYeOR41D7JC^}%F%v6u-lfoUbUfFy6!vb=o9Y(xujlLG zZ`FD|!f)7wo;-onsBb};^GgRjeosxi$Y&mjO^O%iMul;5HdLxsYVvLPWJ5Pm-V=B2 zgz%;skpvd|1s(s8am(B@HGe&*=^K|xUwT@G+1*@o_C3*-of`&!Cx$5&*T!@CW)joT zzYj*pkFwnd3G4el7y}>u{wJLd2KS`bN{az^Y3r}MjG(XU3n`%_RaJWTYKkd0ZgYGA zj^eJ(vB<|F7*X$ca~-+cB~0!yKB-x=M6tDd$~w@eionlvFpq{8#{pZR{F0|JmgtU* z2i=tr8x#F>u>7+%Uu`!8PbiE=Rwtdym|AzRn@kO>eq?bN687uk?KVn#r2h zxAjIJe~oXJN9o-D_-@?F!&varW}0VWT?%Flo!#Y}U6tcV;-NFhwBwt?Zn3_xD;}|{ zp=2_PwDK9cY;83MpS%e8@-n((HUzx#Cw>-dqwU)WOA@bTrpT-phzB<&gwj}CX3)5O zv*L}oh?P7$XENKHLR?9UR*8nI%5G;;U$)vfYlWZ*Z?Bu%xd^!zx!Mf}I>mYx0`TBP zIbR(WyX?cc?^6XF@rp}6ck-d|gWE2H5v{^7LvJD1zBGH$s4a{q#<%<__WVIphor?! z-1NnbJ~{pN_`6?sg_LlAiWV`|(0}wpY-3EK(JlhxY$sD=WENd-h#lI>v?=@1$Z>MF zYih2iHgTs1o0Rv-xgO=t_L7d&%pAyHJ`Ww#e|>30uh{-nYT0)Er$B&tPqO7yi?rEZ zEH-Lv@fY`m!htXNA@ZC^jn$`+S+7b}#J6m7RIPga!uBFc9Jm{UJ+SPQ!K z)(PQ7Z6`>9EeH|z3WPbSjgo#OU-vEIXDYsa{rrS*-*G>G2faVU8N4)=1?zqJhEsMd zy6vOjM@TA@qN#4Fv?;L+lLZ>9_L|yPM~-|xfAXsp zYxmME*k|MYbM_Fploaq_b&0$*Ng*fe&R6A$uGFslge0=nqa(6}6{`^zI<_7(HkO2b z+>P>`XohIOr#vK3sAlOVz;WR$8j4MWmpoL~=KW@L30gThPPXf(?6rl9xw!r(LjHGg zlWOkRsV8lk-DtyA1ms*Vsd+dhs=gAc$;F)Bl(PtItYd|nDr+~0hBPNX_r;wcV3K6> zyF}_nH{(lfla`$_QNE?HtGku=B}=otO*rL< zmR=9+?25H1x=GnP^=vL4JonGETQd9#aj3a*R4Fp3FhA(|`8pc%(oaM=*C9i9JBdce zfTLqn6q7}(aWsJgue###8s_tm4}y{;VNw-e6PGRB7*(vS^)d@_rG4_ismUZV8gR?! z2v#{y#UUPkHQE*f|E>l2zL&cjvlNR}4`$Xf~sp;2WPD z4}I4#-o6ES{Csi=7N(1qFh#!Up10CwVmI1Kacy#8JY>tWi6zYQOC;8sQo-KyDN^W7 zrSTe`gfY3xD&hcGU!QGmqYHBjeZb_tTwX zUbn1MH;!4UVCLvy7ah0(M`O9#GV@gYo=Q@=OIDKj29nLdTjxCQ7I0(8pA??Hl+z{8 ztX3rsqJztnSv@TO#xUOqN?hhN_`Xl3BiR7Heo~M6l%vGZ21if@{lwKM=jf_PubCi1 z0%>%-@#hjp69Cre-2Vf(CkbKj#=H#Kl7(wFh z#MJdivg+%2*`(K1x9x^-!}H7KuxkDsD^}|FHkEhjee(UEwdtlT1h^2a*2-(TX*o@- zS}s4bE}xp|q$WO1F|6(Vyc1k>v;PruB?^(Y{3MI2|H9%x$6!+^2fY+kj?$Fa-KAhe z!`A8b>d{r$Vz!5$+cy>!I{A!^h<3qQw4=-G%r(4*VWK^1t4mugr1PX3ACU9cNUULt zaOZeNb_X8LS7-h8Y@cqy?H=-ukPIk^CRtO9Ml2o?GFYehm3g)egg8a@^H*< z**5{%;*#IxH_*v63NSZS@W0H*l3FL2Z1Pp|H@n_Vdy%L5yyL-3m*+|69ZquXHibPB ztb+HlV>~W?n|Zq9P#!t4Bj37@lE0z4=UNRGFC4|teRF8L?^ZeNkUnBdw|XoU6`<@_ zTApoO0ryupx82SPyqxtRP%#OIan8zb*=oBq2Jn9RZq?EvUnfILPv!m0y}lL>J@Q00 zti6f%TO+9H{_QTM?Uf-_Z=?ke+q%r<<5jGeGwWrd3G>4*87={kBp2-(iVFp|kz6stih+J6D<2IC|pc1;E z8~kMp_I8+Q3=&yF^*}T4CtEs&4BYQ+)>#fbe*6j9YI&DLRs(EXvR>SXT(2nuj)K@v#%m@m3F_%r#<0~3R~*`ZS*wI z*&qf#T3cA*#?pkZXeAn*iW;6=9>{il%0FP>#1G+C#eteRJm&@~ZKY1f8BPAorH@N`zAG)Ca zdaUJFHiQyfWTDAWzc%$l@!ZkVI!k%z^G<+yLX7FyQ~QN0Ik#82qEjUt-T2K*NA+|K z6R1A&kPZv0A&z=;hb~i&#QYmAXr7a?QH!sI)}31lD;(D_ggMshv2)|^jdT$t5KUs} zt)h(r2JD^a)ZzzN@JKsiOctM6(}*KSqs}(tgyK*6%pZ{PZm%8L4>BXOX(B9G4&EVX z9ENEEDVr@K$SW*}!y1-2;?A$~oCdo$aLl&GlB~i@WDNGa|ZFeul{fmw{%BV^T8Oy@4F}` zqGUGZK5Q^IhI650UZObOCZtZb_GO|YFoG;OaPE|I2JBu|Q$gafXN};gj7#V$)r#!? zAvf5s^k)1520+%=UhQk*JlnZ@Cn`~Qyira_uFg+>dbJ5)jigznl``EZ>bDj5X{YhaHEib^r+2=58NtK@x7g5-TRE`_;ruS=zh1SME0UdFtAO`-9f7M z`wkOWbh1?#77LM|;n0qGEAt)lXUFA>hrwVY@$sz^S1}58CWyJSv^*1r%?}0~u_QcA>>WHqsRHi!2 zW~}7eNhh3H$7GOCcuub6`gD5DpyFLrS7^?-UH1N@|I5c$g<+)ql~Xh41nVKLxsH8L zWrkAgZYozYf5h$#*~yun*GW6%wNXBKrF_Prg#rkOxfX%Uea0;niUY##wPiHyUv#H) zQ=aDAQskw+O=7Oh^RK4S>K;2?kb*WA3lhra@O}E=lXAJ>j6zER*d>kpVdix}JQOcJ zNV;GO|L9o7n=#i~e7!XqW49#DvfhhNL}41k7-sMd$S;Uot^?=%)d*)yc_PSWz!Zkg z2iiDzQc-MMw0K)ZiL4m>(UWY;MY9L)(&3T5eh5Sos!5P8h~3y3byEqWrKwDF84$GY9AYL+|gp`jR5OzU#kQO9{rah%e^%B4RYj{rr;7liEFtHO6&)=JF~Px$dp&&_($d zUQHy$;G>>Q#uGiw9|5_ssbR`a>I`zV-xwo@T`cYl3eJDcY5v?xlW_~GYGZs6-9hX$ zERLP@c5WBYg`S^{OF6gRjnY=bh@4M!j;LPu#&Ou3`lqH@J{m%2X1u<@y|^QiOo7iY zv!B`X*;N_W_~Hx^8>L5REh6v-6guA-VE++87cuL%N6@DI!7w4J60IEnt4f)`I9B;} zBLDBV`PvQIwRFpGSWN)<_N0_$ew)Ptt0U6ULeWndD!z;xpM?#-Do>HFpy)(vqet>$hR+^6VWe?-)g6|ZZ(4(Gr6yUB zUt22MF>-6Q0GhD~KpCBZDly*SfdyZfcD5r*yYv7V}`OQKzW?EpM2W1HsoX*ZgR zI-Y(({wkh6%2nNXIA0bYgW~!Lh~#URgU8M3YTd3!AgTBv>g3M+rtHJ^v(Gev-XnO_ z)9(nwFa*Bx^E;Wn!PE2xm4>CGF~J~Su87rzbF17V3@;>?yq`qn!@Blvl^Bo0nAJcj z*;q6b2yMGrZ@PbpV4l@goD54-rFt9pd4=%^{0z3rWA#3Njv;L&!xApd&lLcB5uGJ4 zwr=Y~b;C9!3{(A~QyoKPJO-!9G1wc{Ka)tnN3hDrm?`3a^T9jaoWW0mWi_($cj6r? zu_>dqsl^Ki0#Rj9jIxn1AcUa8Dg4Czvghz@+@eO4j2!DW+&%0IwBHX1HJor zIq2ozdl!y2x7;Qk-YHhQKsO*5x@u_~g{bBuS&tbGYzufhn>m|j`995}zOQ)lA`1B{ zo4?9oEefb|WfjFkdA*#LET7_qcISHl&@fNo??GSrHkawloxiK`z6Hh^i@nsFon2d? z1Fjq$3AT)@Ay<;o>3v{L7|dk;C`4_82mBjGE@%hgcenete}jxJ z=NVgsQDwf-^^e5b`x8sLl26OYJFb=t?n_eCca<3+6h*>hbAM`s(+Il;=-1l%`NDo3 zWQ-y%^4T=0oH-Y{LIQUSk=mL>BvIW_#4DgQDZfQoU&fF$P|P>C{PXz%2Ka0WIipGR zrAb|oWid1k;8%-{*J+TmV0Y8-7sbmXWxh|GuWxzdTyf)Hr5nS6;E>y0xmyCfnBg+` zYm~{v=X>V|ih3U( zAAdAPK_Gi0{jU0RW=0c9DR73kxL(=bO|*azJqUpgkjDR#o|hn+ulRk&v<)N2)}LA= ze<_b$X2XJ()1iIKzAm^wR0rQ()Sm>r=T18n zcNe6k)1Pqf)XHbtsz_kK%|d9SPPg6yu#+4h2Ke$(C|k=zOstvt?$;C7gE+BDhL%@a zej+~DzXsF%v49fVKAMDY@!}-}1~De^^Xy1#Zx^A?YCK61c4JcXXK0we+nbs98sW8B zJ2mIfYhBE99`T6?Cl$^VzW8!38GW{6HM+2I*0*t--@umzDDC|~0{5DLy1Pd0g1{W# z9&g%0wexSYEmdujh`gVrxrmy6BDXpC=Sfiz*Pbv^Prj!DXVIK#$pH7=4NMjfGRcv574I63Rrxz<2ef#O~0fCEt zGGbmQ)mppDog$eYftn4ntsDb9n9cOdNT8R468!@I%$`Pm(TVQhJ)*Jy5Jusc~gZ`=gBOtsY;wP<~NH9(1P5+VfuiGF$8eL6H$=n5PmzkQH7?K*lrC3jih za(Md0RJ+@NDwWyjr(clS0C55)D*6E zTvp_gr&8IY);?4I=nqx;;@0akifsXC&iC&DAw8JZHa63C>DS7`x5niN zdR|*y)>gBnd1+VWlvb~~GCTw_zm$^1Qx@x0O3$~}*zE|pFCHUHjm_KN>H`c9(;BsY zS_~S@ESl55enl>LZlpRUV4$NHl$I*ws}@j}E;9T0bVJmTkdUnHh=t#^zZ2;EQn1ZW zB}EWU^@bgz=QBzzsoMC+UIL|AL5KCVuu!(CcaYsFk+84sp@(jBb92q4e_wTVwUwcv z;g0>YVSq;=x1~ks^!$8Y4>bg!?@^(ArNoF+*@%S&oUNU!NDJ{05V`~#B?A;Zy&K0? z4i3NelofyW#n4bmb#``gX?!;3r3WkgOxBLP^G=#>ZC z+f(=R^OMgOr?L6)A?<)q<0~ZErH8n7YdNU3bmVuP+T`R*ySb{j0Qt)J_csi2(m7S8 zQrRTX%cZa9tpJG`rBtievOhYJhmb?G!e zOHGW5r$pcyA(WW{o^iwBQTpU@Sa8Fv{BXP^<@(ThnvJM(>tZ^ldduz9zM4U+W@h9e za-wX`#N=%8Rbbb+?@m?574r^{hNPw}1d+L{F!iQ=KY(jw<~d;Tjd{RKXIC_cjo|pxZO;2w}08D|gYV zri5W(VU`XK4!*WDdD4uLd=ff1f#k|7gA1@vpAZ1mdp7Hzey9?u4 z45hjTX}AdIQa>R80Q9CVDu&_|N8p|m2#0bBh3>!{*Nm&)h4}h+dv8^4iOh?<_ zrqe!SGOesrK-f#jxwQH8K9#>7;&H}Hv~~ICm{T2kcv85Dy=g&CNMlUrGRr@j-yi_5 zwCOqyXhtI>OrW%lK9%|;irz}wT-KvW^N$f#3@$Q};`2;VFwmEj4%Gm00Jvfmw#C{o zE2I2lZqrx+pdAFRy(Ow`5UbBeXV7KKorbb;o95j_sxnmB8+Km;h+mV&I?I*X-!nWC zt8DH#a>92|)4(s)rG4$MYTOQWJ}c4G4pX)vD{mFm&>EYbh1tSo8eED|k( zsIqE_kHW8zhflX`5pMY!0?UR4`$t!12B7TdYqE}KF6dS7+W?u;de6&8U z_iuotJ8{c^YB6urHi8DlI=^(Yll3yHE!Ur?{jZil{@E|0jp5g83OPUh%i;NrJH_{) z4YqK2c~SR!$sIe~CuG zPN7(|wn2VEdz*j@0p zXigzzVz{S;DSaewc$nD^>Xb(U-@7d+*5icr^bmPjwU(o)$>8dKuxEP6*Fj*~rB?Xh zPwVzo>&5EK9SgkeT2)0Qd!I;k0d$JYQge;&=>U~T!M@;C1^_`aYHXXO07|M-<@|}TwEZQ9AeYr%C7;a@jQE{5 zs$Z+tZqJ^q!!HeH>==O&^85D<_~k=CU*ZFbx!Ul9lLAm@{-nH+t031rQ#n7j$R^u}j(N8_r3+oOpTpqSt@NeSg#dvmEY#h0T#UbkPU=4Z z*605azJr2j>&_c&s1ZjO_cmG!I^*#eMI#lRpPV42%x zMt)yaT}}5Ym=-5=_Ai8`#Znd^sRbacmmh&e2I%+DgU9m-LTTmU;SsIVArDYr z0&@o7?>=jQ#S_3U5Ls#ycRE^tiAE;)UB@0fYw+IQ2LKswphaw%4h>5fHsC?y>{h*Y_*#ArBe9z`@W65y-SyGu&m?^rGab|uaCEzijBwYfP3 z0gOq8ti1rHa^PWXT^xD%@qn^`N2>H)z@2{p<^b>mfKO7!{Jgs4QHk7(-oPv;rVk>z zd{8$)5mgo#BlSL8iAsge)zz)A#ptk#)X&uP-L(F()UKT;^~O#K#`X7blQpfHL}DA$>18y&+2Z*dUa7R@7uuQ^%gK$ zRx>fl#$ic=Z;(3Y0TZ3)&z}SHo%{AX+-uyo2B_?iTNjs!$~9%5;V}Lh%aOTn^q(t^ z0OAI=yAQdGiHRLu-`dgy5t5L^_pBHJ^`MVarjG+E^cxuP9zS_9E6)t{96&2PKwk$) zRXxj7b=e%pL$EChOZUU{eJ_`0%<>f6x9gj)Pckv^F&TWkj;W#E`#vUe<$xwg>il|Q zzngmtkmxz}U7yYjlUf5Bh{lSbsq$6`EopM{!PkE&J)BaOU!!E*@@ql+@328%VkxzK z>UBC~57)Sb){>5zMBD-g%N=_GM2zXvM{9D?5HM-0I!S~C_u>h%ni_I!Uzp|bx*Y08 zwMet^tyu%`zmm2h@4~e$%4y$$zHdAW-GK8!4;s9}otB3Lxyd_fEt6pIkt(Pk|5z&C3{H z=Qd4y-_rk>uG4#K+ec=C|1^G|w@7anKjI$<1Y`z)sWIt1a87L)DH8|fb2ZH9zYu5&|7nK} z^~osTZ$YAGUDHdeBJ66&|0{F`1gIz9ROAFHBXmWZ#`fwUa7>W-q3qXy!0jzGm9Q%E z2T&ste7|7l726g!3jxf}1s{{^L7t3Npyswg2B!a7iCAo4MaXUdw7a54sbR_Kzw*74 zMBT8qkDH6Y|8J&AfN!_nx1PUI9#~d;fBCZmUUO$+?;%ZgBg_X;`o968^4M=uf_Y)5 zxz%S=4;uIP1OuRn18bxth`iFfXll9wHng7#HF=N<^d87KX-=mSHiSF(FU(pO6a}Up zv(YZI1l|C->HXE#1i(xG{EMEu%g??yG3nsOo7G@+yBP;?$z7Y#Hx&WK=W*-R9D9JfAduO5NaRVsuvRIUC^c zcqG_tvr#uNmR}`6qG0u`XBB9u0mW+7+KGQhCD#3VcFNuDZKdl&Ozqk|j2`Cs|AAj_ zeABgc7RLZTNg=?(E$MrAVzAz>fe!f9vK#CMEM&C!a{tU&$|(Qmqeib|r}R0a5gzXb zPp|9N{1>A@*Zwz*V&eCI$0)BA$L83A-^tzscJAW;d^4(#rUM{=W8x~$?7CV6ciKyFj1k8`wk7Q)&lSZCEyjySb`1W@{H5%E9z3f z>7Qd8lo+)68d5Gc{_bncNJQ!!-(Lcsnmk}ps8fFGq>9#h1%JNhq<|xU*F1O;P}z=Ax#`jwyTolf6Z@oZMLe4$mjf1jY;Z$0`{TqaO9bO?cvIFKj6|@ zPj1O5DDY=bdT9p~05ElLxM>u%$^!qkzi(SzHg>&MqewTq32SJvP^F^p9Z~>kZr%yY zqkw9iZc@AN9$`-(^3R9`d`k;=ZD(22_A?$}F54z^@d#ox^(}xhH*M5MiToZ*0Kvt7 z-uCDHQ?PfIxbb($>A`b;$RzRJ0*0GpGfDktCd><%5evzmkTjKUuW{{vt&7O}dn;$k z-aQ6OhWj{oQ~20hZT~d|xj+yrheN=U5lL;lwba$zyJ7xe>THM(m^A#sc*wxijLzSw zI6XRGy_RRlOr-wx77PNmduHVSIy^@`l`?AmDdN`rA^)$p0+c}xWae93YCfGj^&g8v z|B8%mXWF@6Gq51X;XmW3qT_w?H1z+LF$0Kr#$i!IL-}BA{htGtnc&Oeju$KY={1${ ztZT#EgUoG@ze371uFH44j=lfm0>B3R2WQsyLnYHp<)!Eks#9eI#2!egl8OdA0{>b- zdK^F=Gq+lU#^Hs3O%32wGAdnA>5hKz%z!{;D8NgTCV`ND(H6)XFs4HH?UnrfzhwgE z_pC3eDmeZhhYpbJ2^pO<#hX~YQdId}p{dB&|2zQllcq==F7ftJX^h~l1cH+}rGM3( zNDfHZlQJs=_*?!Job&#ct<)UgTYB&WfikiH-IDAU&z=HapMP;sh=e4nH|66B*w=pr zcnQcodS&gFy7w1$mHl&vgU?Fk$J~GFdkEAQ0A41uRa5`-o)SS)Soz6iuG_UqP#6{2 zKW;e?$QE#}j7g&74~RG%%~J#IKyCLIehl5X;j$7Z$%>D_SMSCb?`}v}&o58Ft(PId zP#}FY1B^CaN5K{+H^;5&tXQ$rE0e{w3qPJsr5Z%{1iim+K$XgW`wqxSOY%zNZVeW~ z{B1iks3m`UTT+VDdwP7N;P$7~?#G>Zz}Ro^?2Jt@xo!7POMiyRN@Eud zVEKn!=Ie8*w)O;>8cvAaE!y|t^(>wN$L3QEE^!&yFz0<#a^+O56qsbmSA9_aa& z9)FKsynGJesy5f)>E>vhFV(ZK|Bpym(eN`c_vv1e`)NY$VzH*J2v~-|<$9NUtEe(8 z{@efi(4K9+0jhjhAmm>_>W>5j3W*5_)ERNcg$3_y?C$0gfVj+QD}1_eu3?S&KI%$` zg=!6Gar==y|6P#fBOA?vR7!A{l{AbVgV0HAc`}0cuK?NyObnbHmON#%6#)%MT1a$9%xU;5xt2 z66@RwR4_r`t6W?A43?kg%P2+g?Kv%p!sb$85S5h8bkW#ZsU2bC2=x5p+vt-so!03Z zyZO_}U4yCqxPfO=061_x*&KCD$Q!px74p#rqHdLbhg{qH`w5f)KG*^%)Zg0^3OH*o ztVXhb$_Simq;7(KR%wmlO~2LY6grhVs!RM{yXdQvnwqLz=a6|15&$Z3iF^AjCkOru zCxAJ5Dvx8weZb`&kOaSTIa-zai!OjjL!5jnpV>V=$<8)$`7pZ$-kTR3$^NPqjoJmRX@x)^ZP_n^b0{Zho|zbutShq_xPaoT(Wn9ozk8JZ7z^ zz@jwuoB88UZ!=)i@)^_Bm`gS75Z4)6XpKtP`IAcCCJxlIM|y`$tl zNKVmaR00{179)ixPoC`YAOB7b3dCLhvt3giBYR~!dwNni3>&2)&J;bmDl7-mf)!5! zhvA^2v~)qDy6f70B*zQK)+k75L=U7{-si_v(}EQtzspU|RV}8bsN~9#fG%wGoClD3 zUkYX_M`s|RyBOs;09WWI)`$-818pxWUj=<0;ZSgKo&vGaS1*r(g=`Y>| zoF-hV(HD2!de;OSEH7P)Sc#TDZaoenEjrXsCt(aeZ=&{USJ>o95`J6lQqH3bNr0%T)WyYhDf~=f zTH2L!E4+h*s03C%V#Yt$_@y;P?03m$A&e>0dz0Vi^~`GYm^`+no^Ye>)3Ltbtg zSkSh%zk}&-)HO%SAd*@^4peEYaR`7un8_Oh&(JW9RSd831p|4Wnwzl0AVgCG0(w2u4OHO$qmM(sFo0j~Paq0q9dd3zWh@^7R0jw@7uc}mDh|(B=EVV?Q}qnnW`GmW$S7WjnC12nbvQWFE{5= z{e8L9=jiUNBK!)RTfhGBAhy!%jr$Izf1>e+yTJdq^=kznC&5~-rojB3|3>dUAaJc` z=%PEB^S9a8n;qLH+;I|!Qf^ocrDS&Y+)`%t>AHOESQ1H`(x!qTaYb@lclEdX`0*fe zIv3f~b-)<6`?%{D+feoxQRQSY{#rCW7D*JQEt-F!BDDNJCT5xlIaf8CjKq8Y5S@p+ zASpwBurmL_mmZh(=QR4x>l(5`>C49%@)vR}>%hm8n=7eS*q`bHo+J~~Qz4_gq>gtt zaSq@x7}q=}>wyf~@&Dj7(w&d{hZ!ZaOnh|^73@@)4Lk>mzh?4~D@ zJW)`?h$FTwf1w-^kMb6cvLfgFMxjl08o|Y47MG!7MxiF|GU7w>v-}BJS!Z!hD{2S+ z+BNEe3W;0l0JiS6OHu;}dOEkgYot*PBR- zGR@t)tp%3a{pnG=b9REefb)8%B)T(8B@w69$2F5P>(s3FxUS&}!J3lnMRp3weK?IrJLVa+T8QcH<>7&I9QtjOc6`a#KZ(>6-aQ5@DBN8859}`i*4yIk1b9j z`4hxBG7b1^XKxR=mHTy}%)iV-0|S|m0pjH3l=jf=P6E3K$oF+&ry3ih#vo}DF~`@H zN5`Dtj5FD&Dhv$Vd8Mi}*F>0>hKQM@#|IQ#l32)>bkFpMPw6lXS*7(bdlWk6% zjcowR+=EUcfV-fNNEHKpDC}LK1Okzu%`-mL99Pbrm7W^Aa*fS&(oPjTEKm8!!)o*t zw6tVXhIc`GpY91U$xK%d&v6w{GWVY>hBCK#{Mm9r6A(Xz!D>V500InPfjm(6qc zmN>U{Q-ZmvI9LAAAzz@rA^!{K&Q9mj!Tqz)ynO9XkD=#|j#rQTeWIjMdhgzssIYa9 z&B|>0Z$Uqf>IzVW2nY6K*D&;`L4tmkr%)6V+~DHk2z)mwKeVqP=*H~}Mo`huDK)US=3a}J(0!iCg-h2T2val45Y`B<>AJ+F1 zJn|=l60Ui4YjNJeUEL(dF&kQ{NABaLDj_Ce{ zP-h4KssxB(GbI}S{PVj9oGK`QU!q#a8Y1lmorf`Px>PdM9HG*XI3Yp$>xW>VLKV*x zr`r$ME6X}===Qq~rB`0NYRHNK_SE#ZlbRrP{5+!p!>{1FsB83Tg5>yb zY>D_cYRQ5X$L5dnz2gM}eO~cko!Iv2~APY4|#IT`M@fDj?b;CEWmpkOOxfB57m z`t5bMvY%n}dtd4gpo_Lm53lzEHaFh=e1GrbL3E!C&mVb~F1&hi0RB>-RnXu#1E$E|> zkO31&K>s{ry}He1jONMr+w4G_0$|@3v2#?2k>t3U+wVz3j&QZNi%9B1y5SFIJOYD; z!3Byl&9NIk<;_|iL<$Z2j@aOlYhwRXf})(xU+dnn-97dOg`_)I0rI{3HzEy0G?yoJ zXIuy>`LR5y(f;mg5Jf~sYe3X2Bd|WOrl^KVNoXQ9=OeFr7pMxGtNO2uZqKd_lNs7= z{&!K~zvLne6j@2bWN51>*$OeSzZ-rZp-}`sxi+-f@&WsweHHyl7)Y<2HK%bLN_NCi zJF4)7KvtGf5>Hjev>B6m>%O5KU^@c>+JK+8(>i<8;Mci}Iui1|Z`>l(C zJ$FB9b5x5c4bm)->F$r8b{vMHIF8inpNlDuKn4|W$Xw)%E}Wm{HYNq0Es z!EUeu^Q}zjpjvx5)AKe=`>JMa@d8=_bfsX^%s20IaJOFz%eNVw-yg%Bby3XkM9B;Q zDrCc>2g&ST6Nw%uIfw>9K!d14=Uf$W+uj8L+-!PnoFX_7|>*-I&*allyPT^VoGfx%M zO0a^|k&|-|Y|n_zO)DGtL=19=&{3{EA?DM_PS-J=ZX~j|luNDk_T3 zq7QN6KZ^>~vg#EGeW!I`^}+&%HCz2;xL;Jjvepy3z>tu{>1ngFK}J8p#C?R|vlH?z zUu$dS3=A&0OIY{x^tcjBe1Zd`qnUlb2*=8)J2(_P{&Wt01w8D1kJ%#Qm!bg+)LU7` zM(om+BY!Wx5XABBl;SX~k4ar!n3a>4529l{^7jOBjpDXe^@E)Ib`x>?CVzfus%1Nf z3Y1*Ya<&im<2A66!|vT$xU{qKc!I8B5;jUb7DzcV$;9uvXDYQA^JP+m%T`L=K*Jt+E`IPDdaSU1D@v_ zYt*^jd+toq*LPT7*arOO?S8fO#0#lRw=-KD;GBD)1&IwHCk(tF2|eAwvu~>q=VXvmUvkvr(|%XR13OOa_Y`?^ntxz>y#A;pvHu>|AMJ{T^@l0-G`HKr0D-&dORKt)furFC2T=R) zW+dt4<|(8_5$?AIsL1c_W(} zsd2NU-=gsBN`-f^;D^g#Phmu3KJOw|vb~Ec>SO!HpEUsO> zmY+X7H6u;Tn$iH$xXk;M;w4#PlEo+nlZhJ($KR9E1v^q1m{|%PF$07d%qza{Ph&-^$@z@XphZ$q zIoO$~>FcL2)n2XZeaWh}T)3PBY?G$u z{l?apjAdl}LQB@|B7cGpt%B`swwAkn_~3m)s2Nw}``lc+FCnK)XG?umOOp$dluWx6 z@_ITKSeYz1C=Ph*<=j7J^;hZL%aq}K9I90j{DH;>GEH41&TMy2^=ZP;sRLe%&@cs6 zKGD_H4LTHr$)UA%T@yKX4u7>MOnv*HG^^C1EXHL@rw11;RNq8g^MYR+HGUbtRJ=ZT zp7NPgrm;0U`_?^qCuV#S&Be(%S$ox@nXoiu^6ld;fyKFj7JXL!^4qckr8ga1M{HKn z=m)RPaEoer+)G03lX~V#>)7v6F1R*L2En!Qr;GU7?^c!S<*;5h+&_2ZuB~k@)tNKH zgO*nyRgF&2X6qiipK;cI#1bv+{rN`N=#2Upeyt;FtUAV`L^B)WqbQl=p2TzB8tyhl z>IPwtDdAIUao$i&J;RSyTkE_|?Z4*$$s zTRSf23m-I%$N&1^D|V?h`9zq7du>C*!0`LL11Js-j-|QBP?SaMHW$beV;d6VG=H9x zd!K?z5z5xXRxy2Pz`~zgXkD8b?stsq3_S{6jSKAROK4K^U_?>zKKNi!iu0au%_>@E zN40t{4NZEVvFKnfBq~Ab+!3<`SQ-)*R){f!O~!D0QzrGgxIgsal@!36Gkg@-+$8HJpkS61eqW>Hg;$^(#SM% zWI^JzN~@-;b8v9lP25EeJi}!cFksR3PNxFt zqZ>5o{e7lh5HXPdDkv#kp*)!O;lr;3VK70`6MadkgN*dTSOW-nR~Q+u!6n}5*HM@c zm1*c!y}X7AwvhBaBjK^6Sd<5v<>pF2~T^Egu)LHcP z^`}DBJkm2Vq)@Gkrx|J<u8FwzKeVItOurE0sZ6txG%< zM1Uj*0TR2Xe}Du$jFm<&KgbcP}dEoG-|;kRr&N5#X{I_Yz@V z;8N2o#(t8%bJ4Yogp+U1{EcdHZXkv_W9mw_k@&B3|AgANYg1`zWBU$r{f1gHIL6&x zO?7PBR2pFLFU~<4DgvnPc0)-oLRa4gj*;^JJM4A>Ey}JLV2=&g{SCpV0D>2UKZQCc z|Em#q#`i?3jb(BzOp(rb3|P>B@=xb~L9E65 zVHVCxW&EeB2zOz(tPZpM!-XWtc6*}MN`F)kXS0P`;m)7$xSEO`bLB|BzcYHF?3!mz_x&CxI zzAL9UK^l8up?}**YkUPrLv~tPXw$gw-URewQ0)3$s40|(Kgk7(N(HLfCGRsDR^IDd zVn35Lg*F%j(Ve6qZx0lsNdh3ggMTsm%5Cs#`}ga=Y2?nDer8XC#Nv4F>WK4xw2e<{#in9q*;%+dJD$a1k%K`do-!U(Bn7KLs#RNbd3@ z4du-=3EzP03=CqtyqU34#Gy@(tZTXwBQ|#$4cvxb6)I?!iTX3U+;qT5pFD}R$K@|9 zwaXu_vzT4if!$GAkzd=g2y&mVZZ_?(d@cnyeY=%!d|NNjZ3q_j#jnk(y}``J(9*cV z*ttwi?s?yL(J(YL_=IWyhiim+Uv_CeZQ=gDvf<^mr5;1Al%ymXva?0XqfytiUNIJCpnS`vq7eF#*@9$+5sHlNMNcB%k;k zMAJE!N~dM?C-9aRrF9fr6pH!Kdb|r|c2xrxA)2mcc=;>Vyw&h(c`)58>TBwSgAO8# z0$u8^;dI6ogju_aLxr&HdZxSM2QUMwfGf3ib$uEI_E;rOZf?g}ZRrB+nSPu@`rEgo z6nnxsLx3bHJ*k?dPESvlx4+ZZNBPDO&4qO!`Mn-H1 zI#~~k!ZL4M?MH*4I(Yp@?D$3b61doz0tm?Rol=BO_0jWm<1ebr;*GLoNZR zEpaeRwXU`rgK{qFAvEX8;=g2qdxH39EUP!z|V7bbUQ$EN1yq^3t`)!d2 zZt>><$NrDmKs`=VPUZg!8A$G%o&VvRefh+}N#%0Fo%Emo6Bkg`^?Yu|r55SwF1q9$ zx_ZIGh!klGz&b?$MumltyrU$2ya2QQvWJ7ylc-^0VuC(V#5%EoXMR@Z_XWVaS>$l* zuy1UL4kh|R*j)Mi6g!P={YWLvZOh?496F;uN?_ZfmhdjGsGk^NI%e9Qax{V?dC6VO zpJ)9@wDp*Agy1lp3KNBf6;a=YkLG~Sp@tX%`S9@Y%uXSf`J;UB&%D61qol}Xv4vp? zo6D*sD(X#V3~b>FGfi~K6A7_T12dVrl984qTBGgXdTEqkwW>yw#bHi(9>KLy$znb< zYGX3fPgcA*D5#GQXz%s0AXUOW)VylWQYdzSFjg13ir#40E(n|eA^souv-RUIE})@Y z&dJ{J_0MvIU^Z`@cHR1N^m*eJaa$8oq>jL2^l>wb@WMIVR?zftQLJ`bphiT6E=9n< zKcd1*_S0-}Cz#;!xmFY1y?QPG9rjcj$la5sU}tlG?>KqHFl?SQP0&1c5nl^5LgH zgb6P6V;o;luV()H;&gwK9SHH#$er>1nrD>on|Of)@nXfR!n~T@)}@SOc0{gtwEUqE z+?79y2+|5?LCb1VPj?uD1QA_0ih;ZQ{Lj^(z~o0E47pO7`Bq1pT>~7>do+}NZJ%M= z2Ap3#a$f3F1=_{FO}ozRL@v-*SODsU*>66v>j?US;HR5){CE;)yU8lt>$L(Q-#@ze z8i)m7d_jSXW!b8qrYDZoWw}T!BqSvKo%tE>IKvo^d$XnZN9w0pRG5{(^p%}I^wfHP zpIG+2d_i^^zR$lUn!m(BefaJ}VLV!!eOev)JjEr^@85U0jE3l& zSA_x2oYHhEFe~_)IO65C+kWJp&t&eVN6V=D(iNEYC0`!=?*8M%hVX<9!l?a+#{Ypd z*qfM|cN!vFp<_ZQHsMdh_?&R#g$><7kHwDBX~AUZoN!nAq8a z<5E<}=*sVE?2?xiV8WCx_^kGLmfwY6Hx#Y((jFj1$v&g+c*5!o35t`|4ZTE z#(xnGs!H;NsL3G%n|J!~U%N~!EWqBTJARnX3&LrX92;eDA1+XT96KFdx8YG9(Y8+M zN|xPq6(j&K8X|F;ax7p2px9pc)TbhFS*XJG2K<=bAprtG7gmvnI%!J}uqY&P19)Y` zFKz%q@GoxgKXC%=$@8fYY1P(BM=V69tG>bvoTc$t`J`GfZ4RX(l4v^rzKhv~sMySP zIrMd#N~-~vzC$f|bz2;m#tvc2+YWK}r>R+NWG)Xl&N|?+3#)ks4G|7u<^#owKz?v6 zlBY0bp9U-n%hU>$l**O82U1&M2hgc01&tQiU}(Y_CMQIoX_=Xv2okHPq*hglg0U?e z0qKkk1$XyS5b=>P)~x@-Jrxy|-W**Y1kNGg=8cvCQ5lJU;Q*MBEkXeoIn8_C-;o<} zJd%)=saAO9!~sy1UnA#AO}z(t&Erqn)mJO7L%slX1C+B9_pxNy;UZHID;;s&*uT!@ zjyM&$&vm0_O+;gZeJWt;qZ9C(`Ra9Ho*C)cMc2o;d``{AhbK$7wI+IY(1%40&G131 z_~t(D?t5w3{q7bC@7MPQKK_IErv2U5(Ai7-#?N0AH5Pwq$+ActR+o6|$==jAZ$2Wt z2Y}}Q85yV5x0^_J;DohJOVEv-+L8vGffR!PJ724Pmby|=%2LuYJxkd^|rES z9t7&6LMdA-TL95TdY9AxpP>FfOZl4CMWuAovDRp4I}6VyJ8v|(SA{nckwcb32}>ie z`ZxT0#x+}RGN>*)^_;qT@uCR&y6du3NT}aTAVFRe0I5?c|K*`?atd3`N+{1e zh}tX}Np+}BmKM4P0)Ka+d9k`oL$vEH?zE{tLBonRcQ8wc8IjNqG$Mmg11B0QYfYbNEA)t-!3uv^+_(4N$xoZ$v@Ye zjvUx&Wi-Aa&Kw91#;tK#gks2v0D5Y4fszFKh23{9N@tSl1H|b7rVYD-aP*8E)qi9K zARulq40aQL|5-DEf?QL1Wu;d&9-{aC_dxPJO?XyAL4L=;#32S%yRqGYIt1)IwDfAj zNyP&b`Y=9mB{XWPnkPbDE0c;?0{%h{w|J_0*y<8VAaOYjEz-MoaLF{q;`Fzee;&>o zwO!7l=^b!_=EdNz29l9D6&g6zHQvvUqHXsYlImLDD8o(^J_JsSSM^px+hZXNm!vh^ zF5)`{-;i=$ol^T(v*?g!oh2_Q`S$y+BZnDl%WKq;&HQc!hzUQ3|M`vXT&w$Dm>CqG zk(Mo~ZgdA4e4Va(`&ZO7^+qFpcl@?PF4;sTsKyWn)w9H_^`FhQsiID}f}-M5l~@>_ zY=Ec-6jiX--OxOpWNcy4xs2-r=d~H2$){O-dR*MCxVvAGAhT&aNQV+O?m>d8h=_<2 z2luuSIqq446-hr|^`G%2thL}Ng18bQwAD>{DkC#9g>y&V0Pf%%>SpoD(y`v;G{x2W zPZLmVZYy(nt7FgT8OqDk@Ip!(ge7?cgQT9(b9;p!FKSbIAFN=*9JG6QHgO-(Rf2Yr zap~ufJmHP(%?ER~tnSSWi*&-Rbk9Sil7r0q3XDHK+&MK=UXHB#y}hbSvqM_2$Rf>! z4u-_WW`F)HjI4H0Tt3y69SeG@9uho4TG-jynTlQS*%of5!wy553q>h~u=2m}!vQGb zijR2-CKhP3rD(H>L(UTF*?#HjS&$XBD2x)4I;5TM~WeecrQd3b}k+0yA zNlBMfGL;|{M7|5Y4+>vsBTsA-1FH7xSGJ;Kl|S%n`jmSis)xiil;%M=^XtwZ^jyMV|FAY&A;y` z#t-^H^+iWNb#rPtk4jsT^fgR|W)s^%cS4P!%f?c9$~&t!M0_im8uMYF3YUqvJE~>t zx;^GxCucbCc4;^k9U$8$+O=AzMNAR`X<^y{gDm!?(-)+pK4|InM9vy*c+XxpY%CIR zQ7>(P6wOrG8pT)8AJ`5es{B|zFQ_8&$J`}OhJ-}gfi9M=Q`ck!$tb#(WQ81)4%LR$h7gz&Kt0S?Cv9ii(Z?qrIl+K1uuoLvky!`p&LSj8T4BT)L z_Qj)}LsN4VheJ=SjYr&$b$Q%0-c@m!wX;hrq|PMw3OBcyvbvyHg#&RDQr@ZOekHbP zTgI=JD_o{ln(S^V1{?z|L$qsB21-+>P|4U)<(QBG3v7!nmW+?Efln6Vy^|cnr#P-< zrG%Ls%89>MP0Ou;2rn{xlZFC!t!F#$3Grb(RF|7}lqKi1m`g=+_1)iWDD>8$6qOHZ z;+gH=>jCgDSUM%{iRBBv~Irx*1ZaXYR;Jt8$>QF|IbR>%&N_$?9r+o!-nT& znnm-H!}T1j?EM!0I+Wh`REla|JJg!RXJrkt9jtJ&?dhu$Im|3EAGl}%N&TutV?#>Q zE)QKwZ=I(VI>VshQy%7j$5u@4&aMS5czta{O1C2|za=L*#%=J!TGqb!Y)oJB)Kq3> zrx+e8Q-`B$S+2=MTZjG}a6A(~yvWGuHup)aN&wHFS3OFzlp5bevB-E0+VZnbE5@>9 zvDH-c@(GjND8WQ{L}IN88Md#q*^DSftA)BI$N0STIjgD%ZK+^8I4rcOVqCxe zrAPGK5w?Tyw7s}tQ9g8kEoZn8r;bh2cjQ30e-I^*tlafF^GkGF+GM6vn9Z~UEL&MyfG z@ypTEe?cRr2iwNzO#J53>;&IpjT>$DivK04l}A-C8K3)Y%qy2SuhF`laF>pitNU@M zY2xW|zmA8GiJ4WC@+K*nsq)`HJ{Dk^O_6mgm^ckhNbnP!VvBZ}Qc%e>7aD6zZeJXG zUufTNEfm^7!7G^=A8WdpYUnmcEQjqD4Qn6dfEk%b+w;&FrGytRme3ROSLU2b2ov$L z_A}0gx2vH>H?(Z!hbbgYk?GGAGvtYLz<|8==CH%s!Ey|DXldfG-=;Nt-89FKNioAN z*=iz4_OCciJ5%%Qq|$S&hpUVlrVa#MnK}cmjkjT@h1Iy_uMcv71}Ar{%#^C+CJ84d z#`iI4mAaT?`j}X?;Tqiol36nult!jv?y6N33bx?7qBJeTd~Ck z6|k=BeLYuNG;OZo^nH|Xyen)X3v(87!69(vj&g{sXkVV`n+mJ+^T=vRj=8pQiUB+5 zN=z9WGe8A$kh=@UR{7D;KNn0qAR-!>^T=yY3kT|WPcnG3BQ1(WM0~yT_*_rc96#s< zUIEn`4%bZ<#-^0&Lf2RArDLxt#%VVvOf?6>BYU2)_ivxiz0Qg#Qylr*WLPOw+6(q6 z`Q8!bnOB~<%R^mCce%LVM3JYTko`@|T`?zU=2rPHE5muZ0&GD`< zWBfMon~(h9i5DPrZ@3V`#q7TAJZ}FpDjoKz9G;RdK>(-=o*ENe@6v(4o&bnq@>hQD z<9twA7U9@2OZNpobliQ)xv6~4W|5b%{g%7Yp2YCyH~y|X(SQ1l zbfm*uO=o=-yAsfk&Tw0_CZN{xCxzt1_?BT;IrBVx9eurF{{3E@WLgp5sBwT_A(SsXl9jpU@S_0k`K7mE!pr;!$OgETBuBc z?9^bg<0i4V%I|dI&?Zp^^&AVk1$0=l6btKC1y9UfBuD)ST$C9dHC2F26Ii_3XKY}m z`Q4bQ=3DNjtue{3HOZeEm-p)khSA% z<>hI4;K2Dh6^@IHd%%mcSA{8bqV$O4*LRWW~7Up~)#{?zA`^ zAGpm2Y;4}!w}U7JQW(^#GEyEi`SiGY3FjCx)x~P#;BbLJhbkPj$*!$cw%u!dnB#LG zhnK&=4CTfG?X1}D{4T9E-=rAoB573jIgGgCWVev8TwY*gSUiD%%dBz}ET6^or}`4ElXw0HQduBKXhhkB;S}VhCS{DGR^x;&}w$9V|Z@U zK}GJVR!P?OLDy4vs$o@s-)Bx3&6eqi=r{?_?Rg#dUw)t2b&h-o9CyD+1{raEylVuu zg@ft#cC#}ht9jP4Sj2H)2{Gwi6g${8)}Z!dN}eB?gZM>kSTfLMu3=zCCbo*4?h?0} zHjSSfL~~^}IEf*EJFlp0*@3u2UUJCz4$j*AsN7g=)AdY9l-#@mPwI<0me`A7b*W&? zQ&3hXs=cD&kU-AWo)U;Y*2aLM;_+wv>Q?#5sYDKDc22Y~%dQ9#4TV~Cd}tD`W`DS^+=M|rx_NH(3!WX^3spk{NkOv2 zBJe}WrRxI$;a{ltbs?_On^)n;eKM3xKNGGyOdD-Xh>2kvF&^2jGsO6`G}R%~eK{L5 zjth&GL5C;fd}jTKX9SXNa}V0Rf=UF*Sh=q%h8sx@;aqBk9UvrmSuZih0F_B~ec_Am z=*`FWzErlzkUQ@#!P{5JJ%P=C+6$MjYbW%rb_XG;y;e+$R;)MpjcU* zMQA-f2p?zgQ9|-~-hzQqz`$y14wdYZb#E^7UTT^({Jek25G$+8*v$)T z*NyJRWt9^TXxKtw^y+$zV2L2aI^NIMX1k^iAI8FF*`O5ndj4=i(cn`#el3n8V`E0G zJJ+z2$(qmnN;k7dA(#XAN+cvhwM9!38FEKlmQ~*>cH)` zB6#484zt_H{sx{Uz1+IVASk_!k3G+KRfAxN1V%B_(@X<4SWEjg9F!6rtbxr+dKZt| zR|fh-XsE_=ue}U+S^55jFgiJ2^W7?C^C7W#p>aqoY!(6bTQV_BkFIOGuP=21@V;Dh*88Vc2EH5gjlBkVr@g(NItOLWZvw*`BYP*vGu7~r9L zUNL_rC*ko}$J7k2bar)?0qg2lHhj@&z8QH#Ay`PW!f)W+U@LUph>dhINV6*5WxENR znMm64x=S2_3MtPGECw!bkhl}?@zX&NKW?r9Oa(wy9G zUDtp&>X^z$yk%J!;sIQ4)4PfHw-N7w+XfNeHBoFjZVCwkP%C>bRzG`onnR{Fw`pDI z)oJowz^uUgDFUbb{Q1jt`jBE3PF$~7P;`t!hhY#Xm zHp-WR9X;Qzz&cn5e@AT?K|WNv%1f$)c$g*m+!gIDd>%Y$wb&)m-7QWKvFjTcKrE-s z70HY`xgF_eN*zLp#09|X(c?g=wl6#gT?b#*=;-LtP||mvMKYc@zEIIJ+jlZ_8QbMA z`I96u*`-vA$W^dGwwj@^FS>mLlwBhbHz0kF8X)E49xismix|KY5BmN$Jab0W70Ni&l4{~OD-@cl@KRt(+*YN8YA2Ss z2Ep~g=GJi^Sx|?e<98*ZYE-nevO$9G!3m(r7#hgSuVd}AYfURq1G2`C6j`Mj1>E0x(pG$9GWq$()!qo_lu-H&dnhxG_4 zgl4{=F&C*A_Tj{?=#(kxTAJ54tDT89?_{naoBUxN^Lh<>3g!e!Y%YqMOO)jvt(vy9 z>OOb(prgsa=qe zLv#$DguAzH-BPzte9s?7Ka+cJa}G7~NG0rSKbQJoTfx@-epq%P^?d3OEHVUMa*>M* zsx$JuJh`~=mrpn;}jq`yZ`{2*fsZ{gf#EdIS5!Wv)$JoX_HoVl_qbtV-Ti zo1Wc=9g6@G>}k?%MAEXy_aQCmF(KExxx9BYJh?|lO0t~NZqyl3*m7}NJ-(>(LsGP4LivK2`NKS!K?UFqT3 z_Z}NA0_$@X#oxWpv}bDx<<@;VOUI17io2%Qt`KSz$;%lJpD}6UJ~g^{_fE|SWhkY?hOLa4fh-Yn-rEdOR>F}uq1BNf0s)&%kK!J3%O$@t&DQYx z8q8qv*^&i6!PSxDAj_RRdD6Hxn33w*gWDK;D(^UCGmg42)-bwPv_hJYzp8C)^vi1N z?oNSv-L#Stbr9U*#r!*K7-BQl*WF;nnoLcOX>H&hVhX!De{Sx^Y>B*XwO>bjsysbP z%VRYW47R)tu|isqpdt*@Zl*wa0v75*A|YymnCM7D6t6twb!60cKUjfk8NQJ3$aQ55 z-I|7tWn1Nv?;16o@5a+;jdWG+eB>~DE8-=-O1`g3RD?B~+^#srNs4y3`CWgdAa!cS zrT)>n3CzNv)5|kP9WSquH>?@D;h&hBg(@lJFBYU4#v&~hR8`Xu(HGV<$X&Gr*bNlhLk7#w zI|aSicw=@VSY9TuZL%FrFS$5h$BHyNnn>mNCE&bBo>>q3o?SHKwy^EmjMB|3i4x*` z8Q8<}Q?;(cWT*<4Z$Tqb)-DrwBfUqf@OlKl={ZT|jke_0UFlVarhB<7kjL*^+q)U! zJ@0vSn(U#IfdL#BXWrO_8S73LkH3C7a$?J1K`7Dn;}uOzOpJDsH8xK56CygH0@)z1 zJ}3{GUSvSKL2iwKuO8UXG$7^&sf@5P?2kOVrt2#|(Fw+FNvYrqnk`$)GH8tIG4Px% z-z+2N9XUb+kw%MqmP#)8*VD7h5CpIruYfNZuzP%_stht)_DNF1P{KJgpb#wQ-@ZYqI zqy@u;D$}a_m>I&o^+ycb6N(E`=Ho{U^!4*ya5hCh%9ciwq$Bkxy+!?~drV3a5^lpz z9u~P}I+SUUgSED{PXDR#J)O=R(z(0$9()Gw!eFpApBM2?0oh%Su3H1Tl^~wNXZ_>$ zE1DZhTJB3Bkf?-QYTwf@7~15vNx*Uio~V-Pi8Y5kiLNd?T;J}JV|;!-^$WVYuBJus z;+?{LhM#*A|Jo9{-*ESyto7JDf6rT|l3$ly={1ON9kYGzk;rojgdN@wkqKV(kkC*Q zxIF}zQ~4GHe1|C6Zo1;~m6gdEQqwJE?QJ(baXTPu*? zX-*r=Nf48AN3iJ#+pR}hmfeFtkxO8Xgk*5OaoZ(a`6jW&y26NezkYwI9Y>E9B~*A{x>PTf#2>cU-NasOAHJUR#X-Zx zYW{40dyK%I?eY95wKkpWAIhqg1ru%0lUt=5%Y6gAgINsCVId*L)7?2ps>Y9hUL52l zQwaL6T)y1=;^fslaPLXS3YHF-Q@by=AUgqFhrJM9QC<@U<^)q$Z1sH-jkGUtbx2voCX|rFvI1myXY`02TXx8&~v@Q($ara^5dJssb#EJOSj@lTn z&W}>@SZNoE$T1!ZrZ{oJ&h-W^hY}MT8F>?kPH&mJt2$P*_^sCUiS>%OL(#pm%7c-R zRrA;&d{G~dPYaOoEKQZquf(pVhHD5)F6DJaE?O&!`y9-l{6aTC{)S;W2ScC!n2>Ts zeCtgOJ9fy^;iOtKGBQ4hw8kuNFa#_q>2vBdxD2FMjeW(cCog_(XKHJt9tb+Ei@VWu zg=AZE(QHLC`c$093=F3Fa4TNFZ$ByK+|pHj*wn&y|Ardu>eVS~Sm)|s$1<1Ohcl1y z?ufB-vXb{%J8jf4P^~xdx$uHk1;)QdylCsqMesHaH96U#y83zr@Y?iaT^}pCYjz#` zA>WwV*U@Q7pp1yIPvEU?FJ7@D(%Uy8CcpUiIThL7m_Qz z#pVJ|proVc=jVSMKuOy43H8}o8T4tJTN&J6D?A=EV|)hJal&80(@H*msQE?O%G8aL z%g2_hFdp$jIfbbj@A=;N5BK9^sM*`Mi_{(_h?`7-cLS zHfOV}q@?6Sj2SCRKRCu!DkbCSuWGaYHYGa6-sfjN7WN_jZb`wUA+fl4aNBy6MvGzc_D zI#=#%hp2h(LMdlbXMi}|Fr1N$40b+>;_TdCTkKFUUJSNVfDOE7=iaaB)zsC=ZxT18 zd$_l_+r#!w-DI+{@$tIRsa4ZL*h2*=e;HmQ-QIw+E=+X!_nDx;<>lHZ)N}3J9)*%rwekD_N9XYqB|rCQi91PZn`QRohnS47() zb)bi#ZIjoF9&BBI$0XC}wr66;1WGBFjP&^qj(=4XZlC%eiP*f&!}knAg2;~?)Ig3w N-I2MSC9M7U{{!(S@jn0n literal 0 HcmV?d00001 diff --git a/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg new file mode 100644 index 000000000..87ede75ed --- /dev/null +++ b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg @@ -0,0 +1 @@ +1e+32e+33e+34e+35e+36e+37e+38e+39e+31e+40100k200k300k400k500k600k700k800k900k1M1.1M \ No newline at end of file diff --git a/test/base/test_env.py b/test/base/test_env.py index 7f47501c3..dbd651d14 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -134,7 +134,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): SubprocVectorEnv(env_fns), ShmemVectorEnv(env_fns), ] - if has_ray(): + if has_ray() and sys.platform == "linux": venv += [RayVectorEnv(env_fns)] for v in venv: v.seed(0) diff --git a/test/offline/__init__.py b/test/offline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py new file mode 100644 index 000000000..4c0275e69 --- /dev/null +++ b/test/offline/gather_pendulum_data.py @@ -0,0 +1,170 @@ +import argparse +import os +import pickle + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=200000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--epoch', type=int, default=7) + parser.add_argument('--step-per-epoch', type=int, default=8000) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.125) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + + parser.add_argument("--gamma", default=0.99) + parser.add_argument("--tau", default=0.005) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + # sac: + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', type=int, default=1) + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument('--rew-norm', action="store_true", default=False) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument( + "--save-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl" + ) + args = parser.parse_known_args()[0] + return args + + +def gather_data(): + """Return expert buffer data.""" + args = get_args() + env = gym.make(args.task) + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -250 + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # you can also use tianshou.env.SubprocVectorEnv + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = SACPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + reward_normalization=args.rew_norm, + estimation_step=args.n_step, + action_space=env.action_space, + ) + # collector + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # train_collector.collect(n_step=args.buffer_size) + # log + log_path = os.path.join(args.logdir, args.task, 'sac') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + # trainer + offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + save_fn=save_fn, + stop_fn=stop_fn, + logger=logger, + ) + train_collector.reset() + result = train_collector.collect(n_step=args.buffer_size) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + pickle.dump(buffer, open(args.save_buffer_name, "wb")) + return buffer diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py new file mode 100644 index 000000000..ab98e497a --- /dev/null +++ b/test/offline/test_bcq.py @@ -0,0 +1,221 @@ +import argparse +import datetime +import os +import pickle +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector +from tianshou.env import SubprocVectorEnv +from tianshou.policy import BCQPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.continuous import VAE, Critic, Perturbation + +if __name__ == "__main__": + from gather_pendulum_data import gather_data +else: # pytest + from test.offline.gather_pendulum_data import gather_data + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[200, 150]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--epoch', type=int, default=7) + parser.add_argument('--step-per-epoch', type=int, default=2000) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + + parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[375, 375]) + # default to 2 * action_dim + parser.add_argument('--latent_dim', type=int, default=None) + parser.add_argument("--gamma", default=0.99) + parser.add_argument("--tau", default=0.005) + # Weighting for Clipped Double Q-learning in BCQ + parser.add_argument("--lmbda", default=0.75) + # Max perturbation hyper-parameter for BCQ + parser.add_argument("--phi", default=0.05) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only', + ) + parser.add_argument( + "--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl" + ) + args = parser.parse_known_args()[0] + return args + + +def test_bcq(args=get_args()): + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -800 # too low? + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + + # model + # perturbation network + net_a = MLP( + input_dim=args.state_dim + args.action_dim, + output_dim=args.action_dim, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = Perturbation( + net_a, max_action=args.max_action, device=args.device, phi=args.phi + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + # vae + # output_dim = 0, so the last Module in the encoder is ReLU + vae_encoder = MLP( + input_dim=args.state_dim + args.action_dim, + hidden_sizes=args.vae_hidden_sizes, + device=args.device, + ) + if not args.latent_dim: + args.latent_dim = args.action_dim * 2 + vae_decoder = MLP( + input_dim=args.state_dim + args.latent_dim, + output_dim=args.action_dim, + hidden_sizes=args.vae_hidden_sizes, + device=args.device, + ) + vae = VAE( + vae_encoder, + vae_decoder, + hidden_dim=args.vae_hidden_sizes[-1], + latent_dim=args.latent_dim, + max_action=args.max_action, + device=args.device, + ).to(args.device) + vae_optim = torch.optim.Adam(vae.parameters()) + + policy = BCQPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + vae, + vae_optim, + device=args.device, + gamma=args.gamma, + tau=args.tau, + lmbda=args.lmbda, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + # buffer has been gathered + # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' + log_path = os.path.join(args.logdir, args.task, 'bcq', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def watch(): + policy.load_state_dict( + torch.load( + os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu') + ) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + # trainer + result = offline_trainer( + policy, + buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + stop_fn=stop_fn, + logger=logger, + ) + assert stop_fn(result['best_reward']) + + # Let's watch its performance! + if __name__ == '__main__': + pprint.pprint(result) + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_bcq() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 6a842356f..174762e25 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -19,6 +19,7 @@ from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.imitation.base import ImitationPolicy +from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy @@ -44,6 +45,7 @@ "SACPolicy", "DiscreteSACPolicy", "ImitationPolicy", + "BCQPolicy", "DiscreteBCQPolicy", "DiscreteCQLPolicy", "DiscreteCRRPolicy", diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py new file mode 100644 index 000000000..2aeeb323d --- /dev/null +++ b/tianshou/policy/imitation/bcq.py @@ -0,0 +1,213 @@ +import copy +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, to_torch +from tianshou.policy import BasePolicy +from tianshou.utils.net.continuous import VAE + + +class BCQPolicy(BasePolicy): + """Implementation of BCQ algorithm. arXiv:1812.02900. + + :param Perturbation actor: the actor perturbation. (s, a -> perturbed a) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic1_optim: the optimizer for the first + critic network. + :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic2_optim: the optimizer for the second + critic network. + :param VAE vae: the VAE network, generating actions similar + to those in batch. (s, a -> generated a) + :param torch.optim.Optimizer vae_optim: the optimizer for the VAE network. + :param Union[str, torch.device] device: which device to create this model on. + Default to "cpu". + :param float gamma: discount factor, in [0, 1]. Default to 0.99. + :param float tau: param for soft update of the target network. + Default to 0.005. + :param float lmbda: param for Clipped Double Q-learning. Default to 0.75. + :param int forward_sampled_times: the number of sampled actions in forward + function. The policy samples many actions and takes the action with the + max value. Default to 100. + :param int num_sampled_action: the number of sampled actions in calculating + target Q. The algorithm samples several actions using VAE, and perturbs + each action to get the target Q. Default to 10. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + vae: VAE, + vae_optim: torch.optim.Optimizer, + device: Union[str, torch.device] = "cpu", + gamma: float = 0.99, + tau: float = 0.005, + lmbda: float = 0.75, + forward_sampled_times: int = 100, + num_sampled_action: int = 10, + **kwargs: Any + ) -> None: + # actor is Perturbation! + super().__init__(**kwargs) + self.actor = actor + self.actor_target = copy.deepcopy(self.actor) + self.actor_optim = actor_optim + + self.critic1 = critic1 + self.critic1_target = copy.deepcopy(self.critic1) + self.critic1_optim = critic1_optim + + self.critic2 = critic2 + self.critic2_target = copy.deepcopy(self.critic2) + self.critic2_optim = critic2_optim + + self.vae = vae + self.vae_optim = vae_optim + + self.gamma = gamma + self.tau = tau + self.lmbda = lmbda + self.device = device + self.forward_sampled_times = forward_sampled_times + self.num_sampled_action = num_sampled_action + + def train(self, mode: bool = True) -> "BCQPolicy": + """Set the module in training mode, except for the target network.""" + self.training = mode + self.actor.train(mode) + self.critic1.train(mode) + self.critic2.train(mode) + return self + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: + """Compute action over the given batch data.""" + # There is "obs" in the Batch + # obs_group: several groups. Each group has a state. + obs_group: torch.Tensor = to_torch( # type: ignore + batch.obs, device=self.device + ) + act = [] + for obs in obs_group: + # now obs is (state_dim) + obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1) + # now obs is (forward_sampled_times, state_dim) + + # decode(obs) generates action and actor perturbs it + action = self.actor(obs, self.vae.decode(obs)) + # now action is (forward_sampled_times, action_dim) + q1 = self.critic1(obs, action) + # q1 is (forward_sampled_times, 1) + ind = q1.argmax(0) + act.append(action[ind].cpu().data.numpy().flatten()) + act = np.array(act) + return Batch(act=act) + + def sync_weight(self) -> None: + """Soft-update the weight for the target network.""" + for net, net_target in [ + [self.critic1, self.critic1_target], [self.critic2, self.critic2_target], + [self.actor, self.actor_target] + ]: + for param, target_param in zip(net.parameters(), net_target.parameters()): + target_param.data.copy_( + self.tau * param.data + (1 - self.tau) * target_param.data + ) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + # batch: obs, act, rew, done, obs_next. (numpy array) + # (batch_size, state_dim) + batch: Batch = to_torch( # type: ignore + batch, dtype=torch.float, device=self.device + ) + obs, act = batch.obs, batch.act + batch_size = obs.shape[0] + + # mean, std: (state.shape[0], latent_dim) + recon, mean, std = self.vae(obs, act) + recon_loss = F.mse_loss(act, recon) + # (....) is D_KL( N(mu, sigma) || N(0,1) ) + KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() + vae_loss = recon_loss + KL_loss / 2 + + self.vae_optim.zero_grad() + vae_loss.backward() + self.vae_optim.step() + + # critic training: + with torch.no_grad(): + # repeat num_sampled_action times + obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0) + # now obs_next: (num_sampled_action * batch_size, state_dim) + + # perturbed action generated by VAE + act_next = self.vae.decode(obs_next) + # now obs_next: (num_sampled_action * batch_size, action_dim) + target_Q1 = self.critic1_target(obs_next, act_next) + target_Q2 = self.critic2_target(obs_next, act_next) + + # Clipped Double Q-learning + target_Q = \ + self.lmbda * torch.min(target_Q1, target_Q2) + \ + (1 - self.lmbda) * torch.max(target_Q1, target_Q2) + # now target_Q: (num_sampled_action * batch_size, 1) + + # the max value of Q + target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1) + # now target_Q: (batch_size, 1) + + target_Q = \ + batch.rew.reshape(-1, 1) + \ + (1 - batch.done).reshape(-1, 1) * self.gamma * target_Q + + current_Q1 = self.critic1(obs, act) + current_Q2 = self.critic2(obs, act) + + critic1_loss = F.mse_loss(current_Q1, target_Q) + critic2_loss = F.mse_loss(current_Q2, target_Q) + + self.critic1_optim.zero_grad() + self.critic2_optim.zero_grad() + critic1_loss.backward() + critic2_loss.backward() + self.critic1_optim.step() + self.critic2_optim.step() + + sampled_act = self.vae.decode(obs) + perturbed_act = self.actor(obs, sampled_act) + + # max + actor_loss = -self.critic1(obs, perturbed_act).mean() + + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + # update target network + self.sync_weight() + + result = { + "loss/actor": actor_loss.item(), + "loss/critic1": critic1_loss.item(), + "loss/critic2": critic2_loss.item(), + "loss/vae": vae_loss.item(), + } + return result diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 1bb090cdf..761540502 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -325,3 +325,122 @@ def forward( s = torch.cat([s, a], dim=1) s = self.fc2(s) return s + + +class Perturbation(nn.Module): + """Implementation of perturbation network in BCQ algorithm. Given a state and \ + action, it can generate perturbed action. + + :param torch.nn.Module preprocess_net: a self-defined preprocess_net which output a + flattened hidden state. + :param float max_action: the maximum value of each dimension of action. + :param Union[str, int, torch.device] device: which device to create this model on. + Default to cpu. + :param float phi: max perturbation parameter for BCQ. Default to 0.05. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + .. seealso:: + + You can refer to `examples/offline/offline_bcq.py` to see how to use it. + """ + + def __init__( + self, + preprocess_net: nn.Module, + max_action: float, + device: Union[str, int, torch.device] = "cpu", + phi: float = 0.05 + ): + # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim + super(Perturbation, self).__init__() + self.preprocess_net = preprocess_net + self.device = device + self.max_action = max_action + self.phi = phi + + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + # preprocess_net + logits = self.preprocess_net(torch.cat([state, action], -1))[0] + a = self.phi * self.max_action * torch.tanh(logits) + # clip to [-max_action, max_action] + return (a + action).clamp(-self.max_action, self.max_action) + + +class VAE(nn.Module): + """Implementation of VAE. It models the distribution of action. Given a \ + state, it can generate actions similar to those in batch. It is used \ + in BCQ algorithm. + + :param torch.nn.Module encoder: the encoder in VAE. Its input_dim must be + state_dim + action_dim, and output_dim must be hidden_dim. + :param torch.nn.Module decoder: the decoder in VAE. Its input_dim must be + state_dim + latent_dim, and output_dim must be action_dim. + :param int hidden_dim: the size of the last linear-layer in encoder. + :param int latent_dim: the size of latent layer. + :param float max_action: the maximum value of each dimension of action. + :param Union[str, torch.device] device: which device to create this model on. + Default to "cpu". + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + .. seealso:: + + You can refer to `examples/offline/offline_bcq.py` to see how to use it. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + hidden_dim: int, + latent_dim: int, + max_action: float, + device: Union[str, torch.device] = "cpu" + ): + super(VAE, self).__init__() + self.encoder = encoder + + self.mean = nn.Linear(hidden_dim, latent_dim) + self.log_std = nn.Linear(hidden_dim, latent_dim) + + self.decoder = decoder + + self.max_action = max_action + self.latent_dim = latent_dim + self.device = device + + def forward( + self, state: torch.Tensor, action: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [state, action] -> z , [state, z] -> action + z = self.encoder(torch.cat([state, action], -1)) + # shape of z: (state.shape[:-1], hidden_dim) + + mean = self.mean(z) + # Clamped for numerical stability + log_std = self.log_std(z).clamp(-4, 15) + std = torch.exp(log_std) + # shape of mean, std: (state.shape[:-1], latent_dim) + + z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim) + + u = self.decode(state, z) # (state.shape[:-1], action_dim) + return u, mean, std + + def decode( + self, + state: torch.Tensor, + z: Union[torch.Tensor, None] = None + ) -> torch.Tensor: + # decode(state) -> action + if z is None: + # state.shape[0] may be batch_size + # latent vector clipped to [-0.5, 0.5] + z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \ + .to(self.device).clamp(-0.5, 0.5) + + # decode z with state! + return self.max_action * torch.tanh(self.decoder(torch.cat([state, z], -1)))