-
Notifications
You must be signed in to change notification settings - Fork 5
/
search_alpha.py
128 lines (106 loc) · 4.76 KB
/
search_alpha.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
'''
# @ Author: Yichao Cai
# @ Create Time: 2024-02-07 13:40:23
# @ Description:
'''
import os
import numpy as np
import argparse
import torch
import clip
import os.path as osp
import torch.utils.data as tud
from main.networks import DisentangledNetwork
from main.evaluate import eval_zero_shot
from utils.misc import Args, set_manual_seed, load_property
from utils.data_utils import MultiEnvDataset
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
parser = argparse.ArgumentParser()
parser.add_argument("config", help="The path of config file for evaluation.")
args = parser.parse_args()
# loading configs
configs = Args(args.config)
configs.set_device("cuda" if torch.cuda.is_available() else "cpu")
# print(configs)
# mannual seed
if configs.manual_seed:
set_manual_seed(configs.manual_seed)
classes = {}
# classes["PACS"] = load_property(r"data/classes/PACS.yaml")
# classes["VLCS"] = load_property(r"data/classes/VLCS.yaml")
# classes["OfficeHome"] = load_property(r"data/classes/OfficeHome.yaml")
# classes["DomainNet"] = load_property(r"data/classes/DomainNet.yaml")
class_names_path = r"data/classes/"
for class_file in os.listdir(class_names_path):
classes[class_file[:-5]] = load_property(osp.join(class_names_path, class_file))
clip_model, preprocess = clip.load(configs.clip_name, device=configs.device)
clip_dim = clip_model.text_projection.shape[1] # CLIP's representation dimension
def one_eval_get_zspc(scale):
global configs
global clip_model
global preprocess
global clip_dim
network = DisentangledNetwork(in_dim = clip_dim, latent_dim = configs.latent_dim, out_dim=configs.out_dim,
activation = configs.activation, which_network=configs.which_network,
repeat=configs.repeat, scale=scale)
network = network.to(configs.device)
ckpt_name = None
for file_name in os.listdir(configs.ckpt_path):
if file_name.endswith(".pth"):
ckpt_name = file_name
break
dataset = "PACS"
state_dict = torch.load(osp.join(configs.ckpt_path, ckpt_name), map_location=configs.device)
network.load_state_dict(state_dict=state_dict, strict=True)
network.eval()
cls_nums = len(classes[dataset])
mean_acc_c = []
mean_acc_pc = []
mean_acc_cp = []
for env in configs.eval_sets[dataset]:
evalset = MultiEnvDataset(osp.join("data/datasets", dataset),
test_env=env, transform=preprocess)
loader = tud.DataLoader(evalset, batch_size=128, shuffle=True, num_workers=24)
_, acc_zs = eval_zero_shot(clip_model, network, None, loader, evalset.prompts,
device=configs.device, adver=None, cls_nums=cls_nums,
eval_clip=False)
mean_acc_c.append(acc_zs[0])
mean_acc_cp.append(acc_zs[1])
mean_acc_pc.append(acc_zs[2])
return np.mean(mean_acc_c), np.mean(mean_acc_cp), np.mean(mean_acc_pc)
def binary_search(scale_left, scale_right, next_infer="all", acc_optimals=None):
if next_infer == "all":
acc_left_c, acc_left_cp, acc_left_pc = one_eval_get_zspc(scale_left)
acc_right_c, acc_right_cp, acc_right_pc = one_eval_get_zspc(scale_right)
elif next_infer == "left":
acc_left_c, acc_left_cp, acc_left_pc = one_eval_get_zspc(scale_left)
acc_right_c, acc_right_cp, acc_right_pc = acc_optimals
else:
acc_left_c, acc_left_cp, acc_left_pc = acc_optimals
acc_right_c, acc_right_cp, acc_right_pc = one_eval_get_zspc(scale_right)
# comparison target
acc_left = acc_left_pc
acc_right = acc_right_pc
if acc_left < acc_right:
scale_optimal = scale_right
acc_optimals = (acc_right_c, acc_right_cp, acc_right_pc)
scale_left = 0.5 * (scale_right + scale_left)
next_infer = "left"
else:
scale_optimal = scale_left
acc_optimals = (acc_left_c, acc_left_cp, acc_left_pc)
scale_right = 0.5 * (scale_right + scale_left)
next_infer = "right"
print(f"\n scale: {scale_optimal:.3f}, acc_c: {acc_optimals[0]:.2f}, acc_cp: {acc_optimals[1]:.2f}, acc_pc: {acc_optimals[2]:.2f}")
print(f"next scale left: {scale_left:.3f}, next scale right: {scale_right:.3f}\n")
return scale_left, scale_right, next_infer, acc_optimals
print(osp.join(configs.ckpt_path))
scale_list = [np.power(10, -1.5), np.power(10, -1.), np.power(10, -0.5), np.power(10, 0.), np.power(10, 1.), np.power(10, 1.5), np.power(10, 2.)]
acc_list = []
for scale in scale_list:
acc_c, acc_cp, acc_pc = one_eval_get_zspc(scale)
print(f"\n============\n scale: {scale:.4f}, acc_c: {acc_c:.2f}, acc_cp: {acc_cp:.2f}, acc_pc: {acc_pc:.2f}")
acc_list.append(acc_pc)
index = np.argmax(acc_list)
print(index)