-
Notifications
You must be signed in to change notification settings - Fork 0
/
RL.py
97 lines (73 loc) · 3.19 KB
/
RL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- coding: utf-8 -*-
# @Time : 9/8/21 10:33 PM
# @Author : Tingfeng Li, <tl601@cs.rutgers.edu>, Rutgers University.
import torch
def gather_nd(params, indices, name=None):
'''
the input indices must be a 2d tensor in the form of [[a,b,..,c],...],
which represents the location of the elements.
'''
indices = indices.t().long()
ndim = indices.size(0)
idx = torch.autograd.Variable(torch.zeros_like(indices[0]).long())
indices = torch.autograd.Variable(indices.long())
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
return torch.take(params, idx)
def _sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
#seq_range = torch.range(0, max_len - 1).long()
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
# seq_range_expand = Variable(seq_range_expand)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand)).long()
return seq_range_expand < seq_length_expand
def _build_discounts_matrix(T, gamma):
"""Build lower-triangular matrix of discounts.
For example for T = 3: D = [[1, 0, 0]
[gamma, 1, 0]
[gamma^2, gamma, 1]]
Then with R, our N x T incremental rewards matrix, the discounted sum is
R * D
"""
def roll(x, n):
return torch.cat((x[-n:], x[:-n]))
power_ltri = torch.cumsum(_sequence_mask(torch.arange(0, T) + 1, T), dim=0)
power_ltri = roll(power_ltri, 1)
power_ltri[0] = 0
gamma = torch.ones((T, T)) * gamma
gamma_ltri = gamma.pow(power_ltri.float())
gamma_ltri *= _sequence_mask(torch.arange(0, T) + 1, T).float()
return gamma_ltri
def _get_generated_probabilities(input_batch_size, logits_seq, seq_len, coo_actions): # TODO
"""Returns a [batch_size, seq_len] Tensor with probabilities for each
action that was drawn
"""
softmax_ = torch.nn.Softmax(dim=2)
dists = softmax_(logits_seq)
r_dists = gather_nd(dists, torch.Tensor(coo_actions).cuda())
return r_dists.view((input_batch_size, seq_len))
def get_policy_loss(incremental_rewards, gamma, logits_seq, coo_actions):
'''
Input is a [batch_size, seq_len] Tensor where each entry represents
the incremental reward for an action on a data point
:return:
'''
batch_size, T = incremental_rewards.shape
# Form matrix of discounts to apply
gamma_ltri = _build_discounts_matrix(T, gamma).cuda()
# Compute future discounted rewards as [batch_size x seq_len] matrix
future_rewards = torch.mm(incremental_rewards, gamma_ltri)
# Compute baseline and advantage
baseline = torch.mean(future_rewards, dim=0, keepdim=True)
advantages = future_rewards - baseline
# Apply advantage to policy
policy = _get_generated_probabilities(batch_size, logits_seq, T, coo_actions)
return -(torch.log(policy) * advantages.detach()), policy[:, -1], advantages