diff --git a/README.md b/README.md index df97efd..6b8013d 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ A Go implementation of Vercel's AI SDK [Data Stream Protocol](https://sdk.vercel.ai/docs/ai-sdk-ui/stream-protocol#data-stream-example). -- Supports OpenAI, Google, and Anthropic (with Bedrock support) +- Supports OpenAI, Google, Anthropic (with Bedrock support), and Grok - Examples for integrating `useChat` - Chain tool usage in Go, just like `maxSteps` @@ -50,8 +50,9 @@ Run tests with `go test`. Start the `useChat` demo with: export OPENAI_API_KEY= export ANTHROPIC_API_KEY= export GOOGLE_API_KEY= +export GROK_API_KEY= cd demo bun i bun dev -``` \ No newline at end of file +``` diff --git a/go.mod b/go.mod index ad03a56..dcf1341 100644 --- a/go.mod +++ b/go.mod @@ -1,44 +1,31 @@ module github.com/kylecarbs/aisdk-go -go 1.23.7 +go 1.23.0 require ( github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3 github.com/google/uuid v1.6.0 + github.com/hamguy/go_grok v1.0.0 github.com/openai/openai-go v0.1.0-beta.6 github.com/stretchr/testify v1.10.0 - google.golang.org/genai v0.7.0 + google.golang.org/genai v0.6.0 ) require ( - cloud.google.com/go v0.120.0 // indirect - cloud.google.com/go/auth v0.15.0 // indirect - cloud.google.com/go/compute/metadata v0.6.0 // indirect + cloud.google.com/go v0.121.2 // indirect + cloud.google.com/go/compute/metadata v0.7.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-logr/logr v1.4.2 // indirect - github.com/go-logr/stdr v1.2.2 // indirect github.com/google/go-cmp v0.7.0 // indirect - github.com/google/s2a-go v0.1.9 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.5 // indirect - github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/gorilla/websocket v1.5.3 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect - go.opentelemetry.io/otel v1.35.0 // indirect - go.opentelemetry.io/otel/metric v1.35.0 // indirect - go.opentelemetry.io/otel/trace v1.35.0 // indirect - golang.org/x/crypto v0.36.0 // indirect - golang.org/x/net v0.37.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/text v0.23.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 // indirect - google.golang.org/grpc v1.71.0 // indirect - google.golang.org/protobuf v1.36.6 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.33.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index b83768c..12f1528 100644 --- a/go.sum +++ b/go.sum @@ -1,42 +1,33 @@ -cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA= -cloud.google.com/go v0.120.0/go.mod h1:/beW32s8/pGRuj4IILWQNd4uuebeT4dkOhKmkfit64Q= -cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps= -cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8= -cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= -cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go v0.121.2 h1:v2qQpN6Dx9x2NmwrqlesOt3Ys4ol5/lFZ6Mg1B7OJCg= +cloud.google.com/go v0.121.2/go.mod h1:nRFlrHq39MNVWu+zESP2PosMWA0ryJw8KUBZ2iZpxbw= +cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU= +cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3 h1:b5t1ZJMvV/l99y4jbz7kRFdUp3BSDkI8EhSlHczivtw= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= -github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= -github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.5 h1:VgzTY2jogw3xt39CusEnFJWm7rlsq5yL5q9XdLOuP5g= -github.com/googleapis/enterprise-certificate-proxy v0.3.5/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= -github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= -github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hamguy/go_grok v1.0.0 h1:s96C3dwYLJXKTelGNjd8NAlqbOYk5xwZlvNy0OkowRA= +github.com/hamguy/go_grok v1.0.0/go.mod h1:u07rqUwNjXhyGv6N/z+9sQAJGdTcvej6iBSB50dpFJY= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/openai/openai-go v0.1.0-beta.6 h1:JquYDpprfrGnlKvQQg+apy9dQ8R9mIrm+wNvAPp6jCQ= github.com/openai/openai-go v0.1.0-beta.6/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -51,38 +42,12 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 h1:CV7UdSGJt/Ao6Gp4CXckLxVRRsRgDHoI8XjbL3PDl8s= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0/go.mod h1:FRmFuRJfag1IZ2dPkHnEoSFVgTVPUd2qf5Vi69hLb8I= -go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= -go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= -go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= -go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= -go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= -go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= -go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce1EK0Gyvahk= -go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w= -go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= -go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= -google.golang.org/genai v0.7.0 h1:TINBYXnP+K+D8b16LfVyb6XR3kdtieXy6nJsGoEXcBc= -google.golang.org/genai v0.7.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 h1:e0AIkUUhxyBKh6ssZNrAMeqhA7RKUj42346d1y02i2g= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= -google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= -google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +google.golang.org/genai v0.6.0 h1:S9eDmXHPPqiWrKO2G7ydTNQ70fG1y1+ttR6zsFCPJd0= +google.golang.org/genai v0.6.0/go.mod h1:yPyKKBezIg2rqZziLhHQ5CD62HWr7sLDLc2PDzdrNVs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/grok.go b/grok.go new file mode 100644 index 0000000..7f63901 --- /dev/null +++ b/grok.go @@ -0,0 +1,178 @@ +package aisdk + +import ( + "encoding/base64" + "encoding/json" + "fmt" + + "github.com/hamguy/go_grok/pkg/xai" +) + +func ToolsToGrok(tools []Tool) ([]map[string]interface{}, error) { + if len(tools) == 0 { + return nil, nil + } + + var grokTools []map[string]interface{} + for _, tool := range tools { + grokTool := map[string]interface{}{ + "name": tool.Name, + "description": tool.Description, + } + + if tool.Schema.Properties != nil { + schemaParams := map[string]interface{}{ + "type": "object", + "properties": tool.Schema.Properties, + } + if len(tool.Schema.Required) > 0 { + schemaParams["required"] = tool.Schema.Required + } + grokTool["parameters"] = schemaParams + } + + grokTools = append(grokTools, grokTool) + } + + return grokTools, nil +} + +func MessagesToGrok(messages []Message) ([]map[string]interface{}, error) { + if len(messages) == 0 { + return nil, nil + } + + var grokMessages []map[string]interface{} + for _, message := range messages { + var content interface{} + + if message.Content != "" { + content = message.Content + } else if len(message.Parts) > 0 { + var parts []map[string]interface{} + for _, part := range message.Parts { + switch part.Type { + case PartTypeText: + if part.Text != "" { + content = part.Text + } + case PartTypeFile: + if part.Data != nil && len(part.Data) > 0 { + imgPart := map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": fmt.Sprintf("data:%s;base64,%s", part.MimeType, base64.StdEncoding.EncodeToString(part.Data)), + }, + } + parts = append(parts, imgPart) + } + case PartTypeToolInvocation: + continue + } + } + + if len(parts) > 0 { + content = parts + } + } + + if content == nil { + continue + } + + grokMessage := map[string]interface{}{ + "role": message.Role, + "content": content, + } + + for _, part := range message.Parts { + if part.Type == PartTypeToolInvocation && part.ToolInvocation != nil && part.ToolInvocation.State == ToolInvocationStateResult { + grokMessage["tool_calls"] = []map[string]interface{}{ + { + "id": part.ToolInvocation.ToolCallID, + "type": "function", + "function": map[string]interface{}{ + "name": part.ToolInvocation.ToolName, + "arguments": part.ToolInvocation.Args, + }, + }, + } + } + } + + grokMessages = append(grokMessages, grokMessage) + } + + return grokMessages, nil +} + +func GrokToDataStream(streamChan <-chan *xai.ChatCompletionResponse) DataStream { + return func(yield func(DataStreamPart, error) bool) { + var finalReason FinishReason = FinishReasonUnknown + + for chunk := range streamChan { + if chunk == nil { + continue + } + + if len(chunk.Choices) == 0 { + continue + } + + choice := chunk.Choices[0] + + if choice.FinishReason != "" { + switch choice.FinishReason { + case "stop": + finalReason = FinishReasonStop + case "length": + finalReason = FinishReasonLength + case "tool_calls": + finalReason = FinishReasonToolCalls + } + + if !yield(FinishStepStreamPart{ + IsContinued: false, + FinishReason: finalReason, + }, nil) { + return + } + continue + } + + if choice.Delta != nil && choice.Delta.Content != "" { + if !yield(TextStreamPart{Content: choice.Delta.Content}, nil) { + return + } + } + + if choice.Delta != nil && choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0 { + for _, toolCall := range choice.Delta.ToolCalls { + toolName := toolCall.Function.Name + var args map[string]any + if toolCall.Function.Arguments != nil { + if argsStr, ok := toolCall.Function.Arguments.(string); ok { + if err := json.Unmarshal([]byte(argsStr), &args); err != nil { + args = map[string]any{"text": argsStr} + } + } + } + + if toolName != "" || len(args) > 0 { + if !yield(ToolCallStreamPart{ + ToolCallID: toolCall.ID, + ToolName: toolName, + Args: args, + }, nil) { + return + } + } + } + } + } + + yield(FinishMessageStreamPart{ + FinishReason: finalReason, + }, nil) + } +} diff --git a/grok_test.go b/grok_test.go new file mode 100644 index 0000000..22c6483 --- /dev/null +++ b/grok_test.go @@ -0,0 +1,140 @@ +package aisdk_test + +import ( + "os" + "testing" + + "github.com/hamguy/go_grok/pkg/xai" + "github.com/kylecarbs/aisdk-go" + "github.com/stretchr/testify/require" +) + +func TestGrokToDataStream(t *testing.T) { + t.Parallel() + + mockResponse := &xai.ChatCompletionResponse{ + Choices: []xai.Choice{ + { + Delta: &xai.Message{ + Content: "Hello, world!", + }, + }, + }, + } + + streamChan := make(chan *xai.ChatCompletionResponse, 1) + streamChan <- mockResponse + close(streamChan) + + var acc aisdk.DataStreamAccumulator + stream := aisdk.GrokToDataStream(streamChan) + stream = stream.WithAccumulator(&acc) + + for _, err := range stream { + require.NoError(t, err) + } + + messages := acc.Messages() + require.Len(t, messages, 1) + require.Equal(t, "assistant", messages[0].Role) + require.Equal(t, "Hello, world!", messages[0].Content) +} + +func TestMessagesToGrok(t *testing.T) { + t.Parallel() + + messages := []aisdk.Message{ + { + Role: "system", + Content: "You are a helpful assistant.", + }, + { + Role: "user", + Content: "Hello, how are you?", + }, + } + + grokMessages, err := aisdk.MessagesToGrok(messages) + require.NoError(t, err) + require.Len(t, grokMessages, 2) + require.Equal(t, "system", grokMessages[0]["role"]) + require.Equal(t, "You are a helpful assistant.", grokMessages[0]["content"]) + require.Equal(t, "user", grokMessages[1]["role"]) + require.Equal(t, "Hello, how are you?", grokMessages[1]["content"]) +} + +func TestToolsToGrok(t *testing.T) { + t.Parallel() + + tools := []aisdk.Tool{ + { + Name: "get_weather", + Description: "Get the weather for a location", + Schema: aisdk.Schema{ + Required: []string{"location"}, + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The location to get weather for", + }, + }, + }, + }, + } + + grokTools, err := aisdk.ToolsToGrok(tools) + require.NoError(t, err) + require.Len(t, grokTools, 1) + require.Equal(t, "get_weather", grokTools[0]["name"]) + require.Equal(t, "Get the weather for a location", grokTools[0]["description"]) + require.NotNil(t, grokTools[0]["parameters"]) +} + +func TestMessagesToGrok_Live(t *testing.T) { + t.Parallel() + + apiKey := os.Getenv("GROK_API_KEY") + if apiKey == "" { + t.Skip("GROK_API_KEY is not set") + } + + messages := []aisdk.Message{ + { + Role: "system", + Content: "You are a helpful assistant.", + }, + { + Role: "user", + Content: "Hello, how are you?", + }, + } + + grokMessages, err := aisdk.MessagesToGrok(messages) + require.NoError(t, err) + require.Len(t, grokMessages, 2) + + client := xai.NewClient(apiKey, "grok-2-1212") + + streamTrue := true + req := &xai.ChatCompletionRequest{ + Model: "grok-2-1212", + Messages: grokMessages, + Stream: &streamTrue, + } + + streamChan, err := client.CreateChatCompletionStream(req) + require.NoError(t, err) + + var acc aisdk.DataStreamAccumulator + stream := aisdk.GrokToDataStream(streamChan.Stream) + stream = stream.WithAccumulator(&acc) + + for _, err := range stream { + require.NoError(t, err) + } + + accMessages := acc.Messages() + require.Len(t, accMessages, 1) + require.Equal(t, "assistant", accMessages[0].Role) + require.NotEmpty(t, accMessages[0].Content) +}