Skip to content

Commit 51f65f0

Browse files
committed
Added SWA to test
1 parent 760fcc1 commit 51f65f0

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ mlflow run experiments/ --experiment-name=CIFAR10 -P dataset=CIFAR10 -P network=
7777
```
7878
export MLFLOW_TRACKING_URI=$OUTPUT_PATH/mlruns
7979
80-
mlflow run experiments/ --experiment-name=CIFAR10 -P dataset=CIFAR10 -P network=wideresnet -P params="data_path=../input/cifar10;num_epochs=6250;learning_rate=0.03;batch_size=64;TSA_proba_min=0.1;unlabelled_batch_size=320;"
80+
mlflow run experiments/ --experiment-name=CIFAR10 -P dataset=CIFAR10 -P network=wideresnet -P params="data_path=../input/cifar10;num_epochs=6250;learning_rate=0.03;batch_size=64;TSA_proba_min=0.1;unlabelled_batch_size=320;num_warmup_steps=20000"
8181
```
8282

8383
Unfortunately, I can not reproduce paper's result with 5.3 test error.

code/main.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import torch.optim as optim
1717
from torch.optim.lr_scheduler import CosineAnnealingLR
1818

19+
import torchcontrib
20+
1921
import ignite
2022
from ignite.engine import Events, Engine, create_supervised_evaluator
2123
from ignite.metrics import Accuracy, Loss, RunningAverage
@@ -53,7 +55,11 @@ def run(output_path, config):
5355
momentum=config['momentum'],
5456
weight_decay=config['weight_decay'],
5557
nesterov=True)
56-
58+
59+
with_SWA = config['with_SWA']
60+
if with_SWA:
61+
optimizer = torchcontrib.optim.SWA(optimizer)
62+
5763
criterion = nn.CrossEntropyLoss().to(device)
5864
if config['consistency_criterion'] == "MSE":
5965
consistency_criterion = nn.MSELoss()
@@ -177,6 +183,17 @@ def log_learning_rate(engine):
177183
lr = optimizer.param_groups[0]['lr']
178184
mlflow.log_metric("learning rate", lr, step=step)
179185

186+
if with_SWA:
187+
@trainer.on(Events.COMPLETED)
188+
def swap_swa_sgd(engine):
189+
optimizer.swap_swa_sgd()
190+
optimizer.bn_update(train_labelled_loader, model)
191+
192+
@trainer.on(Events.EPOCH_COMPLETED)
193+
def update_swa(engine):
194+
if engine.state.epoch - 1 > int(num_epochs * 0.75):
195+
optimizer.update_swa()
196+
180197
metric_names = [
181198
'supervised batch loss',
182199
'consistency batch loss',
@@ -301,6 +318,8 @@ def mlflow_val_metrics_logging(engine, tag):
301318
"TSA_proba_max": 1.0,
302319

303320
"no_UDA": False, # disable UDA training
321+
322+
"with_SWA": False,
304323
}
305324

306325
# Override config:

experiments/conda.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ dependencies:
1414
- tensorboardX
1515
- nested_dict
1616
- git+https://github.com/pytorch/ignite.git
17+
- torchcontrib

0 commit comments

Comments
 (0)