diff --git a/common/utils.go b/common/utils.go index d65d42a67..f9f3bc250 100644 --- a/common/utils.go +++ b/common/utils.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "time" + "unsafe" ) func OpenBrowser(url string) { @@ -142,7 +143,7 @@ func init() { } func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) + //rand.Seed(time.Now().UnixNano()) key := make([]byte, 48) for i := 0; i < 16; i++ { key[i] = keyChars[rand.Intn(len(keyChars))] @@ -159,7 +160,7 @@ func GenerateKey() string { } func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) + //rand.Seed(time.Now().UnixNano()) key := make([]byte, length) for i := 0; i < length; i++ { key[i] = keyChars[rand.Intn(len(keyChars))] @@ -216,3 +217,10 @@ func StringsContains(strs []string, str string) bool { } return false } + +// []byte only read, panic on append +func StringToByteSlice(s string) []byte { + tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) + tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} + return *(*[]byte)(unsafe.Pointer(&tmp2)) +} diff --git a/controller/relay-openai.go b/controller/relay-openai.go index f06d8b776..33c5b8f42 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -12,7 +12,7 @@ import ( ) func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { - responseText := "" + var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -29,6 +29,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O dataChan := make(chan string) stopChan := make(chan bool) go func() { + var streamItems []string for scanner.Scan() { data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format @@ -40,27 +41,33 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { - switch relayMode { - case RelayModeChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue // just ignore the error - } - for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content - } - case RelayModeCompletions: - var streamResponse CompletionsStreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue - } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } + streamItems = append(streamItems, data) + } + } + streamResp := "[" + strings.Join(streamItems, ",") + "]" + switch relayMode { + case RelayModeChatCompletions: + var streamResponses []ChatCompletionsStreamResponseSimple + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return // just ignore the error + } + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.Content) + } + } + case RelayModeCompletions: + var streamResponses []CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return // just ignore the error + } + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) } } } @@ -85,7 +92,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } - return nil, responseText + return nil, responseTextBuilder.String() } func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { diff --git a/controller/relay.go b/controller/relay.go index 714910c36..761b2d8e6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -189,6 +189,10 @@ type ChatCompletionsStreamResponse struct { Choices []ChatCompletionsStreamResponseChoice `json:"choices"` } +type ChatCompletionsStreamResponseSimple struct { + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` +} + type CompletionsStreamResponse struct { Choices []struct { Text string `json:"text"` diff --git a/go.mod b/go.mod index d7ac8c136..2953d0d8c 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 - github.com/pkoukk/tiktoken-go v0.1.1 + github.com/pkoukk/tiktoken-go v0.1.6 github.com/samber/lo v1.38.1 github.com/shirou/gopsutil v3.21.11+incompatible github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 @@ -31,7 +31,7 @@ require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dlclark/regexp2 v1.8.1 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect diff --git a/go.sum b/go.sum index 7e4cfa73c..9fa8e7193 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= -github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= @@ -124,8 +124,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo= -github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= +github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= +github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= diff --git a/main.go b/main.go index d4a0f4b58..66e9029b1 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,8 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" + "log" + "net/http" "one-api/common" "one-api/controller" "one-api/middleware" @@ -13,6 +15,8 @@ import ( "one-api/router" "os" "strconv" + + _ "net/http/pprof" ) //go:embed web/build @@ -85,6 +89,9 @@ func main() { } if os.Getenv("ENABLE_PPROF") == "true" { + go func() { + log.Println(http.ListenAndServe("0.0.0.0:8005", nil)) + }() go common.Monitor() common.SysLog("pprof enabled") }