-
Notifications
You must be signed in to change notification settings - Fork 8
/
Sort.lua
78 lines (71 loc) · 2.53 KB
/
Sort.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
local Sort, parent = torch.class('nn.Sort', 'nn.Module')
------------------------------------------------------------------------
--[[ Sort ]]--
-- Applies torch.sort along dimension dim to the input.
-- Returns a table of {sortedIndices, sortedInputs}
-- Used with BlockSparse
------------------------------------------------------------------------
function Sort:__init(dim, descending)
parent.__init(self)
self.dim = dim or 2
self.notDim = (self.dim == 2) and 1 or 2
self.descending = descending or false
self.indice = torch.LongTensor()
self._output = torch.Tensor()
self._input = torch.Tensor()
self.output = {self.indice, self._output}
self._cuda = false
end
function Sort:updateOutput(input)
assert(input:dim() == 2, "Only works with matrices")
if self._cuda then
self._input:resize(input:size())
self._input:copy(input)
input = self._input
end
self._output:sort(self.indice, input, self.dim, self.descending)
if self._cuda then
self._outputCuda:resize(self._output:size())
self._outputCuda:copy(self._output)
self._indiceCuda:resize(self.indice:size())
self._indiceCuda:copy(self.indice)
end
return self.output
end
function Sort:updateGradInput(input, gradOutput)
local dim = self.notDim
self.gradInput:resizeAs(input)
gradOutput = gradOutput[2]
if self._cuda then
local grad
self._gradOutputHost:resize(gradOutput:size(1), gradOutput:size(2))
self._gradOutputHost:copy(gradOutput)
self._gradInputHost:resizeAs(self._input)
for i=1,input:size(dim) do
self._gradInputHost:select(dim, i):indexCopy(1, self.indice:select(dim, i), self._gradOutputHost:select(dim, i))
end
self.gradInput:copy(self._gradInputHost)
else
for i=1,input:size(dim) do
self.gradInput:select(dim, i):indexCopy(1, self.indice:select(dim, i), gradOutput:select(dim, i))
end
end
return self.gradInput
end
function Sort:type(type)
self.gradInput = self.gradInput:type(type)
if (type ~= 'torch.CudaTensor') then
self._output = self._output:type(type)
self._input = self._input:type(type)
self.output = {self.indice, self._output}
else
self._cuda = true
self._output = self._output:float()
self._input = self._input:float()
self._outputCuda = torch.CudaTensor()
self._indiceCuda = torch.CudaTensor()
self._gradInputHost = torch.FloatTensor()
self._gradOutputHost = torch.FloatTensor()
self.output = {self._indiceCuda, self._outputCuda}
end
end