-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMnistLoadTool.py
102 lines (90 loc) · 3.95 KB
/
MnistLoadTool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# coding: utf-8
"""
@author: Inki
@contact: inki.yinji@qq.com
@version: Created in 2021 2021 0511, last modified in 2021 0512.
@source data: No needing.
"""
import numpy as np
from FunctionTool import mnist_bag_loader, print_progress_bar
class MnistLoader:
def __init__(self, po_label=0, bag_size=(10, 50), po_range=(2, 8), bag_num=(100, 100), seed=None,
mnist_path=None):
"""
:param po_label: The label of positive bag, its range you only can enumerate from $[0..9]$
:param bag_size: The size of bags --> (min, max)
:param po_range: The number of positive instance in positive bag --> (min, max)
:param bag_num: The number of positive bags and negative bag --> (num_positive, num_negative)
:param seed: The seed for sampling.
Note: For the fairness of experiments, you should formulate this parameter.
:param mnist_path: The path of saved Mnist data set.
"""
self.po_label = po_label
self.bag_size = bag_size
self.po_range = po_range
self.bag_num = bag_num
self.seed = seed
self.mnist_path = mnist_path
self.__init_mnist_loader()
def __init_mnist_loader(self):
print("Loading mnist...")
self.data_space = []
self.label_space = []
self.bag_space = []
if self.seed is not None:
np.random.seed(self.seed)
self.data_space, self.label_space = self.__load_data(True)
data_space, label_space = self.__load_data(False)
self.data_space.extend(data_space)
self.label_space.extend(label_space)
self.data_space, self.label_space = np.array(self.data_space), np.array(self.label_space)
self.po_idx = np.where(self.label_space == self.po_label)[0]
self.ot_idx = np.where(self.label_space != self.po_label)[0]
self.__generate_po_bag(self.bag_num[0])
self.__generate_ot_bag(self.bag_num[1])
self.bag_space = np.array(self.bag_space)
def __load_data(self, train):
flag = "train" if train else "test"
print("Loading MNIST %s data..." % flag)
data_loader = mnist_bag_loader(train, self.mnist_path)
num_data = len(data_loader)
ret_data, ret_label = [], []
for i, (data, label) in enumerate(data_loader):
print_progress_bar(i, num_data)
data, label = data.reshape(-1).numpy().tolist(), int(label.numpy()[0])
ret_data.append(data)
ret_label.append(label)
print()
return ret_data, ret_label
def __generate_po_bag(self, bag_num):
print("Generating positive bag...")
for i in range(bag_num):
print_progress_bar(i, bag_num)
bag = []
bag_size = np.random.randint(self.po_range[0], self.po_range[1] + 1)
for j in range(bag_size):
ins = self.data_space[np.random.choice(self.po_idx)].tolist() + [1]
bag.append(ins)
bag_size = np.random.randint(self.bag_size[0] - self.po_range[0], self.bag_size[1] - self.po_range[1] + 1)
for j in range(bag_size):
ins = self.data_space[np.random.choice(self.ot_idx)].tolist() + [0]
bag.append(ins)
bag = np.array(bag)
bag = np.array([bag, np.array([[1]])])
self.bag_space.append(bag)
print()
def __generate_ot_bag(self, bag_num):
print("Generate other class bag...")
for i in range(bag_num):
print_progress_bar(i, bag_num)
bag = []
bag_size = np.random.randint(self.bag_size[0], self.bag_size[1] + 1)
for j in range(bag_size):
ins = self.data_space[np.random.choice(self.ot_idx)].tolist() + [0]
bag.append(ins)
bag = np.array(bag)
bag = np.array([bag, np.array([[0]])])
self.bag_space.append(bag)
print()
if __name__ == '__main__':
ml = MnistLoader(po_label=0, seed=1)