Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cloudflare support native openai api #1596

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions relay/adaptor/cloudflare/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)

type Adaptor struct {
Expand All @@ -28,7 +29,14 @@
}

func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil

Check warning on line 38 in relay/adaptor/cloudflare/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/adaptor.go#L32-L38

Added lines #L32 - L38 were not covered by tests
}
}

func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
Expand All @@ -41,7 +49,14 @@
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
switch relayMode {
case relaymode.Completions:
return ConvertCompletionsRequest(*request), nil
case relaymode.ChatCompletions, relaymode.Embeddings:
return request, nil
default:
return nil, errors.New("not implemented")

Check warning on line 58 in relay/adaptor/cloudflare/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/adaptor.go#L52-L58

Added lines #L52 - L58 were not covered by tests
}
}

func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
Expand Down
87 changes: 29 additions & 58 deletions relay/adaptor/cloudflare/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"

"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
Expand All @@ -16,57 +18,23 @@
"github.com/songquanpeng/one-api/relay/model"
)

func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
var promptBuilder strings.Builder
for _, message := range textRequest.Messages {
promptBuilder.WriteString(message.StringContent())
promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
}

func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
p, _ := textRequest.Prompt.(string)

Check warning on line 22 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L21-L22

Added lines #L21 - L22 were not covered by tests
return &Request{
Prompt: p,

Check warning on line 24 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L24

Added line #L24 was not covered by tests
MaxTokens: textRequest.MaxTokens,
Prompt: promptBuilder.String(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}

func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cloudflareResponse.Result.Response,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}

func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareResponse.Response
choice.Delta.Role = "assistant"
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
Created: helper.GetTimestamp(),
}
return &openaiResponse
}

func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)

common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
responseModel := c.GetString("original_model")
responseModel := c.GetString(ctxkey.OriginalModel)

Check warning on line 37 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L37

Added line #L37 was not covered by tests
var responseText string

for scanner.Scan() {
Expand All @@ -77,22 +45,22 @@
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")

var cloudflareResponse StreamResponse
err := json.Unmarshal([]byte(data), &cloudflareResponse)
if data == "[DONE]" {
break

Check warning on line 49 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L48-L49

Added lines #L48 - L49 were not covered by tests
}

var response openai.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)

Check warning on line 53 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L52-L53

Added lines #L52 - L53 were not covered by tests
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}

response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
if response == nil {
continue
for _, v := range response.Choices {
v.Delta.Role = "assistant"
responseText += v.Delta.StringContent()

Check warning on line 60 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L58-L60

Added lines #L58 - L60 were not covered by tests
}

responseText += cloudflareResponse.Response
response.Id = id
response.Model = responseModel

response.Model = modelName

Check warning on line 63 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L63

Added line #L63 was not covered by tests
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
Expand Down Expand Up @@ -123,22 +91,25 @@
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cloudflareResponse Response
err = json.Unmarshal(responseBody, &cloudflareResponse)
var response openai.TextResponse
err = json.Unmarshal(responseBody, &response)

Check warning on line 95 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L94-L95

Added lines #L94 - L95 were not covered by tests
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse)
fullTextResponse.Model = modelName
usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens)
fullTextResponse.Usage = *usage
fullTextResponse.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(fullTextResponse)
response.Model = modelName
var responseText string
for _, v := range response.Choices {
responseText += v.Message.Content.(string)

Check warning on line 102 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L99-L102

Added lines #L99 - L102 were not covered by tests
}
usage := openai.ResponseText2Usage(responseText, modelName, promptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)

Check warning on line 107 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L104-L107

Added lines #L104 - L107 were not covered by tests
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
_, _ = c.Writer.Write(jsonResponse)

Check warning on line 113 in relay/adaptor/cloudflare/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/cloudflare/main.go#L113

Added line #L113 was not covered by tests
return nil, usage
}
30 changes: 9 additions & 21 deletions relay/adaptor/cloudflare/model.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,13 @@
package cloudflare

type Request struct {
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}

type Result struct {
Response string `json:"response"`
}
import "github.com/songquanpeng/one-api/relay/model"

type Response struct {
Result Result `json:"result"`
Success bool `json:"success"`
Errors []string `json:"errors"`
Messages []string `json:"messages"`
}

type StreamResponse struct {
Response string `json:"response"`
type Request struct {
Messages []model.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
Loading