-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
39 lines (29 loc) · 1.46 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse
import os
import gym
import gym_insertion # noqa: F401
from stable_baselines import SAC
def main():
parser = argparse.ArgumentParser("Insertion, Manual mode")
parser.add_argument('checkpoint_path', type=str, help='Path to checkpoint')
parser.add_argument('--host', default="192.168.2.121", type=str, help='IP of the server (default is a Windows#2)')
parser.add_argument('--port', default=9090, type=int, help='Port that should be used to connect to the server')
parser.add_argument('--use_coord', action="store_true", help=('If set, the environment\'s observation space will be'
'coordinates instead of images'))
args = parser.parse_args()
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
env = gym.make('insertion-v0', kwargs={'host': args.host, "port": args.port, "use_coord": args.use_coord})
print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")
if args.use_coord:
model = SAC('MlpPolicy', env, verbose=1, tensorboard_log="../insertion_tensorboard/")
else:
model = SAC('CnnPolicy', env, verbose=1, tensorboard_log="../insertion_tensorboard/")
model.load(args.checkpoint_path, env=env)
obs = env.reset()
for i in range(10000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
if __name__ == "__main__":
main()