Skip to content

Commit

Permalink
eval/fit return dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
halil93ibrahim committed Mar 30, 2022
1 parent cc1fbac commit e6a6e97
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def fit(self, env=None, logging_path='', silent=False, verbose=True):
print('env should be gym.Env')
return
self.last_checkpoint_time_step = 0
self.mean_reward = -10
self.logdir = logging_path
if isinstance(self.env, DummyVecEnv):
self.env = self.env.envs[0]
Expand All @@ -85,6 +86,7 @@ def fit(self, env=None, logging_path='', silent=False, verbose=True):
self.env = DummyVecEnv([lambda: self.env])
self.agent.set_env(self.env)
self.agent.learn(total_timesteps=self.iters, callback=self.callback)
return {"last_20_episodes_mean_reward": self.mean_reward}

def eval(self, env):
"""
Expand All @@ -108,7 +110,7 @@ def eval(self, env):
sum_of_rewards += rewards
if dones:
break
return sum_of_rewards
return {"rewards_collected": sum_of_rewards}

def save(self, path):
"""
Expand Down Expand Up @@ -161,14 +163,14 @@ def callback(self, _locals, _globals):
x, y = ts2xy(load_results(self.logdir), 'timesteps')

if len(y) > 20:
mean_reward = np.mean(y[-20:])
self.mean_reward = np.mean(y[-20:])
else:
return True

if x[-1] - self.last_checkpoint_time_step > self.checkpoint_after_iter:
self.last_checkpoint_time_step = x[-1]
check_point_path = Path(self.logdir,
'checkpoint_save' + str(x[-1]) + 'with_mean_rew' + str(mean_reward))
'checkpoint_save' + str(x[-1]) + 'with_mean_rew' + str(self.mean_reward))
self.save(str(check_point_path))

return True
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_infer(self):
self.assertTrue((action < self.env.action_space.n), "Actions above discrete action space dimensions")

def test_eval(self):
episode_reward = self.learner.eval(self.env)
episode_reward = self.learner.eval(self.env)["rewards_collected"]
self.assertTrue((episode_reward > -100), "Episode reward cannot be lower than -100")
self.assertTrue((episode_reward < 100), "Episode reward cannot pass 100")

Expand Down

0 comments on commit e6a6e97

Please sign in to comment.