# -*- coding: utf-8 -*- # @Time : 9/8/21 10:33 PM # @Author : Tingfeng Li, <tl601@cs.rutgers.edu>, Rutgers University. import torch def gather_nd(params, indices, name=None): ''' the input indices must be a 2d tensor in the form of [[a,b,..,c],...], which represents the location of the elements. ''' indices = indices.t().long() ndim = indices.size(0) idx = torch.autograd.Variable(torch.zeros_like(indices[0]).long()) indices = torch.autograd.Variable(indices.long()) m = 1 for i in range(ndim)[::-1]: idx += indices[i] * m m *= params.size(i) return torch.take(params, idx) def _sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() batch_size = sequence_length.size(0) #seq_range = torch.range(0, max_len - 1).long() seq_range = torch.arange(0, max_len).long() seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) # seq_range_expand = Variable(seq_range_expand) if sequence_length.is_cuda: seq_range_expand = seq_range_expand.cuda() seq_length_expand = (sequence_length.unsqueeze(1) .expand_as(seq_range_expand)).long() return seq_range_expand < seq_length_expand def _build_discounts_matrix(T, gamma): """Build lower-triangular matrix of discounts. For example for T = 3: D = [[1, 0, 0] [gamma, 1, 0] [gamma^2, gamma, 1]] Then with R, our N x T incremental rewards matrix, the discounted sum is R * D """ def roll(x, n): return torch.cat((x[-n:], x[:-n])) power_ltri = torch.cumsum(_sequence_mask(torch.arange(0, T) + 1, T), dim=0) power_ltri = roll(power_ltri, 1) power_ltri[0] = 0 gamma = torch.ones((T, T)) * gamma gamma_ltri = gamma.pow(power_ltri.float()) gamma_ltri *= _sequence_mask(torch.arange(0, T) + 1, T).float() return gamma_ltri def _get_generated_probabilities(input_batch_size, logits_seq, seq_len, coo_actions): # TODO """Returns a [batch_size, seq_len] Tensor with probabilities for each action that was drawn """ softmax_ = torch.nn.Softmax(dim=2) dists = softmax_(logits_seq) r_dists = gather_nd(dists, torch.Tensor(coo_actions).cuda()) return r_dists.view((input_batch_size, seq_len)) def get_policy_loss(incremental_rewards, gamma, logits_seq, coo_actions): ''' Input is a [batch_size, seq_len] Tensor where each entry represents the incremental reward for an action on a data point :return: ''' batch_size, T = incremental_rewards.shape # Form matrix of discounts to apply gamma_ltri = _build_discounts_matrix(T, gamma).cuda() # Compute future discounted rewards as [batch_size x seq_len] matrix future_rewards = torch.mm(incremental_rewards, gamma_ltri) # Compute baseline and advantage baseline = torch.mean(future_rewards, dim=0, keepdim=True) advantages = future_rewards - baseline # Apply advantage to policy policy = _get_generated_probabilities(batch_size, logits_seq, T, coo_actions) return -(torch.log(policy) * advantages.detach()), policy[:, -1], advantages