forked from dungtd2403/ivsr-s2p
-
Notifications
You must be signed in to change notification settings - Fork 0
/
distill.py
100 lines (78 loc) · 3.7 KB
/
distill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import tensorflow as tf
from model.model import Kitti3DPredictor, BackboneSharedParameterizedNet
from model.loss import L2DepthLoss, L2NormRMSE
from solver.optimizer import OptimizerFactory
import argparse
from data.parameterized_parallel_dataset import Dataset, DataLoader
tf.keras.backend.clear_session()
parser = argparse.ArgumentParser(description='Select between small or big data',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data-size', type=str, choices=['big', 'small', 'real', 'kitti'], default='small')
parser.add_argument('-m', '--training-mode', type=str, choices=['normal', 'parameterized', 'shared', 'kitti'], default='parameterized')
parser.add_argument('-b', '--batch-size', type=int, default=32)
parser.add_argument('-j', '--jobs', type=int, default=8)
args = parser.parse_args()
# input_shape = (180, 320)
input_shape = (180, 595)
################################
# Define data and dataloader #
################################
if args.data_size == 'big':
train_path = "./train_new.csv"
val_path = "./val_new.csv"
img_directory = "/media/data/teamAI/phuc/phuc/airsim/data"
elif args.data_size == 'kitti':
train_path = "/media/data/teamAI/minh/kitti_out/kitti_train.csv"
val_path = "/media/data/teamAI/minh/kitti_out/kitti_val.csv"
img_directory = "/media/data/teamAI/minh/kitti_out/semantic-0.4"
else:
train_path = "./train588_50_new.csv"
val_path = "./val588_50_new.csv"
img_directory = "/media/data/teamAI/phuc/phuc/airsim/50imperpose/full/"
train_dataset = Dataset(train_path, img_directory, input_shape)
val_dataset = Dataset(val_path, img_directory, input_shape)
train_loader = DataLoader(train_dataset, input_shape=input_shape, batch_size=args.batch_size, num_parallel_calls=args.jobs)
val_loader = DataLoader(val_dataset, input_shape=input_shape, batch_size=args.batch_size, num_parallel_calls=args.jobs)
################
# Define model #
################
teacher = tf.keras.models.load_model('/media/data/teamAI/minh/ivsr_weights/training_kitti1507/cp-25.cpkt')
teacher.build(input_shape=(None, input_shape[0], input_shape[1], 1))
# teacher.name = 'teacher'
teacher.summary()
if args.training_mode =='shared':
student = BackboneSharedParameterizedNet(num_ext_conv=1)
elif args.data_size == 'kitti':
student = Kitti3DPredictor(num_ext_conv=3)
student.build(input_shape=(None, input_shape[0], input_shape[1], 1))
# inputs = tf.keras.Input(shape=(input_shape[0], input_shape[1], 1))
# _ = net.call(inputs)
student.summary()
#######################
# Define loss function#
#######################
USE_MSE = True
if USE_MSE:
dist_loss_fn = tf.keras.losses.MeanSquaredError()
depth_loss_fn = tf.keras.losses.MeanSquaredError()
else :
dist_loss_fn = L2NormRMSE()
depth_loss_fn = L2DepthLoss()
distill_loss_fn = tf.keras.losses.KLDivergence()
#######################
# Define optimizer#
#######################
factory = OptimizerFactory(lr=1e-3, use_scheduler=False)
optimizer = factory.get_optimizer()
#trainer and train
from solver.distiller import Distiller
trainer = Distiller(train_loader, val_loader=val_loader,
teacher = teacher, student=student,
distance_loss_fn=dist_loss_fn, depth_loss_fn=depth_loss_fn, distillation_loss_fn = distill_loss_fn,
optimizer=optimizer,
log_path='/media/data/teamAI/minh/ivsr-logs/training_kitti1707_distill.txt', savepath='/media/data/teamAI/minh/ivsr_weights/training_kitti1707_distill',
use_mse=USE_MSE,
alpha = 0.1,
temperature= 10)
_ = trainer.train(30, save_checkpoint=True, early_stop=True)
#trainer.save_model()