Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baselines for Tensorflow 2.0. #978

Merged
merged 3 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ respectively. Note that these results may be not on the latest version of the co
To cite this repository in publications:

@misc{baselines,
author = {Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai and Zhokhov, Peter},
author = {Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Tan, Zhenyu and Wu, Yuhuai and Zhokhov, Peter},
title = {OpenAI Baselines},
year = {2017},
publisher = {GitHub},
Expand Down
169 changes: 70 additions & 99 deletions baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
import time
import functools
import tensorflow as tf

from baselines import logger

from baselines.common import set_global_seeds, explained_variance
from baselines.common import tf_util
from baselines.common.policies import build_policy
from baselines.common.models import get_network_builder
from baselines.common.policies import PolicyWithValue


from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.utils import InverseLinearTimeDecay
from baselines.a2c.runner import Runner
from baselines.ppo2.ppo2 import safemean
import os.path as osp
from collections import deque

from tensorflow import losses

class Model(object):
class Model(tf.keras.Model):

"""
We use this class to :
Expand All @@ -30,90 +27,42 @@ class Model(object):
save/load():
- Save load the model
"""
def __init__(self, policy, env, nsteps,
def __init__(self, *, ac_space, policy_network, nupdates,
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):

sess = tf_util.get_session()
nenvs = env.num_envs
nbatch = nenvs*nsteps


with tf.variable_scope('a2c_model', reuse=tf.AUTO_REUSE):
# step_model is used for sampling
step_model = policy(nenvs, 1, sess)

# train_model is used to train our network
train_model = policy(nbatch, nsteps, sess)

A = tf.placeholder(train_model.action.dtype, train_model.action.shape)
ADV = tf.placeholder(tf.float32, [nbatch])
R = tf.placeholder(tf.float32, [nbatch])
LR = tf.placeholder(tf.float32, [])

# Calculate the loss
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss

# Policy loss
neglogpac = train_model.pd.neglogp(A)
# L = A(s,a) * -logpi(a|s)
pg_loss = tf.reduce_mean(ADV * neglogpac)

# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
entropy = tf.reduce_mean(train_model.pd.entropy())

# Value loss
vf_loss = losses.mean_squared_error(tf.squeeze(train_model.vf), R)

loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef

# Update parameters using loss
# 1. Get the model parameters
params = find_trainable_variables("a2c_model")

# 2. Calculate the gradients
grads = tf.gradients(loss, params)
if max_grad_norm is not None:
# Clip the gradients (normalize)
grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
grads = list(zip(grads, params))
# zip aggregate each gradient with parameters associated
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da

# 3. Make op for one policy and value update step of A2C
trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)

_train = trainer.apply_gradients(grads)

lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)

def train(obs, states, rewards, masks, actions, values):
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
# rewards = R + yV(s')
advs = rewards - values
for step in range(len(obs)):
cur_lr = lr.value()

td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
if states is not None:
td_map[train_model.S] = states
td_map[train_model.M] = masks
policy_loss, value_loss, policy_entropy, _ = sess.run(
[pg_loss, vf_loss, entropy, _train],
td_map
)
return policy_loss, value_loss, policy_entropy


self.train = train
self.train_model = train_model
self.step_model = step_model
self.step = step_model.step
self.value = step_model.value
self.initial_state = step_model.initial_state
self.save = functools.partial(tf_util.save_variables, sess=sess)
self.load = functools.partial(tf_util.load_variables, sess=sess)
tf.global_variables_initializer().run(session=sess)
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6)):

super(Model, self).__init__(name='A2CModel')
self.train_model = PolicyWithValue(ac_space, policy_network, value_network=None, estimate_q=False)
lr_schedule = InverseLinearTimeDecay(initial_learning_rate=lr, nupdates=nupdates)
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr_schedule, rho=alpha, epsilon=epsilon)

self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.step = self.train_model.step
self.value = self.train_model.value
self.initial_state = self.train_model.initial_state

@tf.function
def train(self, obs, states, rewards, masks, actions, values):
advs = rewards - values
with tf.GradientTape() as tape:
policy_latent = self.train_model.policy_network(obs)
pd, _ = self.train_model.pdtype.pdfromlatent(policy_latent)
neglogpac = pd.neglogp(actions)
entropy = tf.reduce_mean(pd.entropy())
vpred = self.train_model.value(obs)
vf_loss = tf.reduce_mean(tf.square(vpred - rewards))
pg_loss = tf.reduce_mean(advs * neglogpac)
loss = pg_loss - entropy * self.ent_coef + vf_loss * self.vf_coef

var_list = tape.watched_variables()
grads = tape.gradient(loss, var_list)
grads, _ = tf.clip_by_global_norm(grads, self.max_grad_norm)
grads_and_vars = list(zip(grads, var_list))
self.optimizer.apply_gradients(grads_and_vars)

return pg_loss, vf_loss, entropy


def learn(
Expand Down Expand Up @@ -185,31 +134,53 @@ def learn(

set_global_seeds(seed)

total_timesteps = int(total_timesteps)

# Get the nb of env
nenvs = env.num_envs
policy = build_policy(env, network, **network_kwargs)

# Get state_space and action_space
ob_space = env.observation_space
ac_space = env.action_space

if isinstance(network, str):
network_type = network
policy_network_fn = get_network_builder(network_type)(**network_kwargs)
policy_network = policy_network_fn(ob_space.shape)

# Calculate the batch_size
nbatch = nenvs * nsteps
nupdates = total_timesteps // nbatch

# Instantiate the model object (that creates step_model and train_model)
model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
model = Model(ac_space=ac_space, policy_network=policy_network, nupdates=nupdates, ent_coef=ent_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps)

if load_path is not None:
model.load(load_path)
load_path = osp.expanduser(load_path)
ckpt = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=None)
ckpt.restore(manager.latest_checkpoint)

# Instantiate the runner object
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
epinfobuf = deque(maxlen=100)

# Calculate the batch_size
nbatch = nenvs*nsteps

# Start total timer
tstart = time.time()

for update in range(1, total_timesteps//nbatch+1):
for update in range(1, nupdates+1):
# Get mini batch of experiences
obs, states, rewards, masks, actions, values, epinfos = runner.run()
epinfobuf.extend(epinfos)

obs = tf.constant(obs)
if states is not None:
states = tf.constant(states)
rewards = tf.constant(rewards)
masks = tf.constant(masks)
actions = tf.constant(actions)
values = tf.constant(values)
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
nseconds = time.time()-tstart

Expand Down
34 changes: 19 additions & 15 deletions baselines/a2c/runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tensorflow as tf
import numpy as np
from baselines.a2c.utils import discount_with_dones
from baselines.common.runners import AbstractEnvRunner
Expand All @@ -15,40 +16,37 @@ class Runner(AbstractEnvRunner):
def __init__(self, env, model, nsteps=5, gamma=0.99):
super().__init__(env=env, model=model, nsteps=nsteps)
self.gamma = gamma
self.batch_action_shape = [x if x is not None else -1 for x in model.train_model.action.shape.as_list()]
self.ob_dtype = model.train_model.X.dtype.as_numpy_dtype

def run(self):
# We initialize the lists that will contain the mb of experiences
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
mb_states = self.states
epinfos = []
for n in range(self.nsteps):
for _ in range(self.nsteps):
# Given observations, take action and value (V(s))
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
actions, values, states, _ = self.model.step(self.obs, S=self.states, M=self.dones)

obs = tf.constant(self.obs)
actions, values, self.states, _ = self.model.step(obs)
actions = actions._numpy()
# Append the experiences
mb_obs.append(np.copy(self.obs))
mb_obs.append(self.obs.copy())
mb_actions.append(actions)
mb_values.append(values)
mb_values.append(values._numpy())
mb_dones.append(self.dones)

# Take actions in env and look the results
obs, rewards, dones, infos = self.env.step(actions)
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
for info in infos:
maybeepinfo = info.get('episode')
if maybeepinfo: epinfos.append(maybeepinfo)
self.states = states
self.dones = dones
self.obs = obs
mb_rewards.append(rewards)

mb_dones.append(self.dones)

# Batch of steps to batch of rollouts
mb_obs = np.asarray(mb_obs, dtype=self.ob_dtype).swapaxes(1, 0).reshape(self.batch_ob_shape)
mb_obs = sf01(np.asarray(mb_obs, dtype=self.obs.dtype))
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
mb_actions = np.asarray(mb_actions, dtype=self.model.train_model.action.dtype.name).swapaxes(1, 0)
mb_actions = sf01(np.asarray(mb_actions, dtype=actions.dtype))
mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
mb_masks = mb_dones[:, :-1]
Expand All @@ -57,7 +55,7 @@ def run(self):

if self.gamma > 0.0:
# Discount/bootstrap off value fn
last_values = self.model.value(self.obs, S=self.states, M=self.dones).tolist()
last_values = self.model.value(tf.constant(self.obs))._numpy().tolist()
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
rewards = rewards.tolist()
dones = dones.tolist()
Expand All @@ -68,9 +66,15 @@ def run(self):

mb_rewards[n] = rewards

mb_actions = mb_actions.reshape(self.batch_action_shape)

mb_rewards = mb_rewards.flatten()
mb_values = mb_values.flatten()
mb_masks = mb_masks.flatten()
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos

def sf01(arr):
"""
swap and then flatten axes 0 and 1
"""
s = arr.shape
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
Loading