diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go index 6ff6b0d344..be2fb4ab42 100644 --- a/relay/adaptor/cloudflare/adaptor.go +++ b/relay/adaptor/cloudflare/adaptor.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" ) type Adaptor struct { @@ -28,7 +29,14 @@ func (a *Adaptor) Init(meta *meta.Meta) { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil + switch meta.Mode { + case relaymode.ChatCompletions: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil + default: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { @@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - return ConvertRequest(*request), nil + switch relayMode { + case relaymode.Completions: + return ConvertCompletionsRequest(*request), nil + case relaymode.ChatCompletions, relaymode.Embeddings: + return request, nil + default: + return nil, errors.New("not implemented") + } } func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go index c76520a28a..980a2891a5 100644 --- a/relay/adaptor/cloudflare/main.go +++ b/relay/adaptor/cloudflare/main.go @@ -3,11 +3,13 @@ package cloudflare import ( "bufio" "encoding/json" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/render" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -16,57 +18,23 @@ import ( "github.com/songquanpeng/one-api/relay/model" ) -func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { - var promptBuilder strings.Builder - for _, message := range textRequest.Messages { - promptBuilder.WriteString(message.StringContent()) - promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 - } - +func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request { + p, _ := textRequest.Prompt.(string) return &Request{ + Prompt: p, MaxTokens: textRequest.MaxTokens, - Prompt: promptBuilder.String(), Stream: textRequest.Stream, Temperature: textRequest.Temperature, } } -func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: cloudflareResponse.Result.Response, - }, - FinishReason: "stop", - } - fullTextResponse := openai.TextResponse{ - Object: "chat.completion", - Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, - } - return &fullTextResponse -} - -func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { - var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = cloudflareResponse.Response - choice.Delta.Role = "assistant" - openaiResponse := openai.ChatCompletionsStreamResponse{ - Object: "chat.completion.chunk", - Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, - Created: helper.GetTimestamp(), - } - return &openaiResponse -} - func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) common.SetEventStreamHeaders(c) id := helper.GetResponseID(c) - responseModel := c.GetString("original_model") + responseModel := c.GetString(ctxkey.OriginalModel) var responseText string for scanner.Scan() { @@ -77,22 +45,22 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN data = strings.TrimPrefix(data, "data: ") data = strings.TrimSuffix(data, "\r") - var cloudflareResponse StreamResponse - err := json.Unmarshal([]byte(data), &cloudflareResponse) + if data == "[DONE]" { + break + } + + var response openai.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) continue } - - response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) - if response == nil { - continue + for _, v := range response.Choices { + v.Delta.Role = "assistant" + responseText += v.Delta.StringContent() } - - responseText += cloudflareResponse.Response response.Id = id - response.Model = responseModel - + response.Model = modelName err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) @@ -123,22 +91,25 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var cloudflareResponse Response - err = json.Unmarshal(responseBody, &cloudflareResponse) + var response openai.TextResponse + err = json.Unmarshal(responseBody, &response) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) - fullTextResponse.Model = modelName - usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) - fullTextResponse.Usage = *usage - fullTextResponse.Id = helper.GetResponseID(c) - jsonResponse, err := json.Marshal(fullTextResponse) + response.Model = modelName + var responseText string + for _, v := range response.Choices { + responseText += v.Message.Content.(string) + } + usage := openai.ResponseText2Usage(responseText, modelName, promptTokens) + response.Usage = *usage + response.Id = helper.GetResponseID(c) + jsonResponse, err := json.Marshal(response) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) + _, _ = c.Writer.Write(jsonResponse) return nil, usage } diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go index 0664ecd169..0d3bafe098 100644 --- a/relay/adaptor/cloudflare/model.go +++ b/relay/adaptor/cloudflare/model.go @@ -1,25 +1,13 @@ package cloudflare -type Request struct { - Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Prompt string `json:"prompt,omitempty"` - Raw bool `json:"raw,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` -} - -type Result struct { - Response string `json:"response"` -} +import "github.com/songquanpeng/one-api/relay/model" -type Response struct { - Result Result `json:"result"` - Success bool `json:"success"` - Errors []string `json:"errors"` - Messages []string `json:"messages"` -} - -type StreamResponse struct { - Response string `json:"response"` +type Request struct { + Messages []model.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` }