-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainmain.lua
130 lines (118 loc) · 4.34 KB
/
trainmain.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
print '==> executing all'
-- Extend the path so that this script can be used from other folders
-- Parses the input arguments and sets the variables for the model and certain other variables such as:
-- model : The Neural network model
-- opt : The parsed arguments from the command line
require 'init'
require 'loss'
require 'train'
-- dofile 'test.lua'
----------------------------------------------------------------------
print '==> training!'
function mainfeat()
local trainfeatfile = opt.featfile
local featfile = io.open(trainfeatfile, 'r')
while (true) do
trainData = readDataFeat(featfile)
if (trainData:size()>0) then
local shuffleddata = torch.randperm(trainData:size())
train(shuffleddata)
else
break
end
collectgarbage()
end
featfile:close()
print('==> final results')
confusion:updateValids()
print('average row correct: ' .. (confusion.averageValid*100) .. '%')
print('average rowUcol correct (VOC measure): ' .. (confusion.averageUnionValid*100) .. '%')
print('global correct: ' .. (confusion.totalValid*100) .. '%')
end
function mainscp()
readLabel(opt.labelfile)
local means, variances
if (opt.globalnorm~='') then
means, variances = readglobalnorm(opt.globalnorm)
end
if (opt.scpfile~='') then
local trainfeatfilelist = opt.scpfile
local listfile = io.open(trainfeatfilelist, 'r')
while (true) do
trainData = readDataScp2(listfile, opt.filenum, means, variances)
if (trainData~=nil) then
local shuffleddata = torch.randperm(trainData:size())
train(shuffleddata)
else
break
end
collectgarbage()
end
listfile:close()
print('==> final results')
confusion:updateValids()
print('average row correct: ' .. (confusion.averageValid*100) .. '%')
print('average rowUcol correct (VOC measure): ' .. (confusion.averageUnionValid*100) .. '%')
print('global correct: ' .. (confusion.totalValid*100) .. '%')
end
-- cross validation
if (opt.cvscpfile~='') then
print("\n==> cross validation")
confusion:zero()
local cvfeatfilelist = opt.cvscpfile
local listfile = io.open(cvfeatfilelist, 'r')
while (true) do
cvData = readDataScp2(listfile,opt.filenum)
if (cvData~=nil) then
crossValidate()
else
break
end
collectgarbage()
end
listfile:close()
confusion:updateValids()
print('average row correct: ' .. (confusion.averageValid*100) .. '%')
print('average rowUcol correct (VOC measure): ' .. (confusion.averageUnionValid*100) .. '%')
print('global correct: ' .. (confusion.totalValid*100) .. '%')
print('global correct: ' .. (confusion.totalValid*100) .. '%')
end
end
-- Check if the scpfile argument is given and the scpfile can be found
if (opt.scpfile=='') and (opt.featfile=='') and (opt.cvscpfile=='') then
error("Please specify a file containing the data with -scpfile or Please specify a file containing the data with -fbankfile")
return
elseif (opt.scpfile~='' or opt.cvscpfile~='') then
mainscp()
elseif (opt.featfile~='') then
if io.open(opt.featfile,"rb") == nil then
error(string.format("Given feature file %s cannot be found!",opt.featfile))
return
else
mainfeat()
end
end
-- if not opt.featfile then
-- error("Please specify a file containing the data with -fbankfile")
-- return
-- elseif io.open(opt.featfile,"rb") == nil then
-- error(string.format("Given feature file %s cannot be found!",opt.featfile))
-- return
-- end
-- local trainfeatfile = opt.featfile
-- local featfile = io.open(trainfeatfile, 'r')
-- while (true) do
-- trainData = readData(featfile)
-- if (trainData:size()>0) then
-- local shuffleddata = torch.randperm(trainData:size())
-- train(shuffleddata)
-- else
-- break
-- end
-- collectgarbage()
-- end
-- featfile:close()
-- confusion:updateValids()
-- print('average row correct: ' .. (confusion.averageValid*100) .. '%')
-- print('average rowUcol correct (VOC measure): ' .. (confusion.averageUnionValid*100) .. '%')
-- print('global correct: ' .. (confusion.totalValid*100) .. '%')