-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
86 lines (72 loc) · 3.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import torch.optim as optim
import torch.utils.data as data
from tensorboard_logger import Logger
from utils import FolderDataset, train, ImagePool
from networks import Discriminator, Generator
import numpy as np
starting_epoch = 1
nb_epochs = 100
experiment_name= "train1"
# Folder when events are stored for Tensorboard
train_logger = Logger("events/" + experiment_name + "/train")
# Create experiment folders
if not os.path.isdir('models/' + experiment_name):
os.makedirs('models/' + experiment_name)
# GPU
with torch.cuda.device(0):
nb_filters = 32
input_size = 128
act = True
batch_norm = True
max_len = None
max_steps = np.inf
lr_0 = 0.0002
lr_1 = lr_0 / 100
dataset = FolderDataset(
root_dir_A="datasets/maps/trainA",
root_dir_B="datasets/maps/trainB",
input_size=input_size,
max_len=max_len
)
loader = data.DataLoader(dataset, batch_size=1, shuffle=True)
print("Dataset loaded !")
D_B = Discriminator(nb_filters, 5, input_size).cuda()
D_A = Discriminator(nb_filters, 5, input_size).cuda()
G_AB = Generator(nb_filters, act, batch_norm).cuda()
G_BA = Generator(nb_filters, act, batch_norm).cuda()
if starting_epoch > 1:
D_B_file = 'models/' + experiment_name + '/D_B_' + str(starting_epoch - 1) + '.pth'
D_A_file = 'models/' + experiment_name + '/D_A_' + str(starting_epoch - 1) + '.pth'
G_AB_file = 'models/' + experiment_name + '/G_AB_' + str(starting_epoch - 1) + '.pth'
G_BA_file = 'models/' + experiment_name + '/G_BA_' + str(starting_epoch - 1) + '.pth'
D_B.load_state_dict(torch.load(D_B_file))
D_A.load_state_dict(torch.load(D_A_file))
G_AB.load_state_dict(torch.load(G_AB_file))
G_BA.load_state_dict(torch.load(G_BA_file))
networks = [D_B, D_A, G_AB, G_BA]
optimizer_D_B = optim.Adam(D_B.parameters(), lr=lr_1)
optimizer_D_A = optim.Adam(D_A.parameters(), lr=lr_1)
optimizer_G_AB = optim.Adam(G_AB.parameters(), lr=lr_1)
optimizer_G_BA = optim.Adam(G_BA.parameters(), lr=lr_1)
optimizers = [optimizer_G_AB, optimizer_G_BA, optimizer_D_B, optimizer_D_A]
pool_A = ImagePool(pool_size=50)
pool_B = ImagePool(pool_size=50)
pools = [pool_A, pool_B]
for epoch in range(starting_epoch, nb_epochs + 1):
losses = train(epoch, loader, networks, optimizers, pools, max_steps=max_steps, verbose=False)
train_logger.log_value('D_B loss', losses[0], epoch)
train_logger.log_value('D_A loss', losses[1], epoch)
train_logger.log_value('G_AB loss', losses[2], epoch)
train_logger.log_value('G_BA loss', losses[3], epoch)
total_loss = sum(losses)
print("\nLoss at epoch n.{} : {}".format(epoch, total_loss))
D_B_file = 'models/' + experiment_name + '/D_B_' + str(epoch) + '.pth'
D_A_file = 'models/' + experiment_name + '/D_A_' + str(epoch) + '.pth'
G_AB_file = 'models/' + experiment_name + '/G_AB_' + str(epoch) + '.pth'
G_BA_file = 'models/' + experiment_name + '/G_BA_' + str(epoch) + '.pth'
torch.save(D_B.state_dict(), D_B_file)
torch.save(D_A.state_dict(), D_A_file)
torch.save(G_AB.state_dict(), G_AB_file)
torch.save(G_BA.state_dict(), G_BA_file)