-
Notifications
You must be signed in to change notification settings - Fork 0
/
train2.py
147 lines (120 loc) · 5.11 KB
/
train2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import argparse
import torch
import pickle
import torch.utils.data as utils
import torch.optim as optim
import time
import numpy as np
from utils.graph import Graph
from model.mesh_deformation_network import GraphNetwork
from utils.pool import FeaturePooling
from utils.metrics import loss_function
from dataset.dataset import CustomDatasetFolder
import neptune.new as neptune
# Args
parser = argparse.ArgumentParser(description='Pixel2Mesh training script')
parser.add_argument('--data', type=str, default=None, metavar='D',
help="folder where data is located.")
parser.add_argument('--epochs', type=int, default=5, metavar='E',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=3e-5, metavar='LR',
help='learning rate (default: 3e-5)')
parser.add_argument('--log_step', type=int, default=100, metavar='LS',
help='how many batches to wait before logging training status (default: 100)')
parser.add_argument('--saving_step', type=int, default=1000, metavar='S',
help='how many batches to wait before saving model (default: 1000)')
parser.add_argument('--experiment', type=str, default='./model/', metavar='E',
help='folder where model and optimizer are saved.')
parser.add_argument('--load_model', type=str, default=None, metavar='M',
help='model file to load to continue training.')
parser.add_argument('--load_optimizer', type=str, default=None, metavar='O',
help='model file to load to continue training.')
parser.add_argument('--transformer_model', type=str, default='google/vit-huge-patch14-224-in21k', help='the name of the vit to be used')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
args = parser.parse_args()
# Cuda
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import neptune.new as neptune
run = neptune.init(
project="marcomameli1992/Pix2Mesh-PT3D",
api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJkZWJkNDEyYS01NjI0LTRjMDAtODI5Yi0wMzI4NWU5NDc0ZmMifQ==",
) # your credentials
params = {"learning_rate": args.lr, "optimizer": "Adam"}
run["parameters"] = params
# Model
if args.load_model is not None: # Continue training
state_dict = torch.load(args.load_model, map_location=device)
model_gcn = GraphNetwork()
model_gcn.load_state_dict(state_dict)
else:
model_gcn = GraphNetwork()
# Optimizer
if args.load_optimizer is not None:
state_dict_opt = torch.load(args.load_optimizer, map_location=device)
optimizer = optim.Adam(model_gcn.parameters(), lr=args.lr)
optimizer.load_state_dict(state_dict_opt)
else:
optimizer = optim.Adam(model_gcn.parameters(), lr=args.lr)
model_gcn.train()
os.makedirs(args.experiment, exist_ok=True)
# Graph
graph = Graph("./ellipsoid/init_info.pickle")
# Data Loader
folder = CustomDatasetFolder(args.data, extensions=["dat"])
train_loader = torch.utils.data.DataLoader(folder, batch_size=args.batch_size, shuffle=True)
# Param
nb_epochs = args.epochs
log_step = args.log_step
saving_step = args.saving_step
curr_loss = 0
# To GPU
if use_cuda:
print('Using GPU')
model_gcn.cuda()
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
else:
print('Using CPU')
print("Trainable param:", model_gcn.get_nb_trainable_params())
parameters = {'batch_size': args.batch_size, 'epochs': args.epochs, 'lr': args.lr, 'log_step': args.log_step,
'saving_step': args.saving_step, 'experiment': args.experiment, 'load_model': args.load_model,
'load_optimizer': args.load_optimizer, 'transformer_model': args.transformer_model,
'trainable_param': model_gcn.get_nb_trainable_params()}
run["parameters"] = parameters
# Train
for epoch in range(1, nb_epochs+1):
for n, data in enumerate(train_loader):
im, gt_points, gt_normals = data
if use_cuda:
im = im.cuda()
gt_points = gt_points.cuda()
gt_normals = gt_normals.cuda()
# Forward
graph.reset()
optimizer.zero_grad()
pool = FeaturePooling(im)
pred_points = model_gcn(graph, pool)
# Loss
loss = loss_function(pred_points, gt_points.squeeze(),
gt_normals.squeeze(), graph)
run["train/loss"].log(loss.item())
# Backward
loss.backward()
optimizer.step()
curr_loss += loss
# Log
if (n+1) % log_step == 0:
print("Epoch", epoch)
print("Batch", n+1)
print(" Loss:", curr_loss.data.item()/log_step)
curr_loss = 0
model_file = args.experiment + "model_" + str(n+1) + ".pth"
optimizer_file = args.experiment + "optimizer_" + str(n+1) + ".pth"
torch.save(model_gcn.state_dict(), model_file)
torch.save(optimizer.state_dict(), optimizer_file)
print("Saved model to " + model_file)
run.stop()