-
Notifications
You must be signed in to change notification settings - Fork 110
/
Copy pathattention.py
95 lines (68 loc) · 3.45 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import tensorflow as tf
__all__ = [
"positional_encoding", "Attention"
]
def positional_encoding(dim, sentence_length, dtype=tf.float32):
encoded_vec = np.array([pos/np.power(10000, 2*i/dim) for pos in range(sentence_length) for i in range(dim)])
encoded_vec[::2] = np.sin(encoded_vec[::2])
encoded_vec[1::2] = np.cos(encoded_vec[1::2])
return tf.convert_to_tensor(encoded_vec.reshape([sentence_length, dim]), dtype=dtype)
class Attention:
"""Attention class"""
def __init__(self,
num_heads=1,
masked=False,
linear_key_dim=50,
linear_value_dim=50,
model_dim=100,
dropout=0.2):
assert linear_key_dim % num_heads == 0
assert linear_value_dim % num_heads == 0
self.num_heads = num_heads
self.masked = masked
self.linear_key_dim = linear_key_dim
self.linear_value_dim = linear_value_dim
self.model_dim = model_dim
self.dropout = dropout
def multi_head(self, q, k, v):
q, k, v = self._linear_projection(q, k, v)
qs, ks, vs = self._split_heads(q, k, v)
outputs = self._scaled_dot_product(qs, ks, vs)
output = self._concat_heads(outputs)
output = tf.layers.dense(output, self.model_dim)
return tf.nn.dropout(output, 1.0 - self.dropout)
def _linear_projection(self, q, k, v):
q = tf.layers.dense(q, self.linear_key_dim, use_bias=False)
k = tf.layers.dense(k, self.linear_key_dim, use_bias=False)
v = tf.layers.dense(v, self.linear_value_dim, use_bias=False)
return q, k, v
def _split_heads(self, q, k, v):
def split_last_dimension_then_transpose(tensor, num_heads, dim):
t_shape = tensor.get_shape().as_list()
tensor = tf.reshape(tensor, [-1] + t_shape[1:-1] + [num_heads, dim // num_heads])
return tf.transpose(tensor, [0, 2, 1, 3]) # [batch_size, num_heads, max_seq_len, dim]
qs = split_last_dimension_then_transpose(q, self.num_heads, self.linear_key_dim)
ks = split_last_dimension_then_transpose(k, self.num_heads, self.linear_key_dim)
vs = split_last_dimension_then_transpose(v, self.num_heads, self.linear_value_dim)
return qs, ks, vs
def _scaled_dot_product(self, qs, ks, vs):
key_dim_per_head = self.linear_key_dim // self.num_heads
o1 = tf.matmul(qs, ks, transpose_b=True)
o2 = o1 / (key_dim_per_head**0.5)
if self.masked:
diag_vals = tf.ones_like(o2[0, 0, :, :]) # (batch_size, num_heads, query_dim, key_dim)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (q_dim, k_dim)
masks = tf.tile(tf.reshape(tril, [1, 1] + tril.get_shape().as_list()),
[tf.shape(o2)[0], tf.shape(o2)[1], 1, 1])
paddings = tf.ones_like(masks) * -1e9
o2 = tf.where(tf.equal(masks, 0), paddings, o2)
o3 = tf.nn.softmax(o2)
return tf.matmul(o3, vs)
def _concat_heads(self, outputs):
def transpose_then_concat_last_two_dimenstion(tensor):
tensor = tf.transpose(tensor, [0, 2, 1, 3]) # [batch_size, max_seq_len, num_heads, dim]
t_shape = tensor.get_shape().as_list()
num_heads, dim = t_shape[-2:]
return tf.reshape(tensor, [-1] + t_shape[1:-2] + [num_heads * dim])
return transpose_then_concat_last_two_dimenstion(outputs)