-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcudnnSoftMax.go
255 lines (230 loc) · 5.86 KB
/
cudnnSoftMax.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
package gocudnn
/*
#include <cudnn.h>
*/
import "C"
import (
"fmt"
"unsafe"
"github.com/dereklstinson/cutil"
)
//SoftMaxD holds the soft max flags and soft max funcs
type SoftMaxD struct {
algo C.cudnnSoftmaxAlgorithm_t
mode C.cudnnSoftmaxMode_t
}
//CreateSoftMaxDescriptor creates a gocudnn softmax descriptor. It is not part of cudnn, but I wanted to make the library
//A little more stream lined after using it for a while
func CreateSoftMaxDescriptor() *SoftMaxD {
return &SoftMaxD{}
}
//Set sets the soft max algos.
func (s *SoftMaxD) Set(algo SoftMaxAlgorithm, mode SoftMaxMode) error {
s.algo = algo.c()
s.mode = mode.c()
return nil
}
//Get gets the softmax descriptor values
func (s *SoftMaxD) Get() (algo SoftMaxAlgorithm, mode SoftMaxMode, err error) {
return SoftMaxAlgorithm(s.algo), SoftMaxMode(s.mode), nil
}
func (s *SoftMaxD) String() string {
return fmt.Sprintf("SoftMaxD{\n%v,\n%v,\n}\n", s.algo, s.mode)
}
/* Softmax functions: All of the form "output = alpha * Op(inputs) + beta * output" */
//Forward performs forward softmax
//
//Input/Output: y
func (s *SoftMaxD) Forward(
handle *Handle,
alpha float64,
xD *TensorD, x cutil.Mem,
beta float64,
yD *TensorD, y cutil.Mem) error {
a := cscalarbydatatype(xD.dtype, alpha)
b := cscalarbydatatype(yD.dtype, beta)
if handle.w != nil {
return handle.w.Work(func() error {
return Status(C.cudnnSoftmaxForward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
xD.descriptor, x.Ptr(),
b.CPtr(),
yD.descriptor, y.Ptr(),
)).error("(s *SoftMaxD) Forward")
})
}
return Status(C.cudnnSoftmaxForward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
xD.descriptor, x.Ptr(),
b.CPtr(),
yD.descriptor, y.Ptr(),
)).error("(s *SoftMaxD) Forward")
}
//Backward performs the backward softmax
//
//Input/Output: dx
func (s *SoftMaxD) Backward(
handle *Handle,
alpha float64,
yD *TensorD, y cutil.Mem,
dyD *TensorD, dy cutil.Mem,
beta float64,
dxD *TensorD, dx cutil.Mem,
) error {
a := cscalarbydatatype(yD.dtype, alpha)
b := cscalarbydatatype(dxD.dtype, beta)
if handle.w != nil {
return handle.w.Work(func() error {
return Status(C.cudnnSoftmaxBackward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
yD.descriptor, y.Ptr(),
dyD.descriptor, dy.Ptr(),
b.CPtr(),
dxD.descriptor, dx.Ptr(),
)).error("(s *SoftMaxD) Backward")
})
}
return Status(C.cudnnSoftmaxBackward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
yD.descriptor, y.Ptr(),
dyD.descriptor, dy.Ptr(),
b.CPtr(),
dxD.descriptor, dx.Ptr(),
)).error("(s *SoftMaxD) Backward")
}
//ForwardUS is like Forward but uses unsafe.Pointer instead of cutil.Mem
func (s *SoftMaxD) ForwardUS(
handle *Handle,
alpha float64,
xD *TensorD, x unsafe.Pointer,
beta float64,
yD *TensorD, y unsafe.Pointer) error {
a := cscalarbydatatype(xD.dtype, alpha)
b := cscalarbydatatype(yD.dtype, beta)
if handle.w != nil {
return handle.w.Work(func() error {
return Status(C.cudnnSoftmaxForward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
xD.descriptor, x,
b.CPtr(),
yD.descriptor, y,
)).error("(s *SoftMaxD) ForwardUS")
})
}
return Status(C.cudnnSoftmaxForward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
xD.descriptor, x,
b.CPtr(),
yD.descriptor, y,
)).error("(s *SoftMaxD) ForwardUS")
}
//BackwardUS is like Backward but uses unsafe.Pointer instead of cutil.Mem
func (s *SoftMaxD) BackwardUS(
handle *Handle,
alpha float64,
yD *TensorD, y unsafe.Pointer,
dyD *TensorD, dy unsafe.Pointer,
beta float64,
dxD *TensorD, dx unsafe.Pointer,
) error {
a := cscalarbydatatype(yD.dtype, alpha)
b := cscalarbydatatype(dxD.dtype, beta)
if handle.w != nil {
return handle.w.Work(func() error {
return Status(C.cudnnSoftmaxBackward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
yD.descriptor, y,
dyD.descriptor, dy,
b.CPtr(),
dxD.descriptor, dx,
)).error("(s *SoftMaxD) BackwardUS")
})
}
return Status(C.cudnnSoftmaxBackward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
yD.descriptor, y,
dyD.descriptor, dy,
b.CPtr(),
dxD.descriptor, dx,
)).error("(s *SoftMaxD) BackwardUS")
}
//SoftMaxAlgorithm is used for flags and are exposed through its methods
type SoftMaxAlgorithm C.cudnnSoftmaxAlgorithm_t
//Fast changes s to and returns SoftMaxAlgorithm(C.CUDNN_SOFTMAX_FAST)
func (s *SoftMaxAlgorithm) Fast() SoftMaxAlgorithm {
*s = SoftMaxAlgorithm(C.CUDNN_SOFTMAX_FAST)
return *s
}
//Accurate changes s to and returns SoftMaxAlgorithm(C.CUDNN_SOFTMAX_ACCURATE)
func (s *SoftMaxAlgorithm) Accurate() SoftMaxAlgorithm {
*s = SoftMaxAlgorithm(C.CUDNN_SOFTMAX_ACCURATE)
return *s
}
//Log changes s to and returns SoftMaxAlgorithm(C.CUDNN_SOFTMAX_LOG)
func (s *SoftMaxAlgorithm) Log() SoftMaxAlgorithm {
*s = SoftMaxAlgorithm(C.CUDNN_SOFTMAX_LOG)
return *s
}
func (s SoftMaxAlgorithm) c() C.cudnnSoftmaxAlgorithm_t { return C.cudnnSoftmaxAlgorithm_t(s) }
func (s SoftMaxAlgorithm) String() string {
var x string
f := s
switch s {
case f.Fast():
x = "Fast"
case f.Accurate():
x = "Accurate"
case f.Log():
x = "Log"
default:
x = "Unsupported Flag"
}
return "SoftMaxAlgorithm: " + x
}
//SoftMaxMode is used for softmaxmode flags and are exposed through its methods
type SoftMaxMode C.cudnnSoftmaxMode_t
//Instance changes s to SoftMaxMode(C.CUDNN_SOFTMAX_MODE_INSTANCE) and returns changed value
func (s *SoftMaxMode) Instance() SoftMaxMode {
*s = SoftMaxMode(C.CUDNN_SOFTMAX_MODE_INSTANCE)
return *s
}
//Channel changes s to SoftMaxMode(C.CUDNN_SOFTMAX_MODE_CHANNEL) and returns changed value
func (s *SoftMaxMode) Channel() SoftMaxMode { *s = SoftMaxMode(C.CUDNN_SOFTMAX_MODE_CHANNEL); return *s }
func (s SoftMaxMode) c() C.cudnnSoftmaxMode_t { return C.cudnnSoftmaxMode_t(s) }
func (s SoftMaxMode) String() string {
var x string
f := s
switch s {
case f.Channel():
x = "Channel"
case f.Instance():
x = "Instance"
default:
x = "Unsupported Flag"
}
return "SoftMaxMode: " + x
}