-
Notifications
You must be signed in to change notification settings - Fork 964
/
MaskedSelect.lua
71 lines (64 loc) · 2.45 KB
/
MaskedSelect.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
local unpack = unpack or table.unpack
local MaskedSelect, parent = torch.class('nn.MaskedSelect', 'nn.Module')
--[[ Sets the provided mask value for the module. ]]
function MaskedSelect:__init()
parent.__init(self)
self._maskIndices = torch.LongTensor()
self._maskIndexBuffer = torch.LongTensor()
self._maskIndexBufferCPU = torch.FloatTensor()
self._gradBuffer = torch.Tensor()
self._gradMask = torch.ByteTensor()
end
--[[ Performs maskedSelect operation. ]]
function MaskedSelect:updateOutput(input)
local input, mask = unpack(input)
self.output:maskedSelect(input, mask)
return self.output
end
--[[ Reverse maps unmasked gradOutput back to gradInput. ]]
function MaskedSelect:updateGradInput(input, gradOutput)
local input, mask = unpack(input)
if input:type() == 'torch.CudaTensor' then
self._maskIndexBufferCPU:range(1, mask:nElement()):resize(mask:size())
self._maskIndexBuffer:resize(
self._maskIndexBufferCPU:size()):copy(self._maskIndexBufferCPU)
else
self._maskIndexBuffer:range(1, mask:nElement()):resize(mask:size())
end
self._maskIndices:maskedSelect(self._maskIndexBuffer, mask)
self._gradBuffer:resize(input:nElement()):zero()
self._gradBuffer:scatter(1, self._maskIndices, gradOutput)
self._gradBuffer:resize(input:size())
self.gradInput = {self._gradBuffer,
self._gradMask:resize(mask:size()):fill(0)}
return self.gradInput
end
function MaskedSelect:type(type, tensorCache)
if not type then
return self._type
end
self._gradBuffer = self._gradBuffer:type(type)
self.gradInput = self.gradInput:type(type)
self.output = self.output:type(type)
-- These casts apply when switching between cuda/non-cuda types
if type ~= 'torch.CudaTensor' then
self._maskIndexBuffer = self._maskIndexBuffer:long()
self._maskIndices = self._maskIndices:long()
self._gradMask = self._gradMask:byte()
elseif type == 'torch.CudaTensor' then
self._maskIndexBuffer = self._maskIndexBuffer:cuda()
self._maskIndices = self._maskIndices:cuda()
self._gradMask = self._gradMask:cuda()
end
self._type = type
return self
end
function MaskedSelect:clearState()
return nn.utils.clear(self, {'output',
'gradInput',
'_maskIndexBuffer',
'_maskIndexBufferCPU',
'_maskIndices',
'_gradBuffer',
'_gradMask'})
end