-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_voting.py
125 lines (104 loc) · 4.61 KB
/
eval_voting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import argparse
from PIL import Image
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from collections import Counter
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from utils.core50_data_loader import CORE50
from utils.toolkit import accuracy_binary, accuracy_domain, accuracy_core50
def setup_parser():
parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
parser.add_argument('--resume', type=str, default='', help='resume model')
parser.add_argument('--dataroot', type=str, default='/home/wangyabin/workspace/DeepFake_Data/CL_data/', help='data path')
parser.add_argument('--datatype', type=str, default='core50', help='data type')
return parser
class DummyDataset(Dataset):
def __init__(self, data_path, data_type):
self.trsf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
images = []
labels = []
if data_type == "deepfake":
subsets = ["gaugan", "biggan", "wild", "whichfaceisreal", "san"]
multiclass = [0,0,0,0,0]
for id, name in enumerate(subsets):
root_ = os.path.join(data_path, name, 'val')
# sub_classes = ['']
sub_classes = os.listdir(root_) if multiclass[id] else ['']
for cls in sub_classes:
for imgname in os.listdir(os.path.join(root_, cls, '0_real')):
images.append(os.path.join(root_, cls, '0_real', imgname))
labels.append(0 + 2 * id)
for imgname in os.listdir(os.path.join(root_, cls, '1_fake')):
images.append(os.path.join(root_, cls, '1_fake', imgname))
labels.append(1 + 2 * id)
elif data_type == "domainnet":
self.data_root = data_path
self.image_list_root = self.data_root
self.domain_names = ["clipart","infograph","painting","quickdraw", "real","sketch",]
image_list_paths = [os.path.join(self.image_list_root, d + "_" + "test" + ".txt") for d in self.domain_names]
imgs = []
for taskid, image_list_path in enumerate(image_list_paths):
image_list = open(image_list_path).readlines()
imgs += [(val.split()[0], int(val.split()[1])+taskid*345) for val in image_list]
for item in imgs:
images.append(os.path.join(self.data_root, item[0]))
labels.append(item[1])
elif data_type == "core50":
self.dataset_generator = CORE50(root=data_path, scenario="ni")
images, labels = self.dataset_generator.get_test_set()
labels = labels.tolist()
else:
pass
assert len(images) == len(labels), 'Data size error!'
self.images = images
self.labels = labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.trsf(self.pil_loader(self.images[idx]))
label = self.labels[idx]
return idx, image, label
def pil_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
args = setup_parser().parse_args()
model = torch.load(args.resume)
device = "cuda:0"
model = model.to(device)
test_dataset = DummyDataset(args.dataroot, args.datatype)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=8)
y_pred, y_true = [], []
for _, (path, inputs, targets) in enumerate(test_loader):
inputs = inputs.to(device)
targets = targets.to(device)
with torch.no_grad():
# for ii in
preds = []
for ii in range(len(model.all_keys)):
selection = torch.ones_like(targets).to(device) * ii
outputs = model.interface(inputs, selection)
preds.append(outputs.max(1)[1])
preds = torch.stack(preds).T
predicts = torch.mode(preds, dim=1, keepdim=False)[0]
import pdb;pdb.set_trace()
y_pred.append(predicts.cpu().numpy())
y_true.append(targets.cpu().numpy())
y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)
# import pdb;pdb.set_trace()
if args.datatype == 'deepfake':
print(accuracy_binary(y_pred.T, y_true))
elif args.datatype == 'domainnet':
print(accuracy_domain(y_pred.T, y_true))
elif args.datatype == 'core50':
print(accuracy_core50(y_pred.T, y_true))