This repository has been archived by the owner on Oct 26, 2023. It is now read-only.
forked from NouamaneTazi/bloomz.cpp
-
Notifications
You must be signed in to change notification settings - Fork 1
/
bloomz.go
58 lines (47 loc) · 1.42 KB
/
bloomz.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package bloomz
// #cgo LDFLAGS: -lbloomz -lm -lstdc++ -L./
// #cgo darwin LDFLAGS: -framework Accelerate
// #cgo darwin CXXFLAGS: -std=c++11
// #include <bloomz.h>
import "C"
import (
"fmt"
"strings"
"unsafe"
)
type Bloomz struct {
state unsafe.Pointer
}
func New(model string, opts ...ModelOption) (*Bloomz, error) {
mo := NewModelOptions(opts...)
state := C.bloomz_allocate_state()
modelPath := C.CString(model)
result := C.bloomz_bootstrap(modelPath, state, C.int(mo.ContextSize), C.bool(mo.F16Memory))
if result != 0 {
return nil, fmt.Errorf("failed loading model")
}
return &Bloomz{state: state}, nil
}
func (l *Bloomz) Free() {
C.bloomz_free_model(l.state)
}
func (l *Bloomz) Predict(text string, opts ...PredictOption) (string, error) {
po := NewPredictOptions(opts...)
input := C.CString(text)
if po.Tokens == 0 {
po.Tokens = 99999999
}
out := make([]byte, po.Tokens)
params := C.bloomz_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat))
ret := C.bloomz_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])))
if ret != 0 {
return "", fmt.Errorf("inference failed")
}
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
res = strings.TrimPrefix(res, " ")
res = strings.TrimPrefix(res, text)
res = strings.TrimPrefix(res, "\n")
C.bloomz_free_params(params)
return res, nil
}