-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
72 lines (49 loc) · 1.68 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from abc import ABC
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from data_loader import load_data
class KGDataset(Dataset, ABC):
def __init__(self, data):
self.data = data
self.length = len(self.data)
def __len__(self):
return self.length
def __getitem__(self, item):
return self.data[item].tolist()
@staticmethod
def collate_fn(batch):
entity_id = [int(b[0]) for b in batch]
relation_id = [b[1] for b in batch]
object_id = [b[2] for b in batch]
return {
'entity_id': torch.tensor(entity_id, dtype=torch.int64),
'relation_id': torch.tensor(relation_id, dtype=torch.int64),
'object_id': torch.tensor(object_id, dtype=torch.int64),
}
class RatDataset(Dataset, ABC):
def __init__(self, data):
self.data = data
self.length = len(self.data)
def __len__(self):
return self.length
def __getitem__(self, item):
return self.data[item].tolist()
@staticmethod
def collate_fn(batch):
user_id = [int(b[0]) for b in batch]
entity_id = [int(b[1]) for b in batch]
score = [int(b[2]) for b in batch]
return {
'user_id': torch.tensor(user_id, dtype=torch.int64),
'entity_id': torch.tensor(entity_id, dtype=torch.int64),
'score': torch.tensor(score, dtype=torch.int64),
}
def main():
data = load_data()
n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
train_data, eval_data, test_data = data[4], data[5], data[6]
kg = data[7]
dataset = KGDataset(kg)
if __name__ == '__main__':
main()