Skip to content

Commit 531c7b6

Browse files
committed
Add ReduceLRonPlateau
1 parent 7d87a63 commit 531c7b6

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

main_bayesian.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
import numpy as np
8-
from torch.optim import Adam
8+
from torch.optim import Adam, lr_scheduler
99
from torch.nn import functional as F
1010

1111
import data
@@ -119,13 +119,14 @@ def run(dataset, net_type):
119119

120120
criterion = metrics.ELBO(len(trainset)).to(device)
121121
optimizer = Adam(net.parameters(), lr=lr_start)
122+
lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
122123
valid_loss_max = np.Inf
123124
for epoch in range(n_epochs): # loop over the dataset multiple times
124125
cfg.curr_epoch_no = epoch
125-
utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start))
126126

127127
train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens, beta_type=beta_type)
128128
valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens)
129+
lr_sched.step(valid_loss)
129130

130131
print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
131132
epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl))

main_frequentist.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import numpy as np
88
import torch.nn as nn
9-
from torch.optim import Adam
9+
from torch.optim import Adam, lr_scheduler
1010

1111
import data
1212
import utils
@@ -82,12 +82,13 @@ def run(dataset, net_type):
8282

8383
criterion = nn.CrossEntropyLoss()
8484
optimizer = Adam(net.parameters(), lr=lr)
85+
lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
8586
valid_loss_min = np.Inf
8687
for epoch in range(1, n_epochs+1):
87-
utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr))
8888

8989
train_loss, train_acc = train_model(net, optimizer, criterion, train_loader)
9090
valid_loss, valid_acc = validate_model(net, criterion, valid_loader)
91+
lr_sched.step(valid_loss)
9192

9293
train_loss = train_loss/len(train_loader.dataset)
9394
valid_loss = valid_loss/len(valid_loader.dataset)

metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ def forward(self, input, target, kl, beta):
1414
return F.nll_loss(input, target, reduction='mean') * self.train_size + beta * kl
1515

1616

17-
def lr_linear(epoch_num, decay_start, total_epochs, start_value):
18-
if epoch_num < decay_start:
19-
return start_value
20-
return start_value*float(total_epochs-epoch_num)/float(total_epochs-decay_start)
17+
# def lr_linear(epoch_num, decay_start, total_epochs, start_value):
18+
# if epoch_num < decay_start:
19+
# return start_value
20+
# return start_value*float(total_epochs-epoch_num)/float(total_epochs-decay_start)
2121

2222

2323
def acc(outputs, targets):

0 commit comments

Comments
 (0)