-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathMultiImageSequence.lua
332 lines (276 loc) · 11.9 KB
/
MultiImageSequence.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
------------------------------------------------------------------------
--[[ MultiImageSequence ]]--
-- input : a sequence of images
-- targets : a sequence of images
-- Example dataset :
-- input is a sequence of video frames where targets are binary masks
--
-- Directory is organized as :
-- [datapath]/[seqid]/[input|target][1,2,3,...,T].jpg
-- So basically, the datapath contains a folder for each sequence.
-- Each sequence has T input and target images where T can vary.
-- The target of input[t].jpg is target[t].jpg.
--
-- The return inputs and targets will be separated by mask tokens ([ ]):
-- [ ] target11, target12, target13, ..., target1T [ ] target21, ...
-- [ ] input11, input12, input13, ..., input1T [ ] input21, ...
--
-- The mask tokens [ ] represent images with nothing but zeros.
-- For large datasets use Lua5.2 instead of LuaJIT to avoid mem errors.
------------------------------------------------------------------------
local dl = require 'dataload._env'
local MultiImageSequence, parent = torch.class('dl.MultiImageSequence', 'dl.DataLoader', dl)
function MultiImageSequence:__init(datapath, batchsize, loadsize, samplesize, samplefunc, verbose)
-- 1. post-init arguments
-- samples a random uniform crop location every time-step (instead of once per sequence)
self.cropeverystep = false
-- random-uniformly samples a loadsize between samplesize and loadsize (this effectively scales the croped location)
self.varyloadsize = false
-- varies load size every step instead of once per sequence
self.scaleeverystep = false
-- each new sequence is chosen random uniformly
self.randseq = false
-- modify this to use different patterns for input and target files
self.inputpattern = 'input%d.jpg'
self.targetpattern = 'target%d.jpg'
-- 2. arguments
-- path containing a folder for each sequence
self.datapath = datapath
assert(torch.type(self.datapath) == 'string')
-- number of sequences per batch
self.batchsize = batchsize
assert(torch.type(batchsize) == 'number')
-- size to load the images to, initially
self.loadsize = loadsize
assert(torch.type(self.loadsize) == 'table')
assert(torch.type(self.loadsize[1]) == 'table', 'Missing inputs loadsize')
assert(torch.type(self.loadsize[2]) == 'table', 'Missing targets loadsize')
-- consistent sample size to resize the images.
local inputsamplesize = samplesize or self.loadsize
assert(torch.type(inputsamplesize) == 'table')
assert(torch.type(inputsamplesize[1]) == 'number', 'Provide samplesize for inputs only (target samplesize will be proportional)')
assert(inputsamplesize[2] <= self.loadsize[1][2])
assert(inputsamplesize[3] <= self.loadsize[1][3])
-- make target samplesize proportional to input samplesize w.r.t. loadsize
local h = math.min(self.loadsize[2][2], math.max(1, torch.round(inputsamplesize[2]*self.loadsize[2][2]/self.loadsize[1][2])))
local w = math.min(self.loadsize[2][2], math.max(1, torch.round(inputsamplesize[3]*self.loadsize[2][3]/self.loadsize[1][3])))
local targetsamplesize = {self.loadsize[2][1], h, w}
self.samplesize = {inputsamplesize, targetsamplesize}
-- function f(self, dst, path) used to create a sample(s) from
-- an image path. Stores them in dst. Strings "sampleDefault"
-- "sampleTrain" or "sampleTest" can also be provided as they
-- refer to existing functions
self.samplefunc = samplefunc or 'sampleDefault'
-- display verbose messages
self.verbose = verbose == nil and true or verbose
self:reset()
end
function MultiImageSequence:buildIndex(cachepath, overwrite)
if cachepath and (not overwrite) and paths.filep(cachepath) then
if self.verbose then
print("loading cached index")
end
local cache = torch.load(cachepath, 'ascii')
for k,v in pairs(cache) do
self[k] = v
end
else
-- will need this package later to load images (faster than image package)
require 'graphicsmagick'
local _ = require 'moses'
if self.verbose then
print(string.format("Building index. Counting number of frames"))
end
-- index files
local a = torch.Timer()
local files = paths.indexdir(self.datapath, nil, nil, 'target*')
local seqdirs = {}
for i=1,files:size() do
local filepath = files:filename(i)
local seqdir, idx = filepath:match("/([^/]+)/input([%d]*)[.][^/]+$")
if seqdir then
local seq = seqdirs[seqdir]
if not seq then
seq = {}
seqdirs[seqdir] = seq
end
seq[tonumber(idx)] = true
end
end
self.seqdirs = {}
self.nframe = 0
for seqdir, seq in pairs(seqdirs) do
if #seq > 0 then
table.insert(self.seqdirs, {seqdir, #seq})
self.nframe = self.nframe + #seq
end
end
-- +#seqdirs because of masks between sequences
self.seqlen = torch.ceil((self.nframe + #self.seqdirs)/self.batchsize)
assert(#self.seqdirs > 0)
if cachepath then
local cache = {seqdirs=self.seqdirs, nframe=self.nframe, seqlen=self.seqlen}
torch.save(cachepath, cache, "ascii")
end
end
if self.verbose then
print(string.format("Found %d sequences with a total of %d frames", #self.seqdirs, self.nframe))
end
end
function MultiImageSequence:reset()
parent.reset(self)
self.trackers = {nextseq=1}
end
function MultiImageSequence:size()
if not self.seqdirs then
self:buildIndex()
end
return self.seqlen
end
-- size of input images
function MultiImageSequence:isize(excludedim)
excludedim = excludedim == nil and 1 or excludedim
assert(excludedim == 1)
return {unpack(self.samplesize[1])}
end
-- size of target images
function MultiImageSequence:tsize(excludedim)
excludedim = excludedim == nil and 1 or excludedim
assert(excludedim == 1)
return {unpack(self.samplesize[2])}
end
-- inputs : seqlen x batchsize x c x h x w
-- targets : seqlen x batchsize x c x h x w
function MultiImageSequence:sub(start, stop, inputs, targets, samplefunc)
if not self.seqdirs then
self:buildIndex()
end
local seqlen = stop - start + 1
inputs = inputs or torch.FloatTensor()
inputs:resize(seqlen, self.batchsize, unpack(self.samplesize[1])):zero()
targets = targets or inputs.new()
targets:resize(seqlen, self.batchsize, unpack(self.samplesize[2])):zero()
-- samplefunc is a function that generates a sample input and target
-- given their commensurate paths
local samplefunc = samplefunc or self.samplefunc
if torch.type(samplefunc) == 'string' then
samplefunc = self[samplefunc]
end
assert(torch.type(samplefunc) == 'function')
for i=1,self.batchsize do
local input = inputs:select(2,i)
local target = targets:select(2,i)
local tracker = self.trackers[i] or {}
self.trackers[i] = tracker
local start = 1
while start <= seqlen do
if not tracker.seqid then
tracker.idx = 1
-- each sequence is separated by a zero input and target.
-- this should make the model forget between sequences
-- (use with AbstractRecurrent:maskZero())
input[start]:fill(0)
target[start]:fill(0)
start = start + 1
if self.randseq then
tracker.seqid = math.random(1,#self.seqdirs)
else
tracker.seqid = self.trackers.nextseq
self.trackers.nextseq = self.trackers.nextseq + 1
if self.trackers.nextseq > #self.seqdirs then
self.trackers.nextseq = 1
end
end
end
if start <= seqlen then
local seqdir, nframe = unpack(self.seqdirs[tracker.seqid])
local seqpath = paths.concat(self.datapath, seqdir)
local size = 0
for i=tracker.idx,nframe do
local inputpath = paths.concat(seqpath, string.format(self.inputpattern, i))
local targetpath = paths.concat(seqpath, string.format(self.targetpattern, i))
if i == nframe then
-- move on to next sequence
tracker.seqid = nil
end
size = size + 1
samplefunc(self, input[start + size - 1], target[start + size - 1], inputpath, targetpath, tracker)
tracker.idx = tracker.idx + 1
if start + size - 1 == seqlen then
break
end
end
start = start + size
end
end
assert(start-1 == seqlen)
end
self:collectgarbage()
return inputs, targets
end
function MultiImageSequence:loadImage(path, idx, tracker)
-- https://github.com/clementfarabet/graphicsmagick#gmimage
local gm = require 'graphicsmagick'
local lW, lH
if self.varyloadsize then
if self.scaleeverystep or tracker.idx == 1 then
tracker.lW = tracker.lW or {}
tracker.lH = tracker.lH or {}
lW, lH = self.loadsize[idx][3], self.loadsize[idx][2]
local sW, sH = self.samplesize[idx][3], self.samplesize[idx][2]
-- sample a loadsize between samplesize and loadsize (same scale for input and target)
tracker.scale = idx == 1 and math.random() or tracker.scale
tracker.lW[idx] = torch.round(sW + tracker.scale*(lW-sW))
tracker.lH[idx] = torch.round(sH + tracker.scale*(lH-sH))
end
lW, lH = tracker.lW[idx], tracker.lH[idx]
else
lW, lH = self.loadsize[idx][3], self.loadsize[idx][2]
end
assert(lW and lH)
-- load image with size hints
local input = gm.Image():load(path, lW, lH)
-- resize by imposing the smallest dimension (while keeping aspect ratio)
input:size(nil, math.min(lW, lH))
return input
end
-- just load the image and resize it
function MultiImageSequence:sampleDefault(input, target, inputpath, targetpath, tracker)
local input_ = self:loadImage(inputpath, 1, tracker)
local colorspace = self.samplesize[1][1] == 1 and 'I' or 'RGB'
local out = input_:toTensor('float',colorspace,'DHW', true)
input:resize(out:size(1), self.samplesize[1][2], self.samplesize[1][3])
image.scale(input, out)
local target_ = self:loadImage(targetpath, 2, tracker)
local colorspace = self.samplesize[2][1] == 1 and 'I' or 'RGB'
local out = target_:toTensor('float',colorspace,'DHW', true)
target:resize(out:size(1), self.samplesize[2][2], self.samplesize[2][3])
image.scale(target, out)
return input, target
end
-- function to load the image, jitter it appropriately (random crops, etc.)
function MultiImageSequence:sampleTrain(input, target, inputpath, targetpath, tracker)
local input_ = self:loadImage(inputpath, 1, tracker)
local target_ = self:loadImage(targetpath, 2, tracker)
-- do random crop once per sequence (unless self.cropeverystep)
if tracker.idx == 1 or self.cropeverystep then
tracker.cH = math.random()
tracker.cW = math.random()
end
assert(tracker.cH and tracker.cW)
local iW, iH = input_:size()
local oW, oH = self.samplesize[1][3], self.samplesize[1][2]
local h1, w1 = math.ceil(tracker.cH*(iH-oH)), math.ceil(tracker.cW*(iW-oW))
local out = input_:crop(oW, oH, w1, h1)
local colorspace = self.samplesize[1][1] == 1 and 'I' or 'RGB'
out = out:toTensor('float',colorspace,'DHW', true)
input:copy(out)
local iW, iH = target_:size()
local oW, oH = self.samplesize[2][3], self.samplesize[2][2]
local h1, w1 = math.ceil(tracker.cH*(iH-oH)), math.ceil(tracker.cW*(iW-oW))
local out = target_:crop(oW, oH, w1, h1)
local colorspace = self.samplesize[2][1] == 1 and 'I' or 'RGB'
out = out:toTensor('float',colorspace,'DHW', true)
target:copy(out)
return input, target
end