-
Notifications
You must be signed in to change notification settings - Fork 2
/
proxy_data.py
49 lines (38 loc) · 1.47 KB
/
proxy_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
from PIL import Image
import cv2
import os
import pandas as pd
class Proxy_Data():
def __init__(self, test_transform=None):
super(Proxy_Data, self).__init__()
self.test_transform = test_transform
self.TestData = []
self.TestLabels = []
def concatenate(self, datas, labels):
con_data = datas[0]
con_label = labels[0]
for i in range(1, len(datas)):
con_data = np.concatenate((con_data, datas[i]), axis=0)
con_label = np.concatenate((con_label, labels[i]), axis=0)
return con_data, con_label
def getTestData(self, new_set, new_set_label):
datas, labels = [], []
self.TestData, self.TestLabels = [], []
if len(new_set) != 0 and len(new_set_label) != 0:
datas = [exemplar for exemplar in new_set]
for i in range(len(new_set)):
length = len(datas[i])
labels.append(np.full((length), new_set_label[i]))
self.TestData, self.TestLabels = self.concatenate(datas, labels)
def getTestItem(self, index):
img, target = Image.fromarray(self.TestData[index]), self.TestLabels[index]
if self.test_transform:
img = self.test_transform(img)
return img, target
def __getitem__(self, index):
if self.TestData != []:
return self.getTestItem(index)
def __len__(self):
if self.TestData != []:
return self.TestData.shape[0]