-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathdataset.py
68 lines (60 loc) · 2.77 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from dgl.data import FraudYelpDataset, FraudAmazonDataset
from dgl.data.utils import load_graphs, save_graphs
import dgl
import numpy as np
import torch
class Dataset:
def __init__(self, name='tfinance', homo=True, anomaly_alpha=None, anomaly_std=None):
self.name = name
graph = None
if name == 'tfinance':
graph, label_dict = load_graphs('dataset/tfinance')
graph = graph[0]
graph.ndata['label'] = graph.ndata['label'].argmax(1)
if anomaly_std:
graph, label_dict = load_graphs('dataset/tfinance')
graph = graph[0]
feat = graph.ndata['feature'].numpy()
anomaly_id = graph.ndata['label'][:,1].nonzero().squeeze(1)
feat = (feat-np.average(feat,0)) / np.std(feat,0)
feat[anomaly_id] = anomaly_std * feat[anomaly_id]
graph.ndata['feature'] = torch.tensor(feat)
graph.ndata['label'] = graph.ndata['label'].argmax(1)
if anomaly_alpha:
graph, label_dict = load_graphs('dataset/tfinance')
graph = graph[0]
feat = graph.ndata['feature'].numpy()
anomaly_id = list(graph.ndata['label'][:, 1].nonzero().squeeze(1))
normal_id = list(graph.ndata['label'][:, 0].nonzero().squeeze(1))
label = graph.ndata['label'].argmax(1)
diff = anomaly_alpha * len(label) - len(anomaly_id)
import random
new_id = random.sample(normal_id, int(diff))
# new_id = random.sample(anomaly_id, int(diff))
for idx in new_id:
aid = random.choice(anomaly_id)
# aid = random.choice(normal_id)
feat[idx] = feat[aid]
label[idx] = 1 # 0
elif name == 'tsocial':
graph, label_dict = load_graphs('dataset/tsocial')
graph = graph[0]
elif name == 'yelp':
dataset = FraudYelpDataset()
graph = dataset[0]
if homo:
graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask'])
graph = dgl.add_self_loop(graph)
elif name == 'amazon':
dataset = FraudAmazonDataset()
graph = dataset[0]
if homo:
graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask'])
graph = dgl.add_self_loop(graph)
else:
print('no such dataset')
exit(1)
graph.ndata['label'] = graph.ndata['label'].long().squeeze(-1)
graph.ndata['feature'] = graph.ndata['feature'].float()
print(graph)
self.graph = graph