Skip to content

Commit e7e34c9

Browse files
author
transedward
committed
Add ram for test, not done yet
1 parent 174b57e commit e7e34c9

File tree

5 files changed

+139
-21
lines changed

5 files changed

+139
-21
lines changed

dqn_learn.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
USE_CUDA = torch.cuda.is_available()
2121
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
22+
2223
class Variable(autograd.Variable):
2324
def __init__(self, data, *args, **kwargs):
2425
if USE_CUDA:
@@ -27,6 +28,16 @@ def __init__(self, data, *args, **kwargs):
2728

2829
OptimizerSpec = namedtuple("OptimizerSpec", ["constructor", "kwargs", "lr_schedule"])
2930

31+
# Check if the parameters of the model update accordingly.
32+
def check_norm(model):
33+
total_norm = 0.0
34+
for p in model.parameters():
35+
param_norm = p.grad.data.norm(2.0)
36+
total_norm += param_norm ** 2.0
37+
total_norm = total_norm ** (1.0 / 2.0)
38+
return total_norm
39+
40+
3041
def dqn_learing(
3142
env,
3243
q_func,
@@ -94,10 +105,10 @@ def dqn_learing(
94105

95106
if len(env.observation_space.shape) == 1:
96107
# This means we are running on low-dimensional observations (e.g. RAM)
97-
input_shape = env.observation_space.shape
108+
input_arg = env.observation_space.shape[0]
98109
else:
99110
img_h, img_w, img_c = env.observation_space.shape
100-
input_shape = (img_h, img_w, frame_history_len * img_c)
111+
input_arg = frame_history_len * img_c
101112
num_actions = env.action_space.n
102113

103114
# Construct an epilson greedy policy with given exploration schedule
@@ -106,14 +117,14 @@ def select_epilson_greedy_action(model, obs, t):
106117
eps_threshold = exploration.value(t)
107118
if sample > eps_threshold:
108119
obs = torch.from_numpy(obs).type(dtype).unsqueeze(0) / 255.0
109-
# Detach variable from the current graph since we don't want gradients to propagated
110-
return model(Variable(obs)).detach().data.max(1)[1].cpu()
120+
# Use volatile = True if variable is only used in inference mode, i.e. don’t save the history
121+
return model(Variable(obs, volatile=True)).data.max(1)[1].cpu()
111122
else:
112123
return torch.IntTensor([[random.randrange(num_actions)]])
113124

114125
# Initialize target q function and q function
115-
Q = q_func(input_shape[2], num_actions).type(dtype)
116-
target_Q = q_func(input_shape[2], num_actions).type(dtype)
126+
Q = q_func(input_arg, num_actions).type(dtype)
127+
target_Q = q_func(input_arg, num_actions).type(dtype)
117128

118129
# Construct optimizer with adaptive learning rate
119130
# https://discuss.pytorch.org/t/adaptive-learning-rate/320
@@ -148,6 +159,7 @@ def construct_optimizer(t):
148159
# previous frames.
149160
# recent_observations: shape(img_h, img_w, frame_history_len) are input to to the model
150161
recent_observations = replay_buffer.encode_recent_observation().transpose(2, 0, 1)
162+
# recent_observations = replay_buffer.encode_recent_observation()
151163

152164
# Choose random action if not yet start learning
153165
if t > learning_starts:
@@ -176,9 +188,11 @@ def construct_optimizer(t):
176188
obs_batch, act_batch, rew_batch, next_obs_batch, done_mask = replay_buffer.sample(batch_size)
177189
# Convert numpy nd_array to torch variables for calculation
178190
obs_batch = Variable(torch.from_numpy(obs_batch.transpose(0, 3, 1, 2)).type(dtype) / 255.0)
191+
# obs_batch = Variable(torch.from_numpy(obs_batch).type(dtype) / 255.0)
179192
act_batch = Variable(torch.from_numpy(act_batch).long())
180193
rew_batch = Variable(torch.from_numpy(rew_batch))
181194
next_obs_batch = Variable(torch.from_numpy(next_obs_batch.transpose(0, 3, 1, 2)).type(dtype) / 255.0, volatile=True)
195+
# next_obs_batch = Variable(torch.from_numpy(next_obs_batch).type(dtype) / 255.0, volatile=True)
182196
done_mask = torch.from_numpy(done_mask)
183197

184198
if USE_CUDA:
@@ -190,22 +204,36 @@ def construct_optimizer(t):
190204
# We choose Q based on action taken.
191205
current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
192206
# Compute next Q value, based on which acion gives max Q values
193-
next_max_Q_values = Variable(torch.zeros(batch_size))
194-
next_max_Q_values[done_mask == 0] = target_Q(next_obs_batch).max(1)[0]
207+
next_max_Q_values = Variable(torch.zeros(batch_size).type(dtype))
208+
# # Detach variable from the current graph since we don't want gradients to propagated
209+
next_max_Q_values[done_mask == 0] = target_Q(next_obs_batch).detach().max(1)[0]
195210
# Compute Bellman error, use huber loss to mitigate outlier impact
196211
bellman_error = F.smooth_l1_loss(current_Q_values, rew_batch + (gamma * next_max_Q_values))
197212
# Run backward pass and clip the gradient
213+
Q.zero_grad()
198214
bellman_error.backward()
199-
nn.utils.clip_grad_norm(Q.parameters(), grad_norm_clipping)
215+
216+
if check_norm(Q) > grad_norm_clipping:
217+
print('Before clipping gradient:')
218+
print('total_norm: ', check_norm(Q))
219+
nn.utils.clip_grad_norm(Q.parameters(), grad_norm_clipping)
220+
print('After clipping gradient:')
221+
print('total_norm: ', check_norm(Q))
200222
# Perfom the update
201223
optimizer = construct_optimizer(t)
202224
optimizer.step()
225+
# print('After update Q:')
226+
# check_norm(Q)
203227
num_param_updates += 1
204228

205229
# Periodically update the target network by Q network to target Q network
206230
if num_param_updates % target_update_freq == 0:
231+
# print('Before update target:')
232+
# check_norm(target_Q)
207233
for target_param, param in zip(target_Q.parameters(), Q.parameters()):
208234
target_param.data = param.data.clone()
235+
# print('After update target:')
236+
# check_norm(target_Q)
209237

210238
### 4. Log progress
211239
episode_rewards = get_wrapper_by_name(env, "Monitor").get_episode_rewards()

dqn_model.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def __init__(self, in_channels=4, num_actions=18):
77
Initialize a deep Q-learning network as described in
88
https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf
99
Arguments:
10-
input_channel: number of channel of input.
10+
in_channels: number of channel of input.
1111
i.e The number of most recent frames stacked together as describe in the paper
1212
num_actions: number of action-value to output, one-to-one correspondence to action in game.
1313
"""
@@ -24,3 +24,22 @@ def forward(self, x):
2424
x = F.relu(self.conv3(x))
2525
x = F.relu(self.fc4(x.view(x.size(0), -1)))
2626
return self.fc5(x)
27+
28+
class DQN_RAM(nn.Module):
29+
def __init__(self, in_features=4, num_actions=18):
30+
"""
31+
Initialize a deep Q-learning network for testing algorithm
32+
in_features: number of features of input.
33+
num_actions: number of action-value to output, one-to-one correspondence to action in game.
34+
"""
35+
super(DQN_RAM, self).__init__()
36+
self.fc1 = nn.Linear(in_features, 256)
37+
self.fc2 = nn.Linear(256, 128)
38+
self.fc3 = nn.Linear(128, 64)
39+
self.fc4 = nn.Linear(64, num_actions)
40+
41+
def forward(self, x):
42+
x = F.relu(self.fc1(x))
43+
x = F.relu(self.fc2(x))
44+
x = F.relu(self.fc3(x))
45+
return self.fc4(x)

main.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,6 @@ def stopping_criterion(n):
2121
return lambda env: get_wrapper_by_name(env, "Monitor").get_total_steps() >= n
2222

2323
def main(env, num_timesteps):
24-
# Get Atari games.
25-
benchmark = gym.benchmark_spec('Atari40M')
26-
27-
# Change the index to select a different game.
28-
task = benchmark.tasks[3]
29-
30-
# Run training
31-
seed = 0 # Use a seed of zero (you may want to randomize the seed!)
32-
env = get_env(task, seed)
33-
3424
# This is just a rough estimate
3525
num_iterations = float(num_timesteps) / 4.0
3626

ram.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import gym
2+
import torch.optim as optim
3+
4+
from dqn_model import DQN_RAM
5+
from dqn_learn import OptimizerSpec, dqn_learing
6+
from utils.gym import get_ram_env, get_wrapper_by_name
7+
from utils.schedule import PiecewiseSchedule, LinearSchedule
8+
9+
BATCH_SIZE = 32
10+
GAMMA = 0.99
11+
REPLAY_BUFFER_SIZE=1000000
12+
LEARNING_STARTS=50000
13+
LEARNING_FREQ=4
14+
FRAME_HISTORY_LEN=1
15+
TARGER_UPDATE_FREQ=10000
16+
GRAD_NORM_CLIPPING=10
17+
18+
def stopping_criterion(n):
19+
# notice that here t is the number of steps of the wrapped env,
20+
# which is different from the number of steps in the underlying env
21+
return lambda env: get_wrapper_by_name(env, "Monitor").get_total_steps() >= n
22+
23+
def main(env, num_timesteps=int(4e7)):
24+
# This is just a rough estimate
25+
num_iterations = float(num_timesteps) / 4.0
26+
27+
# define learning rate and exploration schedules below
28+
lr_multiplier = 1.0
29+
lr_schedule = PiecewiseSchedule([
30+
(0, 1e-4 * lr_multiplier),
31+
(num_iterations / 10, 1e-4 * lr_multiplier),
32+
(num_iterations / 2, 5e-5 * lr_multiplier),
33+
], outside_value=5e-5 * lr_multiplier)
34+
35+
optimizer = OptimizerSpec(
36+
constructor=optim.Adam,
37+
kwargs=dict(eps=1e-4),
38+
lr_schedule=lr_schedule
39+
)
40+
41+
exploration_schedule = PiecewiseSchedule([
42+
(0, 0.2),
43+
(1e6, 0.1),
44+
(num_iterations / 2, 0.01),
45+
], outside_value=0.01)
46+
47+
dqn_learing(
48+
env=env,
49+
q_func=DQN_RAM,
50+
optimizer_spec=optimizer,
51+
exploration=exploration_schedule,
52+
stopping_criterion=stopping_criterion(num_timesteps),
53+
replay_buffer_size=REPLAY_BUFFER_SIZE,
54+
batch_size=BATCH_SIZE,
55+
gamma=GAMMA,
56+
learning_starts=LEARNING_STARTS,
57+
learning_freq=LEARNING_FREQ,
58+
frame_history_len=FRAME_HISTORY_LEN,
59+
target_update_freq=TARGER_UPDATE_FREQ,
60+
grad_norm_clipping=GRAD_NORM_CLIPPING
61+
)
62+
63+
if __name__ == '__main__':
64+
# Get Atari games.
65+
env = gym.make('Pong-ram-v0')
66+
67+
# Run training
68+
seed = 0 # Use a seed of zero (you may want to randomize the seed!)
69+
env = get_ram_env(env, seed)
70+
71+
main(env)

utils/gym.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from gym import wrappers
66

77
from utils.seed import set_global_seeds
8-
from utils.atari_wrapper import wrap_deepmind
8+
from utils.atari_wrapper import wrap_deepmind, wrap_deepmind_ram
99

1010
def get_env(task, seed):
1111
env_id = task.env_id
@@ -21,6 +21,16 @@ def get_env(task, seed):
2121

2222
return env
2323

24+
def get_ram_env(env, seed):
25+
set_global_seeds(seed)
26+
env.seed(seed)
27+
28+
expt_dir = '/tmp/gym-results'
29+
env = wrappers.Monitor(env, expt_dir, force=True)
30+
env = wrap_deepmind_ram(env)
31+
32+
return env
33+
2434
def get_wrapper_by_name(env, classname):
2535
currentenv = env
2636
while True:

0 commit comments

Comments
 (0)