-
Notifications
You must be signed in to change notification settings - Fork 4
/
model.py
136 lines (100 loc) · 4.49 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import numpy as np
import mxnet as mx
from mxnet import gluon, autograd, nd
from utils import *
import mxnet.ndarray as F
class ActorCritic_Discrete(gluon.Block):
'''
No activation method on the action logits, as the sampling, log_prob and entropy
from OpenAI seems to work for any scale
'''
def __init__(self, state_dim, action_dim, args, **kwargs):
super(ActorCritic_Discrete, self).__init__(**kwargs)
self.state_dim = state_dim
self.action_dim = action_dim
self.args = args
'''
with self.name_scope():
self.l1 = gluon.nn.Dense(in_units=self.state_dim, units=100, activation='relu')
self.value = gluon.nn.Dense(in_units=100, units=1)
self.logits = gluon.nn.Dense(in_units=100, units=action_dim)
'''
with self.name_scope():
self.conv1 = gluon.nn.Conv2D(channels=32, kernel_size=8, strides=4, use_bias=True, activation='relu')
self.conv2 = gluon.nn.Conv2D(channels=64, kernel_size=4, strides=2, use_bias=True, activation='relu')
self.conv3 = gluon.nn.Conv2D(channels=64, kernel_size=3, strides=1, use_bias=True, activation='relu')
self.l1 = gluon.nn.Dense(units=512, activation='relu')
self.logits = gluon.nn.Dense(units=self.action_dim)
self.value = gluon.nn.Dense(units=1)
self.conv1.collect_params().initialize(OrthoInit(np.sqrt(2)), ctx=self.args.ctx)
self.conv2.collect_params().initialize(OrthoInit(np.sqrt(2)), ctx=self.args.ctx)
self.conv3.collect_params().initialize(OrthoInit(np.sqrt(2)), ctx=self.args.ctx)
self.l1.collect_params().initialize(OrthoInit(np.sqrt(2)), ctx=self.args.ctx)
self.logits.collect_params().initialize(OrthoInit(0.01), ctx=self.args.ctx)
self.value.collect_params().initialize(OrthoInit(1.0), ctx=self.args.ctx)
self.act_loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=True, from_logits=False)
def forward(self, x):
x = x / 255.
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.l1(x)
value = self.value(x)
logits = F.softmax(self.logits(x))
return value, logits
def get_value(self, x):
value, _ = self(x)
return value
def choose_action(self, x):
_, logits = self(x)
action = self.sample(logits)
return action
def sample(self, logits):
# u = nd.random.uniform(shape=logits.shape)
# return nd.argmax(logits - nd.log(-nd.log(u)), axis=-1)
return nd.sample_multinomial(logits)
def log_prob(self, logits, action):
'''
action : action index
logits : unnormalized
'''
# One number, Not vector output
# This doesn't work
return -self.act_loss(logits, action)
def entropy(self, logits):
# This works
out = -nd.sum(logits * nd.log(logits + 1e-8), axis=1)
return out
class ActorCritic_Gaussian(gluon.Block):
def __init__(self, state_dim, action_dim, action_bound, args, **kwargs):
super(ActorCritic_Gaussian, self).__init__(**kwargs)
self.state_dim = state_dim
self.action_dim = action_dim
self.action_bound = action_bound
self.args = args
with self.name_scope():
self.l1 = gluon.nn.Dense(in_units=self.state_dim, units=100, activation='relu')
self.mu = gluon.nn.Dense(in_units=100, units=action_dim, activation='tanh')
self.sigma = gluon.nn.Dense(in_units=100, units=action_dim, activation='softrelu')
self.value = gluon.nn.Dense(in_units=100, units=1)
def forward(self, x):
x = self.l1(x)
mu = self.mu(x) * self.action_bound[1] # ALWAYS scale the mean by action
sigma = self.sigma(x)
val = self.value(x)
return val, mu, sigma
def choose_action(self, x):
_, mu, sigma = self(x)
out = self.sample(mu, sigma)
return out
def get_value(self, x):
val, _, _ = self(x)
return val
def sample(self, mu, sigma):
epsilon = nd.random_normal(shape=mu.shape, loc=0., scale=1., ctx=self.args.ctx)
out = mu + sigma * epsilon
return out
def entropy(self, sigma):
return nd.sum(nd.log(sigma + 1e-8) + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
def log_prob(self, x, mu, sigma):
return -0.5 * np.log(2.0 * np.pi) - nd.log(sigma + 1e-8) - (x - mu) ** 2 / (2 * sigma ** 2 + 1e-8)