diff --git a/README.md b/README.md index df97efd..e421225 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # aisdk-go -[![GitHub Release](https://img.shields.io/github/v/release/kylecarbs/aisdk-go?color=6b9ded&sort=semver)](https://github.com/kylecarbs/aisdk-go/releases) -[![GoDoc](https://godoc.org/github.com/kylecarbs/aisdk-go?status.svg)](https://godoc.org/github.com/kylecarbs/aisdk-go) -[![CI Status](https://github.com/kylecarbs/aisdk-go/workflows/ci/badge.svg)](https://github.com/kylecarbs/aisdk-go/actions) +[![GitHub Release](https://img.shields.io/github/v/release/coder/aisdk-go?color=6b9ded&sort=semver)](https://github.com/coder/aisdk-go/releases) +[![GoDoc](https://godoc.org/github.com/coder/aisdk-go?status.svg)](https://godoc.org/github.com/coder/aisdk-go) +[![CI Status](https://github.com/coder/aisdk-go/workflows/ci/badge.svg)](https://github.com/coder/aisdk-go/actions) > [!WARNING] > This library is super new and may change a lot. diff --git a/anthropic.go b/anthropic.go index 53aaf7b..d42a5fe 100644 --- a/anthropic.go +++ b/anthropic.go @@ -75,7 +75,7 @@ func MessagesToAnthropic(messages []Message) ([]anthropic.MessageParam, []anthro switch part.Type { case PartTypeText: content = append(content, anthropic.ContentBlockParamUnion{ - OfRequestTextBlock: &anthropic.TextBlockParam{ + OfText: &anthropic.TextBlockParam{ Text: part.Text, }, }) @@ -88,7 +88,7 @@ func MessagesToAnthropic(messages []Message) ([]anthropic.MessageParam, []anthro return nil, nil, fmt.Errorf("marshalling tool input for call %s: %w", part.ToolInvocation.ToolCallID, err) } content = append(content, anthropic.ContentBlockParamUnion{ - OfRequestToolUseBlock: &anthropic.ToolUseBlockParam{ + OfToolUse: &anthropic.ToolUseBlockParam{ ID: part.ToolInvocation.ToolCallID, Input: json.RawMessage(argsJSON), Name: part.ToolInvocation.ToolName, @@ -115,13 +115,13 @@ func MessagesToAnthropic(messages []Message) ([]anthropic.MessageParam, []anthro switch resultPart.Type { case PartTypeText: resultContent = append(resultContent, anthropic.ToolResultBlockParamContentUnion{ - OfRequestTextBlock: &anthropic.TextBlockParam{Text: resultPart.Text}, + OfText: &anthropic.TextBlockParam{Text: resultPart.Text}, }) case PartTypeFile: resultContent = append(resultContent, anthropic.ToolResultBlockParamContentUnion{ - OfRequestImageBlock: &anthropic.ImageBlockParam{ + OfImage: &anthropic.ImageBlockParam{ Source: anthropic.ImageBlockParamSourceUnion{ - OfBase64ImageSource: &anthropic.Base64ImageSourceParam{ + OfBase64: &anthropic.Base64ImageSourceParam{ Data: base64.StdEncoding.EncodeToString(resultPart.Data), MediaType: anthropic.Base64ImageSourceMediaType(resultPart.MimeType), }, @@ -136,7 +136,7 @@ func MessagesToAnthropic(messages []Message) ([]anthropic.MessageParam, []anthro Role: anthropic.MessageParamRoleUser, Content: []anthropic.ContentBlockParamUnion{ { - OfRequestToolResultBlock: &anthropic.ToolResultBlockParam{ + OfToolResult: &anthropic.ToolResultBlockParam{ ToolUseID: part.ToolInvocation.ToolCallID, Content: resultContent, }, @@ -152,13 +152,13 @@ func MessagesToAnthropic(messages []Message) ([]anthropic.MessageParam, []anthro switch part.Type { case PartTypeText: content = append(content, anthropic.ContentBlockParamUnion{ - OfRequestTextBlock: &anthropic.TextBlockParam{Text: part.Text}, + OfText: &anthropic.TextBlockParam{Text: part.Text}, }) case PartTypeFile: content = append(content, anthropic.ContentBlockParamUnion{ - OfRequestImageBlock: &anthropic.ImageBlockParam{ + OfImage: &anthropic.ImageBlockParam{ Source: anthropic.ImageBlockParamSourceUnion{ - OfBase64ImageSource: &anthropic.Base64ImageSourceParam{ + OfBase64: &anthropic.Base64ImageSourceParam{ Data: base64.StdEncoding.EncodeToString(part.Data), MediaType: anthropic.Base64ImageSourceMediaType(part.MimeType), }, @@ -181,9 +181,9 @@ func MessagesToAnthropic(messages []Message) ([]anthropic.MessageParam, []anthro return nil, nil, fmt.Errorf("invalid attachment URL: %s", attachment.URL) } content = append(content, anthropic.ContentBlockParamUnion{ - OfRequestImageBlock: &anthropic.ImageBlockParam{ + OfImage: &anthropic.ImageBlockParam{ Source: anthropic.ImageBlockParamSourceUnion{ - OfBase64ImageSource: &anthropic.Base64ImageSourceParam{ + OfBase64: &anthropic.Base64ImageSourceParam{ Data: parts[1], MediaType: anthropic.Base64ImageSourceMediaType(attachment.ContentType), }, diff --git a/anthropic_test.go b/anthropic_test.go index a349cab..1885d6d 100644 --- a/anthropic_test.go +++ b/anthropic_test.go @@ -12,7 +12,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" - "github.com/kylecarbs/aisdk-go" + "github.com/coder/aisdk-go" "github.com/stretchr/testify/require" ) @@ -75,7 +75,7 @@ data: {"type":"message_stop" }` var acc aisdk.DataStreamAccumulator stream := aisdk.AnthropicToDataStream(typedStream) - stream = stream.WithToolCalling(func(toolCall aisdk.ToolCall) aisdk.ToolCallResult { + stream = stream.WithToolCalling(func(toolCall aisdk.ToolCall) any { return map[string]any{"message": "Message printed to the console"} }) stream = stream.WithAccumulator(&acc) @@ -123,12 +123,12 @@ data: {"type":"message_stop" }` require.Len(t, assistantMsg.Content, 2) // Text block + ToolUse block // Check Text Content Block - textBlock := assistantMsg.Content[0].OfRequestTextBlock + textBlock := assistantMsg.Content[0].OfText require.NotNil(t, textBlock) require.Equal(t, "I'll help you print 'hello world' to the console using the print function.", textBlock.Text) // Check Tool Use Content Block - toolUseBlock := assistantMsg.Content[1].OfRequestToolUseBlock + toolUseBlock := assistantMsg.Content[1].OfToolUse require.NotNil(t, toolUseBlock) require.Equal(t, "toolu_01RA76iwg1LbKuDjJnc6ym45", toolUseBlock.ID) require.Equal(t, "print", toolUseBlock.Name) @@ -139,12 +139,12 @@ data: {"type":"message_stop" }` require.Equal(t, anthropic.MessageParamRoleUser, userMsg.Role) require.Len(t, userMsg.Content, 1) // ToolResult block - toolResultBlock := userMsg.Content[0].OfRequestToolResultBlock + toolResultBlock := userMsg.Content[0].OfToolResult require.NotNil(t, toolResultBlock) require.Equal(t, "toolu_01RA76iwg1LbKuDjJnc6ym45", toolResultBlock.ToolUseID) require.Len(t, toolResultBlock.Content, 1) - require.NotNil(t, toolResultBlock.Content[0].OfRequestTextBlock) - require.JSONEq(t, `{"message":"Message printed to the console"}`, toolResultBlock.Content[0].OfRequestTextBlock.Text) + require.NotNil(t, toolResultBlock.Content[0].OfText) + require.JSONEq(t, `{"message":"Message printed to the console"}`, toolResultBlock.Content[0].OfText.Text) // --- Second conversion check (using expectedMessages) --- // This part should remain the same, as it also expects 2 messages now. @@ -163,12 +163,12 @@ data: {"type":"message_stop" }` require.Equal(t, anthropic.MessageParamRoleUser, userMsgWithResult.Role) require.Len(t, userMsgWithResult.Content, 1) // ToolResult block - toolResultBlockWithResult := userMsgWithResult.Content[0].OfRequestToolResultBlock + toolResultBlockWithResult := userMsgWithResult.Content[0].OfToolResult require.NotNil(t, toolResultBlockWithResult) require.Equal(t, "toolu_01RA76iwg1LbKuDjJnc6ym45", toolResultBlockWithResult.ToolUseID) require.Len(t, toolResultBlockWithResult.Content, 1) - require.NotNil(t, toolResultBlockWithResult.Content[0].OfRequestTextBlock) - require.JSONEq(t, `{"message":"Message printed to the console"}`, toolResultBlockWithResult.Content[0].OfRequestTextBlock.Text) + require.NotNil(t, toolResultBlockWithResult.Content[0].OfText) + require.JSONEq(t, `{"message":"Message printed to the console"}`, toolResultBlockWithResult.Content[0].OfText.Text) } func TestMessagesToAnthropic_Live(t *testing.T) { @@ -181,15 +181,25 @@ func TestMessagesToAnthropic_Live(t *testing.T) { client := anthropic.NewClient(option.WithAPIKey(apiKey)) // Ensure messages are converted correctly. + prompt := "use the 'print' tool to print 'Hello, world!' and then show the result" messages, systemPrompts, err := aisdk.MessagesToAnthropic([]aisdk.Message{ { - Role: "system", - Content: "use the 'print' tool to print 'Hello, world!' and then show the result", + Role: "system", + Parts: []aisdk.Part{ + {Text: "You are a helpful assistant.", Type: aisdk.PartTypeText}, + }, }, { - Role: "user", Content: "Go ahead.", + Role: "user", Parts: []aisdk.Part{ + {Text: prompt, Type: aisdk.PartTypeText}, + }, }, }) + require.Len(t, messages, 1) + require.Len(t, systemPrompts, 1) + require.Len(t, messages[0].Content, 1) + require.NotNil(t, messages[0].Content[0].OfText) + require.Equal(t, messages[0].Content[0].OfText.Text, prompt) require.NoError(t, err) stream := client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{ diff --git a/demo/server.go b/demo/server.go index 8543438..ab4f1b3 100644 --- a/demo/server.go +++ b/demo/server.go @@ -11,7 +11,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" anthropicoption "github.com/anthropics/anthropic-sdk-go/option" - "github.com/kylecarbs/aisdk-go" + "github.com/coder/aisdk-go" "github.com/openai/openai-go" openaioption "github.com/openai/openai-go/option" "google.golang.org/genai" @@ -63,7 +63,7 @@ func run(ctx context.Context) error { return } - handleToolCall := func(toolCall aisdk.ToolCall) aisdk.ToolCallResult { + handleToolCall := func(toolCall aisdk.ToolCall) any { return map[string]string{ "message": "It worked!", } @@ -115,10 +115,10 @@ func run(ctx context.Context) error { thinking := anthropic.ThinkingConfigParamUnion{} if req.Thinking { - thinking = anthropic.ThinkingConfigParamOfThinkingConfigEnabled(2048) + thinking = anthropic.ThinkingConfigParamOfEnabled(2048) } stream = aisdk.AnthropicToDataStream(anthropicClient.Messages.NewStreaming(ctx, anthropic.MessageNewParams{ - Model: req.Model, + Model: anthropic.Model(req.Model), Messages: messages, System: system, MaxTokens: 4096, diff --git a/go.mod b/go.mod index ad03a56..e6cdb24 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,13 @@ -module github.com/kylecarbs/aisdk-go +module github.com/coder/aisdk-go go 1.23.7 require ( - github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3 + github.com/anthropics/anthropic-sdk-go v1.4.0 github.com/google/uuid v1.6.0 - github.com/openai/openai-go v0.1.0-beta.6 + github.com/openai/openai-go v1.3.0 github.com/stretchr/testify v1.10.0 - google.golang.org/genai v0.7.0 + google.golang.org/genai v1.10.0 ) require ( diff --git a/go.sum b/go.sum index b83768c..936896b 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps= cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8= cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= -github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3 h1:b5t1ZJMvV/l99y4jbz7kRFdUp3BSDkI8EhSlHczivtw= -github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= +github.com/anthropics/anthropic-sdk-go v1.4.0 h1:fU1jKxYbQdQDiEXCxeW5XZRIOwKevn/PMg8Ay1nnUx0= +github.com/anthropics/anthropic-sdk-go v1.4.0/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -33,8 +33,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ= -github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= +github.com/openai/openai-go v1.3.0 h1:lBpvgXxGHUufk9DNTguval40y2oK0GHZwgWQyUtjPIQ= +github.com/openai/openai-go v1.3.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -75,8 +75,8 @@ golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= -google.golang.org/genai v0.7.0 h1:TINBYXnP+K+D8b16LfVyb6XR3kdtieXy6nJsGoEXcBc= -google.golang.org/genai v0.7.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY= +google.golang.org/genai v1.10.0 h1:ETP0Yksn5KUSEn5+ihMOnP3IqjZ+7Z4i0LjJslEXatI= +google.golang.org/genai v1.10.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY= google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 h1:e0AIkUUhxyBKh6ssZNrAMeqhA7RKUj42346d1y02i2g= google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= diff --git a/google.go b/google.go index 4cea6f8..cfb0888 100644 --- a/google.go +++ b/google.go @@ -345,12 +345,12 @@ func GoogleToDataStream(stream iter.Seq2[*genai.GenerateContentResponse, error]) // Extract final usage data if available if lastResp != nil && lastResp.UsageMetadata != nil { - if lastResp.UsageMetadata.PromptTokenCount != nil { - promptTokens := int64(*lastResp.UsageMetadata.PromptTokenCount) + if lastResp.UsageMetadata.PromptTokenCount > 0 { + promptTokens := int64(lastResp.UsageMetadata.PromptTokenCount) finalUsage.PromptTokens = &promptTokens } - if lastResp.UsageMetadata.CandidatesTokenCount != nil { - completionTokens := int64(*lastResp.UsageMetadata.CandidatesTokenCount) + if lastResp.UsageMetadata.CandidatesTokenCount > 0 { + completionTokens := int64(lastResp.UsageMetadata.CandidatesTokenCount) finalUsage.CompletionTokens = &completionTokens } } diff --git a/google_test.go b/google_test.go new file mode 100644 index 0000000..2a56e10 --- /dev/null +++ b/google_test.go @@ -0,0 +1,155 @@ +package aisdk_test + +import ( + "context" + "encoding/json" + "os" + "strings" + "testing" + + "github.com/coder/aisdk-go" + "github.com/stretchr/testify/require" + "google.golang.org/genai" +) + +func TestGoogleToDataStream(t *testing.T) { + t.Parallel() + + // googleResponses are hardcoded responses from the Google AI Stream endpoint. + googleResponses := `data: {"candidates": [{"content": {"parts": [{"text": "Here"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 10,"totalTokenCount": 10,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 10}]},"modelVersion": "gemini-2.0-flash","responseId": "ibRCaOfLGfuQ1PIPqNma8Aw"} + +data: {"candidates": [{"content": {"parts": [{"text": " you go:\n\nEnglish: potato\nSpanish: patata\nFrench: pom"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 10,"totalTokenCount": 10,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 10}]},"modelVersion": "gemini-2.0-flash","responseId": "ibRCaOfLGfuQ1PIPqNma8Aw"} + +data: {"candidates": [{"content": {"parts": [{"text": "me de terre\nGerman: Kartoffel\nItalian: patata\nJapanese: ジャ"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 10,"totalTokenCount": 10,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 10}]},"modelVersion": "gemini-2.0-flash","responseId": "ibRCaOfLGfuQ1PIPqNma8Aw"} + +data: {"candidates": [{"content": {"parts": [{"text": "ガイモ (jagaimo)\nRussian: картофель (kartofel')\n"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 10,"candidatesTokenCount": 49,"totalTokenCount": 59,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 10}],"candidatesTokensDetails": [{"modality": "TEXT","tokenCount": 49}]},"modelVersion": "gemini-2.0-flash","responseId": "ibRCaOfLGfuQ1PIPqNma8Aw"}` + + // Parse the SSE format manually since Google doesn't expose its SSE parsing :( + // See https://github.com/googleapis/go-genai/blob/v1.10.0/api_client.go#L44 + lines := strings.Split(googleResponses, "\n") + var responses []*genai.GenerateContentResponse + + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "data: ") { + jsonData := strings.TrimPrefix(line, "data: ") + if len(jsonData) == 0 { + continue // Skip empty data lines + } + var resp genai.GenerateContentResponse + err := json.Unmarshal([]byte(jsonData), &resp) + require.NoError(t, err) + responses = append(responses, &resp) + } + } + + // Create an iterator from the parsed responses + mockStream := func(yield func(*genai.GenerateContentResponse, error) bool) { + for _, resp := range responses { + if !yield(resp, nil) { + return + } + } + } + + var acc aisdk.DataStreamAccumulator + stream := aisdk.GoogleToDataStream(mockStream) + stream = stream.WithToolCalling(func(toolCall aisdk.ToolCall) any { + return map[string]any{"message": "Message printed to the console"} + }) + stream = stream.WithAccumulator(&acc) + for _, err := range stream { + require.NoError(t, err) + } + + // Verify that we got messages and they have the expected content + messages := acc.Messages() + require.Len(t, messages, 1) + + expectedContent := "Here you go:\n\nEnglish: potato\nSpanish: patata\nFrench: pomme de terre\nGerman: Kartoffel\nItalian: patata\nJapanese: ジャガイモ (jagaimo)\nRussian: картофель (kartofel')\n" + + msg := messages[0] + require.Equal(t, "assistant", msg.Role) + require.Equal(t, expectedContent, msg.Content) + require.Len(t, msg.Parts, 2) // step-start, text (accumulated) + + // Check step start part + require.Equal(t, aisdk.PartTypeStepStart, msg.Parts[0].Type) + + // Check text part (accumulated across all chunks) + require.Equal(t, aisdk.PartTypeText, msg.Parts[1].Type) + require.Equal(t, expectedContent, msg.Parts[1].Text) + + // Test conversion back to Google format + googleContents, err := aisdk.MessagesToGoogle(messages) + require.NoError(t, err) + + // We expect one content block with just text + require.Len(t, googleContents, 1) + + // Check the content (assistant message with text only) + content := googleContents[0] + require.Equal(t, "model", content.Role) + require.Len(t, content.Parts, 1) // just text part + + // Check text part (accumulated text) + require.Equal(t, expectedContent, content.Parts[0].Text) +} + +func TestMessagesToGoogle_Live(t *testing.T) { + t.Parallel() + apiKey := os.Getenv("GOOGLE_API_KEY") + if apiKey == "" { + t.Skip("GOOGLE_API_KEY is not set") + } + + // Create a Google AI client + ctx := context.Background() + client, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: apiKey, + Backend: genai.BackendGeminiAPI, + }) + require.NoError(t, err) + + prompt := "use the 'print' tool to print 'Hello, world!' and then show the result" + // Test messages with a simple request + messages := []aisdk.Message{ + { + Role: "user", + Parts: []aisdk.Part{ + {Text: prompt, Type: aisdk.PartTypeText}, + }, + }, + } + + // Convert messages to Google format + contents, err := aisdk.MessagesToGoogle(messages) + require.NoError(t, err) + require.Len(t, contents, 1) + require.Len(t, contents[0].Parts, 1) + require.Equal(t, contents[0].Parts[0].Text, prompt) + + _, err = genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: os.Getenv("GOOGLE_API_KEY"), + Backend: genai.BackendGeminiAPI, + }) + require.NoError(t, err) + + stream := client.Models.GenerateContentStream( + ctx, + "gemini-2.0-flash", + contents, + nil, + ) + + dataStream := aisdk.GoogleToDataStream(stream) + var streamErr error + dataStream(func(part aisdk.DataStreamPart, err error) bool { + if err != nil { + streamErr = err + return false + } + return true + }) + require.NoError(t, streamErr) +} diff --git a/openai.go b/openai.go index d649618..910408b 100644 --- a/openai.go +++ b/openai.go @@ -251,11 +251,11 @@ func OpenAIToDataStream(stream *ssestream.Stream[openai.ChatCompletionChunk]) Da finishReason = FinishReasonStop } - if lastChunk.Usage.JSON.CompletionTokens.IsPresent() { + if lastChunk.Usage.JSON.CompletionTokens.Valid() { tokens := int64(lastChunk.Usage.CompletionTokens) completionTokens = &tokens } - if lastChunk.Usage.JSON.PromptTokens.IsPresent() { + if lastChunk.Usage.JSON.PromptTokens.Valid() { tokens := int64(lastChunk.Usage.PromptTokens) promptTokens = &tokens } diff --git a/openai_test.go b/openai_test.go index abbf74b..4612330 100644 --- a/openai_test.go +++ b/openai_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "github.com/kylecarbs/aisdk-go" + "github.com/coder/aisdk-go" "github.com/openai/openai-go" "github.com/openai/openai-go/option" @@ -66,7 +66,7 @@ data: [DONE]` // 3. Pass the typed stream to OpenAIToDataStream and accumulate results var acc aisdk.DataStreamAccumulator stream := aisdk.OpenAIToDataStream(typedStream) - stream = stream.WithToolCalling(func(toolCall aisdk.ToolCall) aisdk.ToolCallResult { + stream = stream.WithToolCalling(func(toolCall aisdk.ToolCall) any { return map[string]any{"message": "Message printed to the console"} }) stream = stream.WithAccumulator(&acc) // Accumulator is attached here @@ -141,12 +141,24 @@ func TestMessagesToOpenAI_Live(t *testing.T) { client := openai.NewClient(option.WithAPIKey(apiKey)) // Ensure messages are converted correctly. + prompt := "use the 'print' tool to print 'Hello, world!' and then show the result" messages, err := aisdk.MessagesToOpenAI([]aisdk.Message{ { Role: "system", - Content: "use the 'print' tool to print 'Hello, world!' and then show the result", + Content: "You are a helpful assistant.", + }, + { + Role: "user", + Parts: []aisdk.Part{ + {Type: aisdk.PartTypeText, Text: prompt}, + }, }, }) + require.Len(t, messages, 2) + require.NotNil(t, messages[1].OfUser) + require.Len(t, messages[1].OfUser.Content.OfArrayOfContentParts, 1) + require.NotNil(t, messages[1].OfUser.Content.OfArrayOfContentParts[0].OfText) + require.Equal(t, messages[1].OfUser.Content.OfArrayOfContentParts[0].OfText.Text, prompt) require.NoError(t, err) stream := client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{ @@ -156,9 +168,13 @@ func TestMessagesToOpenAI_Live(t *testing.T) { require.NoError(t, err) dataStream := aisdk.OpenAIToDataStream(stream) - for _, err := range dataStream { + var streamErr error + dataStream(func(part aisdk.DataStreamPart, err error) bool { if err != nil { - t.Fatal(err) + streamErr = err + return false } - } + return true + }) + require.NoError(t, streamErr) } diff --git a/stream.go b/stream.go index ea5302d..9c5991a 100644 --- a/stream.go +++ b/stream.go @@ -19,7 +19,7 @@ type Chat struct { type DataStream iter.Seq2[DataStreamPart, error] // WithToolCalling passes tool calls to the handleToolCall function. -func (s DataStream) WithToolCalling(handleToolCall func(toolCall ToolCall) ToolCallResult) DataStream { +func (s DataStream) WithToolCalling(handleToolCall func(toolCall ToolCall) any) DataStream { return func(yield func(DataStreamPart, error) bool) { // Track partial tool calls by ID partialToolCalls := make(map[string]struct { @@ -338,8 +338,8 @@ func (p ToolCallStreamPart) Format() (string, error) { // ToolResultStreamPart corresponds to TYPE_ID 'a'. type ToolResultStreamPart struct { - ToolCallID string `json:"toolCallId"` - Result ToolCallResult `json:"result"` + ToolCallID string `json:"toolCallId"` + Result any `json:"result"` } func (p ToolResultStreamPart) TypeID() byte { return 'a' } @@ -498,7 +498,7 @@ type ToolInvocation struct { ToolCallID string `json:"toolCallId"` ToolName string `json:"toolName"` Args any `json:"args"` - Result ToolCallResult `json:"result,omitempty"` + Result any `json:"result,omitempty"` } func WriteDataStreamHeaders(w http.ResponseWriter) { @@ -770,7 +770,7 @@ func (a *DataStreamAccumulator) Usage() Usage { return a.usage } -func toolResultToParts(result ToolCallResult) ([]Part, error) { +func toolResultToParts(result any) ([]Part, error) { switch r := result.(type) { case []Part: return r, nil diff --git a/stream_test.go b/stream_test.go index 75c6859..772611c 100644 --- a/stream_test.go +++ b/stream_test.go @@ -3,7 +3,7 @@ package aisdk_test import ( "testing" - "github.com/kylecarbs/aisdk-go" + "github.com/coder/aisdk-go" "github.com/stretchr/testify/require" )