-
Notifications
You must be signed in to change notification settings - Fork 0
/
animate.py
43 lines (35 loc) · 1.3 KB
/
animate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gymnasium as gym
import torch
from agent import DDPGAgent
from argparse import ArgumentParser
from utils import save_animation
def generate_animation(env_name):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make(env_name, render_mode="rgb_array")
agent = DDPGAgent(env_name, env.observation_space.shape, env.action_space.shape, tau=0.001)
agent.to(device)
agent.load_checkpoints()
best_total_reward = float("-inf")
best_frames = None
for _ in range(10):
frames = []
total_reward = 0
state, _ = env.reset()
term, trunc = False, False
while not term and not trunc:
frames.append(env.render())
action = agent.choose_action(state)
next_state, reward, term, trunc, _ = env.step(action)
state = next_state
total_reward += reward
if total_reward > best_total_reward:
best_total_reward = total_reward
best_frames = frames
save_animation(best_frames, f"environments/{env_name}.gif")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"-e", "--env", required=True, help="Environment name from Gymnasium"
)
args = parser.parse_args()
generate_animation(args.env)