diff --git a/internal/extproc/translator/openai_awsbedrock.go b/internal/extproc/translator/openai_awsbedrock.go index 5e562cb7..7113fde4 100644 --- a/internal/extproc/translator/openai_awsbedrock.go +++ b/internal/extproc/translator/openai_awsbedrock.go @@ -482,6 +482,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseError(respHeaders } else { var buf []byte buf, err = io.ReadAll(body) + fmt.Printf("\nprinting body from ResponseError:\n %v\n", string(buf)) if err != nil { return nil, nil, fmt.Errorf("failed to read error body: %w", err) } @@ -549,6 +550,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders mut.Body = append(mut.Body, oaiEventBytes...) mut.Body = append(mut.Body, []byte("\n\n")...) } + fmt.Printf("\nprinting mut.Body %v", string(mut.Body)) if endOfStream { mut.Body = append(mut.Body, []byte("data: [DONE]\n")...) @@ -560,6 +562,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders if err = json.NewDecoder(body).Decode(&bedrockResp); err != nil { return nil, nil, tokenUsage, fmt.Errorf("failed to unmarshal body: %w", err) } + fmt.Printf("\nbedrock output message from converse: %v\n", len(bedrockResp.Output.Message.Content)) openAIResp := openai.ChatCompletionResponse{ Object: "chat.completion", @@ -579,6 +582,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders } } for i, output := range bedrockResp.Output.Message.Content { + fmt.Printf("\nbedrock output message: i=%v, text=%v, toolResult=%v, toolUse=%v\n", i, *output.Text, output.ToolResult, output.ToolUse) choice := openai.ChatCompletionResponseChoice{ Index: (int64)(i), Message: openai.ChatCompletionResponseChoiceMessage{ @@ -590,6 +594,8 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders if toolCall := o.bedrockToolUseToOpenAICalls(output.ToolUse); toolCall != nil { choice.Message.ToolCalls = []openai.ChatCompletionMessageToolCallParam{*toolCall} } + // TODO: merge the choice message with the + // update dont append openAIResp.Choices = append(openAIResp.Choices, choice) } diff --git a/tests/extproc/real_providers_test.go b/tests/extproc/real_providers_test.go index 32eae0c2..beb26049 100644 --- a/tests/extproc/real_providers_test.go +++ b/tests/extproc/real_providers_test.go @@ -11,13 +11,14 @@ import ( "bufio" "bytes" "cmp" + "context" "encoding/json" "fmt" "os" "testing" "time" - "github.com/openai/openai-go" + openai "github.com/openai/openai-go" "github.com/openai/openai-go/option" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -136,100 +137,163 @@ func TestWithRealProviders(t *testing.T) { }, 30*time.Second, 2*time.Second) }) - t.Run("streaming", func(t *testing.T) { + //t.Run("streaming", func(t *testing.T) { + // client := openai.NewClient(option.WithBaseURL(listenerAddress + "/v1/")) + // for _, tc := range []realProvidersTestCase{ + // {name: "openai", modelName: "gpt-4o-mini", required: requiredCredentialOpenAI}, + // {name: "aws-bedrock", modelName: "us.meta.llama3-2-1b-instruct-v1:0", required: requiredCredentialAWS}, + // } { + // t.Run(tc.name, func(t *testing.T) { + // cc.maybeSkip(t, tc.required) + // require.Eventually(t, func() bool { + // stream := client.Chat.Completions.NewStreaming(t.Context(), openai.ChatCompletionNewParams{ + // Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + // openai.UserMessage("Say this is a test"), + // }), + // Model: openai.F(tc.modelName), + // }) + // defer func() { + // _ = stream.Close() + // }() + // + // acc := openai.ChatCompletionAccumulator{} + // + // for stream.Next() { + // chunk := stream.Current() + // if !acc.AddChunk(chunk) { + // t.Log("error adding chunk") + // return false + // } + // } + // + // if err := stream.Err(); err != nil { + // t.Logf("error: %v", err) + // return false + // } + // + // nonEmptyCompletion := false + // for _, choice := range acc.Choices { + // t.Logf("choice: %s", choice.Message.Content) + // if choice.Message.Content != "" { + // nonEmptyCompletion = true + // } + // } + // if !nonEmptyCompletion { + // // Log the whole response for debugging. + // t.Logf("response: %+v", acc) + // } + // return nonEmptyCompletion + // }, 30*time.Second, 2*time.Second) + // }) + // } + //}) + + t.Run("Bedrock uses tool in response", func(t *testing.T) { + fmt.Println("starting tool test") client := openai.NewClient(option.WithBaseURL(listenerAddress + "/v1/")) - for _, tc := range []realProvidersTestCase{ - {name: "openai", modelName: "gpt-4o-mini", required: requiredCredentialOpenAI}, - {name: "aws-bedrock", modelName: "us.meta.llama3-2-1b-instruct-v1:0", required: requiredCredentialAWS}, + fmt.Println("after client") + for _, tc := range []struct { + testCaseName, + modelName string + }{ + {testCaseName: "aws-bedrock", modelName: "us.anthropic.claude-3-5-sonnet-20240620-v1:0"}, // This will go to "aws-bedrock" using credentials file. } { - t.Run(tc.name, func(t *testing.T) { - cc.maybeSkip(t, tc.required) + t.Run(tc.modelName, func(t *testing.T) { + fmt.Println("inside run") require.Eventually(t, func() bool { - stream := client.Chat.Completions.NewStreaming(t.Context(), openai.ChatCompletionNewParams{ + // Step 1: Initial tool call request + question := "What is the weather in New York City?" + params := openai.ChatCompletionNewParams{ Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Say this is a test"), + openai.UserMessage(question), + }), + Tools: openai.F([]openai.ChatCompletionToolParam{ + { + Type: openai.F(openai.ChatCompletionToolTypeFunction), + Function: openai.F(openai.FunctionDefinitionParam{ + Name: openai.String("get_weather"), + Description: openai.String("Get weather at the given location"), + Parameters: openai.F(openai.FunctionParameters{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]string{ + "type": "string", + }, + }, + "required": []string{"location"}, + }), + }), + }, }), + // TODO: check if we should seed. + Seed: openai.Int(0), Model: openai.F(tc.modelName), - }) - defer func() { - _ = stream.Close() - }() - - acc := openai.ChatCompletionAccumulator{} - - for stream.Next() { - chunk := stream.Current() - if !acc.AddChunk(chunk) { - t.Log("error adding chunk") - return false - } } - - if err := stream.Err(); err != nil { + fmt.Println("after params set") + completion, err := client.Chat.Completions.New(context.Background(), params) + fmt.Println("after completion") + if err != nil { t.Logf("error: %v", err) return false } - - nonEmptyCompletion := false - for _, choice := range acc.Choices { - t.Logf("choice: %s", choice.Message.Content) - if choice.Message.Content != "" { - nonEmptyCompletion = true + // Step 2: Verify tool call + // TODO: remove after test debugging done done + returnsToolCall := false + for _, choice := range completion.Choices { + t.Logf("choice content: %s", choice.Message.Content) + t.Logf("finish reason: %s", choice.FinishReason) + t.Logf("choice toolcall: %v", choice.Message.ToolCalls) + if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonToolCalls { + returnsToolCall = true } } - if !nonEmptyCompletion { - // Log the whole response for debugging. - t.Logf("response: %+v", acc) + if returnsToolCall == false { + t.Logf("Tool call not returned") + return false + } + toolCalls := completion.Choices[0].Message.ToolCalls + if len(toolCalls) == 0 { + t.Logf("Expected tool call from completion result but got none") + return false + } + // Step 3: Simulate the tool returning a response, add the tool response to the params, and check the second response + params.Messages.Value = append(params.Messages.Value, completion.Choices[0].Message) + getWeatherCalled := false + for _, toolCall := range toolCalls { + if toolCall.Function.Name == "get_weather" { + getWeatherCalled = true + // Extract the location from the function call arguments + var args map[string]interface{} + if argErr := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); argErr != nil { + panic(argErr) + } + location := args["location"].(string) + if location != "New York City" { + t.Logf("Expected location to be New York City but got %s", location) + } + // Simulate getting weather data + weatherData := "Sunny, 25°C" + params.Messages.Value = append(params.Messages.Value, openai.ToolMessage(toolCall.ID, weatherData)) + } + } + if getWeatherCalled == false { + t.Logf("get_weather tool not specified in chat completion response") + return false } - return nonEmptyCompletion - }, 30*time.Second, 2*time.Second) - }) - } - }) - t.Run("Bedrock calls tool get_weather function", func(t *testing.T) { - cc.maybeSkip(t, requiredCredentialAWS) - client := openai.NewClient(option.WithBaseURL(listenerAddress + "/v1/")) - require.Eventually(t, func() bool { - chatCompletion, err := client.Chat.Completions.New(t.Context(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("What is the weather like in Paris today?"), - }), - Tools: openai.F([]openai.ChatCompletionToolParam{ - { - Type: openai.F(openai.ChatCompletionToolTypeFunction), - Function: openai.F(openai.FunctionDefinitionParam{ - Name: openai.String("get_weather"), - Description: openai.String("Get weather at the given location"), - Parameters: openai.F(openai.FunctionParameters{ - "type": "object", - "properties": map[string]interface{}{ - "location": map[string]string{ - "type": "string", - }, - }, - "required": []string{"location"}, - }), - }), - }, - }), - Model: openai.F("us.anthropic.claude-3-5-sonnet-20240620-v1:0"), + secondChatCompletion, err := client.Chat.Completions.New(context.Background(), params) + if err != nil { + t.Logf("error during second response: %v", err) + return false + } + + // Step 4: Verify that the second response is correct + completionResult := secondChatCompletion.Choices[0].Message.Content + t.Logf("content of completion response using tool: %s", secondChatCompletion.Choices[0].Message.Content) + return completionResult == "The weather in Paris is currently sunny and 25°C." + }, 500*time.Second, 200*time.Second) }) - if err != nil { - t.Logf("error: %v", err) - return false - } - returnsToolCall := false - for _, choice := range chatCompletion.Choices { - t.Logf("choice content: %s", choice.Message.Content) - t.Logf("finish reason: %s", choice.FinishReason) - t.Logf("choice toolcall: %v", choice.Message.ToolCalls) - if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonToolCalls { - returnsToolCall = true - } - } - return returnsToolCall - }, 30*time.Second, 2*time.Second) + } }) // Models are served by the extproc filter as a direct response so this can run even if the