-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPongAI.py
71 lines (52 loc) · 2.34 KB
/
PongAI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import os
from pong import *
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy, EpsGreedyQPolicy, LinearAnnealedPolicy
from rl.memory import SequentialMemory
if __name__ == '__main__':
# get the script directory
curdir = os.path.dirname(os.path.abspath(__file__))
modelfilename = f"{curdir}\\dqn_pong_weights.h5f"
# Construct the argument parser
parser = argparse.ArgumentParser()
# Switches
parser.add_argument("-train", required=False, dest='train', action='store_const', const=True,
default=False, help="Train the data and save the model")
# Parse the commandline
args = parser.parse_args()
# initialize the game class
env = Pong()
# number of possible actions; 3 for the Pong game
nb_actions = env.action_space.n
# Build the neural network model
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(128, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(nb_actions, activation='linear'))
print(model.summary())
# Keep the previous actions and steps
memory = SequentialMemory(limit=1000000, window_length=1)
# epsilon-greedy policy
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1, value_min=0.05, value_test=.05, nb_steps=20000)
# Use Deep Q Network to train the agent
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10, target_model_update=1e-2,
policy=policy, dueling_type='max')
# compile the network
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
# running in train mode
if args.train:
# fit the model for 100k steps
history=dqn.fit(env, nb_steps=1000000, visualize=False, verbose=1, action_repetition=2, log_interval=100000)
# save the model
dqn.save_weights(modelfilename, overwrite=True)
else:
# load the model
dqn.load_weights(filepath=modelfilename)
# play game for 10 rounds
dqn.test(env, nb_episodes=10, visualize=True)