forked from kang205/SASRec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
112 lines (94 loc) · 5.31 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from modules import *
class Model():
def __init__(self, usernum, itemnum, args, reuse=None):
self.is_training = tf.placeholder(tf.bool, shape=())
self.u = tf.placeholder(tf.int32, shape=(None))
self.input_seq = tf.placeholder(tf.int32, shape=(None, args.maxlen))
self.pos = tf.placeholder(tf.int32, shape=(None, args.maxlen))
self.neg = tf.placeholder(tf.int32, shape=(None, args.maxlen))
pos = self.pos
neg = self.neg
mask = tf.expand_dims(tf.to_float(tf.not_equal(self.input_seq, 0)), -1)
with tf.variable_scope("SASRec", reuse=reuse):
# sequence embedding, item embedding table
self.seq, item_emb_table = embedding(self.input_seq,
vocab_size=itemnum + 1,
num_units=args.hidden_units,
zero_pad=True,
scale=True,
l2_reg=args.l2_emb,
scope="input_embeddings",
with_t=True,
reuse=reuse
)
# Positional Encoding
t, pos_emb_table = embedding(
tf.tile(tf.expand_dims(tf.range(tf.shape(self.input_seq)[1]), 0), [tf.shape(self.input_seq)[0], 1]),
vocab_size=args.maxlen,
num_units=args.hidden_units,
zero_pad=False,
scale=False,
l2_reg=args.l2_emb,
scope="dec_pos",
reuse=reuse,
with_t=True
)
self.seq += t
# Dropout
self.seq = tf.layers.dropout(self.seq,
rate=args.dropout_rate,
training=tf.convert_to_tensor(self.is_training))
self.seq *= mask
# Build blocks
for i in range(args.num_blocks):
with tf.variable_scope("num_blocks_%d" % i):
# Self-attention
self.seq = multihead_attention(queries=normalize(self.seq),
keys=self.seq,
num_units=args.hidden_units,
num_heads=args.num_heads,
dropout_rate=args.dropout_rate,
is_training=self.is_training,
causality=True,
scope="self_attention")
# Feed forward
self.seq = feedforward(normalize(self.seq), num_units=[args.hidden_units, args.hidden_units],
dropout_rate=args.dropout_rate, is_training=self.is_training)
self.seq *= mask
self.seq = normalize(self.seq)
pos = tf.reshape(pos, [tf.shape(self.input_seq)[0] * args.maxlen])
neg = tf.reshape(neg, [tf.shape(self.input_seq)[0] * args.maxlen])
pos_emb = tf.nn.embedding_lookup(item_emb_table, pos)
neg_emb = tf.nn.embedding_lookup(item_emb_table, neg)
seq_emb = tf.reshape(self.seq, [tf.shape(self.input_seq)[0] * args.maxlen, args.hidden_units])
self.test_item = tf.placeholder(tf.int32, shape=(101))
test_item_emb = tf.nn.embedding_lookup(item_emb_table, self.test_item)
self.test_logits = tf.matmul(seq_emb, tf.transpose(test_item_emb))
self.test_logits = tf.reshape(self.test_logits, [tf.shape(self.input_seq)[0], args.maxlen, 101])
self.test_logits = self.test_logits[:, -1, :]
# prediction layer
self.pos_logits = tf.reduce_sum(pos_emb * seq_emb, -1)
self.neg_logits = tf.reduce_sum(neg_emb * seq_emb, -1)
# ignore padding items (0)
istarget = tf.reshape(tf.to_float(tf.not_equal(pos, 0)), [tf.shape(self.input_seq)[0] * args.maxlen])
self.loss = tf.reduce_sum(
- tf.log(tf.sigmoid(self.pos_logits) + 1e-24) * istarget -
tf.log(1 - tf.sigmoid(self.neg_logits) + 1e-24) * istarget
) / tf.reduce_sum(istarget)
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
self.loss += sum(reg_losses)
tf.summary.scalar('loss', self.loss)
self.auc = tf.reduce_sum(
((tf.sign(self.pos_logits - self.neg_logits) + 1) / 2) * istarget
) / tf.reduce_sum(istarget)
if reuse is None:
tf.summary.scalar('auc', self.auc)
self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.optimizer = tf.train.AdamOptimizer(learning_rate=args.lr, beta2=0.98)
self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
else:
tf.summary.scalar('test_auc', self.auc)
self.merged = tf.summary.merge_all()
def predict(self, sess, u, seq, item_idx):
return sess.run(self.test_logits,
{self.u: u, self.input_seq: seq, self.test_item: item_idx, self.is_training: False})