Skip to content

Commit

Permalink
更新日志
Browse files Browse the repository at this point in the history
  • Loading branch information
fruitbars committed Jun 23, 2024
1 parent d8f2f67 commit 8c7afa8
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions pkg/llm/baidu-qianfan/qianfan.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,35 @@ import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"go.uber.org/zap"
"io"
"log"
"net/http"
"simple-one-api/pkg/mylog"
"strings"
)

func QianFanCall(api_key, secret_key, model string, qfReq *QianFanRequest) (*QianFanResponse, error) {
log.Println(api_key, secret_key, model, qfReq)
mylog.Logger.Info("QianFanCall", zap.String("api_key", api_key), zap.String("secret_key", secret_key), zap.String("model", model), zap.Any("qfReq", qfReq))

accessToken := GetAccessToken(api_key, secret_key)
if accessToken == "" {
log.Println("Failed to get access token")
return nil, nil
err := errors.New("Failed to get access token")
mylog.Logger.Error(err.Error())
return nil, err
}

return SendChatRequest(accessToken, model, qfReq)
}

func QianFanCallSSE(api_key, secret_key, model string, qfReq *QianFanRequest, callback func(qfResp *QianFanResponse)) error {
log.Println(api_key, secret_key, model, qfReq)
mylog.Logger.Info("QianFanCall", zap.String("api_key", api_key), zap.String("secret_key", secret_key), zap.String("model", model), zap.Any("qfReq", qfReq))
accessToken := GetAccessToken(api_key, secret_key)
if accessToken == "" {
log.Println("Failed to get access token")
return nil
err := errors.New("Failed to get access token")
mylog.Logger.Error(err.Error())
return err
}

return SendChatRequestWithSSE(accessToken, model, qfReq, callback)
Expand All @@ -40,31 +45,32 @@ func SendChatRequestWithSSE(accessToken, model string, qfReq *QianFanRequest, ca

jsonData, err := json.Marshal(qfReq)
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return err
}

log.Println(string(jsonData))
mylog.Logger.Info(string(jsonData))

client := &http.Client{}
req, err := http.NewRequest("POST", url, bytes.NewReader(jsonData))
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return err
}
req.Header.Add("Content-Type", "application/json")

res, err := client.Do(req)
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return err
}

log.Println(url, "request ok")
mylog.Logger.Debug(url)

defer res.Body.Close()

if res.StatusCode != http.StatusOK {
log.Println("received non-200 response code:", res.StatusCode)
mylog.Logger.Error("received non-200 response code:", zap.Int("StatusCode", res.StatusCode))
return fmt.Errorf("received non-200 response code: %d", res.StatusCode)
}

Expand All @@ -76,22 +82,22 @@ func SendChatRequestWithSSE(accessToken, model string, qfReq *QianFanRequest, ca
if strings.HasPrefix(line, "data:") || len(line) > 0 {
data := strings.TrimPrefix(line, "data:")

log.Println(string(data))
mylog.Logger.Debug(string(data))

var response QianFanResponse
err = json.Unmarshal([]byte(data), &response)
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())

} else {
log.Println("execute callback")
mylog.Logger.Debug("execute callback")
callback(&response)
}
}
}

if err := scanner.Err(); err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return err
}

Expand All @@ -104,43 +110,46 @@ func SendChatRequest(accessToken, model string, qfReq *QianFanRequest) (*QianFan

jsonData, err := json.Marshal(qfReq)
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return nil, err
}

log.Println(string(jsonData))
mylog.Logger.Info(string(jsonData))

client := &http.Client{}
req, err := http.NewRequest("POST", url, bytes.NewReader(jsonData))
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return nil, err
}
req.Header.Add("Content-Type", "application/json")

res, err := client.Do(req)
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return nil, err
}
defer res.Body.Close()

body, err := io.ReadAll(res.Body)
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return nil, err
}

log.Println(string(body))
if res.StatusCode != http.StatusOK {
mylog.Logger.Error("received non-200 response code:", zap.Int("StatusCode", res.StatusCode), zap.String("body", string(body)))
return nil, fmt.Errorf("received non-200 response code: %d", res.StatusCode)
}

var response QianFanResponse
err = json.Unmarshal(body, &response)
if err != nil {
log.Println(err)
mylog.Logger.Error(err.Error())
return nil, err
}

log.Println(response)
mylog.Logger.Info("", zap.Any("response", response))

return &response, nil
}
Expand All @@ -165,31 +174,31 @@ func GetAccessToken(api_key, secret_key string) string {
postData := fmt.Sprintf("grant_type=client_credentials&client_id=%s&client_secret=%s", api_key, secret_key)
resp, err := http.Post(url, "application/x-www-form-urlencoded", strings.NewReader(postData))
if err != nil {
log.Printf("Failed to send request for access token: %v", err)
mylog.Logger.Error(err.Error())
return ""
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read response body: %v", err)
mylog.Logger.Error(err.Error())
return ""
}

// log.Printf("AccessToken response body: %s", body)

var accessTokenObj map[string]interface{}
if err := json.Unmarshal(body, &accessTokenObj); err != nil {
log.Printf("Failed to unmarshal response body: %v", err)
mylog.Logger.Error(err.Error())
return ""
}

if token, ok := accessTokenObj["access_token"].(string); ok {
return token
} else if errDesc, ok := accessTokenObj["error_description"].(string); ok {
log.Printf("Error in getting access token: %s", errDesc)
mylog.Logger.Error("Error in getting access token:", zap.String("errDesc", errDesc))
} else {
log.Printf("Unknown error in access token response")
mylog.Logger.Error("Unknown error in access token response")
}
return ""
}

0 comments on commit 8c7afa8

Please sign in to comment.