-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_callback.py
28 lines (22 loc) · 985 Bytes
/
custom_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
import numpy as np
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.callbacks import BaseCallback
class CurrentTrainReward(BaseCallback):
def __init__(self, log_dir: str, verbose: int = 1):
super(CurrentTrainReward, self).__init__(verbose)
self.log_dir = log_dir
self.current_train_reward = -np.inf
def _on_step(self) -> bool:
# x, y = ts2xy(load_results(self.log_dir), 'timesteps')
# if len(x) > 0:
# value = y[-1]
# self.logger.record('non_smoothed_reward', value)
return True
def _on_training_end(self) -> None:
# Retrieve training reward
x, y = ts2xy(load_results(self.log_dir), 'timesteps')
if len(x) > 0:
# Mean training reward over the last 100 episodes
self.current_train_reward = np.mean(y[-100:])
print("Current Training Reward", self.current_train_reward)