This repository has been archived by the owner on Jul 19, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 82
/
trainer.py
99 lines (83 loc) · 3.04 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from __future__ import division
import torch
from torch.autograd import Variable
from torch.utils import data
# from resnet import FCN
from upsample import FCN
# from gcn import FCN
from datasets import VOCDataSet
from loss import CrossEntropy2d, CrossEntropyLoss2d
from visualize import LinePlotter
from transform import ReLabel, ToLabel, ToSP, Scale
from torchvision.transforms import Compose, CenterCrop, Normalize, ToTensor
import tqdm
from PIL import Image
import numpy as np
input_transform = Compose([
Scale((256, 256), Image.BILINEAR),
ToTensor(),
Normalize([.485, .456, .406], [.229, .224, .225]),
])
target_transform = Compose([
Scale((256, 256), Image.NEAREST),
ToSP(256),
ToLabel(),
ReLabel(255, 21),
])
trainloader = data.DataLoader(VOCDataSet("./data", img_transform=input_transform,
label_transform=target_transform),
batch_size=16, shuffle=True, pin_memory=True)
if torch.cuda.is_available():
model = torch.nn.DataParallel(FCN(22))
model.cuda()
epoches = 80
lr = 1e-4
weight_decay = 2e-5
momentum = 0.9
weight = torch.ones(22)
weight[21] = 0
max_iters = 92*epoches
criterion = CrossEntropyLoss2d(weight.cuda())
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum,
weight_decay=weight_decay)
ploter = LinePlotter()
model.train()
for epoch in range(epoches):
running_loss = 0.0
for i, (images, labels_group) in tqdm.tqdm(enumerate(trainloader)):
if torch.cuda.is_available():
images = [Variable(image.cuda()) for image in images]
labels_group = [labels for labels in labels_group]
else:
images = [Variable(image) for image in images]
labels_group = [labels for labels in labels_group]
optimizer.zero_grad()
losses = []
for img, labels in zip(images, labels_group):
outputs = model(img)
labels = [Variable(label.cuda()) for label in labels]
for pair in zip(outputs, labels):
losses.append(criterion(pair[0], pair[1]))
if epoch < 40:
loss_weight = [0.1, 0.1, 0.1, 0.1, 0.1, 0.5]
else:
loss_weight = [0.5, 0.1, 0.1, 0.1, 0.1, 0.1]
loss = 0
for w, l in zip(loss_weight, losses):
loss += w*l
loss.backward()
optimizer.step()
running_loss += loss.data[0]
# lr = lr * (1-(92*epoch+i)/max_iters)**0.9
# for parameters in optimizer.param_groups:
# parameters['lr'] = lr
print("Epoch [%d] Loss: %.4f" % (epoch+1, running_loss/i))
ploter.plot("loss", "train", epoch+1, running_loss/i)
running_loss = 0
if (epoch+1) % 20 == 0:
lr /= 10
optimizer = torch.optim.SGD(model.parameters(), lr=lr,
momentum=momentum,
weight_decay=weight_decay)
torch.save(model.state_dict(), "./pth/fcn-deconv-%d.pth" % (epoch+1))
torch.save(model.state_dict(), "./pth/fcn-deconv.pth")