diff --git a/src/xAI.Tests/GrokConversionTests.cs b/src/xAI.Tests/GrokConversionTests.cs new file mode 100644 index 0000000..34b7ce3 --- /dev/null +++ b/src/xAI.Tests/GrokConversionTests.cs @@ -0,0 +1,230 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Google.Protobuf.WellKnownTypes; +using Microsoft.Extensions.AI; +using OpenAI.Responses; +using xAI.Protocol; + +namespace xAI; + +public class GrokConversionTests +{ + [Fact] + public void AsTool_WithWebSearch() + { + var webSearch = new HostedWebSearchTool(); + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.WebSearch); + } + + [Fact] + public void AsTool_WithWebSearch_ThrowsIfAllowedAndExcluded() + { + var webSearch = new GrokSearchTool + { + AllowedDomains = ["Foo"], + ExcludedDomains = ["Bar"] + }; + + Assert.Throws(() => webSearch.AsProtocolTool()); + } + + [Fact] + public void AsTool_WithWebSearch_AllowedDomains() + { + var webSearch = new GrokSearchTool + { + AllowedDomains = ["foo.com", "bar.com"], + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.WebSearch); + Assert.Equal(["foo.com", "bar.com"], tool.WebSearch.AllowedDomains); + } + + [Fact] + public void AsTool_WithWebSearch_ExcludedDomains() + { + var webSearch = new GrokSearchTool + { + ExcludedDomains = ["foo.com", "bar.com"], + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.WebSearch); + Assert.Equal(["foo.com", "bar.com"], tool.WebSearch.ExcludedDomains); + } + + [Fact] + public void AsTool_WithWebSearch_ImageUnderstanding() + { + var webSearch = new GrokSearchTool + { + EnableImageUnderstanding = true + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.WebSearch); + Assert.True(tool.WebSearch.EnableImageUnderstanding); + } + + [Fact] + public void AsTool_WithXSearch_ThrowsIfAllowedAndExcluded() + { + var webSearch = new GrokXSearchTool + { + AllowedHandles = ["Foo"], + ExcludedHandles = ["Bar"] + }; + + Assert.Throws(() => webSearch.AsProtocolTool()); + } + + [Fact] + public void AsTool_WithXSearch_AllowedHandles() + { + var webSearch = new GrokXSearchTool + { + AllowedHandles = ["foo", "bar"], + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.Equal(["foo", "bar"], tool.XSearch.AllowedXHandles); + } + + [Fact] + public void AsTool_WithXSearch_ExcludedDomains() + { + var webSearch = new GrokXSearchTool + { + ExcludedHandles = ["foo", "bar"], + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.Equal(["foo", "bar"], tool.XSearch.ExcludedXHandles); + } + + [Fact] + public void AsTool_WithXSearch_ImageUnderstanding() + { + var webSearch = new GrokXSearchTool + { + EnableImageUnderstanding = true + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.True(tool.XSearch.EnableImageUnderstanding); + } + + [Fact] + public void AsTool_WithXSearch_VideoUnderstanding() + { + var webSearch = new GrokXSearchTool + { + EnableVideoUnderstanding = true + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.True(tool.XSearch.EnableVideoUnderstanding); + } + + [Fact] + public void AsTool_WithXSearch_FromTo() + { + var webSearch = new GrokXSearchTool + { + FromDate = DateOnly.FromDateTime(DateTime.UtcNow.Subtract(TimeSpan.FromDays(1))), + ToDate = DateOnly.FromDateTime(DateTime.UtcNow) + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.Equal(tool.XSearch.FromDate, Timestamp.FromDateTime(webSearch.FromDate.Value.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc))); + Assert.Equal(tool.XSearch.ToDate, Timestamp.FromDateTime(webSearch.ToDate.Value.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc))); + } + + [Fact] + public void AsTool_WithFunctionTool() + { + var functionTool = AIFunctionFactory.Create(() => "", "Name", "Description"); + + var tool = functionTool.AsProtocolTool(); + + Assert.NotNull(tool?.Function); + Assert.Equal("Name", tool.Function.Name); + Assert.Equal("Description", tool.Function.Description); + } + + [Fact] + public void AsTool_WithCodeExecution() + { + var codeTool = new HostedCodeInterpreterTool(); + + var tool = codeTool.AsProtocolTool(); + + Assert.NotNull(tool?.CodeExecution); + } + + [Fact] + public void AsTool_WithHostedFileSearchTool() + { + var collectionId = Guid.NewGuid().ToString(); + var instructions = "Return N/A if no results found"; + var fileSearch = new HostedFileSearchTool() + { + MaximumResultCount = 50, + Inputs = [new HostedVectorStoreContent(collectionId)] + }.WithInstructions(instructions); + + var tool = fileSearch.AsProtocolTool(); + + Assert.NotNull(tool?.CollectionsSearch); + Assert.Contains(collectionId, tool.CollectionsSearch.CollectionIds); + Assert.Equal(50, tool.CollectionsSearch.Limit); + Assert.Equal(instructions, tool.CollectionsSearch.Instructions); + } + + [Fact] + public void AsTool_WithHostedMcpTool() + { + var accessToken = Guid.NewGuid().ToString(); + var headers = new Dictionary + { + ["foo"] = "baz" + }; + var mcpTool = new HostedMcpServerTool("foo", "foo.com", new Dictionary + { + ["x-extra"] = "bar", + [nameof(MCP.ExtraHeaders)] = headers + }) + { + AllowedTools = ["list"], + AuthorizationToken = accessToken, + }; + + var tool = mcpTool.AsProtocolTool(); + + Assert.NotNull(tool?.Mcp); + Assert.Equal("foo", tool.Mcp.ServerLabel); + Assert.Equal("foo.com", tool.Mcp.ServerUrl); + Assert.Contains("list", tool.Mcp.AllowedToolNames); + Assert.Equal(accessToken, tool.Mcp.Authorization); + Assert.Contains(KeyValuePair.Create("x-extra", "bar"), tool.Mcp.ExtraHeaders); + Assert.Contains(KeyValuePair.Create("foo", "baz"), tool.Mcp.ExtraHeaders); + } +} diff --git a/src/xAI/Extensions/ChatExtensions.cs b/src/xAI/Extensions/ChatExtensions.cs index 1b2e874..967c516 100644 --- a/src/xAI/Extensions/ChatExtensions.cs +++ b/src/xAI/Extensions/ChatExtensions.cs @@ -1,8 +1,12 @@ -using Microsoft.Extensions.AI; +using System.ComponentModel; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Options; +using xAI.Protocol; namespace xAI; /// Extensions for . +[EditorBrowsable(EditorBrowsableState.Never)] public static partial class ChatOptionsExtensions { extension(ChatOptions options) @@ -14,4 +18,34 @@ public string? EndUserId set => (options.AdditionalProperties ??= [])["EndUserId"] = value; } } +} + +/// Grok-specific extensions for . +[EditorBrowsable(EditorBrowsableState.Never)] +public static partial class HostedFileSearchToolExtensions +{ + extension(HostedFileSearchTool tool) + { + /// + /// User-defined instructions to be included in the search query. Defaults to generic search + /// instructions used by the collections search backend if unset. + /// + public HostedFileSearchTool WithInstructions(string instructions) => new(new Dictionary + { + [nameof(CollectionsSearch.Instructions)] = Throw.IfNullOrEmpty(instructions) + }) + { + Inputs = tool.Inputs, + MaximumResultCount = tool.MaximumResultCount, + }; + } +} + +static partial class AIToolExtensions +{ + extension(AITool tool) + { + public T? GetProperty(string name) => + tool.AdditionalProperties?.TryGetValue(name, out var value) is true && value is T typed ? typed : default; + } } \ No newline at end of file diff --git a/src/xAI/GrokChatClient.cs b/src/xAI/GrokChatClient.cs index 20315e2..d396c65 100644 --- a/src/xAI/GrokChatClient.cs +++ b/src/xAI/GrokChatClient.cs @@ -328,82 +328,8 @@ codeResult.RawRepresentation is ToolCall codeToolCall && if (options?.Tools is not null) { - foreach (var tool in options.Tools) - { - if (tool is AIFunction functionTool) - { - var function = new Function - { - Name = functionTool.Name, - Description = functionTool.Description, - Parameters = JsonSerializer.Serialize(functionTool.JsonSchema) - }; - request.Tools.Add(new Tool { Function = function }); - } - else if (tool is HostedWebSearchTool webSearchTool) - { - if (webSearchTool is GrokXSearchTool xSearch) - { - var toolProto = new XSearch - { - EnableImageUnderstanding = xSearch.EnableImageUnderstanding, - EnableVideoUnderstanding = xSearch.EnableVideoUnderstanding, - }; - - if (xSearch.AllowedHandles is { } allowed) toolProto.AllowedXHandles.AddRange(allowed); - if (xSearch.ExcludedHandles is { } excluded) toolProto.ExcludedXHandles.AddRange(excluded); - if (xSearch.FromDate is { } from) toolProto.FromDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(from.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); - if (xSearch.ToDate is { } to) toolProto.ToDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(to.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); - - request.Tools.Add(new Tool { XSearch = toolProto }); - } - else if (webSearchTool is GrokSearchTool grokSearch) - { - var toolProto = new WebSearch - { - EnableImageUnderstanding = grokSearch.EnableImageUnderstanding, - }; - - if (grokSearch.AllowedDomains is { } allowed) toolProto.AllowedDomains.AddRange(allowed); - if (grokSearch.ExcludedDomains is { } excluded) toolProto.ExcludedDomains.AddRange(excluded); - - request.Tools.Add(new Tool { WebSearch = toolProto }); - } - else - { - request.Tools.Add(new Tool { WebSearch = new WebSearch() }); - } - } - else if (tool is HostedCodeInterpreterTool) - { - request.Tools.Add(new Tool { CodeExecution = new CodeExecution { } }); - } - else if (tool is HostedFileSearchTool fileSearch) - { - var toolProto = new CollectionsSearch(); - - if (fileSearch.Inputs?.OfType() is { } vectorStores) - toolProto.CollectionIds.AddRange(vectorStores.Select(x => x.VectorStoreId).Distinct()); - - if (fileSearch.MaximumResultCount is { } maxResults) - toolProto.Limit = maxResults; - - request.Tools.Add(new Tool { CollectionsSearch = toolProto }); - } - else if (tool is HostedMcpServerTool mcpTool) - { - request.Tools.Add(new Tool - { - Mcp = new MCP - { - Authorization = mcpTool.AuthorizationToken, - ServerLabel = mcpTool.ServerName, - ServerUrl = mcpTool.ServerAddress, - AllowedToolNames = { mcpTool.AllowedTools ?? Array.Empty() } - } - }); - } - } + foreach (var tool in options.Tools.Select(x => x.AsProtocolTool(options))) + if (tool is not null) request.Tools.Add(tool); } if (options?.ResponseFormat is ChatResponseFormatJson) diff --git a/src/xAI/GrokProtocolExtensions.cs b/src/xAI/GrokProtocolExtensions.cs new file mode 100644 index 0000000..43d52d6 --- /dev/null +++ b/src/xAI/GrokProtocolExtensions.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using xAI.Protocol; + +namespace xAI; + +/// Provides extension methods for working with xAI protocol types. +[EditorBrowsable(EditorBrowsableState.Never)] +public static class GrokProtocolExtensions +{ + /// Creates an xAI protocol from an . + /// The tool to convert. + /// An xAI protocol representing or if there is no mapping. + /// is . + public static Tool? AsProtocolTool(this AITool tool, ChatOptions? options = null) => ToProtocolTool(Throw.IfNull(tool), options); + + static Tool? ToProtocolTool(AITool tool, ChatOptions? options = null) + { + switch (tool) + { + case AIFunction functionTool: + return new Tool + { + Function = new Function + { + Name = functionTool.Name, + Description = functionTool.Description, + Parameters = JsonSerializer.Serialize(functionTool.JsonSchema) + } + }; + + case HostedWebSearchTool webSearchTool: + if (webSearchTool is GrokXSearchTool xSearchTool) + { + var xsearch = new XSearch + { + EnableImageUnderstanding = xSearchTool.EnableImageUnderstanding, + EnableVideoUnderstanding = xSearchTool.EnableVideoUnderstanding, + }; + + if (xSearchTool.AllowedHandles is { Count: > 0 } && + xSearchTool.ExcludedHandles is { Count: > 0 }) + throw new NotSupportedException($"Cannot use {nameof(GrokXSearchTool.AllowedHandles)} and {nameof(GrokXSearchTool.ExcludedHandles)} together in the same request."); + + if (xSearchTool.AllowedHandles is { } allowed) + xsearch.AllowedXHandles.AddRange(allowed); + if (xSearchTool.ExcludedHandles is { } excluded) + xsearch.ExcludedXHandles.AddRange(excluded); + if (xSearchTool.FromDate is { } from) + xsearch.FromDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(from.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); + if (xSearchTool.ToDate is { } to) + xsearch.ToDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(to.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); + + return new Tool { XSearch = xsearch }; + } + else if (webSearchTool is GrokSearchTool grokSearch) + { + var websearch = new WebSearch + { + EnableImageUnderstanding = grokSearch.EnableImageUnderstanding, + }; + + if (grokSearch.AllowedDomains is { Count: > 0 } && + grokSearch.ExcludedDomains is { Count: > 0 }) + throw new NotSupportedException($"Cannot use {nameof(GrokSearchTool.AllowedDomains)} and {nameof(GrokSearchTool.ExcludedDomains)} together in the same request."); + + if (grokSearch.AllowedDomains is { } allowed) + websearch.AllowedDomains.AddRange(allowed); + if (grokSearch.ExcludedDomains is { } excluded) + websearch.ExcludedDomains.AddRange(excluded); + + return new Tool { WebSearch = websearch }; + } + else + { + return new Tool { WebSearch = new WebSearch() }; + } + + case HostedCodeInterpreterTool: + return new Tool { CodeExecution = new CodeExecution { } }; + + case HostedFileSearchTool fileSearch: + var collectionTool = new CollectionsSearch(); + + if (fileSearch.Inputs?.OfType() is { } vectorStores) + collectionTool.CollectionIds.AddRange(vectorStores.Select(x => x.VectorStoreId).Distinct()); + + if (fileSearch.MaximumResultCount is { } maxResults) + collectionTool.Limit = maxResults; + if (fileSearch.GetProperty(nameof(CollectionsSearch.Instructions)) is { } instructions) + collectionTool.Instructions = instructions; + + return new Tool { CollectionsSearch = collectionTool }; + + case HostedMcpServerTool mcpTool: + var mcp = new MCP + { + Authorization = mcpTool.AuthorizationToken, + ServerLabel = mcpTool.ServerName, + ServerUrl = mcpTool.ServerAddress, + AllowedToolNames = { mcpTool.AllowedTools ?? Array.Empty() }, + }; + + // We can set an entire dictionary with a specific key + if (mcpTool.GetProperty>(nameof(MCP.ExtraHeaders)) is { } headers) + mcp.ExtraHeaders.Add(headers); + + // Or also the more intuitive mapping of additional properties directly. + foreach (var kv in mcpTool.AdditionalProperties) + if (kv.Value is string value) + mcp.ExtraHeaders.Add(kv.Key, value); + + return new Tool { Mcp = mcp }; + + default: + return null; + } + } +}