-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmain.py
86 lines (76 loc) · 3.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn as nn
from args import args
import os
from model.D3D import LipReading as Model
from data.dataset import LipreadingDataset
from torch.utils.data import DataLoader
import torch.optim as optim
from util import reload_model, AdjustLR, trn_epoch, tst_epoch
if args.usecuda:
torch.backends.cudnn.benchmark = True
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
if args.dataset == 'LRW-1000':
num_classes = 1000
args.padding = 60
if args.dataset == 'LRW':
num_classes = 500
args.padding = 29
model = Model(drop_rate=args.dp, num_classes=num_classes)
print(model)
reload_model(model, path=args.model_path)
if len(args.gpus.split(',')) > 1:
model = nn.DataParallel(model)
if args.usecuda:
torch.backends.cudnn.benchmark = True
model = model.cuda()
if args.opt.lower() == 'adam':
optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.001, nesterov=True)
if not args.no_train:
scheduler = AdjustLR(optimizer, [args.lr], sleep_epochs=1, half=5)
trn_index = os.path.join(args.index_root, 'trn_1000.txt')
dataset = LipreadingDataset(data_root=args.data_root, index_root=trn_index, padding=args.padding)
trn_loader = DataLoader(dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
drop_last=False)
trn_len = len(trn_loader)
if not args.no_val:
val_index = os.path.join(args.index_root, 'val_1000.txt')
dataset = LipreadingDataset(data_root=args.data_root, index_root=val_index, padding=args.padding, augment=False)
val_loader = DataLoader(dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
drop_last=False)
val_len = len(val_loader)
tst_index = os.path.join(args.index_root, 'tst_1000.txt')
dataset = LipreadingDataset(data_root=args.data_root, index_root=tst_index, padding=args.padding, augment=False)
tst_loader = DataLoader(dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
drop_last=False)
tst_len = len(tst_loader)
if not args.no_train:
for epoch in range(args.s_epoch, args.epoch):
scheduler.step(epoch)
trn_epoch(model=model, data_loader=trn_loader, optimizer=optimizer, epoch=epoch)
if not args.no_val:
val_acc = tst_epoch(model=model, data_loader=val_loader, epoch=epoch, stage='val')
tst_acc = tst_epoch(model=model, data_loader=tst_loader, epoch=epoch, stage='tst')
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict['val_acc'] = val_acc
state_dict['test_acc'] = tst_acc
torch.save(state_dict, args.save_path + '/' + str(epoch + 1) + '_.pt')
if not args.no_val:
val_acc = tst_epoch(model=model, data_loader=val_loader, epoch=0, stage='val')
tst_acc = tst_epoch(model=model, data_loader=tst_loader, epoch=0, stage='tst')