forked from kang205/SASRec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
87 lines (73 loc) · 3.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import time
import argparse
import tensorflow as tf
from sampler import WarpSampler
from model import Model
from tqdm import tqdm
from util import *
def str2bool(s):
if s not in {'False', 'True'}:
raise ValueError('Not a valid boolean string')
return s == 'True'
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True)
parser.add_argument('--train_dir', required=True)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=50, type=int)
parser.add_argument('--hidden_units', default=50, type=int)
parser.add_argument('--num_blocks', default=2, type=int)
parser.add_argument('--num_epochs', default=201, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
args = parser.parse_args()
if not os.path.isdir(args.dataset + '_' + args.train_dir):
os.makedirs(args.dataset + '_' + args.train_dir)
with open(os.path.join(args.dataset + '_' + args.train_dir, 'args.txt'), 'w') as f:
f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
f.close()
dataset = data_partition(args.dataset)
[user_train, user_valid, user_test, usernum, itemnum] = dataset
num_batch = len(user_train) / args.batch_size
cc = 0.0
for u in user_train:
cc += len(user_train[u])
print 'average sequence length: %.2f' % (cc / len(user_train))
f = open(os.path.join(args.dataset + '_' + args.train_dir, 'log.txt'), 'w')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)
sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)
model = Model(usernum, itemnum, args)
sess.run(tf.initialize_all_variables())
T = 0.0
t0 = time.time()
try:
for epoch in range(1, args.num_epochs + 1):
for step in tqdm(range(num_batch), total=num_batch, ncols=70, leave=False, unit='b'):
u, seq, pos, neg = sampler.next_batch()
auc, loss, _ = sess.run([model.auc, model.loss, model.train_op],
{model.u: u, model.input_seq: seq, model.pos: pos, model.neg: neg,
model.is_training: True})
if epoch % 20 == 0:
t1 = time.time() - t0
T += t1
print 'Evaluating',
t_test = evaluate(model, dataset, args, sess)
t_valid = evaluate_valid(model, dataset, args, sess)
print ''
print 'epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)' % (
epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1])
f.write(str(t_valid) + ' ' + str(t_test) + '\n')
f.flush()
t0 = time.time()
except:
sampler.close()
f.close()
exit(1)
f.close()
sampler.close()
print("Done")