-
Notifications
You must be signed in to change notification settings - Fork 0
/
WAIL_regularizers.py
58 lines (48 loc) · 1.95 KB
/
WAIL_regularizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Fast implementation of regularizers in WAIL
Wasserstein Adversarial Imitation Learning
https://arxiv.org/abs/1906.08113#
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
"""
import torch
import math
def l2_reg(g_sa, e_sa, g_o, e_o, wail_epsilon):
"""
g_sa : concatenated agent/generator data i.e. g_sa = torch.cat([g_states, g_actions], 1)
e_sa : concatenated expert data i.e. e_sa = torch.cat([e_states, e_actions], 1)
g_o : g_sa passed through Discriminator network i.e. g_o = D(g_sa)
e_o : e_sa passed through Discriminator network i.e. e_o = D(e_sa)
wail_epsilon : hyper-parameter epsilon used in WAIL algorithm
"""
a = e_o.unsqueeze(0).permute(dims=[1,0,2])
b = a - g_o
diff = b.reshape(g_sa.shape[0]*e_sa.shape[0])
r = g_sa.unsqueeze(0).permute(dims=[1,0,2])
s = ((r - e_sa)**2).sum(dim=-1)
dxy = s.reshape(diff.shape)
dxy = torch.sqrt(dxy)
reg = diff - dxy
reg[reg < 0] = 0
reg1 = reg**2
return -reg1.mean()/(4*wail_epsilon)
def entropy_reg(g_sa, e_sa, g_o, e_o, wail_epsilon):
"""
g_sa : concatenated agent/generator data i.e. g_sa = torch.cat([g_states, g_actions], 1)
e_sa : concatenated expert data i.e. e_sa = torch.cat([e_states, e_actions], 1)
g_o : g_sa passed through Discriminator network i.e. g_o = D(g_sa)
e_o : e_sa passed through Discriminator network i.e. e_o = D(e_sa)
wail_epsilon : hyper-parameter epsilon used in WAIL algorithm
"""
a = e_o.unsqueeze(0).permute(dims=[1,0,2])
b = a - g_o
diff = b.reshape(g_sa.shape[0]*e_sa.shape[0])
r = g_sa.unsqueeze(0).permute(dims=[1,0,2])
s = ((r - e_sa)**2).sum(dim=-1)
dxy = s.reshape(diff.shape)
dxy = torch.sqrt(dxy)
reg = (diff - dxy)/args.wail_epsilon
torch.exp(reg).sum()
reg[reg < 0] = 0
reg1 = reg**2
return -reg1.sum()/(4*wail_epsilon)