-
Notifications
You must be signed in to change notification settings - Fork 324
/
ppo.py
154 lines (138 loc) · 6.6 KB
/
ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from rlpyt.algos.pg.base import PolicyGradientAlgo, OptInfo
from rlpyt.agents.base import AgentInputs, AgentInputsRnn
from rlpyt.utils.tensor import valid_mean
from rlpyt.utils.quick_args import save__init__args
from rlpyt.utils.buffer import buffer_to, buffer_method
from rlpyt.utils.collections import namedarraytuple
from rlpyt.utils.misc import iterate_mb_idxs
LossInputs = namedarraytuple("LossInputs",
["agent_inputs", "action", "return_", "advantage", "valid", "old_dist_info"])
class PPO(PolicyGradientAlgo):
"""
Proximal Policy Optimization algorithm. Trains the agent by taking
multiple epochs of gradient steps on minibatches of the training data at
each iteration, with advantages computed by generalized advantage
estimation. Uses clipped likelihood ratios in the policy loss.
"""
def __init__(
self,
discount=0.99,
learning_rate=0.001,
value_loss_coeff=1.,
entropy_loss_coeff=0.01,
OptimCls=torch.optim.Adam,
optim_kwargs=None,
clip_grad_norm=1.,
initial_optim_state_dict=None,
gae_lambda=1,
minibatches=4,
epochs=4,
ratio_clip=0.1,
linear_lr_schedule=True,
normalize_advantage=False,
):
"""Saves input settings."""
if optim_kwargs is None:
optim_kwargs = dict()
save__init__args(locals())
def initialize(self, *args, **kwargs):
"""
Extends base ``initialize()`` to initialize learning rate schedule, if
applicable.
"""
super().initialize(*args, **kwargs)
self._batch_size = self.batch_spec.size // self.minibatches # For logging.
if self.linear_lr_schedule:
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer=self.optimizer,
lr_lambda=lambda itr: (self.n_itr - itr) / self.n_itr) # Step once per itr.
self._ratio_clip = self.ratio_clip # Save base value.
def optimize_agent(self, itr, samples):
"""
Train the agent, for multiple epochs over minibatches taken from the
input samples. Organizes agent inputs from the training data, and
moves them to device (e.g. GPU) up front, so that minibatches are
formed within device, without further data transfer.
"""
recurrent = self.agent.recurrent
agent_inputs = AgentInputs( # Move inputs to device once, index there.
observation=samples.env.observation,
prev_action=samples.agent.prev_action,
prev_reward=samples.env.prev_reward,
)
agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
if hasattr(self.agent, "update_obs_rms"):
self.agent.update_obs_rms(agent_inputs.observation)
return_, advantage, valid = self.process_returns(samples)
loss_inputs = LossInputs( # So can slice all.
agent_inputs=agent_inputs,
action=samples.agent.action,
return_=return_,
advantage=advantage,
valid=valid,
old_dist_info=samples.agent.agent_info.dist_info,
)
if recurrent:
# Leave in [B,N,H] for slicing to minibatches.
init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0.
T, B = samples.env.reward.shape[:2]
opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
# If recurrent, use whole trajectories, only shuffle B; else shuffle all.
batch_size = B if self.agent.recurrent else T * B
mb_size = batch_size // self.minibatches
for _ in range(self.epochs):
for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
T_idxs = slice(None) if recurrent else idxs % T
B_idxs = idxs if recurrent else idxs // T
self.optimizer.zero_grad()
rnn_state = init_rnn_state[B_idxs] if recurrent else None
# NOTE: if not recurrent, will lose leading T dim, should be OK.
loss, entropy, perplexity = self.loss(
*loss_inputs[T_idxs, B_idxs], rnn_state)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.agent.parameters(), self.clip_grad_norm)
self.optimizer.step()
opt_info.loss.append(loss.item())
opt_info.gradNorm.append(torch.tensor(grad_norm).item()) # backwards compatible
opt_info.entropy.append(entropy.item())
opt_info.perplexity.append(perplexity.item())
self.update_counter += 1
if self.linear_lr_schedule:
self.lr_scheduler.step()
self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr
return opt_info
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info,
init_rnn_state=None):
"""
Compute the training loss: policy_loss + value_loss + entropy_loss
Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
Value loss: 0.5 * (estimated_value - return) ^ 2
Calls the agent to compute forward pass on training data, and uses
the ``agent.distribution`` to compute likelihoods and entropies. Valid
for feedforward or recurrent agents.
"""
if init_rnn_state is not None:
# [B,N,H] --> [N,B,H] (for cudnn).
init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
init_rnn_state = buffer_method(init_rnn_state, "contiguous")
dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state)
else:
dist_info, value = self.agent(*agent_inputs)
dist = self.agent.distribution
ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info,
new_dist_info=dist_info)
surr_1 = ratio * advantage
clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
1. + self.ratio_clip)
surr_2 = clipped_ratio * advantage
surrogate = torch.min(surr_1, surr_2)
pi_loss = - valid_mean(surrogate, valid)
value_error = 0.5 * (value - return_) ** 2
value_loss = self.value_loss_coeff * valid_mean(value_error, valid)
entropy = dist.mean_entropy(dist_info, valid)
entropy_loss = - self.entropy_loss_coeff * entropy
loss = pi_loss + value_loss + entropy_loss
perplexity = dist.mean_perplexity(dist_info, valid)
return loss, entropy, perplexity