generated from benchopt/template_benchmark
-
Notifications
You must be signed in to change notification settings - Fork 4
/
torch_solver.py
150 lines (129 loc) · 4.41 KB
/
torch_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import sys
from benchopt import BaseSolver, safe_import_context
from benchopt.stopping_criterion import SufficientProgressCriterion
with safe_import_context() as import_ctx:
import joblib
import torch
from torchvision import transforms
from tqdm import tqdm
AugmentedDataset = import_ctx.import_from(
'lightning_helper', 'AugmentedDataset'
)
class TorchSolver(BaseSolver):
"""Pytorch base solver"""
stopping_criterion = SufficientProgressCriterion(
patience=60, strategy='callback'
)
parameters = {
'batch_size': [128],
'data_aug': [False, True],
'lr_schedule': [None, 'step', 'cosine'],
'steps': [[1/2, 3/4]],
'gamma': [0.1],
}
def skip(
self,
model_init_fn,
dataset,
normalization,
framework,
symmetry,
image_width,
):
if framework != 'pytorch':
return True, 'Not a torch dataset/objective'
coupled_wd = getattr(self, 'coupled_weight_decay', 0.0)
decoupled_wd = getattr(self, 'decoupled_weight_decay', 0.0)
if coupled_wd and decoupled_wd:
return True, 'Cannot use both decoupled and coupled weight decay'
return False, None
def set_objective(
self,
model_init_fn,
dataset,
normalization,
framework,
symmetry,
image_width,
):
self.dataset = dataset
self.model_init_fn = model_init_fn
self.normalization = normalization
self.framework = framework
self.symmetry = symmetry
self.image_width = image_width
if self.data_aug:
data_aug_list = [
transforms.RandomCrop(self.image_width, padding=4),
]
if self.symmetry is not None and 'horizontal' in self.symmetry:
data_aug_list.append(transforms.RandomHorizontalFlip())
data_aug_transform = transforms.Compose(data_aug_list)
else:
data_aug_transform = None
self.dataset = AugmentedDataset(
self.dataset,
data_aug_transform,
self.normalization,
)
# TODO: num_worker should not be hard coded. Finding a sensible way to
# set this value is necessary here.
system = os.environ.get('RUNNER_OS', sys.platform)
is_mac = system in ['darwin', 'macOS']
num_workers = min(10, joblib.cpu_count()) if not is_mac else 0
persistent_workers = num_workers > 0
self.dataloader = torch.utils.data.DataLoader(
self.dataset, batch_size=self.batch_size,
num_workers=num_workers,
persistent_workers=persistent_workers,
pin_memory=True, shuffle=True
)
def set_lr_schedule_and_optimizer(self, model, max_epochs=200):
optimizer = self.optimizer_klass(
model.parameters(),
**self.optimizer_kwargs,
)
if self.lr_schedule == 'step':
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[int(max_epochs*s) for s in self.steps],
gamma=self.gamma,
)
elif self.lr_schedule == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=max_epochs,
)
else:
class NoOpScheduler:
def step(self):
...
scheduler = NoOpScheduler()
return optimizer, scheduler
@staticmethod
def get_next(stop_val):
return stop_val + 1
def run(self, callback):
# model weight initialization
model = self.model_init_fn()
criterion = torch.nn.CrossEntropyLoss()
# optimizer and lr schedule init
max_epochs = callback.stopping_criterion.max_runs
optimizer, lr_schedule = self.set_lr_schedule_and_optimizer(
model,
max_epochs,
)
# Initial evaluation
while callback(model):
for X, y in tqdm(self.dataloader):
if torch.cuda.is_available():
X, y = X.cuda(), y.cuda()
optimizer.zero_grad()
loss = criterion(model(X), y)
loss.backward()
optimizer.step()
lr_schedule.step()
self.model = model
def get_result(self):
return self.model