From 5d9604792681539ee4ce5512d657402a1122784d Mon Sep 17 00:00:00 2001 From: Nemo Fournier Date: Tue, 4 Jun 2019 11:45:37 +0200 Subject: [PATCH] Added testing during training to evaluate the current training status --- environment.py | 3 +++ main.py | 67 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/environment.py b/environment.py index dfe69e9..2a6c695 100644 --- a/environment.py +++ b/environment.py @@ -20,6 +20,9 @@ def get_env(self): """ return self._env, self._params + def close_env(self): + self._env.close() + def get_goal(self): return diff --git a/main.py b/main.py index d9c7a96..fb38cc0 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,11 @@ experiment = "FetchReach-v1" env = gym.make(experiment) -# Hyperparameters +# Program hyperparameters +TESTING_INTERVAL = 50 # number of updates between two evaluation of the policy +TESTING_ROLLOUTS = 100 # number of rollouts performed to evaluate the current policy + +# Algorithm hyperparameters BATCH_SIZE = 32 BUFFER_SIZE = 100000 MAX_STEPS = 50 # WARNING: defined in multiple files... @@ -26,14 +30,16 @@ replay_buffer = ReplayBuffer(BUFFER_SIZE) # should be done per episode -randomized_environment.sample_env() -env, env_params = randomized_environment.get_env() -success = 0 for ep in range(EPISODES): # generate a rollout + # generate an environment + randomized_environment.sample_env() + env, env_params = randomized_environment.get_env() + + # reset the environment current_obs_dict = env.reset() @@ -53,7 +59,6 @@ # rollout the whole episode while not done: - env.render() obs = current_obs_dict['observation'] history = episode.get_history() @@ -109,9 +114,6 @@ # WARNING FIXME: needs padding t_batch += episodes[i].get_terminal()[1:] - - #s_batch, a_batch, r_batch, t_batch, s2_batch, history_batch, env_batch, goal_batch = replay_buffer.sample_batch(BATCH_SIZE) - target_action_batch = agent.evaluate_actor_batch(agent._actor.predict_target, next_s_batch, goal_batch, history_batch) predicted_actions = agent.evaluate_actor_batch(agent._actor.predict, next_s_batch, goal_batch, history_batch) @@ -136,3 +138,52 @@ # Update target networks agent.update_target_actor() agent.update_target_critic() + + randomized_environment.close_env() + + # perform policy evaluation + if ep % TESTING_INTERVAL == 0: + success_number = 0 + + for test_ep in range(TESTING_ROLLOUTS): + randomized_environment.sample_env() + env, env_params = randomized_environment.get_env() + + current_obs_dict = env.reset() + + # read the current goal, and initialize the episode + goal = current_obs_dict['desired_goal'] + episode = Episode(goal, env_params, MAX_STEPS) + + # get the first observation and first fake "old-action" + # TODO: decide if this fake action should be zero or random + obs = current_obs_dict['observation'] + last_action = env.action_space.sample() + + episode.add_step(last_action, obs, 0) + + done = False + + # rollout the whole episode + while not done: + obs = current_obs_dict['observation'] + history = episode.get_history() + + action = agent.evaluate_actor(agent._actor.predict_target, obs, goal, history) + + new_obs_dict, step_reward, done, info = env.step(action[0]) + + new_obs = new_obs_dict['observation'] + + episode.add_step(action[0], new_obs, step_reward) + + total_reward += step_reward + + current_obs_dict = new_obs_dict + + if info['is_success'] > 0.0: + success_number += 1 + + randomized_environment.close_env() + + print("Testing at episode {}, success rate : {}".format(ep, success_number/TESTING_ROLLOUTS))