Skip to content

Commit

Permalink
implement support for multi-actions; related to #16 ;
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitry12 committed Nov 19, 2018
1 parent 723213b commit f3e3d61
Showing 1 changed file with 48 additions and 16 deletions.
64 changes: 48 additions & 16 deletions ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,41 @@
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import json
import os


def get_actions_and_neglogp(p_logits, postprocess_preds, actions=None):
p_logits = postprocess_preds(p_logits)
p_distributions = [tf.distributions.Categorical(
logits=_p_logits) for _p_logits in p_logits]

if actions is None:
# batch of 1
actions = [tf.reshape(_p_distribution.sample(), (1,))
for _p_distribution in p_distributions]

# only log if we are actively generating actions
for i, _p_distribution in enumerate(p_distributions):
tf.contrib.summary.histogram(
'policy_probabilities_' + str(i), _p_distribution.probs)
else:
# actions must be a list of (batchsize,1) tensors
# for zip-map to work
actions = tf.split(actions, actions.shape[1], axis=1)
actions = [tf.reshape(a, (a.shape[0],)) for a in actions]

def neg_log_p_ac_func(p, a):
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=p, labels=a)

neg_log_p_ac = tf.add_n([neg_log_p_ac_func(_p_logits, _action)
for _p_logits, _action in zip(p_logits, actions)])

return actions, neg_log_p_ac


def get_transformations(env):
def vector_to_tf_constant(x, dtype=tf.keras.backend.floatx()):
return tf.constant(x, dtype=dtype, shape=(1, len(x)))
return tf.constant(tf.cast(x, dtype), dtype=dtype)

def passthrough(tensor):
return tensor
Expand All @@ -29,11 +59,15 @@ def preprocess_obs(x): return one_hot(
vector_to_tf_constant(x, dtype=tf.int32), env.observation_space.n)

if isinstance(env.action_space, gym.spaces.Discrete):
action_space_cardinality = env.action_space.n # only works for Discrete
action_space_cardinalities = [
env.action_space.n] # only works for Discrete
elif isinstance(env.action_space, gym.spaces.tuple_space.Tuple) and not [s for s in env.action_space.spaces if not isinstance(s, gym.spaces.Discrete)]:
action_space_cardinality = sum([s.n for s in env.action_space.spaces])
action_space_cardinalities = [s.n for s in env.action_space.spaces]

return (observations_space_dim_count, preprocess_obs, action_space_cardinality, lambda x: x)
def postprocess_preds(x):
return tf.split(x, action_space_cardinalities, axis=1)

return (observations_space_dim_count, preprocess_obs, sum(action_space_cardinalities), postprocess_preds)


def calculate_gae(*, hparams, ADVANTAGE_LAMBDA, rewards, episode_dones, predicted_values, episode_done, last_v_logit):
Expand Down Expand Up @@ -162,14 +196,12 @@ def main(*, hparams, random_name=''):
total_reward_in_current_episode = 0

for _ in range(hparams['TRANSITIONS_IN_EXPERIENCE_BUFFER']):
if isinstance(observation, int):
observation = np.array([observation])
p_logits, v_logit = pv_model(
preprocess_obs(observation))
p_distribution = tf.distributions.Categorical(logits=p_logits)
action = p_distribution.sample()
tf.contrib.summary.histogram(
'policy_probabilities', p_distribution.probs)
neg_log_p_ac = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=p_logits, labels=action)
preprocess_obs([observation]))
action, neg_log_p_ac = get_actions_and_neglogp(
p_logits, postprocess_preds)

# gym envs overwrite observations
observations.append(observation.copy())
Expand All @@ -179,7 +211,7 @@ def main(*, hparams, random_name=''):
neg_log_p_ac_s.append(neg_log_p_ac)

observation, reward, episode_done, infos = env.step(
action.numpy()[0])
[a.numpy()[0] for a in action])
rewards.append(reward)
total_reward += reward
if (hparams['RENDER']):
Expand All @@ -202,7 +234,7 @@ def main(*, hparams, random_name=''):
steps_in_current_episode += 1
total_reward_in_current_episode += reward

_, last_v_logit = pv_model(preprocess_obs(observation))
_, last_v_logit = pv_model(preprocess_obs([observation]))

# Each slice of experience now contains:
# - observation
Expand Down Expand Up @@ -264,9 +296,9 @@ def main(*, hparams, random_name=''):

with tf.GradientTape() as tape:
train_p_logits, train_v_logit = pv_model(
observations_batch)
neg_log_p_ac = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=train_p_logits, labels=old_taken_actions_batch)
preprocess_obs(observations_batch))
_, neg_log_p_ac = get_actions_and_neglogp(
train_p_logits, postprocess_preds, actions=old_taken_actions_batch)

# Only care about how proximate the updated policy is
# relative to the probability of the *taken* action.
Expand Down

0 comments on commit f3e3d61

Please sign in to comment.