Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: support Groq #107

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ const (
ChannelTypeDeepseek = 28
ChannelTypeMoonshot = 29
ChannelTypeMistral = 30
ChannelTypeGroq = 31
)

var ChannelBaseURLs = []string{
Expand Down Expand Up @@ -231,6 +232,7 @@ var ChannelBaseURLs = []string{
"https://api.deepseek.com", //28
"https://api.moonshot.cn", //29
"https://api.mistral.ai", //30
"https://api.groq.com/openai", //30
}

const (
Expand Down
8 changes: 8 additions & 0 deletions common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ func init() {
"mistral-medium-latest": {[]float64{1.35, 4.05}, ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
"mistral-large-latest": {[]float64{4, 12}, ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
"mistral-embed": {[]float64{0.05, 0.05}, ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens

// $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens
"llama2-70b-4096": {[]float64{0.35, 0.4}, ChannelTypeGroq},
// $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens
"llama2-7b-2048": {[]float64{0.05, 0.05}, ChannelTypeGroq},
"gemma-7b-it": {[]float64{0.05, 0.05}, ChannelTypeGroq},
// $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens
"mixtral-8x7b-32768": {[]float64{0.135, 0.135}, ChannelTypeGroq},
}

ModelRatio = make(map[string][]float64)
Expand Down
1 change: 1 addition & 0 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func init() {
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
}
}

Expand Down
35 changes: 35 additions & 0 deletions providers/groq/base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package groq

import (
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
)

// 定义供应商工厂
type GroqProviderFactory struct{}

// 创建 GroqProvider
func (f GroqProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &GroqProvider{
OpenAIProvider: openai.OpenAIProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, openai.RequestErrorHandle),
},
},
}
}

func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.groq.com/openai",
ChatCompletions: "/v1/chat/completions",
}
}

type GroqProvider struct {
openai.OpenAIProvider
}
82 changes: 82 additions & 0 deletions providers/groq/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package groq

import (
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/providers/openai"
"one-api/types"
)

func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
p.getChatRequestBody(request)

req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()

response := &openai.OpenAIProviderChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}

// 检测是否错误
openaiErr := openai.ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}

*p.Usage = *response.Usage

return &response.ChatCompletionResponse, nil
}

func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
p.getChatRequestBody(request)
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()

// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}

chatHandler := openai.OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
}

return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
}

// 获取聊天请求体
func (p *GroqProvider) getChatRequestBody(request *types.ChatCompletionRequest) {
if request.Tools != nil {
request.Tools = nil
}

if request.ToolChoice != nil {
request.ToolChoice = nil
}

if request.ResponseFormat != nil {
request.ResponseFormat = nil
}

if request.N > 1 {
request.N = 1
}

}
2 changes: 2 additions & 0 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"one-api/providers/closeai"
"one-api/providers/deepseek"
"one-api/providers/gemini"
"one-api/providers/groq"
"one-api/providers/minimax"
"one-api/providers/mistral"
"one-api/providers/openai"
Expand Down Expand Up @@ -60,6 +61,7 @@ func init() {
providerFactories[common.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{}
providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}

}

Expand Down
3 changes: 2 additions & 1 deletion types/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ type ChatCompletionRequest struct {
Seed *int `json:"seed,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias any `json:"logit_bias,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
LogProbs *bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
User string `json:"user,omitempty"`
Functions []*ChatCompletionFunction `json:"functions,omitempty"`
Expand Down Expand Up @@ -172,6 +172,7 @@ type ChatCompletionTool struct {
type ChatCompletionChoice struct {
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
LogProbs any `json:"logprobs,omitempty"`
FinishReason any `json:"finish_reason,omitempty"`
ContentFilterResults any `json:"content_filter_results,omitempty"`
FinishDetails any `json:"finish_details,omitempty"`
Expand Down
6 changes: 6 additions & 0 deletions web/src/constants/ChannelConstants.js
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ export const CHANNEL_OPTIONS = {
value: 30,
color: 'orange'
},
31: {
key: 31,
text: 'Groq',
value: 31,
color: 'primary'
},
24: {
key: 24,
text: 'Azure Speech',
Expand Down
7 changes: 7 additions & 0 deletions web/src/views/Channel/type/Config.js
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ const typeConfig = {
test_model: 'open-mistral-7b'
},
modelGroup: 'Mistral'
},
31: {
input: {
models: ['llama2-7b-2048', 'llama2-70b-4096', 'mixtral-8x7b-32768', 'gemma-7b-it'],
test_model: 'llama2-7b-2048'
},
modelGroup: 'Groq'
}
};

Expand Down
Loading