Skip to content

Commit

Permalink
Added testing during training to evaluate the current training status
Browse files Browse the repository at this point in the history
  • Loading branch information
little-nem committed Jun 4, 2019
1 parent 42af539 commit 5d96047
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
3 changes: 3 additions & 0 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def get_env(self):
"""
return self._env, self._params

def close_env(self):
self._env.close()

def get_goal(self):
return

67 changes: 59 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
experiment = "FetchReach-v1"
env = gym.make(experiment)

# Hyperparameters
# Program hyperparameters
TESTING_INTERVAL = 50 # number of updates between two evaluation of the policy
TESTING_ROLLOUTS = 100 # number of rollouts performed to evaluate the current policy

# Algorithm hyperparameters
BATCH_SIZE = 32
BUFFER_SIZE = 100000
MAX_STEPS = 50 # WARNING: defined in multiple files...
Expand All @@ -26,14 +30,16 @@
replay_buffer = ReplayBuffer(BUFFER_SIZE)

# should be done per episode
randomized_environment.sample_env()
env, env_params = randomized_environment.get_env()

success = 0

for ep in range(EPISODES):
# generate a rollout

# generate an environment
randomized_environment.sample_env()
env, env_params = randomized_environment.get_env()


# reset the environment
current_obs_dict = env.reset()

Expand All @@ -53,7 +59,6 @@

# rollout the whole episode
while not done:
env.render()
obs = current_obs_dict['observation']
history = episode.get_history()

Expand Down Expand Up @@ -109,9 +114,6 @@
# WARNING FIXME: needs padding
t_batch += episodes[i].get_terminal()[1:]


#s_batch, a_batch, r_batch, t_batch, s2_batch, history_batch, env_batch, goal_batch = replay_buffer.sample_batch(BATCH_SIZE)

target_action_batch = agent.evaluate_actor_batch(agent._actor.predict_target, next_s_batch, goal_batch, history_batch)

predicted_actions = agent.evaluate_actor_batch(agent._actor.predict, next_s_batch, goal_batch, history_batch)
Expand All @@ -136,3 +138,52 @@
# Update target networks
agent.update_target_actor()
agent.update_target_critic()

randomized_environment.close_env()

# perform policy evaluation
if ep % TESTING_INTERVAL == 0:
success_number = 0

for test_ep in range(TESTING_ROLLOUTS):
randomized_environment.sample_env()
env, env_params = randomized_environment.get_env()

current_obs_dict = env.reset()

# read the current goal, and initialize the episode
goal = current_obs_dict['desired_goal']
episode = Episode(goal, env_params, MAX_STEPS)

# get the first observation and first fake "old-action"
# TODO: decide if this fake action should be zero or random
obs = current_obs_dict['observation']
last_action = env.action_space.sample()

episode.add_step(last_action, obs, 0)

done = False

# rollout the whole episode
while not done:
obs = current_obs_dict['observation']
history = episode.get_history()

action = agent.evaluate_actor(agent._actor.predict_target, obs, goal, history)

new_obs_dict, step_reward, done, info = env.step(action[0])

new_obs = new_obs_dict['observation']

episode.add_step(action[0], new_obs, step_reward)

total_reward += step_reward

current_obs_dict = new_obs_dict

if info['is_success'] > 0.0:
success_number += 1

randomized_environment.close_env()

print("Testing at episode {}, success rate : {}".format(ep, success_number/TESTING_ROLLOUTS))

0 comments on commit 5d96047

Please sign in to comment.