-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathRNN网络搭建
82 lines (74 loc) · 1.92 KB
/
RNN网络搭建
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#rnn网络单元
tf.nn.rnn_cell
tf.keras.layers.RNN
tf.nn.rnn_cell.BasicRNNCell#最基本的rnn单元
(
num_units,
activation=None,#默认tanh,使用tf.nn.relu效果更好
reuse=None,
name=None,
dtype=None,
**kwargs
)
tf.nn.rnn_cell.RNNCell#抽象rnn单元
(
trainable=True,
name=None,
dtype=None,
**kwargs
)
tf.nn.rnn_cell.LSTMCell#LSTM单元
(
num_units,
use_peepholes=False,#是否使用窥视孔
cell_clip=None,#浮点值,对输出进行修剪
initializer=None,
num_proj=None,#整型,投影矩阵的输出维数
proj_clip=None,
num_unit_shards=None,#已被移除
num_proj_shards=None,#已被移除
forget_bias=1.0,
state_is_tuple=True,
activation=None,
reuse=None,
name=None,
dtype=None,
**kwargs
)
tf.nn.rnn_cell.GRUCell#GRU单元
(
num_units,
activation=None,
reuse=None,
kernel_initializer=None,
bias_initializer=None,
name=None,
dtype=None,
**kwargs
)
#多个单元
layers = [tf.nn.rnn_cell.GRUCell(size) for size in hidden_sizes]
cells = tf.nn.rnn_cell.MultiRNNCell(layers)
#搭建rnn
f.nn.dynamic_rnn
tf.nn.bidirectional_dynamic_rnn
#整体过程
hidden_sizes=[]
layers = [tf.nn.rnn_cell.GRUCell(size) for size in hidden_sizes]
cells = tf.nn.rnn_cell.MultiRNNCell(layers)
output, out_state = tf.nn.dynamic_rnn(cell, seq, length, initial_state)
#代码
import tensorflow as tf
batch_size=64
hidden_size=64
max_time=50
input_data=tf.placeholder(tf.float32,shape=(batch_size,max_time,64))
# 创建单元
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 初始化
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state':[batch_size, cell_state_size]
# ‘outputs’:[batch_size, max_time, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)