-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestrun.py
67 lines (51 loc) · 1.6 KB
/
testrun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gym
import gym.spaces
from env.rocketlander import RocketLander
from util.pid import PID_Benchmark
import imageio
import base64
import IPython
def embed_mp4(filename):
"""Embeds an mp4 file in the notebook."""
video = open(filename,'rb').read()
b64 = base64.b64encode(video)
tag = '''
<video width="640" height="480" controls>
<source src="data:video/mp4;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
</video>'''.format(b64.decode())
return IPython.display.HTML(tag)
def create_policy_eval_video(policy, filename, num_episodes=5, fps=60):
filename = filename + ".mp4"
with imageio.get_writer(filename, fps=fps) as video:
for _ in range(num_episodes):
observation = env.reset()
done = False
video.append_data(env.render(mode='rgb_array'))
while not done:
action = policy(observation)
observation, reward, done, info = env.step(action)
video.append_data(env.render(mode='rgb_array'))
return embed_mp4(filename)
# Initialize the PID algorithm
pid = PID_Benchmark()
env = RocketLander(continuous=True)
observation = env.reset()
PRINT_DEBUG_MSG = True
create_policy_eval_video(pid.pid_algorithm, 'PID', fps=60)
for e in range(10):
while True:
env.render()
#action = env.action_space.sample()
action = pid.pid_algorithm(observation)
observation,reward,done,info = env.step(action)
if PRINT_DEBUG_MSG:
print("Action Taken ",action)
print("Observation ",observation)
print("Reward Gained ",reward)
print("Info ",info,end='\n\n')
if done:
print("Simulation done.")
env.reset()
break
env.close()