-
Notifications
You must be signed in to change notification settings - Fork 27
/
data_utils.py
78 lines (62 loc) · 2.09 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch.utils.data as data
import config
def load_all(test_num=100):
""" We load all the three file here to save time in each epoch. """
train_data = pd.read_csv(
config.train_rating,
sep='\t', header=None, names=['user', 'item'],
usecols=[0, 1], dtype={0: np.int32, 1: np.int32})
user_num = train_data['user'].max() + 1
item_num = train_data['item'].max() + 1
train_data = train_data.values.tolist()
# load ratings as a dok matrix
train_mat = sp.dok_matrix((user_num, item_num), dtype=np.float32)
for x in train_data:
train_mat[x[0], x[1]] = 1.0
test_data = []
with open(config.test_negative, 'r') as fd:
line = fd.readline()
while line != None and line != '':
arr = line.split('\t')
u = eval(arr[0])[0]
test_data.append([u, eval(arr[0])[1]])
for i in arr[1:]:
test_data.append([u, int(i)])
line = fd.readline()
return train_data, test_data, user_num, item_num, train_mat
class BPRData(data.Dataset):
def __init__(self, features,
num_item, train_mat=None, num_ng=0, is_training=None):
super(BPRData, self).__init__()
""" Note that the labels are only useful when training, we thus
add them in the ng_sample() function.
"""
self.features = features
self.num_item = num_item
self.train_mat = train_mat
self.num_ng = num_ng
self.is_training = is_training
def ng_sample(self):
assert self.is_training, 'no need to sampling when testing'
self.features_fill = []
for x in self.features:
u, i = x[0], x[1]
for t in range(self.num_ng):
j = np.random.randint(self.num_item)
while (u, j) in self.train_mat:
j = np.random.randint(self.num_item)
self.features_fill.append([u, i, j])
def __len__(self):
return self.num_ng * len(self.features) if \
self.is_training else len(self.features)
def __getitem__(self, idx):
features = self.features_fill if \
self.is_training else self.features
user = features[idx][0]
item_i = features[idx][1]
item_j = features[idx][2] if \
self.is_training else features[idx][1]
return user, item_i, item_j