Skip to content

Commit

Permalink
feat(chat): add param response_format
Browse files Browse the repository at this point in the history
Closes #22
  • Loading branch information
northes committed Aug 2, 2024
1 parent ef6c6e3 commit f00fd08
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 17 deletions.
29 changes: 17 additions & 12 deletions api_chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,25 @@ type ChatCompletionsMessage struct {
}

type ChatCompletionsRequest struct {
Messages []*ChatCompletionsMessage `json:"messages"`
Model ChatCompletionsModelID `json:"model"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
N int `json:"n"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
Stop []string `json:"stop"`
Stream bool `json:"stream"`
Messages []*ChatCompletionsMessage `json:"messages"`
Model ChatCompletionsModelID `json:"model"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
N int `json:"n"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
ResponseFormat *ChatCompletionsRequestResponseFormat `json:"response_format"`
Stop []string `json:"stop"`
Stream bool `json:"stream"`
// When you use a tool, you need to define it
Tools []*ChatCompletionsTool `json:"tools,omitempty"`
}

type ChatCompletionsRequestResponseFormat struct {
Type ChatCompletionsResponseFormatType `json:"type"`
}

type ChatCompletionsResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Expand Down Expand Up @@ -162,7 +167,7 @@ func (c *ChatCompletionsStreamResponse) Receive() <-chan *ChatCompletionsStreamR
for {
line, err := reader.ReadBytes('\n')
rr := ChatCompletionsStreamResponseReceive{}
//slog.Debug("next line", string(line))
// slog.Debug("next line", string(line))
if err != nil {
if err == io.EOF {
c.sendWithFinish(receiveCh)
Expand All @@ -175,7 +180,7 @@ func (c *ChatCompletionsStreamResponse) Receive() <-chan *ChatCompletionsStreamR
prefix := []byte("data: ")

if !bytes.HasPrefix(line, prefix) {
//slog.Debug("no hava prefix,continue", slog.String("line", string(line)))
// slog.Debug("no hava prefix,continue", slog.String("line", string(line)))
continue
}

Expand Down
8 changes: 8 additions & 0 deletions chat_completions_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type IChatCompletionsBuilder interface {
SetTool(tool *ChatCompletionsTool) IChatCompletionsBuilder
SetTools(tools []*ChatCompletionsTool) IChatCompletionsBuilder
SetContextCacheContent(content *ContextCacheContent) IChatCompletionsBuilder
SetResponseFormat(format ChatCompletionsResponseFormatType) IChatCompletionsBuilder

ToRequest() *ChatCompletionsRequest
}
Expand Down Expand Up @@ -214,6 +215,13 @@ func (c *chatCompletionsBuilder) SetContextCacheContent(content *ContextCacheCon
return c
}

func (c *chatCompletionsBuilder) SetResponseFormat(format ChatCompletionsResponseFormatType) IChatCompletionsBuilder {
c.req.ResponseFormat = &ChatCompletionsRequestResponseFormat{
Type: format,
}
return c
}

// ToRequest returns the ChatCompletionsRequest
func (c *chatCompletionsBuilder) ToRequest() *ChatCompletionsRequest {
return c.req
Expand Down
16 changes: 12 additions & 4 deletions chat_completions_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package moonshot_test
import (
"testing"

"github.com/northes/go-moonshot"
"github.com/stretchr/testify/require"

"github.com/northes/go-moonshot"
)

func TestNewChatCompletionsBuilder(t *testing.T) {
Expand All @@ -21,7 +22,7 @@ func TestNewChatCompletionsBuilder(t *testing.T) {
functionName2 = "function2"
)

wantedReq := &moonshot.ChatCompletionsRequest{
var wantedReq = &moonshot.ChatCompletionsRequest{
Messages: []*moonshot.ChatCompletionsMessage{
{
Role: moonshot.RoleContextCache,
Expand Down Expand Up @@ -51,8 +52,11 @@ func TestNewChatCompletionsBuilder(t *testing.T) {
N: 1,
PresencePenalty: 1.2,
FrequencyPenalty: 1.5,
Stop: []string{"结束"},
Stream: true,
ResponseFormat: &moonshot.ChatCompletionsRequestResponseFormat{
Type: moonshot.ChatCompletionsResponseFormatJSONObject,
},
Stop: []string{"结束"},
Stream: true,
Tools: []*moonshot.ChatCompletionsTool{{
Type: moonshot.ChatCompletionsToolTypeFunction,
Function: &moonshot.ChatCompletionsToolFunction{
Expand Down Expand Up @@ -91,6 +95,7 @@ func TestNewChatCompletionsBuilder(t *testing.T) {
SetN(1).
SetPresencePenalty(1.2).
SetFrequencyPenalty(1.5).
SetResponseFormat(moonshot.ChatCompletionsResponseFormatJSONObject).
SetStop([]string{"结束"}).
SetStream(true).
SetTool(&moonshot.ChatCompletionsTool{
Expand Down Expand Up @@ -126,4 +131,7 @@ func TestNewChatCompletionsBuilder(t *testing.T) {

builder2.SetPresencePenalty(2)
tt.NotEqual(wantedReq, builder2.ToRequest())

builder2.SetResponseFormat(moonshot.ChatCompletionsResponseFormatText)
tt.NotEqual(wantedReq, builder2.ToRequest())
}
11 changes: 11 additions & 0 deletions enum_chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,14 @@ const (
func (c ChatCompletionsParametersType) String() string {
return string(c)
}

type ChatCompletionsResponseFormatType string

const (
ChatCompletionsResponseFormatJSONObject ChatCompletionsResponseFormatType = "json_object"
ChatCompletionsResponseFormatText ChatCompletionsResponseFormatType = "text"
)

func (c ChatCompletionsResponseFormatType) String() string {
return string(c)
}
6 changes: 5 additions & 1 deletion enum_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package moonshot_test
import (
"testing"

"github.com/northes/go-moonshot"
"github.com/stretchr/testify/require"

"github.com/northes/go-moonshot"
)

func TestEnumChatCompletions(t *testing.T) {
Expand All @@ -25,4 +26,7 @@ func TestEnumChatCompletions(t *testing.T) {
tt.EqualValues(moonshot.ChatCompletionsToolTypeFunction, moonshot.ChatCompletionsToolTypeFunction.String())

tt.EqualValues(moonshot.ChatCompletionsParametersTypeObject, moonshot.ChatCompletionsParametersTypeObject.String())

tt.EqualValues(moonshot.ChatCompletionsResponseFormatJSONObject, moonshot.ChatCompletionsResponseFormatJSONObject.String())
tt.EqualValues(moonshot.ChatCompletionsResponseFormatText, moonshot.ChatCompletionsResponseFormatText.String())
}

0 comments on commit f00fd08

Please sign in to comment.