Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions src/xAI.Tests/GrokConversionTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
using System;
using System.Collections.Generic;
using System.Text;
using Google.Protobuf.WellKnownTypes;
using Microsoft.Extensions.AI;
using OpenAI.Responses;
using xAI.Protocol;

namespace xAI;

public class GrokConversionTests
{
[Fact]
public void AsTool_WithWebSearch()
{
var webSearch = new HostedWebSearchTool();

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.WebSearch);
}

[Fact]
public void AsTool_WithWebSearch_ThrowsIfAllowedAndExcluded()
{
var webSearch = new GrokSearchTool
{
AllowedDomains = ["Foo"],
ExcludedDomains = ["Bar"]
};

Assert.Throws<NotSupportedException>(() => webSearch.AsProtocolTool());
}

[Fact]
public void AsTool_WithWebSearch_AllowedDomains()
{
var webSearch = new GrokSearchTool
{
AllowedDomains = ["foo.com", "bar.com"],
};

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.WebSearch);
Assert.Equal(["foo.com", "bar.com"], tool.WebSearch.AllowedDomains);
}

[Fact]
public void AsTool_WithWebSearch_ExcludedDomains()
{
var webSearch = new GrokSearchTool
{
ExcludedDomains = ["foo.com", "bar.com"],
};

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.WebSearch);
Assert.Equal(["foo.com", "bar.com"], tool.WebSearch.ExcludedDomains);
}

[Fact]
public void AsTool_WithWebSearch_ImageUnderstanding()
{
var webSearch = new GrokSearchTool
{
EnableImageUnderstanding = true
};

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.WebSearch);
Assert.True(tool.WebSearch.EnableImageUnderstanding);
}

[Fact]
public void AsTool_WithXSearch_ThrowsIfAllowedAndExcluded()
{
var webSearch = new GrokXSearchTool
{
AllowedHandles = ["Foo"],
ExcludedHandles = ["Bar"]
};

Assert.Throws<NotSupportedException>(() => webSearch.AsProtocolTool());
}

[Fact]
public void AsTool_WithXSearch_AllowedHandles()
{
var webSearch = new GrokXSearchTool
{
AllowedHandles = ["foo", "bar"],
};

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.XSearch);
Assert.Equal(["foo", "bar"], tool.XSearch.AllowedXHandles);
}

[Fact]
public void AsTool_WithXSearch_ExcludedDomains()
{
var webSearch = new GrokXSearchTool
{
ExcludedHandles = ["foo", "bar"],
};

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.XSearch);
Assert.Equal(["foo", "bar"], tool.XSearch.ExcludedXHandles);
}

[Fact]
public void AsTool_WithXSearch_ImageUnderstanding()
{
var webSearch = new GrokXSearchTool
{
EnableImageUnderstanding = true
};

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.XSearch);
Assert.True(tool.XSearch.EnableImageUnderstanding);
}

[Fact]
public void AsTool_WithXSearch_VideoUnderstanding()
{
var webSearch = new GrokXSearchTool
{
EnableVideoUnderstanding = true
};

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.XSearch);
Assert.True(tool.XSearch.EnableVideoUnderstanding);
}

[Fact]
public void AsTool_WithXSearch_FromTo()
{
var webSearch = new GrokXSearchTool
{
FromDate = DateOnly.FromDateTime(DateTime.UtcNow.Subtract(TimeSpan.FromDays(1))),
ToDate = DateOnly.FromDateTime(DateTime.UtcNow)
};

var tool = webSearch.AsProtocolTool();

Assert.NotNull(tool?.XSearch);
Assert.Equal(tool.XSearch.FromDate, Timestamp.FromDateTime(webSearch.FromDate.Value.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)));
Assert.Equal(tool.XSearch.ToDate, Timestamp.FromDateTime(webSearch.ToDate.Value.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)));
}

[Fact]
public void AsTool_WithFunctionTool()
{
var functionTool = AIFunctionFactory.Create(() => "", "Name", "Description");

var tool = functionTool.AsProtocolTool();

Assert.NotNull(tool?.Function);
Assert.Equal("Name", tool.Function.Name);
Assert.Equal("Description", tool.Function.Description);
}

[Fact]
public void AsTool_WithCodeExecution()
{
var codeTool = new HostedCodeInterpreterTool();

var tool = codeTool.AsProtocolTool();

Assert.NotNull(tool?.CodeExecution);
}

[Fact]
public void AsTool_WithHostedFileSearchTool()
{
var collectionId = Guid.NewGuid().ToString();
var instructions = "Return N/A if no results found";
var fileSearch = new HostedFileSearchTool()
{
MaximumResultCount = 50,
Inputs = [new HostedVectorStoreContent(collectionId)]
}.WithInstructions(instructions);

var tool = fileSearch.AsProtocolTool();

Assert.NotNull(tool?.CollectionsSearch);
Assert.Contains(collectionId, tool.CollectionsSearch.CollectionIds);
Assert.Equal(50, tool.CollectionsSearch.Limit);
Assert.Equal(instructions, tool.CollectionsSearch.Instructions);
}

[Fact]
public void AsTool_WithHostedMcpTool()
{
var accessToken = Guid.NewGuid().ToString();
var headers = new Dictionary<string, string>
{
["foo"] = "baz"
};
var mcpTool = new HostedMcpServerTool("foo", "foo.com", new Dictionary<string, object?>
{
["x-extra"] = "bar",
[nameof(MCP.ExtraHeaders)] = headers
})
{
AllowedTools = ["list"],
AuthorizationToken = accessToken,
};

var tool = mcpTool.AsProtocolTool();

Assert.NotNull(tool?.Mcp);
Assert.Equal("foo", tool.Mcp.ServerLabel);
Assert.Equal("foo.com", tool.Mcp.ServerUrl);
Assert.Contains("list", tool.Mcp.AllowedToolNames);
Assert.Equal(accessToken, tool.Mcp.Authorization);
Assert.Contains(KeyValuePair.Create("x-extra", "bar"), tool.Mcp.ExtraHeaders);
Assert.Contains(KeyValuePair.Create("foo", "baz"), tool.Mcp.ExtraHeaders);
}
}
36 changes: 35 additions & 1 deletion src/xAI/Extensions/ChatExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
using Microsoft.Extensions.AI;
using System.ComponentModel;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;
using xAI.Protocol;

namespace xAI;

/// <summary>Extensions for <see cref="ChatOptions"/>.</summary>
[EditorBrowsable(EditorBrowsableState.Never)]
public static partial class ChatOptionsExtensions
{
extension(ChatOptions options)
Expand All @@ -14,4 +18,34 @@ public string? EndUserId
set => (options.AdditionalProperties ??= [])["EndUserId"] = value;
}
}
}

/// <summary>Grok-specific extensions for <see cref="HostedFileSearchTool"/>.</summary>
[EditorBrowsable(EditorBrowsableState.Never)]
public static partial class HostedFileSearchToolExtensions
{
extension(HostedFileSearchTool tool)
{
/// <summary>
/// User-defined instructions to be included in the search query. Defaults to generic search
/// instructions used by the collections search backend if unset.
/// </summary>
public HostedFileSearchTool WithInstructions(string instructions) => new(new Dictionary<string, object?>
{
[nameof(CollectionsSearch.Instructions)] = Throw.IfNullOrEmpty(instructions)
})
{
Inputs = tool.Inputs,
MaximumResultCount = tool.MaximumResultCount,
};
}
}

static partial class AIToolExtensions
{
extension(AITool tool)
{
public T? GetProperty<T>(string name) =>
tool.AdditionalProperties?.TryGetValue(name, out var value) is true && value is T typed ? typed : default;
}
}
78 changes: 2 additions & 76 deletions src/xAI/GrokChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -328,82 +328,8 @@ codeResult.RawRepresentation is ToolCall codeToolCall &&

if (options?.Tools is not null)
{
foreach (var tool in options.Tools)
{
if (tool is AIFunction functionTool)
{
var function = new Function
{
Name = functionTool.Name,
Description = functionTool.Description,
Parameters = JsonSerializer.Serialize(functionTool.JsonSchema)
};
request.Tools.Add(new Tool { Function = function });
}
else if (tool is HostedWebSearchTool webSearchTool)
{
if (webSearchTool is GrokXSearchTool xSearch)
{
var toolProto = new XSearch
{
EnableImageUnderstanding = xSearch.EnableImageUnderstanding,
EnableVideoUnderstanding = xSearch.EnableVideoUnderstanding,
};

if (xSearch.AllowedHandles is { } allowed) toolProto.AllowedXHandles.AddRange(allowed);
if (xSearch.ExcludedHandles is { } excluded) toolProto.ExcludedXHandles.AddRange(excluded);
if (xSearch.FromDate is { } from) toolProto.FromDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(from.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc));
if (xSearch.ToDate is { } to) toolProto.ToDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(to.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc));

request.Tools.Add(new Tool { XSearch = toolProto });
}
else if (webSearchTool is GrokSearchTool grokSearch)
{
var toolProto = new WebSearch
{
EnableImageUnderstanding = grokSearch.EnableImageUnderstanding,
};

if (grokSearch.AllowedDomains is { } allowed) toolProto.AllowedDomains.AddRange(allowed);
if (grokSearch.ExcludedDomains is { } excluded) toolProto.ExcludedDomains.AddRange(excluded);

request.Tools.Add(new Tool { WebSearch = toolProto });
}
else
{
request.Tools.Add(new Tool { WebSearch = new WebSearch() });
}
}
else if (tool is HostedCodeInterpreterTool)
{
request.Tools.Add(new Tool { CodeExecution = new CodeExecution { } });
}
else if (tool is HostedFileSearchTool fileSearch)
{
var toolProto = new CollectionsSearch();

if (fileSearch.Inputs?.OfType<HostedVectorStoreContent>() is { } vectorStores)
toolProto.CollectionIds.AddRange(vectorStores.Select(x => x.VectorStoreId).Distinct());

if (fileSearch.MaximumResultCount is { } maxResults)
toolProto.Limit = maxResults;

request.Tools.Add(new Tool { CollectionsSearch = toolProto });
}
else if (tool is HostedMcpServerTool mcpTool)
{
request.Tools.Add(new Tool
{
Mcp = new MCP
{
Authorization = mcpTool.AuthorizationToken,
ServerLabel = mcpTool.ServerName,
ServerUrl = mcpTool.ServerAddress,
AllowedToolNames = { mcpTool.AllowedTools ?? Array.Empty<string>() }
}
});
}
}
foreach (var tool in options.Tools.Select(x => x.AsProtocolTool(options)))
if (tool is not null) request.Tools.Add(tool);
}

if (options?.ResponseFormat is ChatResponseFormatJson)
Expand Down
Loading