-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcudnn.go
137 lines (112 loc) · 3.68 KB
/
cudnn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package gocudnn
/*
#include <cudnn.h>
#include <cuda.h>
*/
import "C"
import (
"errors"
"github.com/dereklstinson/cutil"
"github.com/dereklstinson/half"
)
var cudnndebugmode bool
//DebugMode is for debugging code soley for these bindings.
func DebugMode() {
cudnndebugmode = true
}
//DimMax is the max dims for tensors
const DimMax = int32(C.CUDNN_DIM_MAX)
//BnMinEpsilon is the min epsilon for batchnorm
//It used to be 1e-5, but it is now 0
const BnMinEpsilon = (float64)(C.CUDNN_BN_MIN_EPSILON)
//CScalarByDataType takes the DataType flag and puts num into a CScalar interface. The value of num will be bound by what is passed for DataType.
//If a DataType isn't supported by the function it will return nil.
func cscalarbydatatype(dtype DataType, num float64) cutil.CScalar {
var x DataType //CUDNN_DATATYPE_FLOAT
switch dtype {
case x.Double():
return cutil.CDouble(num)
case x.Float():
return cutil.CFloat(num)
case x.Int32():
y := float32(num)
return cutil.CFloat(y)
case x.Int8():
y := float32(num)
return cutil.CFloat(y)
case x.UInt8():
y := float32(num)
return cutil.CFloat(y)
case x.Half():
y := float32(num)
return cutil.CFloat(y)
default:
return nil
}
}
//CScalarByDataType takes the DataType flag and puts num into a CScalar interface. The value of num will be bound by what is passed for DataType.
//If a DataType isn't supported by the function it will return nil.
func cscalarbydatatypeforsettensor(dtype DataType, num float64) cutil.CScalar {
var x DataType //CUDNN_DATATYPE_FLOAT
switch dtype {
case x.Double():
return cutil.CDouble(num)
case x.Float():
return cutil.CFloat(num)
case x.Int32():
return cutil.CInt(num)
case x.Int8():
return cutil.CChar(num)
case x.UInt8():
return cutil.CUChar(num)
case x.Half():
y := float32(num)
return cutil.CHalf(half.NewFloat16(y))
default:
return nil
}
}
//RuntimeTag is a type that cudnn looks to check or kernels to see if they are working correctly.
//Should be used with batchnormialization
type RuntimeTag C.cudnnRuntimeTag_t
// ErrQueryMode are basically flags that are used for different modes that are exposed through the
//types methods
type ErrQueryMode C.cudnnErrQueryMode_t
//RawCode sets e to and returns ErrQueryMode(C.CUDNN_ERRQUERY_RAWCODE)
func (e *ErrQueryMode) RawCode() ErrQueryMode { *e = ErrQueryMode(C.CUDNN_ERRQUERY_RAWCODE); return *e }
//NonBlocking sets e to and returns ErrQueryMode(C.CUDNN_ERRQUERY_NONBLOCKING)
func (e *ErrQueryMode) NonBlocking() ErrQueryMode {
*e = ErrQueryMode(C.CUDNN_ERRQUERY_NONBLOCKING)
return *e
}
//Blocking sets e to and returns ErrQueryMode(C.CUDNN_ERRQUERY_BLOCKING)
func (e *ErrQueryMode) Blocking() ErrQueryMode {
*e = ErrQueryMode(C.CUDNN_ERRQUERY_BLOCKING)
return *e
}
func (e ErrQueryMode) c() C.cudnnErrQueryMode_t { return C.cudnnErrQueryMode_t(e) }
//GetVersion returns the version
func GetVersion() uint {
return uint(C.cudnnGetVersion())
}
//GetCudaartVersion prints cuda run time version
func GetCudaartVersion() uint {
return uint(C.cudnnGetCudartVersion())
}
//QueryRuntimeError check cudnnQueryRuntimeError in DEEP Learning SDK Documentation
//tag should be nil
func (handle *Handle) QueryRuntimeError(mode ErrQueryMode, tag *RuntimeTag) (Status, error) {
var rstatus C.cudnnStatus_t
if tag == nil {
var err error
if handle.w != nil {
err = handle.w.Work(func() error {
return Status(C.cudnnQueryRuntimeError(handle.x, &rstatus, C.cudnnErrQueryMode_t(mode), nil)).error("QueryRuntimeError")
})
} else {
err = Status(C.cudnnQueryRuntimeError(handle.x, &rstatus, C.cudnnErrQueryMode_t(mode), nil)).error("QueryRuntimeError")
}
return Status(rstatus), err
}
return Status(rstatus), errors.New("Tag flags not supported")
}