-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
91 lines (76 loc) · 4.14 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import distutils.version
use_tf100_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('1.0.0')
def normalized_columns_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def flatten(x):
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None):
with tf.variable_scope(name):
stride_shape = [1, stride[0], stride[1], 1]
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
# there are "num input feature maps * filter height * filter width"
# inputs to each hidden unit
fan_in = np.prod(filter_shape[:3])
# each unit in the lower layer receives a gradient from:
# "num output feature maps * filter height * filter width" /
# pooling size
fan_out = np.prod(filter_shape[:2]) * num_filters
# initialize weights with random weights
w_bound = np.sqrt(6. / (fan_in + fan_out))
w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
collections=collections)
b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.constant_initializer(0.0),
collections=collections)
return tf.nn.conv2d(x, w, stride_shape, pad) + b
def linear(x, size, name, initializer=None, bias_init=0):
w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer)
b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(bias_init))
return tf.matmul(x, w) + b
def categorical_sample(logits, d):
value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], keep_dims=True), 1), [1])
return tf.one_hot(value, d)
class GRUPolicy(object):
"""
This class builds up the network. There are 4 convolutional layers followed by a recurrent layer making use
of GRU cells (http://arxiv.org/abs/1406.1078) inspired by louiehelm (https://github.com/louiehelm/universe-starter-agent/tree/gru).
Every convolutional layer features 32 filters of the size 3x3 that are using a stride size of 2x2.
There are 256 GRU cells in the RNN layer.
"""
def __init__(self, ob_space, ac_space):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# introduce a "fake" batch dimension of 1 after flatten so that we can do GRU over time dim
x = tf.expand_dims(flatten(x), 1)
# perform a dropout of 10%
x = tf.nn.dropout(x, 0.9)
size = 256
gru = rnn.GRUCell(size)
h_init = np.zeros((1, size), np.float32)
self.state_init = [h_init]
h_in = tf.placeholder(tf.float32, [1, size])
self.state_in = [h_in]
gru_outputs, gru_state = tf.nn.dynamic_rnn(
gru, x, initial_state=h_in, sequence_length=[size], time_major=True)
x = tf.reshape(gru_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [gru_state[:1]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def get_initial_features(self):
return self.state_init
def act(self, ob, h):
sess = tf.get_default_session()
return sess.run([self.sample, self.vf] + self.state_out,
{self.x: [ob], self.state_in[0]: h})
def value(self, ob, h):
sess = tf.get_default_session()
return sess.run(self.vf, {self.x: [ob], self.state_in[0]: h})[0]