|
16 | 16 | import torch.optim as optim
|
17 | 17 | from torch.optim.lr_scheduler import CosineAnnealingLR
|
18 | 18 |
|
| 19 | +import torchcontrib |
| 20 | + |
19 | 21 | import ignite
|
20 | 22 | from ignite.engine import Events, Engine, create_supervised_evaluator
|
21 | 23 | from ignite.metrics import Accuracy, Loss, RunningAverage
|
@@ -53,7 +55,11 @@ def run(output_path, config):
|
53 | 55 | momentum=config['momentum'],
|
54 | 56 | weight_decay=config['weight_decay'],
|
55 | 57 | nesterov=True)
|
56 |
| - |
| 58 | + |
| 59 | + with_SWA = config['with_SWA'] |
| 60 | + if with_SWA: |
| 61 | + optimizer = torchcontrib.optim.SWA(optimizer) |
| 62 | + |
57 | 63 | criterion = nn.CrossEntropyLoss().to(device)
|
58 | 64 | if config['consistency_criterion'] == "MSE":
|
59 | 65 | consistency_criterion = nn.MSELoss()
|
@@ -177,6 +183,17 @@ def log_learning_rate(engine):
|
177 | 183 | lr = optimizer.param_groups[0]['lr']
|
178 | 184 | mlflow.log_metric("learning rate", lr, step=step)
|
179 | 185 |
|
| 186 | + if with_SWA: |
| 187 | + @trainer.on(Events.COMPLETED) |
| 188 | + def swap_swa_sgd(engine): |
| 189 | + optimizer.swap_swa_sgd() |
| 190 | + optimizer.bn_update(train_labelled_loader, model) |
| 191 | + |
| 192 | + @trainer.on(Events.EPOCH_COMPLETED) |
| 193 | + def update_swa(engine): |
| 194 | + if engine.state.epoch - 1 > int(num_epochs * 0.75): |
| 195 | + optimizer.update_swa() |
| 196 | + |
180 | 197 | metric_names = [
|
181 | 198 | 'supervised batch loss',
|
182 | 199 | 'consistency batch loss',
|
@@ -301,6 +318,8 @@ def mlflow_val_metrics_logging(engine, tag):
|
301 | 318 | "TSA_proba_max": 1.0,
|
302 | 319 |
|
303 | 320 | "no_UDA": False, # disable UDA training
|
| 321 | + |
| 322 | + "with_SWA": False, |
304 | 323 | }
|
305 | 324 |
|
306 | 325 | # Override config:
|
|
0 commit comments