diff --git a/api_chat_completions.go b/api_chat_completions.go index 9259e63..e4249a2 100644 --- a/api_chat_completions.go +++ b/api_chat_completions.go @@ -39,20 +39,25 @@ type ChatCompletionsMessage struct { } type ChatCompletionsRequest struct { - Messages []*ChatCompletionsMessage `json:"messages"` - Model ChatCompletionsModelID `json:"model"` - MaxTokens int `json:"max_tokens"` - Temperature float64 `json:"temperature"` - TopP float64 `json:"top_p"` - N int `json:"n"` - PresencePenalty float64 `json:"presence_penalty"` - FrequencyPenalty float64 `json:"frequency_penalty"` - Stop []string `json:"stop"` - Stream bool `json:"stream"` + Messages []*ChatCompletionsMessage `json:"messages"` + Model ChatCompletionsModelID `json:"model"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + N int `json:"n"` + PresencePenalty float64 `json:"presence_penalty"` + FrequencyPenalty float64 `json:"frequency_penalty"` + ResponseFormat *ChatCompletionsRequestResponseFormat `json:"response_format"` + Stop []string `json:"stop"` + Stream bool `json:"stream"` // When you use a tool, you need to define it Tools []*ChatCompletionsTool `json:"tools,omitempty"` } +type ChatCompletionsRequestResponseFormat struct { + Type ChatCompletionsResponseFormatType `json:"type"` +} + type ChatCompletionsResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -162,7 +167,7 @@ func (c *ChatCompletionsStreamResponse) Receive() <-chan *ChatCompletionsStreamR for { line, err := reader.ReadBytes('\n') rr := ChatCompletionsStreamResponseReceive{} - //slog.Debug("next line", string(line)) + // slog.Debug("next line", string(line)) if err != nil { if err == io.EOF { c.sendWithFinish(receiveCh) @@ -175,7 +180,7 @@ func (c *ChatCompletionsStreamResponse) Receive() <-chan *ChatCompletionsStreamR prefix := []byte("data: ") if !bytes.HasPrefix(line, prefix) { - //slog.Debug("no hava prefix,continue", slog.String("line", string(line))) + // slog.Debug("no hava prefix,continue", slog.String("line", string(line))) continue } diff --git a/chat_completions_builder.go b/chat_completions_builder.go index b99b52f..08a058a 100644 --- a/chat_completions_builder.go +++ b/chat_completions_builder.go @@ -21,6 +21,7 @@ type IChatCompletionsBuilder interface { SetTool(tool *ChatCompletionsTool) IChatCompletionsBuilder SetTools(tools []*ChatCompletionsTool) IChatCompletionsBuilder SetContextCacheContent(content *ContextCacheContent) IChatCompletionsBuilder + SetResponseFormat(format ChatCompletionsResponseFormatType) IChatCompletionsBuilder ToRequest() *ChatCompletionsRequest } @@ -214,6 +215,13 @@ func (c *chatCompletionsBuilder) SetContextCacheContent(content *ContextCacheCon return c } +func (c *chatCompletionsBuilder) SetResponseFormat(format ChatCompletionsResponseFormatType) IChatCompletionsBuilder { + c.req.ResponseFormat = &ChatCompletionsRequestResponseFormat{ + Type: format, + } + return c +} + // ToRequest returns the ChatCompletionsRequest func (c *chatCompletionsBuilder) ToRequest() *ChatCompletionsRequest { return c.req diff --git a/chat_completions_builder_test.go b/chat_completions_builder_test.go index cd9b053..d406273 100644 --- a/chat_completions_builder_test.go +++ b/chat_completions_builder_test.go @@ -3,8 +3,9 @@ package moonshot_test import ( "testing" - "github.com/northes/go-moonshot" "github.com/stretchr/testify/require" + + "github.com/northes/go-moonshot" ) func TestNewChatCompletionsBuilder(t *testing.T) { @@ -21,7 +22,7 @@ func TestNewChatCompletionsBuilder(t *testing.T) { functionName2 = "function2" ) - wantedReq := &moonshot.ChatCompletionsRequest{ + var wantedReq = &moonshot.ChatCompletionsRequest{ Messages: []*moonshot.ChatCompletionsMessage{ { Role: moonshot.RoleContextCache, @@ -51,8 +52,11 @@ func TestNewChatCompletionsBuilder(t *testing.T) { N: 1, PresencePenalty: 1.2, FrequencyPenalty: 1.5, - Stop: []string{"结束"}, - Stream: true, + ResponseFormat: &moonshot.ChatCompletionsRequestResponseFormat{ + Type: moonshot.ChatCompletionsResponseFormatJSONObject, + }, + Stop: []string{"结束"}, + Stream: true, Tools: []*moonshot.ChatCompletionsTool{{ Type: moonshot.ChatCompletionsToolTypeFunction, Function: &moonshot.ChatCompletionsToolFunction{ @@ -91,6 +95,7 @@ func TestNewChatCompletionsBuilder(t *testing.T) { SetN(1). SetPresencePenalty(1.2). SetFrequencyPenalty(1.5). + SetResponseFormat(moonshot.ChatCompletionsResponseFormatJSONObject). SetStop([]string{"结束"}). SetStream(true). SetTool(&moonshot.ChatCompletionsTool{ @@ -126,4 +131,7 @@ func TestNewChatCompletionsBuilder(t *testing.T) { builder2.SetPresencePenalty(2) tt.NotEqual(wantedReq, builder2.ToRequest()) + + builder2.SetResponseFormat(moonshot.ChatCompletionsResponseFormatText) + tt.NotEqual(wantedReq, builder2.ToRequest()) } diff --git a/enum_chat_completions.go b/enum_chat_completions.go index 21fff69..fd34756 100644 --- a/enum_chat_completions.go +++ b/enum_chat_completions.go @@ -67,3 +67,14 @@ const ( func (c ChatCompletionsParametersType) String() string { return string(c) } + +type ChatCompletionsResponseFormatType string + +const ( + ChatCompletionsResponseFormatJSONObject ChatCompletionsResponseFormatType = "json_object" + ChatCompletionsResponseFormatText ChatCompletionsResponseFormatType = "text" +) + +func (c ChatCompletionsResponseFormatType) String() string { + return string(c) +} diff --git a/enum_chat_completions_test.go b/enum_chat_completions_test.go index b31232f..32d953a 100644 --- a/enum_chat_completions_test.go +++ b/enum_chat_completions_test.go @@ -3,8 +3,9 @@ package moonshot_test import ( "testing" - "github.com/northes/go-moonshot" "github.com/stretchr/testify/require" + + "github.com/northes/go-moonshot" ) func TestEnumChatCompletions(t *testing.T) { @@ -25,4 +26,7 @@ func TestEnumChatCompletions(t *testing.T) { tt.EqualValues(moonshot.ChatCompletionsToolTypeFunction, moonshot.ChatCompletionsToolTypeFunction.String()) tt.EqualValues(moonshot.ChatCompletionsParametersTypeObject, moonshot.ChatCompletionsParametersTypeObject.String()) + + tt.EqualValues(moonshot.ChatCompletionsResponseFormatJSONObject, moonshot.ChatCompletionsResponseFormatJSONObject.String()) + tt.EqualValues(moonshot.ChatCompletionsResponseFormatText, moonshot.ChatCompletionsResponseFormatText.String()) }