-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbinding.go
347 lines (289 loc) · 14.2 KB
/
binding.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
// Copyright (c) seasonjs. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.
package rwkv
import (
"errors"
"github.com/ebitengine/purego"
"unsafe"
)
type QuantizedFormat string
const (
Q4_0 QuantizedFormat = "Q4_0"
Q4_1 QuantizedFormat = "Q4_1"
Q5_0 QuantizedFormat = "Q5_0"
Q5_1 QuantizedFormat = "Q5_0"
Q8_0 QuantizedFormat = "Q8_0"
)
const (
cRwkvSetPrintErrors = "rwkv_set_print_errors"
cRwkvGetPrintErrors = "rwkv_get_print_errors"
cRwkvGetLastError = "rwkv_get_last_error"
cRwkvInitFromFile = "rwkv_init_from_file"
cRwkvCloneContext = "rwkv_clone_context"
cRwkvGpuOffloadLayers = "rwkv_gpu_offload_layers"
cRwkvEval = "rwkv_eval"
cRwkvEvalSequence = "rwkv_eval_sequence"
cRwkvGetNVocab = "rwkv_get_n_vocab"
cRwkvGetNEmbedding = "rwkv_get_n_embed"
cRwkvGetNLayer = "rwkv_get_n_layer"
cRwkvGetStateLength = "rwkv_get_state_len"
cRwkvGetLogitsLength = "rwkv_get_logits_len"
cRwkvInitState = "rwkv_init_state"
cRwkvFree = "rwkv_free"
cRwkvQuantizeModelFile = "rwkv_quantize_model_file"
cRwkvGetSystemInfoString = "rwkv_get_system_info_string"
)
type GpuType string
type RwkvCtx struct {
ctx uintptr
}
type CRwkv interface {
// RwkvSetPrintErrors Sets whether errors are automatically printed to stderr.
// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails.
// - ctx: the context to suppress error messages for.
// If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors,
// as well as the default for new context.
// - print_errors: whether error messages should be automatically printed.
RwkvSetPrintErrors(ctx *RwkvCtx, enable bool)
// RwkvGetPrintErrors Gets whether errors are automatically printed to stderr.
// - ctx: the context to retrieve the setting for, or NULL for the global setting.
RwkvGetPrintErrors(ctx *RwkvCtx) bool
// RwkvGetLastError Retrieves and clears the error flags.
// - ctx: the context the retrieve the error for, or NULL for the global error.
RwkvGetLastError(ctx *RwkvCtx) error
// RwkvInitFromFile Loads the model from a file and prepares it for inference.
// Returns NULL on any error.
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
RwkvInitFromFile(filePath string, threads uint32) *RwkvCtx
// RwkvCloneContext Creates a new context from an existing one.
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
// Each rwkv_context can have one eval running at a time.
// Every rwkv_context must be freed using rwkv_free.
// - ctx: context to be cloned.
// - n_threads: count of threads to use, must be positive.
RwkvCloneContext(ctx *RwkvCtx, threads uint32) *RwkvCtx
// RwkvGpuOffloadLayers Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
RwkvGpuOffloadLayers(ctx *RwkvCtx, nGpuLayers uint32) error
// RwkvEval Evaluates the model for a single token.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RwkvEval(ctx *RwkvCtx, token uint32, stateIn []float32, stateOut []float32, logitsOut []float32) error
// RwkvEvalSequence Evaluates the model for a sequence of tokens.
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RwkvEvalSequence(ctx *RwkvCtx, token uint32, sequenceLen uint64, stateIn []float32, stateOut []float32, logitsOut []float32) error
// RwkvGetNVocab Returns the number of tokens in the given model's vocabulary.
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
RwkvGetNVocab(ctx *RwkvCtx) uint64
// RwkvGetNEmbedding Returns the number of elements in the given model's embedding.
// Useful for reading individual fields of a model's hidden state.
RwkvGetNEmbedding(ctx *RwkvCtx) uint64
// RwkvGetNLayer Returns the number of layers in the given model.
// Useful for always offloading the entire model to GPU.
RwkvGetNLayer(ctx *RwkvCtx) uint64
// RwkvGetStateLength Returns the number of float elements in a complete state for the given model.
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
RwkvGetStateLength(ctx *RwkvCtx) uint64
// RwkvGetLogitsLength Returns the number of float elements in the logits output of a given model.
// This is currently always identical to n_vocab.
RwkvGetLogitsLength(ctx *RwkvCtx) uint64
// RwkvInitState Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
// Useful in cases where tracking the first call to these functions may be annoying or expensive.
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
// - state: FP32 buffer of size rwkv_get_state_len() to initialize
RwkvInitState(ctx *RwkvCtx, state []float32)
// RwkvFree Frees all allocated memory and the context.
// Does not need to be called on the same thread that created the rwkv_context.
RwkvFree(ctx *RwkvCtx) error
// RwkvQuantizeModelFile Quantizes FP32 or FP16 model to one of quantized formats.
// Returns false on any error. Error messages would be printed to stderr.
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here.
// - format_name: must be one of available format names below.
// Available format names:
// - Q4_0
// - Q4_1
// - Q5_0
// - Q5_1
// - Q8_0
RwkvQuantizeModelFile(ctx *RwkvCtx, in, out string, format QuantizedFormat) error
// RwkvGetSystemInfoString Returns system information string.
RwkvGetSystemInfoString() string
}
type CRwkvImpl struct {
libRwkv uintptr
cRwkvSetPrintErrors func(uintptr, bool)
cRwkvGetPrintErrors func(uintptr) bool
cRwkvGetLastError func(uintptr) uint32
cRwkvInitFromFile func(modelFilePath string, nThreads uint32) uintptr
cRwkvCloneContext func(ctx uintptr, nThreads uint32) uintptr
cRwkvGpuOffloadLayers func(ctx uintptr, nGpuLayers uint32) bool
cRwkvEval func(ctx uintptr, token uint32, stateIn uintptr, stateOut uintptr, logitsOut uintptr) bool
cRwkvEvalSequence func(ctx uintptr, token uint32, sequenceLen uint64, stateIn uintptr, stateOut uintptr, logitsOut uintptr) bool
cRwkvGetNVocab func(ctx uintptr) uint64
cRwkvGetNEmbedding func(ctx uintptr) uint64
cRwkvGetNLayer func(ctx uintptr) uint64
cRwkvGetStateLength func(ctx uintptr) uint64
cRwkvGetLogitsLength func(ctx uintptr) uint64
cRwkvInitState func(ctx uintptr, state uintptr)
cRwkvFree func(ctx uintptr)
cRwkvQuantizeModelFile func(modelFilePathIn string, modelFilePathOut string, formatName string) bool
cRwkvGetSystemInfoString func() string
}
func NewCRwkv(libraryPath string) (*CRwkvImpl, error) {
libRwkv, err := openLibrary(libraryPath)
if err != nil {
return nil, err
}
var (
rwkvSetPrintErrors func(uintptr, bool)
rwkvGetPrintErrors func(uintptr) bool
rwkvGetLastError func(uintptr) uint32
rwkvInitFromFile func(modelFilePath string, nThreads uint32) uintptr
rwkvCloneContext func(ctx uintptr, nThreads uint32) uintptr
rwkvGpuOffloadLayers func(ctx uintptr, nGpuLayers uint32) bool
rwkvEval func(ctx uintptr, token uint32, stateIn uintptr, stateOut uintptr, logitsOut uintptr) bool
rwkvEvalSequence func(ctx uintptr, token uint32, sequenceLen uint64, stateIn uintptr, stateOut uintptr, logitsOut uintptr) bool
rwkvGetNVocab func(ctx uintptr) uint64
rwkvGetNEmbedding func(ctx uintptr) uint64
rwkvGetNLayer func(ctx uintptr) uint64
rwkvGetStateLength func(ctx uintptr) uint64
rwkvGetLogitsLength func(ctx uintptr) uint64
rwkvInitState func(ctx uintptr, state uintptr)
rwkvFree func(ctx uintptr)
rwkvQuantizeModelFile func(modelFilePathIn string, modelFilePathOut string, formatName string) bool
rwkvGetSystemInfoString func() string
)
purego.RegisterLibFunc(&rwkvSetPrintErrors, libRwkv, cRwkvSetPrintErrors)
purego.RegisterLibFunc(&rwkvGetPrintErrors, libRwkv, cRwkvGetPrintErrors)
purego.RegisterLibFunc(&rwkvGetLastError, libRwkv, cRwkvGetLastError)
purego.RegisterLibFunc(&rwkvInitFromFile, libRwkv, cRwkvInitFromFile)
purego.RegisterLibFunc(&rwkvCloneContext, libRwkv, cRwkvCloneContext)
purego.RegisterLibFunc(&rwkvGpuOffloadLayers, libRwkv, cRwkvGpuOffloadLayers)
purego.RegisterLibFunc(&rwkvEval, libRwkv, cRwkvEval)
purego.RegisterLibFunc(&rwkvEvalSequence, libRwkv, cRwkvEvalSequence)
purego.RegisterLibFunc(&rwkvGetNVocab, libRwkv, cRwkvGetNVocab)
purego.RegisterLibFunc(&rwkvGetNEmbedding, libRwkv, cRwkvGetNEmbedding)
purego.RegisterLibFunc(&rwkvGetNLayer, libRwkv, cRwkvGetNLayer)
purego.RegisterLibFunc(&rwkvGetStateLength, libRwkv, cRwkvGetStateLength)
purego.RegisterLibFunc(&rwkvGetLogitsLength, libRwkv, cRwkvGetLogitsLength)
purego.RegisterLibFunc(&rwkvInitState, libRwkv, cRwkvInitState)
purego.RegisterLibFunc(&rwkvFree, libRwkv, cRwkvFree)
purego.RegisterLibFunc(&rwkvQuantizeModelFile, libRwkv, cRwkvQuantizeModelFile)
purego.RegisterLibFunc(&rwkvGetSystemInfoString, libRwkv, cRwkvGetSystemInfoString)
return &CRwkvImpl{
libRwkv: libRwkv,
cRwkvSetPrintErrors: rwkvSetPrintErrors,
cRwkvGetPrintErrors: rwkvGetPrintErrors,
cRwkvGetLastError: rwkvGetLastError,
cRwkvInitFromFile: rwkvInitFromFile,
cRwkvCloneContext: rwkvCloneContext,
cRwkvGpuOffloadLayers: rwkvGpuOffloadLayers,
cRwkvEval: rwkvEval,
cRwkvEvalSequence: rwkvEvalSequence,
cRwkvGetNVocab: rwkvGetNVocab,
cRwkvGetNEmbedding: rwkvGetNEmbedding,
cRwkvGetNLayer: rwkvGetNLayer,
cRwkvGetStateLength: rwkvGetStateLength,
cRwkvGetLogitsLength: rwkvGetLogitsLength,
cRwkvInitState: rwkvInitState,
cRwkvFree: rwkvFree,
cRwkvQuantizeModelFile: rwkvQuantizeModelFile,
cRwkvGetSystemInfoString: rwkvGetSystemInfoString,
}, nil
}
func (c *CRwkvImpl) RwkvSetPrintErrors(ctx *RwkvCtx, enable bool) {
c.cRwkvSetPrintErrors(ctx.ctx, enable)
}
func (c *CRwkvImpl) RwkvGetPrintErrors(ctx *RwkvCtx) bool {
return c.cRwkvGetPrintErrors(ctx.ctx)
}
func (c *CRwkvImpl) RwkvGetLastError(ctx *RwkvCtx) error {
cErr := c.cRwkvGetLastError(ctx.ctx)
err := RwkvErrors(cErr)
if errors.Is(err, RwkvErrorNone) {
return nil
}
return err
}
func (c *CRwkvImpl) RwkvInitFromFile(filePath string, threads uint32) *RwkvCtx {
ctx := c.cRwkvInitFromFile(filePath, threads)
return &RwkvCtx{ctx: ctx}
}
func (c *CRwkvImpl) RwkvCloneContext(ctx *RwkvCtx, threads uint32) *RwkvCtx {
newCtx := c.cRwkvCloneContext(ctx.ctx, threads)
return &RwkvCtx{ctx: newCtx}
}
func (c *CRwkvImpl) RwkvGpuOffloadLayers(ctx *RwkvCtx, nGpuLayers uint32) error {
ok := c.cRwkvGpuOffloadLayers(ctx.ctx, nGpuLayers)
if !ok {
return c.RwkvGetLastError(ctx)
}
return nil
}
func (c *CRwkvImpl) RwkvEval(ctx *RwkvCtx, token uint32, stateIn []float32, stateOut []float32, logitsOut []float32) error {
ok := c.cRwkvEval(ctx.ctx, token, uintptr(unsafe.Pointer(&stateIn[0])), uintptr(unsafe.Pointer(&stateOut[0])), uintptr(unsafe.Pointer(&logitsOut[0])))
if !ok {
return c.RwkvGetLastError(ctx)
}
return nil
}
func (c *CRwkvImpl) RwkvEvalSequence(ctx *RwkvCtx, token uint32, sequenceLen uint64, stateIn []float32, stateOut []float32, logitsOut []float32) error {
ok := c.cRwkvEvalSequence(ctx.ctx, token, sequenceLen, uintptr(unsafe.Pointer(&stateIn[0])), uintptr(unsafe.Pointer(&stateOut[0])), uintptr(unsafe.Pointer(&logitsOut[0])))
if !ok {
return c.RwkvGetLastError(ctx)
}
return nil
}
func (c *CRwkvImpl) RwkvGetNVocab(ctx *RwkvCtx) uint64 {
return c.cRwkvGetNVocab(ctx.ctx)
}
func (c *CRwkvImpl) RwkvGetNEmbedding(ctx *RwkvCtx) uint64 {
return c.cRwkvGetNEmbedding(ctx.ctx)
}
func (c *CRwkvImpl) RwkvGetNLayer(ctx *RwkvCtx) uint64 {
return c.cRwkvGetNLayer(ctx.ctx)
}
func (c *CRwkvImpl) RwkvGetStateLength(ctx *RwkvCtx) uint64 {
return c.cRwkvGetStateLength(ctx.ctx)
}
func (c *CRwkvImpl) RwkvGetLogitsLength(ctx *RwkvCtx) uint64 {
return c.cRwkvGetLogitsLength(ctx.ctx)
}
func (c *CRwkvImpl) RwkvInitState(ctx *RwkvCtx, state []float32) {
c.cRwkvInitState(ctx.ctx, uintptr(unsafe.Pointer(&state[0])))
}
func (c *CRwkvImpl) RwkvFree(ctx *RwkvCtx) error {
c.cRwkvFree(ctx.ctx)
if c.libRwkv != 0 {
err := closeLibrary(c.libRwkv)
if err != nil {
return err
}
}
c.libRwkv = 0
return nil
}
func (c *CRwkvImpl) RwkvQuantizeModelFile(ctx *RwkvCtx, in, out string, format QuantizedFormat) error {
ok := c.cRwkvQuantizeModelFile(in, out, string(format))
if !ok {
return c.RwkvGetLastError(ctx)
}
return nil
}
func (c *CRwkvImpl) RwkvGetSystemInfoString() string {
return c.cRwkvGetSystemInfoString()
}