diff --git a/configs/pomdp/cartpole/f/mlp.yml b/configs/pomdp/cartpole/f/mlp.yml new file mode 100644 index 0000000..183c051 --- /dev/null +++ b/configs/pomdp/cartpole/f/mlp.yml @@ -0,0 +1,38 @@ +seed: 73 +cuda: -1 # use_gpu +env: + env_type: pomdp + env_name: CartPole-F-v0 + + num_eval_tasks: 20 # num of eval episodes + +train: + # 500*200 = 100k steps + num_iters: 500 # number meta-training iterates + num_init_rollouts_pool: 5 # before training + num_rollouts_per_iter: 1 + buffer_size: 1e6 + batch_size: 256 + +eval: + eval_stochastic: false # also eval stochastic policy + log_interval: 1 # num of iters + save_interval: -1 + log_tensorboard: true + +policy: + arch: mlp + algo: sacd # only support sac-discrete + + dqn_layers: [128, 128] + policy_layers: [128, 128] + lr: 0.0003 + gamma: 0.99 + tau: 0.005 + + # sac alpha + entropy_alpha: null + automatic_entropy_tuning: true + target_entropy: 0.7 # the ratio: target_entropy = ratio * log(|A|) + alpha_lr: 0.0003 + diff --git a/configs/pomdp/cartpole/v/rnn.yml b/configs/pomdp/cartpole/v/rnn.yml new file mode 100644 index 0000000..1d014e1 --- /dev/null +++ b/configs/pomdp/cartpole/v/rnn.yml @@ -0,0 +1,49 @@ +seed: 73 +cuda: 0 # use_gpu +env: + env_type: pomdp + env_name: CartPole-V-v0 + + num_eval_tasks: 20 # num of eval episodes + +train: + # 50*200 = 10k steps + num_iters: 50 # number meta-training iterates + num_init_rollouts_pool: 5 # before training + num_rollouts_per_iter: 1 + + num_updates_per_iter: 1.0 + # buffer params + buffer_size: 1e6 + batch_size: 32 # to tune based on sampled_seq_len + sampled_seq_len: -1 # -1 is all, or positive integer + sample_weight_baseline: 0.0 + +eval: + eval_stochastic: false # also eval stochastic policy + log_interval: 1 # num of iters + save_interval: -1 + log_tensorboard: true + +policy: + separate: True + arch: lstm # [lstm, gru] + algo: sacd # only support sac-discrete + + action_embedding_size: 8 # no action input + state_embedding_size: 32 + reward_embedding_size: 8 + rnn_hidden_size: 128 + + dqn_layers: [128, 128] + policy_layers: [128, 128] + lr: 0.0003 + gamma: 0.9 + tau: 0.005 + + # sacd alpha + entropy_alpha: null + automatic_entropy_tuning: true + target_entropy: 0.7 # the ratio: target_entropy = ratio * log(|A|) + alpha_lr: 0.0003 + diff --git a/configs/pomdp/lunalander/f/mlp.yml b/configs/pomdp/lunalander/f/mlp.yml new file mode 100644 index 0000000..923cc04 --- /dev/null +++ b/configs/pomdp/lunalander/f/mlp.yml @@ -0,0 +1,38 @@ +seed: 73 +cuda: -1 # use_gpu +env: + env_type: pomdp + env_name: LunarLander-F-v0 + + num_eval_tasks: 20 # num of eval episodes + +train: + # 500*1000 = 500k steps + num_iters: 500 # number meta-training iterates + num_init_rollouts_pool: 5 # before training + num_rollouts_per_iter: 1 + buffer_size: 1e6 + batch_size: 256 + +eval: + eval_stochastic: false # also eval stochastic policy + log_interval: 1 # num of iters + save_interval: -1 + log_tensorboard: true + +policy: + arch: mlp + algo: sacd # only support sac-discrete + + dqn_layers: [128, 128] + policy_layers: [128, 128] + lr: 0.0003 + gamma: 0.99 + tau: 0.005 + + # sac alpha + entropy_alpha: null + automatic_entropy_tuning: true + target_entropy: 0.7 # the ratio: target_entropy = ratio * log(|A|) + alpha_lr: 0.0003 + diff --git a/configs/pomdp/lunalander/v/rnn.yml b/configs/pomdp/lunalander/v/rnn.yml new file mode 100644 index 0000000..77c3433 --- /dev/null +++ b/configs/pomdp/lunalander/v/rnn.yml @@ -0,0 +1,49 @@ +seed: 73 +cuda: 0 # use_gpu +env: + env_type: pomdp + env_name: LunarLander-V-v0 + + num_eval_tasks: 20 # num of eval episodes + +train: + # 200*1000 = 200k steps + num_iters: 200 # number meta-training iterates + num_init_rollouts_pool: 5 # before training + num_rollouts_per_iter: 1 + + num_updates_per_iter: 0.2 + # buffer params + buffer_size: 1e6 + batch_size: 32 # to tune based on sampled_seq_len + sampled_seq_len: -1 # -1 is all, or positive integer + sample_weight_baseline: 0.0 + +eval: + eval_stochastic: false # also eval stochastic policy + log_interval: 4 # num of iters + save_interval: -1 + log_tensorboard: true + +policy: + separate: True + arch: lstm # [lstm, gru] + algo: sacd # only support sac-discrete + + action_embedding_size: 8 # no action input + state_embedding_size: 32 + reward_embedding_size: 8 + rnn_hidden_size: 128 + + dqn_layers: [128, 128] + policy_layers: [128, 128] + lr: 0.0003 + gamma: 0.99 + tau: 0.005 + + # sacd alpha + entropy_alpha: null + automatic_entropy_tuning: true + target_entropy: 0.7 # the ratio: target_entropy = ratio * log(|A|) + alpha_lr: 0.0003 + diff --git a/configs/pomdp/pendulum/v/varibad.yml b/configs/pomdp/pendulum/v/varibad.yml deleted file mode 100644 index 116d0e5..0000000 --- a/configs/pomdp/pendulum/v/varibad.yml +++ /dev/null @@ -1,72 +0,0 @@ -seed: 73 -cuda: 0 # use_gpu -env: - env_type: pomdp - env_name: Pendulum-V-v0 - - num_eval_tasks: 20 # num of eval episodes - -train: - # 250*200 = 50k steps - num_iters: 250 # number meta-training iterates - num_init_rollouts_pool: 5 # before training - num_rollouts_per_iter: 1 - - rl_updates_per_iter: 200 - vae_updates_per_iter: 200 - policy_batch_size: 6400 # 32*200 - vae_batch_num_rollouts: 32 - - log_interval: 5 # num of iters - save_interval: 100 # -1 - log_tensorboard: true - -policy: - buffer_size: 1e6 - # bamdp related - sample_embeddings: false # (otherwise: pass mean) obs_dim + 2*task_dim - switch_to_belief_reward: null # when to switch from R to R+; None is to not switch - - policy: sac - dqn_layers: [128, 128] - policy_layers: [128, 128] - lr: 0.0003 - # sac alpha - entropy_alpha: 0.01 # tend to be det policy... - automatic_entropy_tuning: true - alpha_lr: 0.0003 - - gamma: 0.9 - tau: 0.005 - -vae: - buffer_size: 1e6 - task_embedding_size: 5 # dim of latent space - - optim: - vae_lr: 0.001 - rew_loss_coeff: 1.0 - state_loss_coeff: 1.0 # (vs reward loss) - kl_weight: 0.1 - kl_to_gauss_prior: false - train_by_batch: true # false is by split - - # encoder - encoder: - aggregator_hidden_size: 128 - layers_before_aggregator: [] - layers_after_aggregator: [] - action_embedding_size: 0 # no action input - state_embedding_size: 32 - reward_embedding_size: 8 - - decoder: - disable_stochasticity_in_latent: false - # decoder: reward function r(s,m) - decode_reward: true - reward_decoder_layers: [64, 32] - rew_pred_type: deterministic # gaussian, deterministic - input_prev_state: false - input_action: false - # decoder: state transition p(s'|s,a,m) - decode_state: false # not used diff --git a/docs/acknowledge.md b/docs/acknowledge.md index 33b8156..00bb07e 100644 --- a/docs/acknowledge.md +++ b/docs/acknowledge.md @@ -8,6 +8,7 @@ We acknowledge the following repositories that greatly shaped our implementation - https://github.com/oist-cnru/Variational-Recurrent-Models for providing the pomdp VRM algorithm and environments - https://github.com/quantumiracle/Popular-RL-Algorithms for inspiring the recurrent policies design - https://github.com/lmzintgraf/varibad for inspiring the recurrent policies design and providing learning curve data +- https://github.com/ku2482/sac-discrete.pytorch for providing the SAC-discrete code Please cite their work if you also find their code useful to your project: ``` @@ -58,4 +59,10 @@ Please cite their work if you also find their code useful to your project: note={\url{http://www.deepreinforcementlearningbook.org}}, year={2020} } +@article{christodoulou2019soft, + title={Soft actor-critic for discrete action settings}, + author={Christodoulou, Petros}, + journal={arXiv preprint arXiv:1910.07207}, + year={2019} +} ``` diff --git a/docs/run_commands.md b/docs/run_commands.md index b7fab9c..d881d44 100644 --- a/docs/run_commands.md +++ b/docs/run_commands.md @@ -4,7 +4,7 @@ Before start running any experiments, we suggest to have a good plan of *environment series* based on difficulty level. As it is hard to analyze and varies from algorithm to algorithm, we provide some rough estimates: -1. Extremely Simple as a Sanity Check: Pendulum-V (also shown in our minimal example jupyter notebook) +1. Extremely Simple as a Sanity Check: Pendulum-V (also shown in our minimal example jupyter notebook) and CartPole-V (for discrete action space) 2. Simple, Fast, yet Non-trivial: Wind (require precise inference and control), Semi-Circle (sparse reward). Both are continuous gridworlds, thus very fast. 3. Medium: Cheetah-Vel (1-dim stationary hidden state), `*`-Robust (2-dim stationary hidden state), `*`-P (could be roughly inferred by 2nd order MDP) 4. Hard: `*`-Dir (relatively complicated dynamics), `*`-V (long-term inference), `*`-Generalize (extrapolation) @@ -44,6 +44,12 @@ python PPO/main.py --config configs/pomdp/ant_blt/p/ppo_rnn.yml \ python VRM/run_experiment.py configs/pomdp/ant_blt/p/vrm.yml ``` +Mar 2022: we support recurrent SAC-discrete for POMDPs with **discrete action space**. Take CartPole-V as example: +``` +python policies/main.py --cfg configs/pomdp/cartpole/v/rnn.yml --target_entropy 0.7 +``` +See [this PR for detailed instructions](https://github.com/twni2016/pomdp-baselines/pull/1). + ### Meta RL {Semi-Circle, Wind, Cheetah-Vel} in the paper, corresponding to `configs/meta/`. Among them, Cheetah-Vel requires MuJoCo, and Semi-Circle can serve as a sanity check. Wind looks simple but not very easy to solve. diff --git a/envs/pomdp/__init__.py b/envs/pomdp/__init__.py index 4a0c76c..42c6697 100644 --- a/envs/pomdp/__init__.py +++ b/envs/pomdp/__init__.py @@ -29,6 +29,52 @@ max_episode_steps=200, ) +register( + "CartPole-F-v0", + entry_point="envs.pomdp.wrappers:POMDPWrapper", + kwargs=dict( + env=gym.make("CartPole-v0"), partially_obs_dims=[0, 1, 2, 3] + ), # angle & velocity + max_episode_steps=200, # reward threshold for solving the task: 195 +) + +register( + "CartPole-P-v0", + entry_point="envs.pomdp.wrappers:POMDPWrapper", + kwargs=dict(env=gym.make("CartPole-v0"), partially_obs_dims=[0, 2]), + max_episode_steps=200, +) + +register( + "CartPole-V-v0", + entry_point="envs.pomdp.wrappers:POMDPWrapper", + kwargs=dict(env=gym.make("CartPole-v0"), partially_obs_dims=[1, 3]), + max_episode_steps=200, +) + +register( + "LunarLander-F-v0", + entry_point="envs.pomdp.wrappers:POMDPWrapper", + kwargs=dict( + env=gym.make("LunarLander-v2"), partially_obs_dims=list(range(8)) + ), # angle & velocity + max_episode_steps=1000, # reward threshold for solving the task: 200 +) + +register( + "LunarLander-P-v0", + entry_point="envs.pomdp.wrappers:POMDPWrapper", + kwargs=dict(env=gym.make("LunarLander-v2"), partially_obs_dims=[0, 1, 4, 6, 7]), + max_episode_steps=1000, +) + +register( + "LunarLander-V-v0", + entry_point="envs.pomdp.wrappers:POMDPWrapper", + kwargs=dict(env=gym.make("LunarLander-v2"), partially_obs_dims=[2, 3, 5, 6, 7]), + max_episode_steps=1000, +) + ### Below are pybullect (roboschool) environments, using BLT for Bullet import pybullet_envs diff --git a/envs/pomdp/wrappers.py b/envs/pomdp/wrappers.py index 85429f1..2ad0c22 100644 --- a/envs/pomdp/wrappers.py +++ b/envs/pomdp/wrappers.py @@ -16,9 +16,13 @@ def __init__(self, env, partially_obs_dims: list): dtype=np.float32, ) - # if continuous actions, make sure in [-1, 1] - # NOTE: policy won't use action_space.low/high, just set [-1,1] - # this is a bad practice... + if self.env.action_space.__class__.__name__ == "Box": + self.act_continuous = True + # if continuous actions, make sure in [-1, 1] + # NOTE: policy won't use action_space.low/high, just set [-1,1] + # this is a bad practice... + else: + self.act_continuous = False def get_obs(self, state): return state[self.partially_obs_dims].copy() @@ -28,12 +32,13 @@ def reset(self): return self.get_obs(state) def step(self, action): - # recover the action - action = np.clip(action, -1, 1) # first clip into [-1, 1] - lb = self.env.action_space.low - ub = self.env.action_space.high - action = lb + (action + 1.0) * 0.5 * (ub - lb) - action = np.clip(action, lb, ub) + if self.act_continuous: + # recover the action + action = np.clip(action, -1, 1) # first clip into [-1, 1] + lb = self.env.action_space.low + ub = self.env.action_space.high + action = lb + (action + 1.0) * 0.5 * (ub - lb) + action = np.clip(action, lb, ub) state, reward, done, info = self.env.step(action) diff --git a/policies/learner.py b/policies/learner.py index fce1bd9..9668d49 100644 --- a/policies/learner.py +++ b/policies/learner.py @@ -5,6 +5,7 @@ import math import numpy as np import torch +from torch.nn import functional as F import random import gym @@ -191,8 +192,14 @@ def check_env_class(env_name): raise ValueError # get action / observation dimensions - assert self.train_env.action_space.__class__.__name__ == "Box" - self.act_dim = self.train_env.action_space.shape[0] + if self.train_env.action_space.__class__.__name__ == "Box": + # continuous action space + self.act_dim = self.train_env.action_space.shape[0] + self.act_continuous = True + else: + assert self.train_env.action_space.__class__.__name__ == "Discrete" + self.act_dim = self.train_env.action_space.n + self.act_continuous = False self.obs_dim = self.train_env.observation_space.shape[0] # include 1-dim done logger.log("obs_dim", self.obs_dim, "act_dim", self.act_dim) @@ -242,7 +249,7 @@ def init_train( self.policy_storage = SimpleReplayBuffer( max_replay_buffer_size=int(buffer_size), observation_dim=self.obs_dim, - action_dim=self.act_dim, + action_dim=self.act_dim if self.act_continuous else 1, # save memory max_trajectory_len=self.max_trajectory_len, add_timeout=False, # no timeout storage ) @@ -254,7 +261,7 @@ def init_train( self.policy_storage = SeqReplayBuffer( max_replay_buffer_size=int(buffer_size), observation_dim=self.obs_dim, - action_dim=self.act_dim, + action_dim=self.act_dim if self.act_continuous else 1, # save memory sampled_seq_len=sampled_seq_len, sample_weight_baseline=sample_weight_baseline, ) @@ -426,7 +433,11 @@ def collect_rollouts(self, num_rollouts, random_actions=False): if random_actions: action = ptu.FloatTensor( [self.train_env.action_space.sample()] - ) # (1, A) + ) # (1, A) for continuous action, (1) for discrete action + if not self.act_continuous: + action = F.one_hot( + action.long(), num_classes=self.act_dim + ).float() # (1, A) else: # policy takes hidden state as input for rnn, while takes obs for mlp if self.policy_arch == "mlp": action, _, _, _ = self.agent.act(obs, deterministic=False) @@ -469,7 +480,13 @@ def collect_rollouts(self, num_rollouts, random_actions=False): if self.policy_arch == "mlp": self.policy_storage.add_sample( observation=ptu.get_numpy(obs.squeeze(dim=0)), - action=ptu.get_numpy(action.squeeze(dim=0)), + action=ptu.get_numpy( + action.squeeze(dim=0) + if self.act_continuous + else torch.argmax( + action.squeeze(dim=0), dim=-1, keepdims=True + ) # (1,) + ), reward=ptu.get_numpy(reward.squeeze(dim=0)), terminal=np.array([term], dtype=float), next_observation=ptu.get_numpy(next_obs.squeeze(dim=0)), @@ -490,9 +507,15 @@ def collect_rollouts(self, num_rollouts, random_actions=False): break # has to manually break if self.policy_arch == "memory": # add collected sequence to buffer + act_buffer = torch.cat(act_list, dim=0) # (L, dim) + if not self.act_continuous: + act_buffer = torch.argmax( + act_buffer, dim=-1, keepdims=True + ) # (L, 1) + self.policy_storage.add_episode( observations=ptu.get_numpy(torch.cat(obs_list, dim=0)), # (L, dim) - actions=ptu.get_numpy(torch.cat(act_list, dim=0)), # (L, dim) + actions=ptu.get_numpy(act_buffer), # (L, dim) rewards=ptu.get_numpy(torch.cat(rew_list, dim=0)), # (L, dim) terminals=np.array(term_list).reshape(-1, 1), # (L, 1) next_observations=ptu.get_numpy( diff --git a/policies/main.py b/policies/main.py index 6672874..5939407 100644 --- a/policies/main.py +++ b/policies/main.py @@ -15,7 +15,8 @@ FLAGS = flags.FLAGS flags.DEFINE_string("cfg", None, "path to configuration file") -flags.DEFINE_string("algo", "sac", "[td3, sac]") +flags.DEFINE_string("algo", None, "[td3, sac, sacd]") +flags.DEFINE_float("target_entropy", None, "for [sac, sacd]") flags.DEFINE_integer("seed", None, "seed") flags.DEFINE_integer("cuda", None, "cuda device id") flags.DEFINE_boolean( @@ -29,7 +30,10 @@ v = yaml.load(open(FLAGS.cfg)) # overwrite config params -v["policy"]["algo"] = FLAGS.algo +if FLAGS.algo is not None: + v["policy"]["algo"] = FLAGS.algo +if FLAGS.target_entropy is not None: + v["policy"]["target_entropy"] = FLAGS.target_entropy if FLAGS.seed is not None: v["seed"] = FLAGS.seed if FLAGS.cuda is not None: @@ -72,7 +76,7 @@ arch, algo = v["policy"]["arch"], v["policy"]["algo"] assert arch in ["mlp", "lstm", "gru"] -assert algo in ["td3", "sac"] +assert algo in ["td3", "sac", "sacd"] if arch == "mlp": if oracle: algo_name = f"oracle_{algo}" @@ -95,6 +99,9 @@ exp_id += "_shared" exp_id += "/" +if algo in ["sac", "sacd"] and "target_entropy" in v["policy"]: + exp_id += f"ent-{v['policy']['target_entropy']}/" + if arch in ["lstm", "gru"]: exp_id += f"len-{v['train']['sampled_seq_len']}/bs-{v['train']['batch_size']}/" exp_id += f"baseline-{v['train']['sample_weight_baseline']}/" diff --git a/policies/models/policy_mlp.py b/policies/models/policy_mlp.py index a3f8ef5..6aa79f8 100644 --- a/policies/models/policy_mlp.py +++ b/policies/models/policy_mlp.py @@ -11,7 +11,7 @@ from torch.optim import Adam import torchkit.pytorch_utils as ptu from torchkit.networks import FlattenMlp -from torchkit.continous_actor import DeterministicPolicy, TanhGaussianPolicy +from torchkit.actor import DeterministicPolicy, TanhGaussianPolicy, CategoricalPolicy class ModelFreeOffPolicy_MLP(nn.Module): @@ -23,6 +23,7 @@ class ModelFreeOffPolicy_MLP(nn.Module): TD3_name = "td3" SAC_name = "sac" + SACD_name = "sacd" def __init__( self, @@ -52,17 +53,28 @@ def __init__( self.gamma = gamma self.tau = tau - assert algo in [self.TD3_name, self.SAC_name] + assert algo in [self.TD3_name, self.SAC_name, self.SACD_name] self.algo = algo # q networks - use two network to mitigate positive bias + if self.algo in [self.TD3_name, self.SAC_name]: + extra_input_size = action_dim + output_size = 1 + else: # sac-discrete + extra_input_size = 0 + output_size = action_dim + self.qf1 = FlattenMlp( - input_size=obs_dim + action_dim, output_size=1, hidden_sizes=dqn_layers + input_size=obs_dim + extra_input_size, + output_size=output_size, + hidden_sizes=dqn_layers, ) self.qf1_optim = Adam(self.qf1.parameters(), lr=lr) self.qf2 = FlattenMlp( - input_size=obs_dim + action_dim, output_size=1, hidden_sizes=dqn_layers + input_size=obs_dim + extra_input_size, + output_size=output_size, + hidden_sizes=dqn_layers, ) self.qf2_optim = Adam(self.qf2.parameters(), lr=lr) @@ -81,16 +93,26 @@ def __init__( self.target_noise = target_noise self.target_noise_clip = target_noise_clip - else: # sac: automatic entropy coefficient tuning + elif self.algo == self.SAC_name: self.policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=policy_layers ) + else: # sac-discrete + self.policy = CategoricalPolicy( + obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=policy_layers + ) + + if self.algo in [self.SAC_name, self.SACD_name]: self.automatic_entropy_tuning = automatic_entropy_tuning if self.automatic_entropy_tuning: if target_entropy is not None: - self.target_entropy = float(target_entropy) + if self.algo == self.SAC_name: + self.target_entropy = float(target_entropy) + else: # sac-discrete: beta * log(|A|) + self.target_entropy = float(target_entropy) * np.log(action_dim) else: + assert self.algo == self.SAC_name self.target_entropy = -float(action_dim) self.log_alpha_entropy = torch.zeros( 1, requires_grad=True, device=ptu.device @@ -102,14 +124,6 @@ def __init__( self.policy_optim = Adam(self.policy.parameters(), lr=lr) - def forward(self, obs): - if self.algo == self.TD3_name: - action = self.policy(obs) - else: - action, _, _, _ = self.policy(obs) - q1, q2 = self.qf1(obs, action), self.qf2(obs, action) - return action, q1, q2 - def act( self, obs, deterministic=False, return_log_prob=False, use_target_policy=False ): @@ -125,11 +139,16 @@ def act( -1, 1 ) # NOTE return action, mean, None, None - # sac - action, mean, log_std, log_prob = self.policy( - obs, deterministic=deterministic, return_log_prob=return_log_prob - ) - return action, mean, log_std, log_prob + elif self.algo == self.SAC_name: + action, mean, log_std, log_prob = self.policy( + obs, deterministic=deterministic, return_log_prob=return_log_prob + ) + return action, mean, log_std, log_prob + else: + action, prob, log_prob = self.policy( + obs, deterministic=deterministic, return_log_prob=return_log_prob + ) + return action, prob, log_prob, None def update(self, batch): obs, next_obs = batch["obs"], batch["obs2"] # (B, dim) @@ -145,23 +164,43 @@ def update(self, batch): torch.randn_like(next_action) * self.target_noise ).clamp(-self.target_noise_clip, self.target_noise_clip) next_action = (next_action + action_noise).clamp(-1, 1) # NOTE - - else: + elif self.algo == self.SAC_name: next_action, _, _, next_log_prob = self.act( next_obs, return_log_prob=True ) + else: + _, next_prob, next_log_prob, _ = self.act( + next_obs, return_log_prob=True + ) # (B, A), (B, A) + + if self.algo in [self.TD3_name, self.SAC_name]: + next_q1 = self.qf1_target(next_obs, next_action) # (B, 1) + next_q2 = self.qf2_target(next_obs, next_action) + else: + next_q1 = self.qf1_target(next_obs) # (B, A) + next_q2 = self.qf2_target(next_obs) - next_q1 = self.qf1_target(next_obs, next_action) - next_q2 = self.qf2_target(next_obs, next_action) min_next_q_target = torch.min(next_q1, next_q2) - if self.algo == self.SAC_name: + if self.algo in [self.SAC_name, self.SACD_name]: min_next_q_target += self.alpha_entropy * (-next_log_prob) + if self.algo == self.SACD_name: # E_{a'\sim \pi}[Q(s',a')], (B, 1) + min_next_q_target = (next_prob * min_next_q_target).sum( + dim=-1, keepdims=True + ) + q_target = reward + (1.0 - done) * self.gamma * min_next_q_target - q1_pred = self.qf1(obs, action) - q2_pred = self.qf2(obs, action) + if self.algo in [self.TD3_name, self.SAC_name]: + q1_pred = self.qf1(obs, action) + q2_pred = self.qf2(obs, action) + else: + action = action.long() # (B, 1) + q1_pred = self.qf1(obs) + q2_pred = self.qf2(obs) + q1_pred = q1_pred.gather(dim=-1, index=action) + q2_pred = q2_pred.gather(dim=-1, index=action) qf1_loss = F.mse_loss(q1_pred, q_target) # TD error qf2_loss = F.mse_loss(q2_pred, q_target) # TD error @@ -182,14 +221,27 @@ def update(self, batch): new_action, _, _, _ = self.act( obs, deterministic=True, use_target_policy=False ) - else: + elif self.algo == self.SAC_name: new_action, _, _, log_prob = self.act(obs, return_log_prob=True) - min_q_new_actions = self._min_q(obs, new_action) + else: + _, new_prob, log_prob, _ = self.act(obs, return_log_prob=True) + + if self.algo in [self.TD3_name, self.SAC_name]: + q1 = self.qf1(obs, new_action) + q2 = self.qf2(obs, new_action) + else: + q1 = self.qf1(obs) + q2 = self.qf2(obs) + + min_q_new_actions = torch.min(q1, q2) policy_loss = -min_q_new_actions - if self.algo == self.SAC_name: + if self.algo in [self.SAC_name, self.SACD_name]: policy_loss += self.alpha_entropy * log_prob + if self.algo == self.SACD_name: # E_{a\sim \pi}[Q(s,a)] + policy_loss = (new_prob * policy_loss).sum(axis=-1, keepdims=True) # (B,1) + policy_loss = policy_loss.mean() # update policy network @@ -197,34 +249,34 @@ def update(self, batch): policy_loss.backward() self.policy_optim.step() - if self.algo == self.SAC_name and self.automatic_entropy_tuning: - alpha_entropy_loss = -( - self.log_alpha_entropy.exp() * (log_prob + self.target_entropy).detach() - ).mean() + if self.algo in [self.SAC_name, self.SACD_name]: + if self.algo == self.SACD_name: # -> negative entropy (B, 1) + log_prob = (new_prob * log_prob).sum(axis=-1, keepdims=True) - self.alpha_entropy_optim.zero_grad() - alpha_entropy_loss.backward() - self.alpha_entropy_optim.step() + current_log_prob = log_prob.mean().item() - self.alpha_entropy = self.log_alpha_entropy.exp().detach().item() + if self.automatic_entropy_tuning: + alpha_entropy_loss = -self.log_alpha_entropy.exp() * ( + current_log_prob + self.target_entropy + ) + + self.alpha_entropy_optim.zero_grad() + alpha_entropy_loss.backward() + self.alpha_entropy_optim.step() + + self.alpha_entropy = self.log_alpha_entropy.exp().item() outputs = { "qf1_loss": qf1_loss.item(), "qf2_loss": qf2_loss.item(), "policy_loss": policy_loss.item(), } - if self.algo == self.SAC_name: + if self.algo in [self.SAC_name, self.SACD_name]: outputs.update( - {"policy_entropy": -log_prob.mean().item(), "alpha": self.alpha_entropy} + {"policy_entropy": -current_log_prob, "alpha": self.alpha_entropy} ) return outputs - def _min_q(self, obs, action): - q1 = self.qf1(obs, action) - q2 = self.qf2(obs, action) - min_q = torch.min(q1, q2) - return min_q - def soft_target_update(self): ptu.soft_update_from_to(self.qf1, self.qf1_target, self.tau) ptu.soft_update_from_to(self.qf2, self.qf2_target, self.tau) diff --git a/policies/models/policy_rnn.py b/policies/models/policy_rnn.py index 8920785..959ed7b 100644 --- a/policies/models/policy_rnn.py +++ b/policies/models/policy_rnn.py @@ -34,6 +34,7 @@ class ModelFreeOffPolicy_Separate_RNN(nn.Module): TD3_name = Actor_RNN.TD3_name SAC_name = Actor_RNN.SAC_name + SACD_name = Actor_RNN.SACD_name def __init__( self, @@ -69,7 +70,7 @@ def __init__( self.gamma = gamma self.tau = tau - assert algo in [self.TD3_name, self.SAC_name] + assert algo in [self.TD3_name, self.SAC_name, self.SACD_name] self.algo = algo # Critics @@ -77,6 +78,7 @@ def __init__( obs_dim, action_dim, encoder, + algo, action_embedding_size, state_embedding_size, reward_embedding_size, @@ -114,8 +116,12 @@ def __init__( self.automatic_entropy_tuning = automatic_entropy_tuning if self.automatic_entropy_tuning: if target_entropy is not None: - self.target_entropy = float(target_entropy) + if self.algo == self.SAC_name: + self.target_entropy = float(target_entropy) + else: # sac-discrete: beta * log(|A|) + self.target_entropy = float(target_entropy) * np.log(action_dim) else: + assert self.algo == self.SAC_name self.target_entropy = -float(action_dim) self.log_alpha_entropy = torch.zeros( 1, requires_grad=True, device=ptu.device @@ -128,8 +134,6 @@ def __init__( # use separate optimizers self.critic_optimizer = Adam(self.critic.parameters(), lr=lr) self.actor_optimizer = Adam(self.actor.parameters(), lr=lr) - logger.log(self.critic) - logger.log(self.actor) @torch.no_grad() def get_initial_info(self): @@ -202,23 +206,35 @@ def forward(self, actions, rewards, observs, dones, masks): torch.randn_like(new_actions) * self.target_noise ).clamp(-self.target_noise_clip, self.target_noise_clip) new_actions = (new_actions + action_noise).clamp(-1, 1) # NOTE - else: + elif self.algo == self.SAC_name: new_actions, new_log_probs = self.actor( prev_actions=actions, rewards=rewards, observs=observs ) + else: + new_probs, new_log_probs = self.actor( + prev_actions=actions, rewards=rewards, observs=observs + ) next_q1, next_q2 = self.critic_target( prev_actions=actions, rewards=rewards, observs=observs, - current_actions=new_actions, - ) # (T+1, B, 1) + current_actions=new_actions + if self.algo in [self.TD3_name, self.SAC_name] + else new_probs, + ) # (T+1, B, 1 or A) + min_next_q_target = torch.min(next_q1, next_q2) - if self.algo == self.SAC_name: + if self.algo in [self.SAC_name, self.SACD_name]: min_next_q_target += self.alpha_entropy * ( -new_log_probs - ) # (T+1, B, 1) + ) # (T+1, B, 1 or A) + + if self.algo == self.SACD_name: # E_{a'\sim \pi}[Q(h',a')], (T+1, B, 1) + min_next_q_target = (new_probs * min_next_q_target).sum( + dim=-1, keepdims=True + ) # q_target: (T, B, 1) q_target = ( @@ -231,8 +247,20 @@ def forward(self, actions, rewards, observs, dones, masks): prev_actions=actions, rewards=rewards, observs=observs, - current_actions=actions[1:], # (T, B, 1) - ) # (T, B, 1) + current_actions=actions[1:], + ) # (T, B, 1 or A) + + if self.algo == self.SACD_name: + stored_actions = actions[1:] # (T, B, A) + stored_actions = torch.argmax( + stored_actions, dim=-1, keepdims=True + ) # (T, B, 1) + q1_pred = q1_pred.gather( + dim=-1, index=stored_actions + ) # (T, B, A) -> (T, B, 1) + q2_pred = q2_pred.gather( + dim=-1, index=stored_actions + ) # (T, B, A) -> (T, B, 1) # masked Bellman error: masks (T,B,1) ignore the invalid error # this is not equal to masks * q1_pred, cuz the denominator in mean() @@ -250,24 +278,38 @@ def forward(self, actions, rewards, observs, dones, masks): if self.algo == self.TD3_name: new_actions, _ = self.actor( prev_actions=actions, rewards=rewards, observs=observs - ) # (T+1, B, dim) - else: + ) # (T+1, B, A) + elif self.algo == self.SAC_name: new_actions, log_probs = self.actor( prev_actions=actions, rewards=rewards, observs=observs - ) # (T+1, B, dim) + ) # (T+1, B, A) + else: + new_probs, log_probs = self.actor( + prev_actions=actions, rewards=rewards, observs=observs + ) # (T+1, B, A) q1, q2 = self.critic( prev_actions=actions, rewards=rewards, observs=observs, - current_actions=new_actions, - ) # (T+1, B, 1) - min_q_new_actions = torch.min(q1, q2) # (T+1,B,1) + current_actions=new_actions + if self.algo in [self.TD3_name, self.SAC_name] + else new_probs, + ) # (T+1, B, 1 or A) + min_q_new_actions = torch.min(q1, q2) # (T+1,B,1 or A) policy_loss = -min_q_new_actions - if self.algo == self.SAC_name: # Q(h(t), pi(h(t))) + H[pi(h(t))] + if self.algo in [ + self.SAC_name, + self.SACD_name, + ]: # Q(h(t), pi(h(t))) + H[pi(h(t))] policy_loss += self.alpha_entropy * log_probs + if self.algo == self.SACD_name: # E_{a\sim \pi}[Q(h,a)] + policy_loss = (new_probs * policy_loss).sum( + axis=-1, keepdims=True + ) # (T+1,B,1) + policy_loss = policy_loss[:-1] # (T,B,1) remove the last obs # masked policy_loss policy_loss = (policy_loss * masks).sum() / num_valid @@ -280,11 +322,14 @@ def forward(self, actions, rewards, observs, dones, masks): self.soft_target_update() ### 4. update alpha - if self.algo == self.SAC_name: + if self.algo in [self.SAC_name, self.SACD_name]: # extract valid log_probs + if self.algo == self.SACD_name: # -> negative entropy (T+1, B, 1) + log_probs = (new_probs * log_probs).sum(axis=-1, keepdims=True) with torch.no_grad(): current_log_probs = (log_probs[:-1] * masks).sum() / num_valid current_log_probs = current_log_probs.item() + if self.automatic_entropy_tuning: alpha_entropy_loss = -self.log_alpha_entropy.exp() * ( current_log_probs + self.target_entropy @@ -301,7 +346,7 @@ def forward(self, actions, rewards, observs, dones, masks): "qf2_loss": qf2_loss.item(), "policy_loss": policy_loss.item(), } - if self.algo == self.SAC_name: + if self.algo in [self.SAC_name, self.SACD_name]: outputs.update( {"policy_entropy": -current_log_probs, "alpha": self.alpha_entropy} ) @@ -325,6 +370,12 @@ def update(self, batch): # all are 3D tensor (T,B,dim) actions, rewards, dones = batch["act"], batch["rew"], batch["term"] _, batch_size, _ = actions.shape + if self.algo == self.SACD_name: + # for discrete action space, convert to one-hot vectors + actions = F.one_hot( + actions.squeeze(-1).long(), num_classes=self.action_dim + ).float() # (T, B, A) + masks = batch["mask"] obs, next_obs = batch["obs"], batch["obs2"] # (T, B, dim) diff --git a/policies/models/policy_rnn_shared.py b/policies/models/policy_rnn_shared.py index d42bcf3..a3d4ef3 100644 --- a/policies/models/policy_rnn_shared.py +++ b/policies/models/policy_rnn_shared.py @@ -9,7 +9,7 @@ from torch.optim import Adam from utils import helpers as utl from torchkit.networks import FlattenMlp -from torchkit.continous_actor import DeterministicPolicy, TanhGaussianPolicy +from torchkit.actor import DeterministicPolicy, TanhGaussianPolicy import torchkit.pytorch_utils as ptu from utils import logger diff --git a/readme.md b/readme.md index 9eaa661..8c2ef08 100644 --- a/readme.md +++ b/readme.md @@ -24,6 +24,7 @@ There are many other (more complicated or specialized) methods for POMDPs and it Note that current repo should be run smoothly. DONE: +* Mar 2022: introduce recurrent [SAC-discrete](https://arxiv.org/abs/1910.07207) for **discrete action** space and see [this PR for instructions](https://github.com/twni2016/pomdp-baselines/pull/1) * Feb 2022: simplify `--oracle` commands, and upload the plotting scripts * Jan 2022: introduce new meta RL environments (*-Dir), and replace re-implementation of off-policy variBAD with original implementation * Dec 2021: add some command-line arguments to overwrite the config file and save the updated one @@ -49,7 +50,7 @@ pip install -r requirements.txt ``` The `requirements.txt` file includes all the dependencies (e.g. PyTorch, PyBullet) used in our experiments (including the compared methods), but there are two exceptions: -- To run Cheetah-Vel and Ant-Dir in meta RL, you have to install [MuJoCo](https://github.com/openai/mujoco-py) on your own (it is free now!) +- To run Cheetah-Vel and Ant-Dir in meta RL, you have to install [MuJoCo](https://github.com/openai/mujoco-py) (>=2.0.2) on your own (it is free now!) - To run robust RL and generalization in RL experiments, you have to install [Roboschool](https://github.com/openai/roboschool). - We found it hard to install Roboschool from scratch, therefore we provide a docker file `roboschool.sif` in [google drive](https://drive.google.com/file/d/1KpTpVwoU02AI7uQrk2T9hQ6s15EISRTa/view?usp=sharing) that contains Roboschool and the other necessary libraries, adapted from [SunBlaze repo](https://github.com/sunblaze-ucb/rl-generalization). - To download and activate the docker file by singularity (tested in v3.7) on a cluster (on a single server should be similar): @@ -80,7 +81,7 @@ To run our implementation, Markovian, and Oracle, in simply ``` export PYTHONPATH=${PWD}:$PYTHONPATH python3 policies/main.py --cfg configs///.yml \ - [--algo {td3,sac} --seed --cuda --oracle] + [--algo {td3,sac,sacd} --seed --cuda --oracle] ``` where `algo_name` specifies the algorithm name: - `mlp` correspond to **Markovian** policies diff --git a/torchkit/continous_actor.py b/torchkit/actor.py similarity index 68% rename from torchkit/continous_actor.py rename to torchkit/actor.py index eedd10b..7e48c9e 100644 --- a/torchkit/continous_actor.py +++ b/torchkit/actor.py @@ -1,13 +1,16 @@ import numpy as np import torch from torch import nn as nn +import torch.nn.functional as F +from torch.distributions import Categorical from torchkit.distributions import TanhNormal from torchkit.networks import Mlp LOG_SIG_MAX = 2 LOG_SIG_MIN = -20 +PROB_MIN = 1e-8 class DeterministicPolicy(Mlp): @@ -139,3 +142,61 @@ def forward( action = tanh_normal.sample() return action, mean, log_std, log_prob + + +class CategoricalPolicy(Mlp): + """Based on https://github.com/ku2482/sac-discrete.pytorch/blob/master/sacd/model.py + Usage: SAC-discrete + ``` + policy = CategoricalPolicy(...) + action, _, _ = policy(obs, deterministic=True) + action, _, _ = policy(obs, deterministic=False) + action, prob, log_prob = policy(obs, deterministic=False, return_log_prob=True) + ``` + NOTE: action space must be discrete + """ + + def __init__(self, obs_dim, action_dim, hidden_sizes, init_w=1e-3, **kwargs): + self.save_init_params(locals()) + super().__init__( + hidden_sizes, + input_size=obs_dim, + output_size=action_dim, + init_w=init_w, + **kwargs, + ) + self.obs_dim = obs_dim + self.action_dim = action_dim + + def forward( + self, + obs, + deterministic=False, + return_log_prob=False, + ): + """ + :param obs: Observation, usually 2D (B, dim), but maybe 3D (T, B, dim) + :param deterministic: If True, do not sample + :param return_log_prob: If True, return a sample and its log probability + return: action (*, B, A), prob (*, B, A), log_prob (*, B, A) + """ + action_logits = super().forward(obs) # (*, A) + + prob, log_prob = None, None + if deterministic: + action = torch.argmax(action_logits, dim=-1) # (*) + assert ( + return_log_prob == False + ) # NOTE: cannot be used for estimating entropy + else: + prob = F.softmax(action_logits, dim=-1) # (*, A) + distr = Categorical(prob) + # categorical distr cannot reparameterize + action = distr.sample() # (*) + if return_log_prob: + log_prob = torch.log(torch.clamp(prob, min=PROB_MIN)) + + # convert to one-hot vectors + action = F.one_hot(action.long(), num_classes=self.action_dim).float() # (*, A) + + return action, prob, log_prob diff --git a/torchkit/recurrent_actor.py b/torchkit/recurrent_actor.py index 0b7a52b..476a9ba 100644 --- a/torchkit/recurrent_actor.py +++ b/torchkit/recurrent_actor.py @@ -2,13 +2,14 @@ import torch.nn as nn from torch.nn import functional as F from utils import helpers as utl -from torchkit.continous_actor import DeterministicPolicy, TanhGaussianPolicy +from torchkit.actor import DeterministicPolicy, TanhGaussianPolicy, CategoricalPolicy import torchkit.pytorch_utils as ptu class Actor_RNN(nn.Module): TD3_name = "td3" SAC_name = "sac" + SACD_name = "sacd" LSTM_name = "lstm" GRU_name = "gru" RNNs = { @@ -35,7 +36,7 @@ def __init__( self.obs_dim = obs_dim self.action_dim = action_dim - assert algo in [self.TD3_name, self.SAC_name] + assert algo in [self.TD3_name, self.SAC_name, self.SACD_name] self.algo = algo ### Build Model @@ -86,12 +87,18 @@ def __init__( action_dim=self.action_dim, hidden_sizes=policy_layers, ) - else: + elif self.algo == self.SAC_name: self.policy = TanhGaussianPolicy( obs_dim=self.rnn_hidden_size + state_embedding_size, action_dim=self.action_dim, hidden_sizes=policy_layers, ) + else: # SAC-Discrete + self.policy = CategoricalPolicy( + obs_dim=self.rnn_hidden_size + state_embedding_size, + action_dim=self.action_dim, + hidden_sizes=policy_layers, + ) def get_hidden_states( self, prev_actions, rewards, observs, initial_internal_state=None @@ -136,12 +143,16 @@ def forward(self, prev_actions, rewards, observs): # 4. Actor if self.algo == self.TD3_name: - new_actions, log_probs = self.policy(joint_embeds), None - else: # SAC + new_actions = self.policy(joint_embeds) + return new_actions, None # (T+1, B, dim), None + elif self.algo == self.SAC_name: new_actions, _, _, log_probs = self.policy( joint_embeds, return_log_prob=True ) - return new_actions, log_probs # (T+1, B, dim), (T+1, B, 1) or None + return new_actions, log_probs # (T+1, B, dim), (T+1, B, 1) + else: # sac-d + _, probs, log_probs = self.policy(joint_embeds, return_log_prob=True) + return probs, log_probs # (T+1, B, dim), (T+1, B, dim) @torch.no_grad() def get_initial_info(self): @@ -203,10 +214,14 @@ def act( -1, 1 ) # NOTE action_tuple = (action, mean, None, None) - else: - # sac + elif self.algo == self.SAC_name: action_tuple = self.policy( joint_embeds, False, deterministic, return_log_prob ) - + else: + # sac-discrete + action, prob, log_prob = self.policy( + joint_embeds, deterministic, return_log_prob + ) + action_tuple = (action, prob, log_prob, None) return action_tuple, current_internal_state diff --git a/torchkit/recurrent_critic.py b/torchkit/recurrent_critic.py index 1e72814..596aed6 100644 --- a/torchkit/recurrent_critic.py +++ b/torchkit/recurrent_critic.py @@ -8,11 +8,19 @@ class Critic_RNN(nn.Module): + TD3_name = Actor_RNN.TD3_name + SAC_name = Actor_RNN.SAC_name + SACD_name = Actor_RNN.SACD_name + LSTM_name = Actor_RNN.LSTM_name + GRU_name = Actor_RNN.GRU_name + RNNs = Actor_RNN.RNNs + def __init__( self, obs_dim, action_dim, encoder, + algo, action_embedding_size, state_embedding_size, reward_embedding_size, @@ -25,6 +33,7 @@ def __init__( self.obs_dim = obs_dim self.action_dim = action_dim + self.algo = algo ### Build Model ## 1. embed action, state, reward (Feed-forward layers first) @@ -40,10 +49,10 @@ def __init__( ) self.rnn_hidden_size = rnn_hidden_size - assert encoder in Actor_RNN.RNNs + assert encoder in self.RNNs self.encoder = encoder - self.rnn = Actor_RNN.RNNs[encoder]( + self.rnn = self.RNNs[encoder]( input_size=rnn_input_size, hidden_size=self.rnn_hidden_size, num_layers=rnn_num_layers, @@ -57,20 +66,27 @@ def __init__( elif "weight" in name: nn.init.orthogonal_(param) + if self.algo in [self.TD3_name, self.SAC_name]: + extra_input_size = action_dim + output_size = 1 + else: # sac-discrete + extra_input_size = 0 + output_size = action_dim + ## 3. build another obs+act branch self.current_state_action_encoder = utl.FeatureExtractor( - obs_dim + action_dim, rnn_input_size, F.relu + obs_dim + extra_input_size, rnn_input_size, F.relu ) ## 4. build q networks self.qf1 = FlattenMlp( input_size=self.rnn_hidden_size + rnn_input_size, - output_size=1, + output_size=output_size, hidden_sizes=dqn_layers, ) self.qf2 = FlattenMlp( input_size=self.rnn_hidden_size + rnn_input_size, - output_size=1, + output_size=output_size, hidden_sizes=dqn_layers, ) @@ -90,7 +106,7 @@ def forward(self, prev_actions, rewards, observs, current_actions): """ For prev_actions a, rewards r, observs o: (T+1, B, dim) a[t] -> r[t], o[t] - current_actions a': (T or T+1, B, dim) + current_actions (or action probs for discrete actions) a': (T or T+1, B, dim) o[t] -> a'[t] NOTE: there is one timestep misalignment in prev_actions and current_actions """ @@ -112,18 +128,26 @@ def forward(self, prev_actions, rewards, observs, current_actions): # 2. another branch for state & **current** action if current_actions.shape[0] == observs.shape[0]: # current_actions include last obs's action, i.e. we have a'[T] in reaction to o[T] - curr_embed = self.current_state_action_encoder( - torch.cat((observs, current_actions), dim=-1) - ) # (T+1, B, dim) + if self.algo in [self.TD3_name, self.SAC_name]: + curr_embed = self.current_state_action_encoder( + torch.cat((observs, current_actions), dim=-1) + ) # (T+1, B, dim) + else: + curr_embed = self.current_state_action_encoder(observs) # (T+1, B, dim) # 3. joint embeds joint_embeds = torch.cat( (hidden_states, curr_embed), dim=-1 ) # (T+1, B, dim) else: # current_actions does NOT include last obs's action - curr_embed = self.current_state_action_encoder( - torch.cat((observs[:-1], current_actions), dim=-1) - ) # (T, B, dim) + if self.algo in [self.TD3_name, self.SAC_name]: + curr_embed = self.current_state_action_encoder( + torch.cat((observs[:-1], current_actions), dim=-1) + ) # (T, B, dim) + else: + curr_embed = self.current_state_action_encoder( + observs[:-1] + ) # (T, B, dim) # 3. joint embeds joint_embeds = torch.cat( (hidden_states[:-1], curr_embed), dim=-1 @@ -133,4 +157,4 @@ def forward(self, prev_actions, rewards, observs, current_actions): q1 = self.qf1(joint_embeds) q2 = self.qf2(joint_embeds) - return q1, q2 # (T or T+1, B, 1) + return q1, q2 # (T or T+1, B, 1 or A) diff --git a/utils/helpers.py b/utils/helpers.py index 0d8aff3..b25e705 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -43,10 +43,12 @@ def get_dim(space): def env_step(env, action): # action: (A) - # return: all 2D tensor shape (B=1, dim) for vae input - # squeeze if action should be scalar else unchanged - action = ptu.get_numpy(action.squeeze(dim=-1)) + # return: all 2D tensor shape (B=1, dim) + action = ptu.get_numpy(action) + if env.action_space.__class__.__name__ == "Discrete": + action = np.argmax(action) # one-hot to int next_obs, reward, done, info = env.step(action) + # move to torch next_obs = ptu.from_numpy(next_obs).view(-1, next_obs.shape[0]) reward = ptu.FloatTensor([reward]).view(-1, 1)