-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathschedules.py
85 lines (81 loc) · 2.98 KB
/
schedules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def schedule_MNIST(args):
schedule = dict([])
schedule['epochs'] = 420
schedule['optimizer'] = 'SGDM'
schedule['start_lr'] = 0.005/args.batch_size
schedule['lr_decay_epochs'] = [50, 100, 200, 300, 350]
schedule['lr_decay_factor'] = 0.2
schedule['kappa'] = 0.3
schedule['out_start_epoch'] = 10000
schedule['kappa_epoch_ramp'] = 0
schedule['eps'] = 0.3
schedule['eps_start_epoch'] = 10000
schedule['eps_epoch_ramp'] = 0
if args.method == 'plain':
pass
if args.method in {'OE', 'CEDA', 'GOOD'}:
schedule['out_start_epoch'] = 2
schedule['kappa_epoch_ramp'] = 100
if args.method in {'GOOD'}:
schedule['eps_start_epoch'] = 10
schedule['eps_epoch_ramp'] = 120
if args.acet: #might need adjustment if not run with --method plain
schedule['acet_n'] = 40
schedule['out_start_epoch'] = 2
schedule['kappa_epoch_ramp'] = 0
schedule['eps_start_epoch'] = 6
schedule['eps_epoch_ramp'] = 0
return schedule
def schedule_CIFAR10(args):
schedule = dict([])
schedule['epochs'] = 420
schedule['optimizer'] = 'ADAM'
schedule['start_lr'] = 0.1/args.batch_size
schedule['lr_decay_epochs'] = [30, 100]
schedule['lr_decay_factor'] = 0.2
schedule['kappa'] = 1.0
schedule['out_start_epoch'] = 10000
schedule['kappa_epoch_ramp'] = 0
schedule['eps'] = 0.01
schedule['eps_start_epoch'] = 10000
schedule['eps_epoch_ramp'] = 0
if args.method == 'plain':
pass
if args.method in {'OE', 'CEDA'}:
schedule['out_start_epoch'] = 60
schedule['kappa_epoch_ramp'] = 300
if args.method in {'GOOD'}:
assert args.pretrained
schedule['epochs'] = 900
schedule['start_lr'] = 1.28e-2/args.batch_size
schedule['lr_decay_epochs'] = [450, 750, 850]
schedule['out_start_epoch'] = -2
schedule['kappa_epoch_ramp'] = 300
schedule['eps_start_epoch'] = 4
schedule['eps_epoch_ramp'] = 200
if args.acet: #might need adjustment if not run with --method plain
schedule['acet_n'] = 40
schedule['out_start_epoch'] = 2
schedule['kappa_epoch_ramp'] = 0
schedule['eps_start_epoch'] = 6
schedule['eps_epoch_ramp'] = 0
return schedule
def default_schedule(args):
if args.dset_in_name == 'MNIST':
return schedule_MNIST(args)
if args.dset_in_name == 'CIFAR10':
return schedule_CIFAR10(args)
if args.dset_in_name == 'SVHN':
schedule = schedule_CIFAR10(args)
if not args.pretrained:
schedule['start_lr'] = 0.01/args.batch_size
schedule['eps'] = 0.03
return schedule
def customized_schedule(schedule, args):
for setting in ['start_lr', 'optimizer', 'epochs', 'kappa', 'eps']:
set_value = getattr(args, setting)
if set_value is not None:
schedule[setting] = set_value
return schedule
def schedule(args):
return customized_schedule(default_schedule(args), args)