Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

e2e: add agent-tool test #285

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
2 changes: 2 additions & 0 deletions internal/extproc/translator/openai_awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders
if toolCall := o.bedrockToolUseToOpenAICalls(output.ToolUse); toolCall != nil {
choice.Message.ToolCalls = []openai.ChatCompletionMessageToolCallParam{*toolCall}
}
// TODO: merge the choice message with the
// update dont append
openAIResp.Choices = append(openAIResp.Choices, choice)
}

Expand Down
124 changes: 98 additions & 26 deletions tests/extproc/real_providers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,34 +181,106 @@
})
}
})

t.Run("Bedrock calls tool get_weather function", func(t *testing.T) {
cc.maybeSkip(t, requiredCredentialAWS)

t.Run("Bedrock uses tool in response", func(t *testing.T) {
client := openai.NewClient(option.WithBaseURL(listenerAddress + "/v1/"))
require.Eventually(t, func() bool {
chatCompletion, err := client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage("What is the weather like in Paris today?"),
}),
Tools: openai.F([]openai.ChatCompletionToolParam{
{
Type: openai.F(openai.ChatCompletionToolTypeFunction),
Function: openai.F(openai.FunctionDefinitionParam{
Name: openai.String("get_weather"),
Description: openai.String("Get weather at the given location"),
Parameters: openai.F(openai.FunctionParameters{
"type": "object",
"properties": map[string]interface{}{
"location": map[string]string{
"type": "string",
},
},
"required": []string{"location"},
}),
for _, tc := range []struct {
testCaseName,
modelName string
}{
{testCaseName: "aws-bedrock", modelName: "us.anthropic.claude-3-5-sonnet-20240620-v1:0"}, // This will go to "aws-bedrock" using credentials file.
} {
t.Run(tc.modelName, func(t *testing.T) {
require.Eventually(t, func() bool {
// Step 1: Initial tool call request
question := "What is the weather in New York City?"
params := openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage(question),
}),
},
}),
Model: openai.F("us.anthropic.claude-3-5-sonnet-20240620-v1:0"),
Tools: openai.F([]openai.ChatCompletionToolParam{
{
Type: openai.F(openai.ChatCompletionToolTypeFunction),
Function: openai.F(openai.FunctionDefinitionParam{
Name: openai.String("get_weather"),
Description: openai.String("Get weather at the given location"),
Parameters: openai.F(openai.FunctionParameters{
"type": "object",
"properties": map[string]interface{}{
"location": map[string]string{
"type": "string",
},
},
"required": []string{"location"},
}),
}),
},
}),
// TODO: check if we should seed.
Seed: openai.Int(0),
Model: openai.F(tc.modelName),
}
completion, err := client.Chat.Completions.New(context.Background(), params)
if err != nil {
t.Logf("error: %v", err)
return false
}
// Step 2: Verify tool call
// TODO: remove after test done
returnsToolCall := false
for _, choice := range completion.Choices {
t.Logf("choice content: %s", choice.Message.Content)
t.Logf("finish reason: %s", choice.FinishReason)
t.Logf("choice toolcall: %v", choice.Message.ToolCalls)
if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonToolCalls {
returnsToolCall = true
}
}
if returnsToolCall == false {
t.Logf("Tool call not returned")
return false
}
toolCalls := completion.Choices[0].Message.ToolCalls
if len(toolCalls) == 0 {
t.Logf("Expected tool call from completion result but got none")
return false
}
// Step 3: Simulate the tool returning a response, add the tool response to the params, and check the second response
params.Messages.Value = append(params.Messages.Value, completion.Choices[0].Message)
getWeatherCalled := false
for _, toolCall := range toolCalls {
if toolCall.Function.Name == "get_weather" {
getWeatherCalled = true
// Extract the location from the function call arguments
var args map[string]interface{}
if argErr := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); argErr != nil {
panic(argErr)
}
location := args["location"].(string)
if location != "New York City" {
t.Logf("Expected location to be New York City but got %s", location)
}
// Simulate getting weather data
weatherData := "Sunny, 25°C"
params.Messages.Value = append(params.Messages.Value, openai.ToolMessage(toolCall.ID, weatherData))
}
}
if getWeatherCalled == false {
t.Logf("get_weather tool not specified in chat completion response")
return false
}

secondChatCompletion, err := client.Chat.Completions.New(context.Background(), params)
if err != nil {
t.Logf("error during second response: %v", err)
return false
}

// Step 4: Verify that the second response is correct
completionResult := secondChatCompletion.Choices[0].Message.Content
t.Logf("content of completion response using tool: %s", secondChatCompletion.Choices[0].Message.Content)
return completionResult == "The weather in Paris is currently sunny and 25°C."
}, 30*time.Second, 2*time.Second)
})
if err != nil {
t.Logf("error: %v", err)
Expand All @@ -224,7 +296,7 @@
}
}
return returnsToolCall
}, 30*time.Second, 2*time.Second)

Check failure on line 299 in tests/extproc/real_providers_test.go

View workflow job for this annotation

GitHub Actions / Check

expected ';', found ','

Check failure on line 299 in tests/extproc/real_providers_test.go

View workflow job for this annotation

GitHub Actions / External Processor Test (Envoy v1.33)

expected ';', found ','

Check failure on line 299 in tests/extproc/real_providers_test.go

View workflow job for this annotation

GitHub Actions / External Processor Test (Envoy latest)

expected ';', found ','
})
}

Expand Down Expand Up @@ -252,17 +324,17 @@
}

// maybeSkip skips the test if the required credentials are not set.
func (c credentialsContext) maybeSkip(t *testing.T, required requiredCredential) {

Check failure on line 327 in tests/extproc/real_providers_test.go

View workflow job for this annotation

GitHub Actions / Check

missing ',' in argument list
if required&requiredCredentialOpenAI != 0 && !c.openAIValid {
t.Skip("skipping test as OpenAI API key is not set in TEST_OPENAI_API_KEY")
}
if required&requiredCredentialAWS != 0 && !c.awsValid {
t.Skip("skipping test as AWS credentials are not set in TEST_AWS_ACCESS_KEY_ID and TEST_AWS_SECRET_ACCESS_KEY")
}
}

Check failure on line 334 in tests/extproc/real_providers_test.go

View workflow job for this annotation

GitHub Actions / Check

missing ',' before newline in argument list

// requireNewCredentialsContext creates a new credential context for the tests from the environment variables.
func requireNewCredentialsContext(t *testing.T) (ctx credentialsContext) {

Check failure on line 337 in tests/extproc/real_providers_test.go

View workflow job for this annotation

GitHub Actions / Check

expected '(', found requireNewCredentialsContext
// Set up credential file for OpenAI.
openAIAPIKey := os.Getenv("TEST_OPENAI_API_KEY")

Expand Down Expand Up @@ -297,4 +369,4 @@
openAIAPIKeyFilePath: openAIAPIKeyFilePath,
awsFilePath: awsFilePath,
}
}

Check failure on line 372 in tests/extproc/real_providers_test.go

View workflow job for this annotation

GitHub Actions / Check

missing ',' before newline in argument list
Loading