Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow more input types to functions, fix tests #377

Merged
merged 23 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,23 @@ type ChatCompletionRequest struct {
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Functions []*FunctionDefine `json:"functions,omitempty"`
FunctionCall string `json:"function_call,omitempty"`
Functions []FunctionDefinition `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
}

type FunctionDefine struct {
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
// it's required in function call
Parameters *FunctionParams `json:"parameters"`
// Parameters is an object describing the function.
// You can pass a raw byte array describing the schema,
// or you can pass in a struct which serializes to the proper JSONSchema.
// The JSONSchemaDefinition struct is provided for convenience, but you should
// consider another specialized library for more complex schemas.
Parameters any `json:"parameters"`
}

type FunctionParams struct {
// the Type must be JSONSchemaTypeObject
Type JSONSchemaType `json:"type"`
Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
// Deprecated: use FunctionDefinition instead.
type FunctionDefine = FunctionDefinition

type JSONSchemaType string

Expand All @@ -83,22 +83,26 @@ const (
JSONSchemaTypeBoolean JSONSchemaType = "boolean"
)

// JSONSchemaDefine is a struct for JSON Schema.
type JSONSchemaDefine struct {
// JSONSchemaDefinition is a struct for JSON Schema.
// It is fairly limited and you may have better luck using a third-party library.
type JSONSchemaDefinition struct {
// Type is a type of JSON Schema.
Type JSONSchemaType `json:"type,omitempty"`
// Description is a description of JSON Schema.
Description string `json:"description,omitempty"`
// Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString.
Enum []string `json:"enum,omitempty"`
// Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject.
Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"`
Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"`
// Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject.
Required []string `json:"required,omitempty"`
// Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray.
Items *JSONSchemaDefine `json:"items,omitempty"`
Items *JSONSchemaDefinition `json:"items,omitempty"`
}

// Deprecated: use JSONSchemaDefinition instead.
type JSONSchemaDefine = JSONSchemaDefinition

type FinishReason string

const (
Expand Down
157 changes: 154 additions & 3 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,130 @@ func TestChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}

// TestChatCompletionsFunctions tests including a function call.
func TestChatCompletionsFunctions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
t.Run("bytes", func(t *testing.T) {
//nolint:lll
msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`)
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []FunctionDefine{{
Name: "test",
Parameters: &msg,
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
t.Run("struct", func(t *testing.T) {
type testMessage struct {
Count int `json:"count"`
Words []string `json:"words"`
}
msg := testMessage{
Count: 2,
Words: []string{"hello", "world"},
}
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []FunctionDefinition{{
Name: "test",
Parameters: &msg,
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
t.Run("JSONSchemaDefine", func(t *testing.T) {
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []FunctionDefinition{{
Name: "test",
Parameters: &JSONSchemaDefinition{
Type: JSONSchemaTypeObject,
Properties: map[string]JSONSchemaDefinition{
"count": {
Type: JSONSchemaTypeNumber,
Description: "total number of words in sentence",
},
"words": {
Type: JSONSchemaTypeArray,
Description: "list of words in sentence",
Items: &JSONSchemaDefinition{
Type: JSONSchemaTypeString,
},
},
"enumTest": {
Type: JSONSchemaTypeString,
Enum: []string{"hello", "world"},
},
},
},
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) {
// this is a compatibility check
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []FunctionDefine{{
Name: "test",
Parameters: &JSONSchemaDefine{
Type: JSONSchemaTypeObject,
Properties: map[string]JSONSchemaDefine{
"count": {
Type: JSONSchemaTypeNumber,
Description: "total number of words in sentence",
},
"words": {
Type: JSONSchemaTypeArray,
Description: "list of words in sentence",
Items: &JSONSchemaDefine{
Type: JSONSchemaTypeString,
},
},
"enumTest": {
Type: JSONSchemaTypeString,
Enum: []string{"hello", "world"},
},
},
},
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
}

func TestAzureChatCompletions(t *testing.T) {
client, server, teardown := setupAzureTestServer()
defer teardown()
Expand Down Expand Up @@ -109,7 +233,34 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
n := completionReq.N
if n == 0 {
n = 1
}
for i := 0; i < n; i++ {
// if there are functions, include them
if len(completionReq.Functions) > 0 {
var fcb []byte
b := completionReq.Functions[0].Parameters
fcb, err = json.Marshal(b)
if err != nil {
http.Error(w, "could not marshal function parameters", http.StatusInternalServerError)
return
}

res.Choices = append(res.Choices, ChatCompletionChoice{
Message: ChatCompletionMessage{
Role: ChatMessageRoleFunction,
// this is valid json so it should be fine
FunctionCall: &FunctionCall{
Name: completionReq.Functions[0].Name,
Arguments: string(fcb),
},
},
Index: i,
})
continue
}
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)

Expand All @@ -121,8 +272,8 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
Index: i,
})
}
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
inputTokens := numTokens(completionReq.Messages[0].Content) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
Expand Down
10 changes: 7 additions & 3 deletions completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
n := completionReq.N
if n == 0 {
n = 1
}
for i := 0; i < n; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)
if completionReq.Echo {
Expand All @@ -94,8 +98,8 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
Index: i,
})
}
inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
inputTokens := numTokens(completionReq.Prompt.(string)) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
Expand Down