-
Notifications
You must be signed in to change notification settings - Fork 4
/
mask_generate.py
62 lines (48 loc) · 1.79 KB
/
mask_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pickle
import os
import numpy as np
filepath = './data/yelp'
num_user = 0
max_length = 30
def generate_visited_dict(data):
data_dict = {}
for i in range(data.shape[0]):
uid = data[i,0]
iid = data[i,1]
if uid in data_dict:
data_dict[uid].append(iid)
else:
data_dict[uid] = [iid]
return data_dict
def generate_pop_dict(data):
data_dict = {}
for i in range(data.shape[0]):
iid = data[i,1]
if iid in data_dict:
data_dict[iid] += 1
else:
data_dict[iid] = 1
return data_dict
def generate_visited_matrix(data_dict, pop_dict, num_user, max_length):
visited_matrix = np.zeros((num_user, max_length))
mask_matrix = np.zeros((num_user, max_length))
for k, v in data_dict.items():
pop_list = []
for each in v:
pop_list.append([each, pop_dict[each]])
pop_list = sorted(pop_list, key=lambda x:x[1], reverse=True)
item_list = []
for i in range(min(max_length, len(v))):
item_list.append(pop_list[i][0])
for i in range(max_length):
if i < len(item_list):
visited_matrix[k, i] = item_list[i]
else:
mask_matrix[k, i] = 1
return visited_matrix, mask_matrix
data_all = pickle.load(open(os.path.join(filepath, 'data_all.pkl'), 'rb'))
data_dict = generate_visited_dict(data_all[0])
pop_dict = generate_pop_dict(data_all[0])
visited_matrix, mask_matrix = generate_visited_matrix(data_dict, pop_dict, num_user, max_length)
visited_and_mask_matrix = (visited_matrix, mask_matrix)
pickle.dump(visited_and_mask_matrix, open(os.path.join(filepath, 'visited_and_mask_matrix_%d.pkl' % max_length), 'wb'))