diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantChatClient.cs new file mode 100644 index 00000000000..c3aab83da61 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantChatClient.cs @@ -0,0 +1,446 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; +using OpenAI.Assistants; + +#pragma warning disable CA1031 // Do not catch general exception types +#pragma warning disable SA1005 // Single line comments should begin with single space +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable S125 // Sections of code should not be commented out +#pragma warning disable S907 // "goto" statement should not be used +#pragma warning disable S1067 // Expressions should not be too complex +#pragma warning disable S1751 // Loops with at most one iteration should be refactored +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields +#pragma warning disable S4456 // Parameter validation in yielding methods should be wrapped +#pragma warning disable S4457 // Parameter validation in "async"/"await" methods should be wrapped + +namespace Microsoft.Extensions.AI; + +/// Represents an for an Azure.AI.Agents.Persistent . +[Experimental("OPENAI001")] +internal sealed partial class OpenAIAssistantChatClient : IChatClient +{ + /// The underlying . + private readonly AssistantClient _client; + + /// Metadata for the client. + private readonly ChatClientMetadata _metadata; + + /// The ID of the agent to use. + private readonly string _assistantId; + + /// The thread ID to use if none is supplied in . + private readonly string? _defaultThreadId; + + /// Initializes a new instance of the class for the specified . + public OpenAIAssistantChatClient(AssistantClient assistantClient, string assistantId, string? defaultThreadId) + { + _client = Throw.IfNull(assistantClient); + _assistantId = Throw.IfNullOrWhitespace(assistantId); + + _defaultThreadId = defaultThreadId; + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + Uri providerUrl = typeof(AssistantClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(assistantClient) as Uri ?? OpenAIResponseChatClient.DefaultOpenAIEndpoint; + + _metadata = new("openai", providerUrl); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) => + serviceType is null ? throw new ArgumentNullException(nameof(serviceType)) : + serviceKey is not null ? null : + serviceType == typeof(ChatClientMetadata) ? _metadata : + serviceType == typeof(AssistantClient) ? _client : + serviceType.IsInstanceOfType(this) ? this : + null; + + /// + public Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + GetStreamingResponseAsync(messages, options, cancellationToken).ToChatResponseAsync(cancellationToken); + + /// + public async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + // Extract necessary state from messages and options. + (RunCreationOptions runOptions, List? toolResults) = CreateRunOptions(messages, options); + + // Get the thread ID. + string? threadId = options?.ConversationId ?? _defaultThreadId; + if (threadId is null && toolResults is not null) + { + Throw.ArgumentException(nameof(messages), "No thread ID was provided, but chat messages includes tool results."); + } + + // Get any active run ID for this thread. This is necessary in case a thread has been left with an + // active run, in which all attempts other than submitting tools will fail. We thus need to cancel + // any active run on the thread. + ThreadRun? threadRun = null; + if (threadId is not null) + { + await foreach (var run in _client.GetRunsAsync( + threadId, + new RunCollectionOptions { Order = RunCollectionOrder.Descending, PageSizeLimit = 1 }, + cancellationToken: cancellationToken).ConfigureAwait(false)) + { + if (run.Status != RunStatus.Completed && run.Status != RunStatus.Cancelled && run.Status != RunStatus.Failed && run.Status != RunStatus.Expired) + { + threadRun = run; + } + + break; + } + } + + // Submit the request. + IAsyncEnumerable updates; + if (threadRun is not null && + ConvertFunctionResultsToToolOutput(toolResults, out List? toolOutputs) is { } toolRunId && + toolRunId == threadRun.Id) + { + // There's an active run and we have tool results to submit, so submit the results and continue streaming. + // This is going to ignore any additional messages in the run options, as we are only submitting tool outputs, + // but there doesn't appear to be a way to submit additional messages, and having such additional messages is rare. + updates = _client.SubmitToolOutputsToRunStreamingAsync(threadRun.ThreadId, threadRun.Id, toolOutputs, cancellationToken); + } + else + { + if (threadId is null) + { + // No thread ID was provided, so create a new thread. + ThreadCreationOptions threadCreationOptions = new(); + foreach (var message in runOptions.AdditionalMessages) + { + threadCreationOptions.InitialMessages.Add(message); + } + + runOptions.AdditionalMessages.Clear(); + + var thread = await _client.CreateThreadAsync(threadCreationOptions, cancellationToken).ConfigureAwait(false); + threadId = thread.Value.Id; + } + else if (threadRun is not null) + { + // There was an active run; we need to cancel it before starting a new run. + _ = await _client.CancelRunAsync(threadId, threadRun.Id, cancellationToken).ConfigureAwait(false); + threadRun = null; + } + + // Now create a new run and stream the results. + updates = _client.CreateRunStreamingAsync( + threadId: threadId, + _assistantId, + runOptions, + cancellationToken); + } + + // Process each update. + string? responseId = null; + await foreach (var update in updates.ConfigureAwait(false)) + { + switch (update) + { + case ThreadUpdate tu: + threadId ??= tu.Value.Id; + goto default; + + case RunUpdate ru: + threadId ??= ru.Value.ThreadId; + responseId ??= ru.Value.Id; + + ChatResponseUpdate ruUpdate = new() + { + AuthorName = _assistantId, + ConversationId = threadId, + CreatedAt = ru.Value.CreatedAt, + MessageId = responseId, + ModelId = ru.Value.Model, + RawRepresentation = ru, + ResponseId = responseId, + Role = ChatRole.Assistant, + }; + + if (ru.Value.Usage is { } usage) + { + ruUpdate.Contents.Add(new UsageContent(new() + { + InputTokenCount = usage.InputTokenCount, + OutputTokenCount = usage.OutputTokenCount, + TotalTokenCount = usage.TotalTokenCount, + })); + } + + if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName) + { + ruUpdate.Contents.Add( + new FunctionCallContent( + JsonSerializer.Serialize([ru.Value.Id, toolCallId], AssistantJsonContext.Default.StringArray), + functionName, + JsonSerializer.Deserialize(rau.FunctionArguments, AssistantJsonContext.Default.IDictionaryStringObject)!)); + } + + yield return ruUpdate; + break; + + case MessageContentUpdate mcu: + yield return new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text) + { + AuthorName = _assistantId, + ConversationId = threadId, + MessageId = responseId, + RawRepresentation = mcu, + ResponseId = responseId, + }; + break; + + default: + yield return new ChatResponseUpdate + { + AuthorName = _assistantId, + ConversationId = threadId, + MessageId = responseId, + RawRepresentation = update, + ResponseId = responseId, + Role = ChatRole.Assistant, + }; + break; + } + } + } + + /// + void IDisposable.Dispose() + { + // nop + } + + /// + /// Creates the to use for the request and extracts any function result contents + /// that need to be submitted as tool results. + /// + private (RunCreationOptions RunOptions, List? ToolResults) CreateRunOptions( + IEnumerable messages, ChatOptions? options) + { + // Create the options instance to populate, either a fresh or using one the caller provides. + RunCreationOptions runOptions = + options?.RawRepresentationFactory?.Invoke(this) as RunCreationOptions ?? + new(); + + // Populate the run options from the ChatOptions, if provided. + if (options is not null) + { + runOptions.MaxOutputTokenCount ??= options.MaxOutputTokens; + runOptions.ModelOverride ??= options.ModelId; + runOptions.NucleusSamplingFactor ??= options.TopP; + runOptions.Temperature ??= options.Temperature; + runOptions.AllowParallelToolCalls ??= options.AllowMultipleToolCalls; + + if (options.Tools is { Count: > 0 } tools) + { + // The caller can provide tools in the supplied ThreadAndRunOptions. Augment it with any supplied via ChatOptions.Tools. + foreach (AITool tool in tools) + { + switch (tool) + { + case AIFunction aiFunction: + bool? strict = aiFunction.AdditionalProperties.TryGetValue(nameof(strict), out var strictValue) && strictValue is bool strictBool ? + strictBool : + null; + runOptions.ToolsOverride.Add(new FunctionToolDefinition(aiFunction.Name) + { + Description = aiFunction.Description, + Parameters = BinaryData.FromBytes(JsonSerializer.SerializeToUtf8Bytes(aiFunction.JsonSchema, AssistantJsonContext.Default.JsonElement)), + StrictParameterSchemaEnabled = strict, + }); + break; + + case HostedCodeInterpreterTool: + runOptions.ToolsOverride.Add(new CodeInterpreterToolDefinition()); + break; + } + } + } + + // Store the tool mode, if relevant. + if (runOptions.ToolConstraint is null) + { + switch (options.ToolMode) + { + case NoneChatToolMode: + runOptions.ToolConstraint = ToolConstraint.None; + break; + + case null: + case AutoChatToolMode: + runOptions.ToolConstraint = ToolConstraint.Auto; + break; + + case RequiredChatToolMode required when required.RequiredFunctionName is { } functionName: + runOptions.ToolConstraint = new ToolConstraint(ToolDefinition.CreateFunction(functionName)); + break; + + case RequiredChatToolMode required: + runOptions.ToolConstraint = ToolConstraint.Required; + break; + } + } + + // Store the response format, if relevant. + if (runOptions.ResponseFormat is null) + { + switch (options.ResponseFormat) + { + case ChatResponseFormatText: + runOptions.ResponseFormat = AssistantResponseFormat.CreateTextFormat(); + break; + + case ChatResponseFormatJson jsonFormat when jsonFormat.Schema is not null: + runOptions.ResponseFormat = AssistantResponseFormat.CreateJsonSchemaFormat( + jsonFormat.SchemaName, + BinaryData.FromBytes(JsonSerializer.SerializeToUtf8Bytes(jsonFormat.Schema, AssistantJsonContext.Default.JsonElement)), + jsonFormat.SchemaDescription); + break; + + case ChatResponseFormatJson jsonFormat: + runOptions.ResponseFormat = AssistantResponseFormat.CreateJsonObjectFormat(); + break; + } + } + } + + // Process ChatMessages. + StringBuilder? instructions = null; + List? functionResults = null; + foreach (var chatMessage in messages) + { + List messageContents = []; + + // Assistants doesn't support system/developer messages directly. It does support transient per-request instructions, + // so we can use the system/developer messages to build up a set of instructions that will be passed to the assistant + // as part of this request. However, in doing so, on a subsequent request that information will be lost, as there's no + // way to store per-thread instructions in the OpenAI Assistants API. We don't want to convert these to user messages, + // however, as that would then expose the system/developer messages in a way that might make the model more likely + // to include that information in its responses. System messages should ideally be instead done as instructions to + // the assistant when the assistant is created. + if (chatMessage.Role == ChatRole.System || + chatMessage.Role == OpenAIResponseChatClient.ChatRoleDeveloper) + { + instructions ??= new(); + foreach (var textContent in chatMessage.Contents.OfType()) + { + _ = instructions.Append(textContent); + } + + continue; + } + + foreach (AIContent content in chatMessage.Contents) + { + switch (content) + { + case TextContent text: + messageContents.Add(MessageContent.FromText(text.Text)); + break; + + case UriContent image when image.HasTopLevelMediaType("image"): + messageContents.Add(MessageContent.FromImageUri(image.Uri)); + break; + + // Assistants doesn't support data URIs. + //case DataContent image when image.HasTopLevelMediaType("image"): + // messageContents.Add(MessageContent.FromImageUri(new Uri(image.Uri))); + // break; + + case FunctionResultContent result: + (functionResults ??= []).Add(result); + break; + + case AIContent when content.RawRepresentation is MessageContent rawRep: + messageContents.Add(rawRep); + break; + } + } + + if (messageContents.Count > 0) + { + runOptions.AdditionalMessages.Add(new ThreadInitializationMessage( + chatMessage.Role == ChatRole.Assistant ? MessageRole.Assistant : MessageRole.User, + messageContents)); + } + } + + if (instructions is not null) + { + runOptions.AdditionalInstructions = instructions.ToString(); + } + + return (runOptions, functionResults); + } + + /// Convert instances to instances. + /// The tool results to process. + /// The generated list of tool outputs, if any could be created. + /// The run ID associated with the corresponding function call requests. + private static string? ConvertFunctionResultsToToolOutput(List? toolResults, out List? toolOutputs) + { + string? runId = null; + toolOutputs = null; + if (toolResults?.Count > 0) + { + foreach (var frc in toolResults) + { + // When creating the FunctionCallContext, we created it with a CallId == [runId, callId]. + // We need to extract the run ID and ensure that the ToolOutput we send back to Azure + // is only the call ID. + string[]? runAndCallIDs; + try + { + runAndCallIDs = JsonSerializer.Deserialize(frc.CallId, AssistantJsonContext.Default.StringArray); + } + catch + { + continue; + } + + if (runAndCallIDs is null || + runAndCallIDs.Length != 2 || + string.IsNullOrWhiteSpace(runAndCallIDs[0]) || // run ID + string.IsNullOrWhiteSpace(runAndCallIDs[1]) || // call ID + (runId is not null && runId != runAndCallIDs[0])) + { + continue; + } + + runId = runAndCallIDs[0]; + (toolOutputs ??= []).Add(new(runAndCallIDs[1], frc.Result?.ToString() ?? string.Empty)); + } + } + + return runId; + } + + [JsonSerializable(typeof(JsonElement))] + [JsonSerializable(typeof(string[]))] + [JsonSerializable(typeof(IDictionary))] + private sealed partial class AssistantJsonContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index cd46dae53b0..f97ebd492a7 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -34,9 +34,6 @@ internal sealed partial class OpenAIChatClient : IChatClient MoveDefaultKeywordToDescription = true, }); - /// Gets the default OpenAI endpoint. - private static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1"); - /// Metadata about the client. private readonly ChatClientMetadata _metadata; @@ -57,7 +54,7 @@ public OpenAIChatClient(ChatClient chatClient) // implement the abstractions directly rather than providing adapters on top of the public APIs, // the package can provide such implementations separate from what's exposed in the public API. Uri providerUrl = typeof(ChatClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) - ?.GetValue(chatClient) as Uri ?? DefaultOpenAIEndpoint; + ?.GetValue(chatClient) as Uri ?? OpenAIResponseChatClient.DefaultOpenAIEndpoint; string? model = typeof(ChatClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) ?.GetValue(chatClient) as string; @@ -113,8 +110,6 @@ void IDisposable.Dispose() // Nothing to dispose. Implementation required for the IChatClient interface. } - private static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer"); - /// Converts an Extensions chat message enumerable to an OpenAI chat message enumerable. private static IEnumerable ToOpenAIChatMessages(IEnumerable inputs, JsonSerializerOptions options) { @@ -125,12 +120,12 @@ void IDisposable.Dispose() { if (input.Role == ChatRole.System || input.Role == ChatRole.User || - input.Role == ChatRoleDeveloper) + input.Role == OpenAIResponseChatClient.ChatRoleDeveloper) { var parts = ToOpenAIChatContent(input.Contents); yield return input.Role == ChatRole.System ? new SystemChatMessage(parts) { ParticipantName = input.AuthorName } : - input.Role == ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } : + input.Role == OpenAIResponseChatClient.ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } : new UserChatMessage(parts) { ParticipantName = input.AuthorName }; } else if (input.Role == ChatRole.Tool) @@ -622,7 +617,7 @@ private static ChatRole FromOpenAIChatRole(ChatMessageRole role) => ChatMessageRole.User => ChatRole.User, ChatMessageRole.Assistant => ChatRole.Assistant, ChatMessageRole.Tool => ChatRole.Tool, - ChatMessageRole.Developer => ChatRoleDeveloper, + ChatMessageRole.Developer => OpenAIResponseChatClient.ChatRoleDeveloper, _ => new ChatRole(role.ToString()), }; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs index 81d2fe55a03..ea43b7e5e31 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs @@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis; using OpenAI; +using OpenAI.Assistants; using OpenAI.Audio; using OpenAI.Chat; using OpenAI.Embeddings; @@ -25,6 +26,19 @@ public static IChatClient AsIChatClient(this ChatClient chatClient) => public static IChatClient AsIChatClient(this OpenAIResponseClient responseClient) => new OpenAIResponseChatClient(responseClient); + /// Gets an for use with this . + /// The instance to be accessed as an . + /// The unique identifier of the assistant with which to interact. + /// + /// An optional existing thread identifier for the chat session. This serves as a default, and may be overridden per call to + /// or via the + /// property. If no thread ID is provided via either mechanism, a new thread will be created for the request. + /// + /// An instance configured to interact with the specified agent and thread. + [Experimental("OPENAI001")] + public static IChatClient AsIChatClient(this AssistantClient assistantClient, string assistantId, string? threadId = null) => + new OpenAIAssistantChatClient(assistantClient, assistantId, threadId); + /// Gets an for use with this . /// The client. /// An that can be used to transcribe audio via the . diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs index d3caee286be..34e6977e1f7 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs @@ -27,10 +27,10 @@ namespace Microsoft.Extensions.AI; internal sealed partial class OpenAIResponseChatClient : IChatClient { /// Gets the default OpenAI endpoint. - private static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1"); + internal static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1"); - /// A for "developer". - private static readonly ChatRole _chatRoleDeveloper = new("developer"); + /// Gets a for "developer". + internal static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer"); /// Metadata about the client. private readonly ChatClientMetadata _metadata; @@ -88,7 +88,7 @@ public async Task GetResponseAsync( // Convert and return the results. ChatResponse response = new() { - ConversationId = openAIResponse.Id, + ConversationId = openAIOptions.StoredOutputEnabled is false ? null : openAIResponse.Id, CreatedAt = openAIResponse.CreatedAt, FinishReason = ToFinishReason(openAIResponse.IncompleteStatusDetails?.Reason), Messages = [new(ChatRole.Assistant, [])], @@ -167,6 +167,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( // Make the call to the OpenAIResponseClient and process the streaming results. DateTimeOffset? createdAt = null; string? responseId = null; + string? conversationId = null; string? modelId = null; string? lastMessageId = null; ChatRole? lastRole = null; @@ -179,18 +180,19 @@ public async IAsyncEnumerable GetStreamingResponseAsync( case StreamingResponseCreatedUpdate createdUpdate: createdAt = createdUpdate.Response.CreatedAt; responseId = createdUpdate.Response.Id; + conversationId = openAIOptions.StoredOutputEnabled is false ? null : responseId; modelId = createdUpdate.Response.Model; goto default; case StreamingResponseCompletedUpdate completedUpdate: yield return new() { + Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [], + ConversationId = conversationId, + CreatedAt = createdAt, FinishReason = ToFinishReason(completedUpdate.Response?.IncompleteStatusDetails?.Reason) ?? (functionCallInfos is not null ? ChatFinishReason.ToolCalls : ChatFinishReason.Stop), - Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [], - ConversationId = responseId, - CreatedAt = createdAt, MessageId = lastMessageId, ModelId = modelId, RawRepresentation = streamingUpdate, @@ -223,7 +225,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( lastRole = ToChatRole(messageItem?.Role); yield return new ChatResponseUpdate(lastRole, outputTextDeltaUpdate.Delta) { - ConversationId = responseId, + ConversationId = conversationId, CreatedAt = createdAt, MessageId = lastMessageId, ModelId = modelId, @@ -258,7 +260,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( lastRole = ChatRole.Assistant; yield return new ChatResponseUpdate(lastRole, [fci]) { - ConversationId = responseId, + ConversationId = conversationId, CreatedAt = createdAt, MessageId = lastMessageId, ModelId = modelId, @@ -275,7 +277,6 @@ public async IAsyncEnumerable GetStreamingResponseAsync( case StreamingResponseErrorUpdate errorUpdate: yield return new ChatResponseUpdate { - ConversationId = responseId, Contents = [ new ErrorContent(errorUpdate.Message) @@ -284,6 +285,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( Details = errorUpdate.Param, } ], + ConversationId = conversationId, CreatedAt = createdAt, MessageId = lastMessageId, ModelId = modelId, @@ -296,21 +298,21 @@ public async IAsyncEnumerable GetStreamingResponseAsync( case StreamingResponseRefusalDoneUpdate refusalDone: yield return new ChatResponseUpdate { + Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }], + ConversationId = conversationId, CreatedAt = createdAt, MessageId = lastMessageId, ModelId = modelId, RawRepresentation = streamingUpdate, ResponseId = responseId, Role = lastRole, - ConversationId = responseId, - Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }], }; break; default: yield return new ChatResponseUpdate { - ConversationId = responseId, + ConversationId = conversationId, CreatedAt = createdAt, MessageId = lastMessageId, ModelId = modelId, @@ -334,7 +336,7 @@ private static ChatRole ToChatRole(MessageRole? role) => role switch { MessageRole.System => ChatRole.System, - MessageRole.Developer => _chatRoleDeveloper, + MessageRole.Developer => ChatRoleDeveloper, MessageRole.User => ChatRole.User, _ => ChatRole.Assistant, }; @@ -452,7 +454,7 @@ private static IEnumerable ToOpenAIResponseItems( foreach (ChatMessage input in inputs) { if (input.Role == ChatRole.System || - input.Role == _chatRoleDeveloper) + input.Role == ChatRoleDeveloper) { string text = input.Text; if (!string.IsNullOrWhiteSpace(text)) diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index f28374e7d79..f34f930f3ea 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -618,9 +618,9 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange // Second time, the calls to the LLM don't happen, but the function is called again var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Equal(response.Text, secondResponse.Text); Assert.Equal(2, functionCallCount); Assert.Equal(FunctionInvokingChatClientSetsConversationId ? 3 : 2, llmCallCount!.CallCount); + Assert.Equal(response.Text, secondResponse.Text); } public virtual bool FunctionInvokingChatClientSetsConversationId => false; diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs index ea87408da38..5a7bf0b246e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs @@ -52,7 +52,7 @@ private static BinaryEmbedding QuantizeToBinary(Embedding embedding) { if (vector[i] > 0) { - result[i / 8] = true; + result[i] = true; } } diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIAssistantChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIAssistantChatClientIntegrationTests.cs new file mode 100644 index 00000000000..e616d5fb87b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIAssistantChatClientIntegrationTests.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable CA1822 // Mark members as static +#pragma warning disable CA2000 // Dispose objects before losing scope +#pragma warning disable S1135 // Track uses of "TODO" tags +#pragma warning disable xUnit1013 // Public method should be marked as test + +using System; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using OpenAI.Assistants; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class OpenAIAssistantChatClientIntegrationTests : ChatClientIntegrationTests +{ + protected override IChatClient? CreateChatClient() + { + var openAIClient = IntegrationTestHelpers.GetOpenAIClient(); + if (openAIClient is null) + { + return null; + } + + AssistantClient ac = openAIClient.GetAssistantClient(); + var assistant = + ac.GetAssistants().FirstOrDefault() ?? + ac.CreateAssistant("gpt-4o-mini"); + + return ac.AsIChatClient(assistant.Id); + } + + public override bool FunctionInvokingChatClientSetsConversationId => true; + + // These tests aren't written in a way that works well with threads. + public override Task Caching_AfterFunctionInvocation_FunctionOutputChangedAsync() => Task.CompletedTask; + public override Task Caching_AfterFunctionInvocation_FunctionOutputUnchangedAsync() => Task.CompletedTask; + + // Assistants doesn't support data URIs. + public override Task MultiModal_DescribeImage() => Task.CompletedTask; + public override Task MultiModal_DescribePdf() => Task.CompletedTask; + + // [Fact] // uncomment and run to clear out _all_ threads in your OpenAI account + public async Task DeleteAllThreads() + { + using HttpClient client = new(new HttpClientHandler + { + AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, + }); + + // These values need to be filled in. The bearer token needs to be sniffed from a browser + // session interacting with the dashboard (e.g. use F12 networking tools to look at request headers + // made to "https://api.openai.com/v1/threads?limit=10" after clicking on Assistants | Threads in the + // OpenAI portal dashboard). + client.DefaultRequestHeaders.Add("authorization", $"Bearer sess-ENTERYOURSESSIONTOKEN"); + client.DefaultRequestHeaders.Add("openai-organization", "org-ENTERYOURORGID"); + client.DefaultRequestHeaders.Add("openai-project", "proj_ENTERYOURPROJECTID"); + + AssistantClient ac = new AssistantClient(Environment.GetEnvironmentVariable("AI:OpenAI:ApiKey")!); + while (true) + { + string listing = await client.GetStringAsync("https://api.openai.com/v1/threads?limit=100"); + + var matches = Regex.Matches(listing, @"thread_\w+"); + if (matches.Count == 0) + { + break; + } + + foreach (Match m in matches) + { + var dr = await ac.DeleteThreadAsync(m.Value); + Assert.True(dr.Value.Deleted); + } + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIAssistantChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIAssistantChatClientTests.cs new file mode 100644 index 00000000000..6d3a02a08ec --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIAssistantChatClientTests.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using Azure.AI.OpenAI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using OpenAI; +using OpenAI.Assistants; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace Microsoft.Extensions.AI; + +public class OpenAIAssistantChatClientTests +{ + [Fact] + public void AsIChatClient_InvalidArgs_Throws() + { + Assert.Throws("assistantClient", () => ((AssistantClient)null!).AsIChatClient("assistantId")); + Assert.Throws("assistantId", () => new AssistantClient("ignored").AsIChatClient(null!)); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void AsIChatClient_OpenAIClient_ProducesExpectedMetadata(bool useAzureOpenAI) + { + Uri endpoint = new("http://localhost/some/endpoint"); + + var client = useAzureOpenAI ? + new AzureOpenAIClient(endpoint, new ApiKeyCredential("key")) : + new OpenAIClient(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint }); + + IChatClient[] clients = + [ + client.GetAssistantClient().AsIChatClient("assistantId"), + client.GetAssistantClient().AsIChatClient("assistantId", "threadId"), + ]; + + foreach (var chatClient in clients) + { + var metadata = chatClient.GetService(); + Assert.Equal("openai", metadata?.ProviderName); + Assert.Equal(endpoint, metadata?.ProviderUri); + } + } + + [Fact] + public void GetService_AssistantClient_SuccessfullyReturnsUnderlyingClient() + { + AssistantClient assistantClient = new OpenAIClient("key").GetAssistantClient(); + IChatClient chatClient = assistantClient.AsIChatClient("assistantId"); + + Assert.Same(assistantClient, chatClient.GetService()); + + Assert.Null(chatClient.GetService()); + + using IChatClient pipeline = chatClient + .AsBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Build(); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(assistantClient, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientIntegrationTests.cs index 2c1d6cdc80e..f8e835bdb81 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientIntegrationTests.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Threading.Tasks; + namespace Microsoft.Extensions.AI; public class OpenAIResponseClientIntegrationTests : ChatClientIntegrationTests @@ -11,4 +13,7 @@ public class OpenAIResponseClientIntegrationTests : ChatClientIntegrationTests .AsIChatClient(); public override bool FunctionInvokingChatClientSetsConversationId => true; + + // Test structure doesn't make sense with Respones. + public override Task Caching_AfterFunctionInvocation_FunctionOutputUnchangedAsync() => Task.CompletedTask; }