Skip to content

Commit

Permalink
dynamically compose logic for observations preprocessing to support d…
Browse files Browse the repository at this point in the history
…iscrete obs-spaces; related to #16 ;
  • Loading branch information
dimitry12 committed Nov 15, 2018
1 parent bb995f5 commit 80ed26a
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,29 @@ def main(*, hparams, random_name=''):
writer = tf.contrib.summary.create_file_writer(log_dir_name)
rl_writer = tf.contrib.summary.create_file_writer(log_dir_name + 'rl')

# only works for Box
observations_space_dim_count = env.observation_space.shape[0]
action_space_cardinality = env.action_space.n # only works for Discrete
def vector_to_tf_constant(x):
return tf.constant(x, dtype=tf.keras.backend.floatx(), shape=(1, len(x)))

def passthrough(tensor):
return tensor

def one_hot(tensor, dims):
return tf.layers.flatten(tf.one_hot(tensor, dims))

if isinstance(env.observation_space, gym.spaces.Box):
observations_space_dim_count = env.observation_space.shape[0]

def preprocess_obs(x): return passthrough(vector_to_tf_constant(x))
elif isinstance(env.observation_space, gym.spaces.Discrete):
observations_space_dim_count = 1

def preprocess_obs(x): return one_hot(
vector_to_tf_constant(x), env.observation_space.n)

if isinstance(env.action_space, gym.spaces.Discrete):
action_space_cardinality = env.action_space.n # only works for Discrete
elif isinstance(env.action_space, gym.spaces.tuple_space.Tuple) and not [s for s in env.action_space.spaces if not isinstance(s, gym.spaces.Discrete)]:
action_space_cardinality = sum([s.n for s in env.action_space.spaces])

pv_model = P_and_V_Model(classes=action_space_cardinality, mlp_layers=hparams['MLP_LAYERS'],
mlp_units=hparams['MLP_UNITS'], v_mlp_layers=hparams['V_MLP_LAYERS'], p_mlp_layers=hparams['P_MLP_LAYERS'])
Expand Down Expand Up @@ -134,12 +154,9 @@ def main(*, hparams, random_name=''):
steps_in_current_episode = 0
total_reward_in_current_episode = 0

def vector_to_tf_constant(x): return tf.constant(
x, dtype=tf.keras.backend.floatx(), shape=(1, len(x)))

for _ in range(hparams['TRANSITIONS_IN_EXPERIENCE_BUFFER']):
p_logits, v_logit = pv_model(
vector_to_tf_constant(observation))
preprocess_obs(observation))
p_distribution = tf.distributions.Categorical(logits=p_logits)
action = p_distribution.sample()
tf.contrib.summary.histogram(
Expand Down Expand Up @@ -178,7 +195,7 @@ def vector_to_tf_constant(x): return tf.constant(
steps_in_current_episode += 1
total_reward_in_current_episode += reward

_, last_v_logit = pv_model(vector_to_tf_constant(observation))
_, last_v_logit = pv_model(preprocess_obs(observation))

# Each slice of experience now contains:
# - observation
Expand Down

0 comments on commit 80ed26a

Please sign in to comment.