-
Notifications
You must be signed in to change notification settings - Fork 0
/
tools.py
95 lines (80 loc) · 2.71 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import pickle as p
from PIL import Image
import torch
import argparse
from config.defaults import update_config,_C as cfg
from trainer.build_trainer import build_trainer
import random
import os
import torch.backends.cudnn as cudnn
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb')as f:
datadict = p.load(f, encoding='bytes')
# 以字典的形式取出数据
X = datadict[b'data']
Y = datadict[b'labels']
X = X.reshape(10000, 3, 32, 32)
Y = np.array(Y)
print(Y.shape)
return X, Y
def save_img():
imgX, imgY = load_CIFAR_batch("/home/aa/xlhuang/Openset-LT-SSL/data/cifar10/cifar-10-batches-py/data_batch_1")
for i in range(imgX.shape[0]):
imgs = imgX[i]
img0 = imgs[0]
img1 = imgs[1]
img2 = imgs[2]
i0 = Image.fromarray(img0)
i1 = Image.fromarray(img1)
i2 = Image.fromarray(img2)
img = Image.merge("RGB",(i0,i1,i2))
name = "img" + str(i)+".png"
if not os.path.exists("./data/"+str(imgY[i])):
os.mkdir("./data/"+str(imgY[i]))
img.save("./data/"+str(imgY[i])+"/"+name,"png")
print("save successfully!")
def parse_args():
parser = argparse.ArgumentParser(description="codes for BBN")
parser.add_argument(
"--cfg",
help="decide which cfg to use",
required=False,
default="cfg/baseline_cifar10.yaml",
type=str,
)
parser.add_argument(
"opts",
help="modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
return args
# seed=random.randint(1,1000)
seed=7
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
np.random.seed(seed)
args = parse_args()
update_config(cfg, args)
IF=cfg.DATASET.IFS# 10,50,100
ood_r=cfg.DATASET.OODRS # basline不用mixup的话不用考虑r 0.0,0.25, 0.5, 0.75,1.0 randomsampler+classreversedsampler没有用到mixup
for if_ in IF: # if
# 同分布
for r in ood_r:
cfg.defrost()
cfg.DATASET.DL.IMB_FACTOR_L=if_
cfg.DATASET.DU.ID.IMB_FACTOR_UL=if_
cfg.SEED=seed
cfg.DATASET.DU.OOD.RATIO=r
print("*************{} IF {} R {} begin *************".format(cfg.DATASET.NAME,if_,r))
cfg.freeze()
trainer=build_trainer(cfg)
trainer.train()
print("*************{} IF {} R {} end *************".format(cfg.DATASET.NAME,if_,r))