From 284d24154bf6750b0aab9c48e69d24917000f92f Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Tue, 30 Dec 2025 23:15:41 -0300 Subject: [PATCH] Bring in M.E.AI integration from Devlooped.Extensions.AI This will now be the core of the xAI package. --- .netconfig | 7 +- readme.md | 189 +++- src/xAI.Protocol/readme.md | 2 +- src/xAI.Protocol/xAI.Protocol.csproj | 3 +- src/xAI.Tests/ChatClientTests.cs | 509 ++++++++++ src/xAI.Tests/{ => Extensions}/Attributes.cs | 0 src/xAI.Tests/Extensions/CallHelpers.cs | 46 + src/xAI.Tests/Extensions/Configuration.cs | 31 + src/xAI.Tests/Extensions/Logging.cs | 40 + src/xAI.Tests/xAI.Tests.csproj | 34 +- src/xAI/Extensions/ChatExtensions.cs | 17 + src/xAI/Extensions/Throw.cs | 992 +++++++++++++++++++ src/xAI/GrokChatClient.cs | 465 +++++++++ src/xAI/GrokChatOptions.cs | 32 + src/xAI/GrokClient.cs | 47 + src/xAI/GrokClientExtensions.cs | 13 + src/xAI/GrokClientOptions.cs | 16 + src/xAI/GrokSearchTool.cs | 23 + src/xAI/GrokXSearch.cs | 24 + src/xAI/HostedToolCallContent.cs | 12 + src/xAI/HostedToolResultContent.cs | 17 + src/xAI/readme.md | 8 + src/xAI/xAI.csproj | 32 + xAI.slnx | 1 + 24 files changed, 2546 insertions(+), 14 deletions(-) create mode 100644 src/xAI.Tests/ChatClientTests.cs rename src/xAI.Tests/{ => Extensions}/Attributes.cs (100%) create mode 100644 src/xAI.Tests/Extensions/CallHelpers.cs create mode 100644 src/xAI.Tests/Extensions/Configuration.cs create mode 100644 src/xAI.Tests/Extensions/Logging.cs create mode 100644 src/xAI/Extensions/ChatExtensions.cs create mode 100644 src/xAI/Extensions/Throw.cs create mode 100644 src/xAI/GrokChatClient.cs create mode 100644 src/xAI/GrokChatOptions.cs create mode 100644 src/xAI/GrokClient.cs create mode 100644 src/xAI/GrokClientExtensions.cs create mode 100644 src/xAI/GrokClientOptions.cs create mode 100644 src/xAI/GrokSearchTool.cs create mode 100644 src/xAI/GrokXSearch.cs create mode 100644 src/xAI/HostedToolCallContent.cs create mode 100644 src/xAI/HostedToolResultContent.cs create mode 100644 src/xAI/readme.md create mode 100644 src/xAI/xAI.csproj diff --git a/.netconfig b/.netconfig index 6719ea3..0f5396d 100644 --- a/.netconfig +++ b/.netconfig @@ -180,7 +180,7 @@ sha = 407aa2d9319f5db12964540810b446fecc22d419 etag = 0dca55f20a72d3279554837f4eba867a1de37fe0f4a7535c2d9bc43867361cc5 weak -[file "src/Tests/Attributes.cs"] +[file "src/xAI.Tests/Extensions/Attributes.cs"] url = https://github.com/devlooped/catbag/blob/main/Xunit/Attributes.cs sha = 40914971d4d6b42d6f8a90923b131136f7e609a5 etag = c77e7b435ce1df06fb60a3b0e15a0833d8e45d4d19f366c6184140ebb4814b1a @@ -190,3 +190,8 @@ sha = 666a2a7c315f72199c418f11482a950fc69a8901 etag = 91ea15c07bfd784036c6ca931f5b2df7e9767b8367146d96c79caef09d63899f weak +[file "src/xAI/Extensions/Throw.cs"] + url = https://github.com/devlooped/catbag/blob/main/System/Throw.cs + sha = 3012d56be7554c483e5c5d277144c063969cada9 + etag = 43c81c6c6dcdf5baee40a9e3edc5e871e473e6c954c901b82bb87a3a48888ea0 + weak diff --git a/readme.md b/readme.md index b6b9ed9..56dd496 100644 --- a/readme.md +++ b/readme.md @@ -1,4 +1,4 @@ -![Icon](assets/icon.png) xAI .NET SDK +![Icon](assets/icon.png) .NET SDK ============ [![Version](https://img.shields.io/nuget/vpre/xAI.svg?color=royalblue)](https://www.nuget.org/packages/xAI) @@ -6,7 +6,8 @@ [![EULA](https://img.shields.io/badge/EULA-OSMF-blue?labelColor=black&color=C9FF30)](osmfeula.txt) [![OSS](https://img.shields.io/github/license/devlooped/oss.svg?color=blue)](license.txt) -xAI .NET SDK based on the official gRPC API reference from xAI +xAI .NET SDK based on the official gRPC API reference from xAI with integration for +Microsoft.Extensions.AI and Microsoft.Agents.AI. ## Open Source Maintenance Fee @@ -21,10 +22,188 @@ OSMF tier. A single fee covers all of [Devlooped packages](https://www.nuget.org - + +xAI/Grok integration for Microsoft.Extensions.AI `IChatClient` with full support for all +[agentic tools](https://docs.x.ai/docs/guides/tools/overview): + +```csharp +var grok = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) + .AsIChatClient("grok-4.1-fast"); +``` +## Web Search + +```csharp +var messages = new Chat() +{ + { "system", "You are an AI assistant that knows how to search the web." }, + { "user", "What's Tesla stock worth today? Search X and the news for latest info." }, +}; + +var grok = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!).AsIChatClient("grok-4.1-fast"); + +var options = new ChatOptions +{ + Tools = [new HostedWebSearchTool()] // 👈 compatible with OpenAI +}; + +var response = await grok.GetResponseAsync(messages, options); +``` + +In addition to basic web search as shown above, Grok supports more +[advanced search](https://docs.x.ai/docs/guides/tools/search-tools) scenarios, +which can be opted-in by using Grok-specific types: + +```csharp +var grok = new GrokChatClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) + .AsIChatClient("grok-4.1-fast"); +var response = await grok.GetResponseAsync( + "What are the latest product news by Tesla?", + new ChatOptions + { + Tools = [new GrokSearchTool() + { + AllowedDomains = [ "ir.tesla.com" ] + }] + }); +``` + +You can alternatively set `ExcludedDomains` instead, and enable image +understanding with `EnableImageUndestanding`. Learn more about these filters +at [web search parameters](https://docs.x.ai/docs/guides/tools/search-tools#web-search-parameters). + +## X Search + +In addition to web search, Grok also supports searching on X (formerly Twitter): + +```csharp +var response = await grok.GetResponseAsync( + "What's the latest on Optimus?", + new ChatOptions + { + Tools = [new GrokXSearchTool + { + // AllowedHandles = [...], + // ExcludedHandles = [...], + // EnableImageUnderstanding = true, + // EnableVideoUnderstanding = true, + // FromDate = ..., + // ToDate = ..., + }] + }); +``` + +Learn more about available filters at [X search parameters](https://docs.x.ai/docs/guides/tools/search-tools#x-search-parameters). + +You can combine both web and X search in the same request by adding both tools. + +## Code Execution + +The code execution tool enables Grok to write and execute Python code in real-time, +dramatically expanding its capabilities beyond text generation. This powerful feature +allows Grok to perform precise calculations, complex data analysis, statistical +computations, and solve mathematical problems that would be impossible through text alone. + +This is Grok's equivalent of the OpenAI code interpreter, and is configured the same way: + +```csharp +var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); +var response = await grok.GetResponseAsync( + "Calculate the compound interest for $10,000 at 5% annually for 10 years", + new ChatOptions + { + Tools = [new HostedCodeInterpreterTool()] + }); + +var text = response.Text; +Assert.Contains("$6,288.95", text); +``` + +If you want to access the output from the code execution, you can add that as an +include in the options: + +```csharp +var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); +var options = new GrokChatOptions +{ + Include = { IncludeOption.CodeExecutionCallOutput }, + Tools = [new HostedCodeInterpreterTool()] +}; + +var response = await grok.GetResponseAsync( + "Calculate the compound interest for $10,000 at 5% annually for 10 years", + options); + +var content = response.Messages + .SelectMany(x => x.Contents) + .OfType() + .First(); + +foreach (AIContent output in content.Outputs) + // process outputs from code interpreter +``` + +Learn more about the [code execution tool](https://docs.x.ai/docs/guides/tools/code-execution-tool). + +## Collection Search + +If you maintain a [collection](https://docs.x.ai/docs/key-information/collections), +Grok can perform semantic search on it: + +```csharp +var options = new ChatOptions +{ + Tools = [new HostedFileSearchTool { + Inputs = [new HostedVectorStoreContent("[collection_id]")] + }] +}; +``` + +Learn more about [collection search](https://docs.x.ai/docs/guides/tools/collections-search-tool). + +## Remote MCP + +Remote MCP Tools allow Grok to connect to external MCP (Model Context Protocol) servers. +This example sets up the GitHub MCP server so queries about releases (limited specifically +in this case): + +```csharp +var options = new ChatOptions +{ + Tools = [new HostedMcpServerTool("GitHub", "https://api.githubcopilot.com/mcp/") { + AuthorizationToken = Configuration["GITHUB_TOKEN"]!, + AllowedTools = ["list_releases"], + }] +}; +``` + +Just like with code execution, you can opt-in to surfacing the MCP outputs in +the response: + +```csharp +var options = new GrokChatOptions +{ + // Exposes McpServerToolResultContent in responses + Include = { IncludeOption.McpCallOutput }, + Tools = [new HostedMcpServerTool("GitHub", "https://api.githubcopilot.com/mcp/") { + AuthorizationToken = Configuration["GITHUB_TOKEN"]!, + AllowedTools = ["list_releases"], + }] +}; + +``` + +Learn more about [Remote MCP tools](https://docs.x.ai/docs/guides/tools/remote-mcp-tools). + + +# xAI.Protocol + +[![Version](https://img.shields.io/nuget/vpre/xAI.Protocol.svg?color=royalblue)](https://www.nuget.org/packages/xAI.Protocol) +[![Downloads](https://img.shields.io/nuget/dt/xAI.Protocol.svg?color=green)](https://www.nuget.org/packages/xAI.Protocol) + + ## Usage -This project provides a .NET client for the gRPC API of xAI with full support for all services +The xAI.Protocol package provides a .NET client for the gRPC API of xAI with full support for all services documented in the [official API reference](https://docs.x.ai/docs/grpc-reference) and corresponding [proto files](https://github.com/xai-org/xai-proto/tree/main/proto/xai/api/v1). @@ -54,7 +233,7 @@ ensuring it remains up-to-date with any changes or additions made to the API as See for example the [introduction of tool output and citations](https://github.com/devlooped/GrokClient/pull/3). - + # Sponsors diff --git a/src/xAI.Protocol/readme.md b/src/xAI.Protocol/readme.md index 7be4a46..7dd9088 100644 --- a/src/xAI.Protocol/readme.md +++ b/src/xAI.Protocol/readme.md @@ -4,7 +4,7 @@ Grok client based on the official gRPC API reference from xAI + - \ No newline at end of file diff --git a/src/xAI.Protocol/xAI.Protocol.csproj b/src/xAI.Protocol/xAI.Protocol.csproj index 1b00e59..23822d7 100644 --- a/src/xAI.Protocol/xAI.Protocol.csproj +++ b/src/xAI.Protocol/xAI.Protocol.csproj @@ -13,16 +13,15 @@ + - - diff --git a/src/xAI.Tests/ChatClientTests.cs b/src/xAI.Tests/ChatClientTests.cs new file mode 100644 index 0000000..92b48ff --- /dev/null +++ b/src/xAI.Tests/ChatClientTests.cs @@ -0,0 +1,509 @@ +using System.Text.Json; +using System.Text.Json.Nodes; +using Azure; +using Devlooped.Extensions.AI; +using Microsoft.Extensions.AI; +using Moq; +using Tests.Client.Helpers; +using xAI; +using xAI.Protocol; +using static ConfigurationExtensions; +using Chat = Devlooped.Extensions.AI.Chat; +using OpenAIClientOptions = OpenAI.OpenAIClientOptions; + +namespace xAI.Tests; + +public class ChatClientTests(ITestOutputHelper output) +{ + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesTools() + { + var messages = new Chat() + { + { "system", "You are a bot that invokes the tool get_date when asked for the date." }, + { "user", "What day is today?" }, + }; + + var chat = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4") + .AsBuilder() + .UseLogging(output.AsLoggerFactory()) + .Build(); + + var options = new GrokChatOptions + { + ModelId = "grok-4-fast-non-reasoning", + Tools = [AIFunctionFactory.Create(() => DateTimeOffset.Now.ToString("O"), "get_date")], + AdditionalProperties = new() + { + { "foo", "bar" } + } + }; + + var response = await chat.GetResponseAsync(messages, options); + var getdate = response.Messages + .SelectMany(x => x.Contents.OfType()) + .Any(x => x.Name == "get_date"); + + Assert.True(getdate); + // NOTE: the chat client was requested as grok-3 but the chat options wanted a + // different model and the grok client honors that choice. + Assert.Equal(options.ModelId, response.ModelId); + } + + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesToolAndSearch() + { + var messages = new Chat() + { + { "system", "You use Nasdaq for stocks news and prices." }, + { "user", "What's Tesla stock worth today?" }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4") + .AsBuilder() + .UseFunctionInvocation() + .UseLogging(output.AsLoggerFactory()) + .Build(); + + var getDateCalls = 0; + var options = new GrokChatOptions + { + ModelId = "grok-4-1-fast-non-reasoning", + Search = GrokSearch.Web, + Tools = [AIFunctionFactory.Create(() => + { + getDateCalls++; + return DateTimeOffset.Now.ToString("O"); + }, "get_date", "Gets the current date")], + }; + + var response = await grok.GetResponseAsync(messages, options); + + // The get_date result shows up as a tool role + Assert.Contains(response.Messages, x => x.Role == ChatRole.Tool); + + // Citations include nasdaq.com at least as a web search source + var urls = response.Messages + .SelectMany(x => x.Contents) + .SelectMany(x => x.Annotations?.OfType() ?? []) + .Where(x => x.Url is not null) + .Select(x => x.Url!) + .ToList(); + + Assert.Equal(1, getDateCalls); + Assert.Contains(urls, x => x.Host.EndsWith("nasdaq.com")); + Assert.Contains(urls, x => x.PathAndQuery.Contains("/TSLA")); + Assert.Equal(options.ModelId, response.ModelId); + + var calls = response.Messages + .SelectMany(x => x.Contents.OfType()) + .Select(x => x.RawRepresentation as xAI.Protocol.ToolCall) + .Where(x => x is not null) + .ToList(); + + Assert.NotEmpty(calls); + Assert.Contains(calls, x => x?.Type == xAI.Protocol.ToolCallType.WebSearchTool); + } + + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesSpecificSearchUrl() + { + var messages = new Chat() + { + { "system", "Sos un asistente del Cerro Catedral, usas la funcionalidad de Live Search en el sitio oficial." }, + { "system", $"Hoy es {DateTime.Now.ToString("o")}" }, + { "user", "Que calidad de nieve hay hoy?" }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-1-fast-non-reasoning"); + + var options = new ChatOptions + { + Tools = [new GrokSearchTool() + { + AllowedDomains = [ "catedralaltapatagonia.com" ] + }] + }; + + var response = await grok.GetResponseAsync(messages, options); + var text = response.Text; + + var citations = response.Messages + .SelectMany(x => x.Contents) + .SelectMany(x => x.Annotations ?? []) + .OfType() + .Where(x => x.Url != null) + .Select(x => x.Url!.AbsoluteUri) + .ToList(); + + Assert.Contains("https://partediario.catedralaltapatagonia.com/partediario/", citations); + } + + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesHostedSearchTool() + { + var messages = new Chat() + { + { "system", "You are an AI assistant that knows how to search the web." }, + { "user", "What's Tesla stock worth today? Search X, Yahoo and the news for latest info." }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); + + var options = new GrokChatOptions + { + Include = { IncludeOption.WebSearchCallOutput }, + Tools = [new HostedWebSearchTool()] + }; + + var response = await grok.GetResponseAsync(messages, options); + var text = response.Text; + + Assert.Contains("TSLA", text); + Assert.NotNull(response.ModelId); + + var urls = response.Messages + .SelectMany(x => x.Contents) + .SelectMany(x => x.Annotations?.OfType() ?? []) + .Where(x => x.Url is not null) + .Select(x => x.Url!) + .ToList(); + + Assert.Contains(urls, x => x.Host == "finance.yahoo.com"); + Assert.Contains(urls, x => x.PathAndQuery.Contains("/TSLA")); + } + + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesGrokSearchToolIncludesDomain() + { + var messages = new Chat() + { + { "system", "You are an AI assistant that knows how to search the web." }, + { "user", "What is the latest news about Microsoft?" }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); + + var options = new ChatOptions + { + Tools = [new GrokSearchTool + { + AllowedDomains = ["microsoft.com", "news.microsoft.com"], + }] + }; + + var response = await grok.GetResponseAsync(messages, options); + + Assert.NotNull(response.Text); + Assert.Contains("Microsoft", response.Text); + + var urls = response.Messages + .SelectMany(x => x.Contents) + .SelectMany(x => x.Annotations?.OfType() ?? []) + .Where(x => x.Url is not null) + .Select(x => x.Url!) + .ToList(); + + foreach (var url in urls) + { + output.WriteLine(url.ToString()); + } + + Assert.All(urls, x => x.Host.EndsWith(".microsoft.com")); + } + + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesGrokSearchToolExcludesDomain() + { + var messages = new Chat() + { + { "system", "You are an AI assistant that knows how to search the web." }, + { "user", "What is the latest news about Microsoft?" }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); + + var options = new ChatOptions + { + Tools = [new GrokSearchTool + { + ExcludedDomains = ["blogs.microsoft.com"] + }] + }; + + var response = await grok.GetResponseAsync(messages, options); + + Assert.NotNull(response.Text); + Assert.Contains("Microsoft", response.Text); + + var urls = response.Messages + .SelectMany(x => x.Contents) + .SelectMany(x => x.Annotations?.OfType() ?? []) + .Where(x => x.Url is not null) + .Select(x => x.Url!) + .ToList(); + + foreach (var url in urls) + { + output.WriteLine(url.ToString()); + } + + Assert.DoesNotContain(urls, x => x.Host == "blogs.microsoft.com"); + } + + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesHostedCodeExecution() + { + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); + + var response = await grok.GetResponseAsync( + "Calculate the compound interest for $10,000 at 5% annually for 10 years", + new ChatOptions + { + Tools = [new HostedCodeInterpreterTool()] + }); + + var text = response.Text; + + Assert.Contains("$6,288.95", text); + Assert.NotEmpty(response.Messages + .SelectMany(x => x.Contents) + .OfType()); + + // result content is not available by default + Assert.Empty(response.Messages + .SelectMany(x => x.Contents) + .OfType()); + } + + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesHostedCodeExecutionWithOutput() + { + var messages = new Chat() + { + { "user", "Calculate the compound interest for $10,000 at 5% annually for 10 years" }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); + + var options = new GrokChatOptions + { + Include = { IncludeOption.CodeExecutionCallOutput }, + Tools = [new HostedCodeInterpreterTool()] + }; + + var response = await grok.GetResponseAsync(messages, options); + + Assert.Contains("$6,288.95", response.Text); + Assert.NotEmpty(response.Messages + .SelectMany(x => x.Contents) + .OfType()); + + // result content opted-in is found + Assert.NotEmpty(response.Messages + .SelectMany(x => x.Contents) + .OfType()); + } + + [SecretsFact("XAI_API_KEY")] + public async Task GrokInvokesHostedCollectionSearch() + { + var messages = new Chat() + { + { "user", "¿Cuál es el monto exacto del rango de la multa por inasistencia injustificada a la audiencia señalada por el juez en el proceso sucesorio, según lo establecido en el Artículo 691 del Código Procesal Civil y Comercial de la Nación (Ley 17.454)?" }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); + + var options = new GrokChatOptions + { + Include = { IncludeOption.CollectionsSearchCallOutput }, + Tools = [new HostedFileSearchTool { + Inputs = [new HostedVectorStoreContent("collection_91559d9b-a55d-42fe-b2ad-ecf8904d9049")] + }] + }; + + var response = await grok.GetResponseAsync(messages, options); + var text = response.Text; + + Assert.Contains("11,74", text); + Assert.Contains(response.Messages + .SelectMany(x => x.Contents) + .OfType() + .Select(x => x.RawRepresentation as xAI.Protocol.ToolCall), + x => x?.Type == xAI.Protocol.ToolCallType.CollectionsSearchTool); + } + + [SecretsFact("XAI_API_KEY", "GITHUB_TOKEN")] + public async Task GrokInvokesHostedMcp() + { + var messages = new Chat() + { + { "user", "When was GrokClient v1.0.0 released on the devlooped/GrokClient repo? Respond with just the date, in YYYY-MM-DD format." }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); + + var options = new ChatOptions + { + Tools = [new HostedMcpServerTool("GitHub", "https://api.githubcopilot.com/mcp/") { + AuthorizationToken = Configuration["GITHUB_TOKEN"]!, + AllowedTools = ["list_releases"], + }] + }; + + var response = await grok.GetResponseAsync(messages, options); + var text = response.Text; + + Assert.Equal("2025-11-29", text); + var call = Assert.Single(response.Messages + .SelectMany(x => x.Contents) + .OfType()); + + Assert.Equal("GitHub.list_releases", call.ToolName); + } + + [SecretsFact("XAI_API_KEY", "GITHUB_TOKEN")] + public async Task GrokInvokesHostedMcpWithOutput() + { + var messages = new Chat() + { + { "user", "When was GrokClient v1.0.0 released on the devlooped/GrokClient repo? Respond with just the date, in YYYY-MM-DD format." }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-fast"); + + var options = new GrokChatOptions + { + Include = { IncludeOption.McpCallOutput }, + Tools = [new HostedMcpServerTool("GitHub", "https://api.githubcopilot.com/mcp/") { + AuthorizationToken = Configuration["GITHUB_TOKEN"]!, + AllowedTools = ["list_releases"], + }] + }; + + var response = await grok.GetResponseAsync(messages, options); + + // Can include result of MCP tool + var output = Assert.Single(response.Messages + .SelectMany(x => x.Contents) + .OfType()); + + Assert.NotNull(output.Output); + Assert.Single(output.Output); + var json = Assert.Single(output.Output!.OfType()).Text; + var tags = JsonSerializer.Deserialize>(json, new JsonSerializerOptions(JsonSerializerDefaults.Web) + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower + }); + + Assert.NotNull(tags); + Assert.Contains(tags, x => x.TagName == "v1.0.0"); + } + + record Release(string TagName, DateTimeOffset CreatedAt); + + [SecretsFact("XAI_API_KEY", "GITHUB_TOKEN")] + public async Task GrokStreamsUpdatesFromAllTools() + { + var messages = new Chat() + { + { "user", + """ + What's the oldest stable version released on the devlooped/GrokClient repo on GitHub?, + what is the current price of Tesla stock, + and what is the current date? Respond with the following JSON: + { + "today": "[get_date result]", + "release": "[first stable release of devlooped/GrokClient, using GitHub MCP tool]", + "price": [$TSLA price using web search tool] + } + """ + }, + }; + + var grok = new GrokClient(Configuration["XAI_API_KEY"]!) + .AsIChatClient("grok-4-fast") + .AsBuilder() + .UseFunctionInvocation() + .UseLogging(output.AsLoggerFactory()) + .Build(); + + var getDateCalls = 0; + var options = new GrokChatOptions + { + Include = { IncludeOption.McpCallOutput }, + Tools = + [ + new HostedWebSearchTool(), + new HostedMcpServerTool("GitHub", "https://api.githubcopilot.com/mcp/") { + AuthorizationToken = Configuration["GITHUB_TOKEN"]!, + AllowedTools = ["list_releases", "get_release_by_tag"], + }, + AIFunctionFactory.Create(() => { + getDateCalls++; + return DateTimeOffset.Now.ToString("O"); + }, "get_date", "Gets the current date") + ] + }; + + var updates = await grok.GetStreamingResponseAsync(messages, options).ToListAsync(); + var response = updates.ToChatResponse(); + var typed = JsonSerializer.Deserialize(response.Messages.Last().Text, new JsonSerializerOptions(JsonSerializerDefaults.Web)); + + Assert.NotNull(typed); + + Assert.NotEmpty(response.Messages + .SelectMany(x => x.Contents) + .OfType()); + + Assert.Contains(response.Messages + .SelectMany(x => x.Contents) + .OfType() + .Select(x => x.RawRepresentation as xAI.Protocol.ToolCall), + x => x?.Type == xAI.Protocol.ToolCallType.WebSearchTool); + + Assert.Equal(1, getDateCalls); + + Assert.Equal(DateOnly.FromDateTime(DateTime.Today), typed.Today); + Assert.EndsWith("1.0.0", typed.Release); + Assert.True(typed.Price > 100); + } + + [Fact] + public async Task GrokCustomFactoryInvokedFromOptions() + { + var invoked = false; + var client = new Mock(MockBehavior.Strict); + client.Setup(x => x.GetCompletionAsync(It.IsAny(), null, null, CancellationToken.None)) + .Returns(CallHelpers.CreateAsyncUnaryCall(new GetChatCompletionResponse + { + Outputs = + { + new CompletionOutput + { + Message = new CompletionMessage + { + Content = "Hey Cazzulino!" + } + } + } + })); + + var grok = new GrokChatClient(client.Object, "grok-4-1-fast"); + var response = await grok.GetResponseAsync("Hi, my internet alias is kzu. Lookup my real full name online.", + new GrokChatOptions + { + RawRepresentationFactory = (client) => + { + invoked = true; + return new GetCompletionsRequest(); + } + }); + + Assert.True(invoked); + Assert.Equal("Hey Cazzulino!", response.Text); + } + + record Response(DateOnly Today, string Release, decimal Price); +} diff --git a/src/xAI.Tests/Attributes.cs b/src/xAI.Tests/Extensions/Attributes.cs similarity index 100% rename from src/xAI.Tests/Attributes.cs rename to src/xAI.Tests/Extensions/Attributes.cs diff --git a/src/xAI.Tests/Extensions/CallHelpers.cs b/src/xAI.Tests/Extensions/CallHelpers.cs new file mode 100644 index 0000000..77b8050 --- /dev/null +++ b/src/xAI.Tests/Extensions/CallHelpers.cs @@ -0,0 +1,46 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Grpc.Core; + +namespace Tests.Client.Helpers +{ + static class CallHelpers + { + public static AsyncUnaryCall CreateAsyncUnaryCall(TResponse response) + { + return new AsyncUnaryCall( + Task.FromResult(response), + Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + } + + public static AsyncUnaryCall CreateAsyncUnaryCall(StatusCode statusCode) + { + var status = new Status(statusCode, string.Empty); + return new AsyncUnaryCall( + Task.FromException(new RpcException(status)), + Task.FromResult(new Metadata()), + () => status, + () => new Metadata(), + () => { }); + } + } +} diff --git a/src/xAI.Tests/Extensions/Configuration.cs b/src/xAI.Tests/Extensions/Configuration.cs new file mode 100644 index 0000000..8c6c76b --- /dev/null +++ b/src/xAI.Tests/Extensions/Configuration.cs @@ -0,0 +1,31 @@ +using System.Reflection; +using System.Text; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +public static class ConfigurationExtensions +{ + public static IConfiguration Configuration { get; } = new ConfigurationBuilder() + .AddEnvironmentVariables() + .AddUserSecrets(Assembly.GetExecutingAssembly()) + .Build(); + + public static TOptions GetOptions(this IConfiguration configuration, string name) + where TOptions : class, new() + => new ServiceCollection() + .Configure(configuration.GetSection(name)) + .BuildServiceProvider() + .GetRequiredService>() + .Value; + + public static TOptions GetOptions(this IConfiguration configuration) + where TOptions : class, new() + { + var name = typeof(TOptions).Name; + if (name.EndsWith("Options")) + return configuration.GetOptions(name[..^7]); + + return configuration.GetOptions(name); + } +} diff --git a/src/xAI.Tests/Extensions/Logging.cs b/src/xAI.Tests/Extensions/Logging.cs new file mode 100644 index 0000000..4bfb81c --- /dev/null +++ b/src/xAI.Tests/Extensions/Logging.cs @@ -0,0 +1,40 @@ +using Microsoft.Extensions.Logging; + +public static class LoggerFactoryExtensions +{ + public static ILoggerFactory AsLoggerFactory(this ITestOutputHelper output) => new TestLoggerFactory(output); + + public static ILoggingBuilder AddTestOutput(this ILoggingBuilder builder, ITestOutputHelper output) + => builder.AddProvider(new TestLoggerProider(output)); + + class TestLoggerProider(ITestOutputHelper output) : ILoggerProvider + { + readonly ILoggerFactory factory = new TestLoggerFactory(output); + + public ILogger CreateLogger(string categoryName) => factory.CreateLogger(categoryName); + + public void Dispose() { } + } + + class TestLoggerFactory(ITestOutputHelper output) : ILoggerFactory + { + public ILogger CreateLogger(string categoryName) => new TestOutputLogger(output, categoryName); + public void AddProvider(ILoggerProvider provider) { } + public void Dispose() { } + + // create ilogger implementation over testoutputhelper + public class TestOutputLogger(ITestOutputHelper output, string categoryName) : ILogger + { + public IDisposable? BeginScope(TState state) where TState : notnull => null!; + + public bool IsEnabled(LogLevel logLevel) => true; + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (formatter == null) throw new ArgumentNullException(nameof(formatter)); + if (state == null) throw new ArgumentNullException(nameof(state)); + output.WriteLine($"{logLevel}: {categoryName}: {formatter(state, exception)}"); + } + } + } +} \ No newline at end of file diff --git a/src/xAI.Tests/xAI.Tests.csproj b/src/xAI.Tests/xAI.Tests.csproj index 3844031..172a822 100644 --- a/src/xAI.Tests/xAI.Tests.csproj +++ b/src/xAI.Tests/xAI.Tests.csproj @@ -5,25 +5,49 @@ enable enable false + MEAI001;xAI001;$(NoWarn) - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/xAI/Extensions/ChatExtensions.cs b/src/xAI/Extensions/ChatExtensions.cs new file mode 100644 index 0000000..1b2e874 --- /dev/null +++ b/src/xAI/Extensions/ChatExtensions.cs @@ -0,0 +1,17 @@ +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Extensions for . +public static partial class ChatOptionsExtensions +{ + extension(ChatOptions options) + { + /// Gets or sets the end user ID for the chat session. + public string? EndUserId + { + get => (options.AdditionalProperties ??= []).TryGetValue("EndUserId", out var value) ? value as string : null; + set => (options.AdditionalProperties ??= [])["EndUserId"] = value; + } + } +} \ No newline at end of file diff --git a/src/xAI/Extensions/Throw.cs b/src/xAI/Extensions/Throw.cs new file mode 100644 index 0000000..eea3e12 --- /dev/null +++ b/src/xAI/Extensions/Throw.cs @@ -0,0 +1,992 @@ +// +#region License +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// Adapted from https://github.com/dotnet/extensions/blob/main/src/Shared/Throw/Throw.cs +#endregion + +#nullable enable +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +#pragma warning disable CA1716 +namespace System; +#pragma warning restore CA1716 + +/// +/// Defines static methods used to throw exceptions. +/// +/// +/// The main purpose is to reduce code size, improve performance, and standardize exception +/// messages. +/// +[SuppressMessage("Minor Code Smell", "S4136:Method overloads should be grouped together", Justification = "Doesn't work with the region layout")] +[SuppressMessage("Minor Code Smell", "S2333:Partial is gratuitous in this context", Justification = "Some projects add additional partial parts.")] +[SuppressMessage("Design", "CA1716", Justification = "Not part of an API")] + +#if !SHARED_PROJECT +[ExcludeFromCodeCoverage] +#endif + +static partial class Throw +{ + #region For Object + + /// + /// Throws an if the specified argument is . + /// + /// Argument type to be checked for . + /// Object to be checked for . + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static T IfNull([NotNull] T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument is null) + { + ArgumentNullException(paramName); + } + + return argument; + } + + /// + /// Throws an if the specified argument is , + /// or if the specified member is . + /// + /// Argument type to be checked for . + /// Member type to be checked for . + /// Argument to be checked for . + /// Object member to be checked for . + /// The name of the parameter being checked. + /// The name of the member. + /// The original value of . + /// + /// + /// Throws.IfNullOrMemberNull(myObject, myObject?.MyProperty) + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static TMember IfNullOrMemberNull( + [NotNull] TParameter argument, + [NotNull] TMember member, + [CallerArgumentExpression(nameof(argument))] string paramName = "", + [CallerArgumentExpression(nameof(member))] string memberName = "") + { + if (argument is null) + { + ArgumentNullException(paramName); + } + + if (member is null) + { + ArgumentException(paramName, $"Member {memberName} of {paramName} is null"); + } + + return member; + } + + /// + /// Throws an if the specified member is . + /// + /// Argument type. + /// Member type to be checked for . + /// Argument to which member belongs. + /// Object member to be checked for . + /// The name of the parameter being checked. + /// The name of the member. + /// The original value of . + /// + /// + /// Throws.IfMemberNull(myObject, myObject.MyProperty) + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + [SuppressMessage("Style", "IDE0060:Remove unused parameter", Justification = "Analyzer isn't seeing the reference to 'argument' in the attribute")] + public static TMember IfMemberNull( + TParameter argument, + [NotNull] TMember member, + [CallerArgumentExpression(nameof(argument))] string paramName = "", + [CallerArgumentExpression(nameof(member))] string memberName = "") + where TParameter : notnull + { + if (member is null) + { + ArgumentException(paramName, $"Member {memberName} of {paramName} is null"); + } + + return member; + } + + #endregion + + #region For String + + /// + /// Throws either an or an + /// if the specified string is or whitespace respectively. + /// + /// String to be checked for or whitespace. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static string IfNullOrWhitespace([NotNull] string? argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { +#if !NETCOREAPP3_1_OR_GREATER + if (argument == null) + { + ArgumentNullException(paramName); + } +#endif + + if (string.IsNullOrWhiteSpace(argument)) + { + if (argument == null) + { + ArgumentNullException(paramName); + } + else + { + ArgumentException(paramName, "Argument is whitespace"); + } + } + + return argument; + } + + /// + /// Throws an if the string is , + /// or if it is empty. + /// + /// String to be checked for or empty. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static string IfNullOrEmpty([NotNull] string? argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { +#if !NETCOREAPP3_1_OR_GREATER + if (argument == null) + { + ArgumentNullException(paramName); + } +#endif + + if (string.IsNullOrEmpty(argument)) + { + if (argument == null) + { + ArgumentNullException(paramName); + } + else + { + ArgumentException(paramName, "Argument is an empty string"); + } + } + + return argument; + } + + #endregion + + #region For Buffer + + /// + /// Throws an if the argument's buffer size is less than the required buffer size. + /// + /// The actual buffer size. + /// The required buffer size. + /// The name of the parameter to be checked. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void IfBufferTooSmall(int bufferSize, int requiredSize, string paramName = "") + { + if (bufferSize < requiredSize) + { + ArgumentException(paramName, $"Buffer too small, needed a size of {requiredSize} but got {bufferSize}"); + } + } + + #endregion + + #region For Enums + + /// + /// Throws an if the enum value is not valid. + /// + /// The argument to evaluate. + /// The name of the parameter being checked. + /// The type of the enumeration. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static T IfOutOfRange(T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + where T : struct, Enum + { +#if NET5_0_OR_GREATER + if (!Enum.IsDefined(argument)) +#else + if (!Enum.IsDefined(typeof(T), argument)) +#endif + { + ArgumentOutOfRangeException(paramName, $"{argument} is an invalid value for enum type {typeof(T)}"); + } + + return argument; + } + + #endregion + + #region For Collections + + /// + /// Throws an if the collection is , + /// or if it is empty. + /// + /// The collection to evaluate. + /// The name of the parameter being checked. + /// The type of objects in the collection. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + + // The method has actually 100% coverage, but due to a bug in the code coverage tool, + // a lower number is reported. Therefore, we temporarily exclude this method + // from the coverage measurements. Once the bug in the code coverage tool is fixed, + // the exclusion attribute can be removed. + [ExcludeFromCodeCoverage] + public static IEnumerable IfNullOrEmpty([NotNull] IEnumerable? argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == null) + { + ArgumentNullException(paramName); + } + else + { + switch (argument) + { + case ICollection collection: + if (collection.Count == 0) + { + ArgumentException(paramName, "Collection is empty"); + } + + break; + case IReadOnlyCollection readOnlyCollection: + if (readOnlyCollection.Count == 0) + { + ArgumentException(paramName, "Collection is empty"); + } + + break; + default: + using (IEnumerator enumerator = argument.GetEnumerator()) + { + if (!enumerator.MoveNext()) + { + ArgumentException(paramName, "Collection is empty"); + } + } + + break; + } + } + + return argument; + } + + #endregion + + #region Exceptions + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentNullException(string paramName) + => throw new ArgumentNullException(paramName); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentNullException(string paramName, string? message) + => throw new ArgumentNullException(paramName, message); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentOutOfRangeException(string paramName) + => throw new ArgumentOutOfRangeException(paramName); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentOutOfRangeException(string paramName, string? message) + => throw new ArgumentOutOfRangeException(paramName, message); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// The value of the argument that caused this exception. + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentOutOfRangeException(string paramName, object? actualValue, string? message) + => throw new ArgumentOutOfRangeException(paramName, actualValue, message); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentException(string paramName, string? message) + => throw new ArgumentException(message, paramName); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// A message that describes the error. + /// The exception that is the cause of the current exception. + /// + /// If the is not a , the current exception is raised in a catch + /// block that handles the inner exception. + /// +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentException(string paramName, string? message, Exception? innerException) + => throw new ArgumentException(message, paramName, innerException); + + /// + /// Throws an . + /// + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void InvalidOperationException(string message) + => throw new InvalidOperationException(message); + + /// + /// Throws an . + /// + /// A message that describes the error. + /// The exception that is the cause of the current exception. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void InvalidOperationException(string message, Exception? innerException) + => throw new InvalidOperationException(message, innerException); + + #endregion + + #region For Integer + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfLessThan(int argument, int min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfGreaterThan(int argument, int max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfLessThanOrEqual(int argument, int min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument <= min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfGreaterThanOrEqual(int argument, int max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument >= max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfOutOfRange(int argument, int min, int max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min || argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfZero(int argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == 0) + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion + + #region For Unsigned Integer + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfLessThan(uint argument, uint min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfGreaterThan(uint argument, uint max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfLessThanOrEqual(uint argument, uint min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument <= min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfGreaterThanOrEqual(uint argument, uint max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument >= max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfOutOfRange(uint argument, uint min, uint max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min || argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfZero(uint argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == 0U) + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion + + #region For Long + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfLessThan(long argument, long min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfGreaterThan(long argument, long max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfLessThanOrEqual(long argument, long min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument <= min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfGreaterThanOrEqual(long argument, long max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument >= max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfOutOfRange(long argument, long min, long max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min || argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfZero(long argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == 0L) + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion + + #region For Unsigned Long + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfLessThan(ulong argument, ulong min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfGreaterThan(ulong argument, ulong max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfLessThanOrEqual(ulong argument, ulong min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument <= min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfGreaterThanOrEqual(ulong argument, ulong max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument >= max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfOutOfRange(ulong argument, ulong min, ulong max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min || argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfZero(ulong argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == 0UL) + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion + + #region For Double + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfLessThan(double argument, double min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly +#pragma warning disable S1940 // Boolean checks should not be inverted + if (!(argument >= min)) +#pragma warning restore S1940 // Boolean checks should not be inverted + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfGreaterThan(double argument, double max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly +#pragma warning disable S1940 // Boolean checks should not be inverted + if (!(argument <= max)) +#pragma warning restore S1940 // Boolean checks should not be inverted + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfLessThanOrEqual(double argument, double min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly +#pragma warning disable S1940 // Boolean checks should not be inverted + if (!(argument > min)) +#pragma warning restore S1940 // Boolean checks should not be inverted + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfGreaterThanOrEqual(double argument, double max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly +#pragma warning disable S1940 // Boolean checks should not be inverted + if (!(argument < max)) +#pragma warning restore S1940 // Boolean checks should not be inverted + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfOutOfRange(double argument, double min, double max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly + if (!(min <= argument && argument <= max)) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfZero(double argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { +#pragma warning disable S1244 // Floating point numbers should not be tested for equality + if (argument == 0.0) +#pragma warning restore S1244 // Floating point numbers should not be tested for equality + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion +} diff --git a/src/xAI/GrokChatClient.cs b/src/xAI/GrokChatClient.cs new file mode 100644 index 0000000..20315e2 --- /dev/null +++ b/src/xAI/GrokChatClient.cs @@ -0,0 +1,465 @@ +using System.Text.Json; +using Grpc.Core; +using Grpc.Net.Client; +using Microsoft.Extensions.AI; +using xAI.Protocol; +using static xAI.Protocol.Chat; + +namespace xAI; + +class GrokChatClient : IChatClient +{ + readonly ChatClientMetadata metadata; + readonly ChatClient client; + readonly string defaultModelId; + readonly GrokClientOptions clientOptions; + + internal GrokChatClient(GrpcChannel channel, GrokClientOptions clientOptions, string defaultModelId) + : this(new ChatClient(channel), clientOptions, defaultModelId) + { } + + /// + /// Test constructor. + /// + internal GrokChatClient(ChatClient client, string defaultModelId) + : this(client, new(), defaultModelId) + { } + + GrokChatClient(ChatClient client, GrokClientOptions clientOptions, string defaultModelId) + { + this.client = client; + this.clientOptions = clientOptions; + this.defaultModelId = defaultModelId; + metadata = new ChatClientMetadata("xai", clientOptions.Endpoint, defaultModelId); + } + + public async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + var request = MapToRequest(messages, options); + var response = await client.GetCompletionAsync(request, cancellationToken: cancellationToken); + var lastOutput = response.Outputs.OrderByDescending(x => x.Index).FirstOrDefault(); + + if (lastOutput == null) + { + return new ChatResponse() + { + ResponseId = response.Id, + ModelId = response.Model, + CreatedAt = response.Created.ToDateTimeOffset(), + Usage = MapToUsage(response.Usage), + }; + } + + var message = new ChatMessage(MapRole(lastOutput.Message.Role), default(string)); + var citations = response.Citations?.Distinct().Select(MapCitation).ToList(); + + foreach (var output in response.Outputs.OrderBy(x => x.Index)) + { + if (output.Message.Content is { Length: > 0 } text) + { + // Special-case output from tools + if (output.Message.Role == MessageRole.RoleTool && + output.Message.ToolCalls.Count == 1 && + output.Message.ToolCalls[0] is { } toolCall) + { + if (toolCall.Type == ToolCallType.McpTool) + { + message.Contents.Add(new McpServerToolCallContent(toolCall.Id, toolCall.Function.Name, null) + { + RawRepresentation = toolCall + }); + message.Contents.Add(new McpServerToolResultContent(toolCall.Id) + { + RawRepresentation = toolCall, + Output = [new TextContent(text)] + }); + continue; + } + else if (toolCall.Type == ToolCallType.CodeExecutionTool) + { + message.Contents.Add(new CodeInterpreterToolCallContent() + { + CallId = toolCall.Id, + RawRepresentation = toolCall + }); + message.Contents.Add(new CodeInterpreterToolResultContent() + { + CallId = toolCall.Id, + RawRepresentation = toolCall, + Outputs = [new TextContent(text)] + }); + continue; + } + } + + var content = new TextContent(text) { Annotations = citations }; + + foreach (var citation in output.Message.Citations) + (content.Annotations ??= []).Add(MapInlineCitation(citation)); + + message.Contents.Add(content); + } + + foreach (var toolCall in output.Message.ToolCalls) + message.Contents.Add(MapToolCall(toolCall)); + } + + return new ChatResponse(message) + { + ResponseId = response.Id, + ModelId = response.Model, + CreatedAt = response.Created?.ToDateTimeOffset(), + FinishReason = lastOutput != null ? MapFinishReason(lastOutput.FinishReason) : null, + Usage = MapToUsage(response.Usage), + }; + } + + AIContent MapToolCall(ToolCall toolCall) => toolCall.Type switch + { + ToolCallType.ClientSideTool => new FunctionCallContent( + toolCall.Id, + toolCall.Function.Name, + !string.IsNullOrEmpty(toolCall.Function.Arguments) + ? JsonSerializer.Deserialize>(toolCall.Function.Arguments) + : null) + { + RawRepresentation = toolCall + }, + ToolCallType.McpTool => new McpServerToolCallContent(toolCall.Id, toolCall.Function.Name, null) + { + RawRepresentation = toolCall + }, + ToolCallType.CodeExecutionTool => new CodeInterpreterToolCallContent() + { + CallId = toolCall.Id, + RawRepresentation = toolCall + }, + _ => new HostedToolCallContent() + { + CallId = toolCall.Id, + RawRepresentation = toolCall + } + }; + + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + return CompleteChatStreamingCore(messages, options, cancellationToken); + + async IAsyncEnumerable CompleteChatStreamingCore(IEnumerable messages, ChatOptions? options, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken) + { + var request = MapToRequest(messages, options); + var call = client.GetCompletionChunk(request, cancellationToken: cancellationToken); + + await foreach (var chunk in call.ResponseStream.ReadAllAsync(cancellationToken)) + { + var output = chunk.Outputs[0]; + var text = output.Delta.Content is { Length: > 0 } delta ? delta : null; + + // Use positional arguments for ChatResponseUpdate + var update = new ChatResponseUpdate(MapRole(output.Delta.Role), text) + { + ResponseId = chunk.Id, + ModelId = chunk.Model, + CreatedAt = chunk.Created?.ToDateTimeOffset(), + FinishReason = output.FinishReason != FinishReason.ReasonInvalid ? MapFinishReason(output.FinishReason) : null, + }; + + if (chunk.Citations is { Count: > 0 } citations) + { + var textContent = update.Contents.OfType().FirstOrDefault(); + if (textContent == null) + { + textContent = new TextContent(string.Empty); + update.Contents.Add(textContent); + } + + foreach (var citation in citations.Distinct()) + (textContent.Annotations ??= []).Add(MapCitation(citation)); + } + + foreach (var toolCall in output.Delta.ToolCalls) + update.Contents.Add(MapToolCall(toolCall)); + + if (update.Contents.Any()) + yield return update; + } + } + } + + static CitationAnnotation MapInlineCitation(InlineCitation citation) => citation.CitationCase switch + { + InlineCitation.CitationOneofCase.WebCitation => new CitationAnnotation { Url = new(citation.WebCitation.Url) }, + InlineCitation.CitationOneofCase.XCitation => new CitationAnnotation { Url = new(citation.XCitation.Url) }, + InlineCitation.CitationOneofCase.CollectionsCitation => new CitationAnnotation + { + FileId = citation.CollectionsCitation.FileId, + Snippet = citation.CollectionsCitation.ChunkContent, + ToolName = "file_search", + }, + _ => new CitationAnnotation() + }; + + static CitationAnnotation MapCitation(string citation) + { + var url = new Uri(citation); + if (url.Scheme != "collections") + return new CitationAnnotation { Url = url }; + + // Special-case collection citations so we get better metadata + var collection = url.Host; + var file = url.AbsolutePath[7..]; + return new CitationAnnotation + { + ToolName = "collections_search", + FileId = file, + AdditionalProperties = new AdditionalPropertiesDictionary + { + { "collection_id", collection } + } + }; + } + + GetCompletionsRequest MapToRequest(IEnumerable messages, ChatOptions? options) + { + var request = options?.RawRepresentationFactory?.Invoke(this) as GetCompletionsRequest ?? new GetCompletionsRequest() + { + // By default always include citations in the final output if available + Include = { IncludeOption.InlineCitations }, + Model = options?.ModelId ?? defaultModelId, + }; + + if (string.IsNullOrEmpty(request.Model)) + request.Model = options?.ModelId ?? defaultModelId; + + if ((options?.EndUserId ?? clientOptions.EndUserId) is { } user) request.User = user; + if (options?.MaxOutputTokens is { } maxTokens) request.MaxTokens = maxTokens; + if (options?.Temperature is { } temperature) request.Temperature = temperature; + if (options?.TopP is { } topP) request.TopP = topP; + if (options?.FrequencyPenalty is { } frequencyPenalty) request.FrequencyPenalty = frequencyPenalty; + if (options?.PresencePenalty is { } presencePenalty) request.PresencePenalty = presencePenalty; + + foreach (var message in messages) + { + var gmsg = new Message { Role = MapRole(message.Role) }; + + foreach (var content in message.Contents) + { + if (content is TextContent textContent && !string.IsNullOrEmpty(textContent.Text)) + { + gmsg.Content.Add(new Content { Text = textContent.Text }); + } + else if (content.RawRepresentation is ToolCall toolCall) + { + gmsg.ToolCalls.Add(toolCall); + } + else if (content is FunctionCallContent functionCall) + { + gmsg.ToolCalls.Add(new ToolCall + { + Id = functionCall.CallId, + Type = ToolCallType.ClientSideTool, + Function = new FunctionCall + { + Name = functionCall.Name, + Arguments = JsonSerializer.Serialize(functionCall.Arguments) + } + }); + } + else if (content is FunctionResultContent resultContent) + { + request.Messages.Add(new Message + { + Role = MessageRole.RoleTool, + Content = { new Content { Text = JsonSerializer.Serialize(resultContent.Result) ?? "null" } } + }); + } + else if (content is McpServerToolResultContent mcpResult && + mcpResult.RawRepresentation is ToolCall mcpToolCall && + mcpResult.Output is { Count: 1 } && + mcpResult.Output[0] is TextContent mcpText) + { + request.Messages.Add(new Message + { + Role = MessageRole.RoleTool, + ToolCalls = { mcpToolCall }, + Content = { new Content { Text = mcpText.Text } } + }); + } + else if (content is CodeInterpreterToolResultContent codeResult && + codeResult.RawRepresentation is ToolCall codeToolCall && + codeResult.Outputs is { Count: 1 } && + codeResult.Outputs[0] is TextContent codeText) + { + request.Messages.Add(new Message + { + Role = MessageRole.RoleTool, + ToolCalls = { codeToolCall }, + Content = { new Content { Text = codeText.Text } } + }); + } + } + + if (gmsg.Content.Count == 0 && gmsg.ToolCalls.Count == 0) + continue; + + // If we have only tool calls and no content, the gRPC enpoint fails, so add an empty one. + if (gmsg.Content.Count == 0) + gmsg.Content.Add(new Content()); + + request.Messages.Add(gmsg); + } + + IList includes = [IncludeOption.InlineCitations]; + if (options is GrokChatOptions grokOptions) + { + // NOTE: overrides our default include for inline citations, potentially. + request.Include.Clear(); + request.Include.AddRange(grokOptions.Include); + + if (grokOptions.Search.HasFlag(GrokSearch.X)) + { + (options.Tools ??= []).Insert(0, new GrokXSearchTool()); + } + else if (grokOptions.Search.HasFlag(GrokSearch.Web)) + { + (options.Tools ??= []).Insert(0, new GrokSearchTool()); + } + } + + if (options?.Tools is not null) + { + foreach (var tool in options.Tools) + { + if (tool is AIFunction functionTool) + { + var function = new Function + { + Name = functionTool.Name, + Description = functionTool.Description, + Parameters = JsonSerializer.Serialize(functionTool.JsonSchema) + }; + request.Tools.Add(new Tool { Function = function }); + } + else if (tool is HostedWebSearchTool webSearchTool) + { + if (webSearchTool is GrokXSearchTool xSearch) + { + var toolProto = new XSearch + { + EnableImageUnderstanding = xSearch.EnableImageUnderstanding, + EnableVideoUnderstanding = xSearch.EnableVideoUnderstanding, + }; + + if (xSearch.AllowedHandles is { } allowed) toolProto.AllowedXHandles.AddRange(allowed); + if (xSearch.ExcludedHandles is { } excluded) toolProto.ExcludedXHandles.AddRange(excluded); + if (xSearch.FromDate is { } from) toolProto.FromDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(from.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); + if (xSearch.ToDate is { } to) toolProto.ToDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(to.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); + + request.Tools.Add(new Tool { XSearch = toolProto }); + } + else if (webSearchTool is GrokSearchTool grokSearch) + { + var toolProto = new WebSearch + { + EnableImageUnderstanding = grokSearch.EnableImageUnderstanding, + }; + + if (grokSearch.AllowedDomains is { } allowed) toolProto.AllowedDomains.AddRange(allowed); + if (grokSearch.ExcludedDomains is { } excluded) toolProto.ExcludedDomains.AddRange(excluded); + + request.Tools.Add(new Tool { WebSearch = toolProto }); + } + else + { + request.Tools.Add(new Tool { WebSearch = new WebSearch() }); + } + } + else if (tool is HostedCodeInterpreterTool) + { + request.Tools.Add(new Tool { CodeExecution = new CodeExecution { } }); + } + else if (tool is HostedFileSearchTool fileSearch) + { + var toolProto = new CollectionsSearch(); + + if (fileSearch.Inputs?.OfType() is { } vectorStores) + toolProto.CollectionIds.AddRange(vectorStores.Select(x => x.VectorStoreId).Distinct()); + + if (fileSearch.MaximumResultCount is { } maxResults) + toolProto.Limit = maxResults; + + request.Tools.Add(new Tool { CollectionsSearch = toolProto }); + } + else if (tool is HostedMcpServerTool mcpTool) + { + request.Tools.Add(new Tool + { + Mcp = new MCP + { + Authorization = mcpTool.AuthorizationToken, + ServerLabel = mcpTool.ServerName, + ServerUrl = mcpTool.ServerAddress, + AllowedToolNames = { mcpTool.AllowedTools ?? Array.Empty() } + } + }); + } + } + } + + if (options?.ResponseFormat is ChatResponseFormatJson) + { + request.ResponseFormat = new ResponseFormat + { + FormatType = FormatType.JsonObject + }; + } + + return request; + } + + static MessageRole MapRole(ChatRole role) => role switch + { + _ when role == ChatRole.System => MessageRole.RoleSystem, + _ when role == ChatRole.User => MessageRole.RoleUser, + _ when role == ChatRole.Assistant => MessageRole.RoleAssistant, + _ when role == ChatRole.Tool => MessageRole.RoleTool, + _ => MessageRole.RoleUser + }; + + static ChatRole MapRole(MessageRole role) => role switch + { + MessageRole.RoleSystem => ChatRole.System, + MessageRole.RoleUser => ChatRole.User, + MessageRole.RoleAssistant => ChatRole.Assistant, + MessageRole.RoleTool => ChatRole.Tool, + _ => ChatRole.Assistant + }; + + static ChatFinishReason? MapFinishReason(FinishReason finishReason) => finishReason switch + { + FinishReason.ReasonStop => ChatFinishReason.Stop, + FinishReason.ReasonMaxLen => ChatFinishReason.Length, + FinishReason.ReasonToolCalls => ChatFinishReason.ToolCalls, + FinishReason.ReasonMaxContext => ChatFinishReason.Length, + FinishReason.ReasonTimeLimit => ChatFinishReason.Length, + _ => null + }; + + static UsageDetails? MapToUsage(SamplingUsage usage) => usage == null ? null : new() + { + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + TotalTokenCount = usage.TotalTokens + }; + + /// + public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch + { + Type t when t == typeof(ChatClientMetadata) => metadata, + Type t when t == typeof(GrokChatClient) => this, + _ => null + }; + + /// + public void Dispose() { } +} diff --git a/src/xAI/GrokChatOptions.cs b/src/xAI/GrokChatOptions.cs new file mode 100644 index 0000000..29b8737 --- /dev/null +++ b/src/xAI/GrokChatOptions.cs @@ -0,0 +1,32 @@ +using System.ComponentModel; +using Microsoft.Extensions.AI; +using xAI.Protocol; + +namespace xAI; + +/// Customizes Grok's agentic search tools. +/// See https://docs.x.ai/docs/guides/tools/search-tools. +[Flags] +public enum GrokSearch +{ + /// Disables agentic search capabilities. + None = 0, + /// Enables all available agentic search capabilities. + All = Web | X, + /// Allows the agent to search the web and browse pages. + Web = 1, + /// Allows the agent to perform keyword search, semantic search, user search, and thread fetch on X. + X = 2, +} + +/// Grok-specific chat options that extend the base . +public class GrokChatOptions : ChatOptions +{ + /// Configures Grok's agentic search capabilities. + /// See https://docs.x.ai/docs/guides/tools/search-tools. + public GrokSearch Search { get; set; } = GrokSearch.None; + + /// Additional outputs to include in responses. + /// Defaults to including . + public IList Include { get; set; } = [IncludeOption.InlineCitations]; +} diff --git a/src/xAI/GrokClient.cs b/src/xAI/GrokClient.cs new file mode 100644 index 0000000..3a9a4c5 --- /dev/null +++ b/src/xAI/GrokClient.cs @@ -0,0 +1,47 @@ +using System.Collections.Concurrent; +using System.Net.Http.Headers; +using Grpc.Net.Client; + +namespace xAI; + +/// Client for interacting with the Grok service. +/// The API key used for authentication. +/// The options used to configure the client. +public class GrokClient(string apiKey, GrokClientOptions options) +{ + static ConcurrentDictionary<(Uri, string), GrpcChannel> channels = []; + + /// Initializes a new instance of the class with default options. + public GrokClient(string apiKey) : this(apiKey, new GrokClientOptions()) { } + + /// Gets the API key used for authentication. + public string ApiKey { get; } = apiKey; + + /// Gets or sets the endpoint for the service. + public Uri Endpoint { get; set; } = options.Endpoint; + + /// Gets the options used to configure the client. + public GrokClientOptions Options { get; } = options; + + internal GrpcChannel Channel => channels.GetOrAdd((Endpoint, ApiKey), key => + { + var handler = new AuthenticationHeaderHandler(ApiKey) + { + InnerHandler = Options.ChannelOptions?.HttpHandler ?? new HttpClientHandler() + }; + + var options = Options.ChannelOptions ?? new GrpcChannelOptions(); + options.HttpHandler = handler; + + return GrpcChannel.ForAddress(Endpoint, options); + }); + + class AuthenticationHeaderHandler(string apiKey) : DelegatingHandler + { + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); + return base.SendAsync(request, cancellationToken); + } + } +} diff --git a/src/xAI/GrokClientExtensions.cs b/src/xAI/GrokClientExtensions.cs new file mode 100644 index 0000000..34e3867 --- /dev/null +++ b/src/xAI/GrokClientExtensions.cs @@ -0,0 +1,13 @@ +using System.ComponentModel; +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Provides extension methods for . +[EditorBrowsable(EditorBrowsableState.Never)] +public static class GrokClientExtensions +{ + /// Creates a new from the specified using the given model as the default. + public static IChatClient AsIChatClient(this GrokClient client, string defaultModelId) + => new GrokChatClient(client.Channel, client.Options, defaultModelId); +} diff --git a/src/xAI/GrokClientOptions.cs b/src/xAI/GrokClientOptions.cs new file mode 100644 index 0000000..e6e149e --- /dev/null +++ b/src/xAI/GrokClientOptions.cs @@ -0,0 +1,16 @@ +using Grpc.Net.Client; + +namespace xAI; + +/// Options for configuring the . +public class GrokClientOptions +{ + /// Gets or sets the service endpoint. + public Uri Endpoint { get; set; } = new("https://api.x.ai"); + + /// Gets or sets the gRPC channel options. + public GrpcChannelOptions? ChannelOptions { get; set; } + + /// Gets or sets the end user ID for the chat session. + public string? EndUserId { get; set; } +} diff --git a/src/xAI/GrokSearchTool.cs b/src/xAI/GrokSearchTool.cs new file mode 100644 index 0000000..7bccab7 --- /dev/null +++ b/src/xAI/GrokSearchTool.cs @@ -0,0 +1,23 @@ +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Configures Grok's agentic search tool. +/// See https://docs.x.ai/docs/guides/tools/search-tools +public class GrokSearchTool : HostedWebSearchTool +{ + /// + public override string Name => "web_search"; + + /// + public override string Description => "Performs agentic web search"; + + /// Use to make the web search only perform the search and web browsing on web pages that fall within the specified domains. Can include a maximum of five domains. + public IList? AllowedDomains { get; set; } + + /// Use to prevent the model from including the specified domains in any web search tool invocations and from browsing any pages on those domains. Can include a maximum of five domains. + public IList? ExcludedDomains { get; set; } + + /// See https://docs.x.ai/docs/guides/tools/search-tools#enable-image-understanding + public bool EnableImageUnderstanding { get; set; } +} \ No newline at end of file diff --git a/src/xAI/GrokXSearch.cs b/src/xAI/GrokXSearch.cs new file mode 100644 index 0000000..ba0e60a --- /dev/null +++ b/src/xAI/GrokXSearch.cs @@ -0,0 +1,24 @@ +using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Configures Grok's agentic search tool for X. +/// See https://docs.x.ai/docs/guides/tools/search-tools#x-search-parameters +public class GrokXSearchTool : HostedWebSearchTool +{ + /// See https://docs.x.ai/docs/guides/tools/search-tools#only-consider-x-posts-from-specific-handles + [JsonPropertyName("allowed_x_handles")] + public IList? AllowedHandles { get; set; } + /// See https://docs.x.ai/docs/guides/tools/search-tools#exclude-x-posts-from-specific-handles + [JsonPropertyName("excluded_x_handles")] + public IList? ExcludedHandles { get; set; } + /// See https://docs.x.ai/docs/guides/tools/search-tools#date-range + public DateOnly? FromDate { get; set; } + /// See https://docs.x.ai/docs/guides/tools/search-tools#date-range + public DateOnly? ToDate { get; set; } + /// See https://docs.x.ai/docs/guides/tools/search-tools#enable-image-understanding-1 + public bool EnableImageUnderstanding { get; set; } + /// See https://docs.x.ai/docs/guides/tools/search-tools#enable-video-understanding + public bool EnableVideoUnderstanding { get; set; } +} \ No newline at end of file diff --git a/src/xAI/HostedToolCallContent.cs b/src/xAI/HostedToolCallContent.cs new file mode 100644 index 0000000..450fd23 --- /dev/null +++ b/src/xAI/HostedToolCallContent.cs @@ -0,0 +1,12 @@ +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Represents a hosted tool agentic call. +[Experimental("xAI001")] +public class HostedToolCallContent : AIContent +{ + /// Gets or sets the tool call ID. + public virtual string? CallId { get; set; } +} diff --git a/src/xAI/HostedToolResultContent.cs b/src/xAI/HostedToolResultContent.cs new file mode 100644 index 0000000..4c8694a --- /dev/null +++ b/src/xAI/HostedToolResultContent.cs @@ -0,0 +1,17 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Represents a hosted tool agentic call. +[DebuggerDisplay("{DebuggerDisplay,nq}")] +[Experimental("xAI001")] +public class HostedToolResultContent : AIContent +{ + /// Gets or sets the tool call ID. + public virtual string? CallId { get; set; } + + /// Gets or sets the resulting contents from the tool. + public virtual IList? Outputs { get; set; } +} \ No newline at end of file diff --git a/src/xAI/readme.md b/src/xAI/readme.md new file mode 100644 index 0000000..14c1dec --- /dev/null +++ b/src/xAI/readme.md @@ -0,0 +1,8 @@ +[![EULA](https://img.shields.io/badge/EULA-OSMF-blue?labelColor=black&color=C9FF30)](osmfeula.txt) +[![OSS](https://img.shields.io/github/license/devlooped/oss.svg?color=blue)](license.txt) +[![GitHub](https://img.shields.io/badge/-source-181717.svg?logo=GitHub)](https://github.com/devlooped/AI) + + + + + \ No newline at end of file diff --git a/src/xAI/xAI.csproj b/src/xAI/xAI.csproj new file mode 100644 index 0000000..54fd735 --- /dev/null +++ b/src/xAI/xAI.csproj @@ -0,0 +1,32 @@ + + + + net8.0;net10.0 + xAI + xAI + xAI/Grok integration for Microsoft.Extensions.AI + + OSMFEULA.txt + true + true + MEAI001;xAI001;$(NoWarn) + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/xAI.slnx b/xAI.slnx index 2cb093e..5a3b5a0 100644 --- a/xAI.slnx +++ b/xAI.slnx @@ -1,4 +1,5 @@ +