diff --git a/samples/embeddings/python/.funcignore b/samples/embeddings/python/.funcignore new file mode 100644 index 00000000..9966315f --- /dev/null +++ b/samples/embeddings/python/.funcignore @@ -0,0 +1,8 @@ +.git* +.vscode +__azurite_db*__.json +__blobstorage__ +__queuestorage__ +local.settings.json +test +.venv \ No newline at end of file diff --git a/samples/rag-aisearch/csharp-ooproc/host.json b/samples/rag-aisearch/csharp-ooproc/host.json index 590b2e78..e5ad6d35 100644 --- a/samples/rag-aisearch/csharp-ooproc/host.json +++ b/samples/rag-aisearch/csharp-ooproc/host.json @@ -9,6 +9,7 @@ "openai": { "searchProvider": { "type": "azureAiSearch", + "aiSearchConnectionNamePrefix": "AISearch", "isSemanticSearchEnabled": true, "useSemanticCaptions": true, "vectorSearchDimensions": 1536 diff --git a/samples/rag-aisearch/python/.funcignore b/samples/rag-aisearch/python/.funcignore new file mode 100644 index 00000000..9966315f --- /dev/null +++ b/samples/rag-aisearch/python/.funcignore @@ -0,0 +1,8 @@ +.git* +.vscode +__azurite_db*__.json +__blobstorage__ +__queuestorage__ +local.settings.json +test +.venv \ No newline at end of file diff --git a/samples/rag-cosmosdb/python/.funcignore b/samples/rag-cosmosdb/python/.funcignore new file mode 100644 index 00000000..9966315f --- /dev/null +++ b/samples/rag-cosmosdb/python/.funcignore @@ -0,0 +1,8 @@ +.git* +.vscode +__azurite_db*__.json +__blobstorage__ +__queuestorage__ +local.settings.json +test +.venv \ No newline at end of file diff --git a/samples/rag-kusto/python/.funcignore b/samples/rag-kusto/python/.funcignore new file mode 100644 index 00000000..9966315f --- /dev/null +++ b/samples/rag-kusto/python/.funcignore @@ -0,0 +1,8 @@ +.git* +.vscode +__azurite_db*__.json +__blobstorage__ +__queuestorage__ +local.settings.json +test +.venv \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchConfigOptions.cs b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchConfigOptions.cs index 7337478f..32f62249 100644 --- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchConfigOptions.cs +++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchConfigOptions.cs @@ -14,5 +14,5 @@ public class AzureAISearchConfigOptions public int VectorSearchDimensions { get; set; } = 1536; - public string? SearchAPIKeySetting { get; set; } + public string? SearchConnectionNamePrefix { get; set; } } diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs index 3d52cbe6..56a4d50e 100644 --- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Collections.Concurrent; using Azure; -using Azure.Identity; +using Azure.Core; using Azure.Search.Documents; using Azure.Search.Documents.Indexes; using Azure.Search.Documents.Indexes.Models; using Azure.Search.Documents.Models; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; +using Microsoft.Extensions.Azure; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -16,12 +18,17 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.AzureAISearch; sealed class AzureAISearchProvider : ISearchProvider { + readonly ConcurrentDictionary searchClients = new(); // value is client, endpoint, indexName + readonly ConcurrentDictionary searchIndexClients = new(); // value is client, endpoint + readonly ConcurrentDictionary tokenCredentials = new(); // sectionNamePrefix as key and token credential as value + readonly IConfiguration configuration; readonly ILogger logger; + readonly AzureComponentFactory azureComponentFactory; readonly bool isSemanticSearchEnabled = false; readonly bool useSemanticCaptions = false; readonly int vectorSearchDimensions = 1536; - readonly string searchAPIKeySetting = "SearchAPIKey"; + readonly string searchConnectionNamePrefix = "AISearch"; const string defaultSearchIndexName = "openai-index"; const string vectorSearchConfigName = "openai-vector-config"; const string vectorSearchProfile = "openai-vector-profile"; @@ -34,9 +41,10 @@ sealed class AzureAISearchProvider : ISearchProvider /// The configuration. /// The logger factory. /// Throws ArgumentNullException if logger factory is null. - public AzureAISearchProvider(IConfiguration configuration, ILoggerFactory loggerFactory, IOptions azureAiSearchConfigOptions) + public AzureAISearchProvider(IConfiguration configuration, ILoggerFactory loggerFactory, IOptions azureAiSearchConfigOptions, AzureComponentFactory azureComponentFactory) { this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory)); if (loggerFactory == null) { @@ -45,7 +53,7 @@ public AzureAISearchProvider(IConfiguration configuration, ILoggerFactory logger this.isSemanticSearchEnabled = azureAiSearchConfigOptions.Value.IsSemanticSearchEnabled; this.useSemanticCaptions = azureAiSearchConfigOptions.Value.UseSemanticCaptions; - this.searchAPIKeySetting = azureAiSearchConfigOptions.Value.SearchAPIKeySetting ?? this.searchAPIKeySetting; + this.searchConnectionNamePrefix = azureAiSearchConfigOptions.Value.SearchConnectionNamePrefix ?? this.searchConnectionNamePrefix; int value = azureAiSearchConfigOptions.Value.VectorSearchDimensions; if (value < 2 || value > 3072) { @@ -69,10 +77,9 @@ public async Task AddDocumentAsync(SearchableDocument document, CancellationToke { throw new ArgumentNullException(nameof(document.ConnectionInfo)); } - string endpoint = this.configuration.GetValue(document.ConnectionInfo.ConnectionName); - SearchIndexClient searchIndexClient = this.GetSearchIndexClient(endpoint); - SearchClient searchClient = this.GetSearchClient(endpoint, document.ConnectionInfo.CollectionName ?? defaultSearchIndexName); + SearchIndexClient searchIndexClient = this.GetSearchIndexClient(document.ConnectionInfo); + SearchClient searchClient = this.GetSearchClient(document.ConnectionInfo); await this.CreateIndexIfDoesntExist(searchIndexClient, document.ConnectionInfo.CollectionName ?? defaultSearchIndexName, cancellationToken); @@ -98,8 +105,7 @@ public async Task SearchAsync(SearchRequest request) throw new ArgumentNullException(nameof(request.ConnectionInfo)); } - string endpoint = this.configuration.GetValue(request.ConnectionInfo.ConnectionName); - SearchClient searchClient = this.GetSearchClient(endpoint, request.ConnectionInfo.CollectionName ?? defaultSearchIndexName); + SearchClient searchClient = this.GetSearchClient(request.ConnectionInfo); SearchOptions searchOptions = this.isSemanticSearchEnabled ? new SearchOptions @@ -269,32 +275,46 @@ async Task IndexDocumentsBatchAsync(SearchClient searchClient, IndexDocumentsBat succeeded); } - SearchIndexClient GetSearchIndexClient(string endpoint) + SearchIndexClient GetSearchIndexClient(ConnectionInfo connectionInfo) { - string? key = this.configuration.GetValue(this.searchAPIKeySetting); - if (string.IsNullOrEmpty(key)) - { - return new SearchIndexClient(new Uri(endpoint), new DefaultAzureCredential()); - } - else - { - return new SearchIndexClient(new Uri(endpoint), new AzureKeyCredential(key)); - } + (SearchIndexClient searchIndexClient, string endpoint) = + this.searchIndexClients.GetOrAdd( + connectionInfo.ConnectionName, + name => + { + string endpoint = this.configuration.GetValue(connectionInfo.ConnectionName); + return (new SearchIndexClient(new Uri(endpoint), this.GetSearchTokenCredential()), endpoint); + }); + + return searchIndexClient; + } - SearchClient GetSearchClient(string endpoint, string searchIndexName) + SearchClient GetSearchClient(ConnectionInfo connectionInfo) { - string? key = this.configuration.GetValue(this.searchAPIKeySetting); - SearchClient searchClient; - if (string.IsNullOrEmpty(key)) - { - searchClient = new SearchClient(new Uri(endpoint), searchIndexName, new DefaultAzureCredential()); - } - else - { - searchClient = new SearchClient(new Uri(endpoint), searchIndexName, new AzureKeyCredential(key)); - } + (SearchClient searchClient, string endpoint, string searchIndexName) = + this.searchClients.GetOrAdd( + connectionInfo.ConnectionName, + name => + { + string endpoint = this.configuration.GetValue(connectionInfo.ConnectionName); + string searchIndexName = connectionInfo.CollectionName ?? defaultSearchIndexName; + searchClient = new SearchClient(new Uri(endpoint), searchIndexName, this.GetSearchTokenCredential()); + return (searchClient, endpoint, searchIndexName); + }); return searchClient; } + + TokenCredential GetSearchTokenCredential() + { + IConfigurationSection searchConnectionConfigSection = this.configuration.GetSection(this.searchConnectionNamePrefix); + TokenCredential tokenCredential = this.tokenCredentials.GetOrAdd( + this.searchConnectionNamePrefix, + name => + { + return this.azureComponentFactory.CreateTokenCredential(searchConnectionConfigSection); + }); + return tokenCredential; + } } diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md index 70a13ac6..c9ce6812 100644 --- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md +++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## v0.4.0 - Unreleased + +### Breaking + +- Managed identity support and consistency established with other Azure Functions extensions +- + ## v0.3.0 - 2024/10/08 ### Changed diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs index e31e8a3b..83e8b94e 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs @@ -8,6 +8,10 @@ using Microsoft.Extensions.Azure; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using OpenAI; +using OpenAI.Assistants; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; @@ -27,8 +31,10 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List const int FunctionCallBatchLimit = 50; + + readonly TableClient tableClient; + readonly TableServiceClient tableServiceClient; const string DefaultChatStorage = "AzureWebJobsStorage"; - readonly OpenAIClient openAIClient; readonly IAssistantSkillInvoker skillInvoker; readonly ILogger logger; readonly AzureComponentFactory azureComponentFactory; @@ -37,7 +43,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List openAiConfigOptions, AzureComponentFactory azureComponentFactory, IConfiguration configuration, IAssistantSkillInvoker skillInvoker, @@ -49,7 +55,6 @@ public DefaultAssistantService( } this.skillInvoker = skillInvoker ?? throw new ArgumentNullException(nameof(skillInvoker)); - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); this.logger = loggerFactory.CreateLogger(); this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory)); @@ -117,7 +122,7 @@ async Task DeleteBatch() partitionKey: request.Id, messageIndex: 1, // 1-based index content: request.Instructions, - role: ChatRole.System); + role: ChatMessageRole.System); batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity)); } @@ -211,43 +216,44 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib partitionKey: attribute.Id, messageIndex: ++chatState.Metadata.TotalMessages, content: attribute.UserMessage, - role: ChatRole.User); + role: ChatMessageRole.User); chatState.Messages.Add(chatMessageEntity); // Add the chat message to the batch batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity)); string deploymentName = attribute.Model ?? OpenAIModels.DefaultChatModel; - IList? functions = this.skillInvoker.GetFunctionsDefinitions(); + IList? functions = this.skillInvoker.GetFunctionsDefinitions(); + AzureOpenAIClient azureOpenAIClient = new AzureOpenAIClient(); + OpenAIClientOptions clientOptions = new OpenAIClientOptions(); + ChatClient chatClient = new ChatClient(deploymentName, clientOptions); // We loop if the model returns function calls. Otherwise, we break after receiving a response. while (true) { // Get the next response from the LLM - ChatCompletionsOptions chatRequest = new(deploymentName, ToOpenAIChatRequestMessages(chatState.Messages)); + ChatCompletionOptions chatRequest = new (deploymentName, ToOpenAIChatRequestMessages(chatState.Messages)); if (functions is not null) { - foreach (ChatCompletionsFunctionToolDefinition fn in functions) + foreach (FunctionToolDefinition fn in functions) { chatRequest.Tools.Add(fn); } } - Response response = await this.openAIClient.GetChatCompletionsAsync( - chatRequest, - cancellationToken); + var respone = chatClient.CompleteChatAsync(chatState.Messages, cancellationToken); // We don't normally expect more than one message, but just in case we get multiple messages, // return all of them separated by two newlines. string replyMessage = string.Join( Environment.NewLine + Environment.NewLine, - response.Value.Choices.Select(choice => choice.Message.Content)); + response.Value.Content); if (!string.IsNullOrWhiteSpace(replyMessage)) { this.logger.LogInformation( "[{Id}] Got LLM response consisting of {Count} tokens: {Text}", attribute.Id, - response.Value.Usage.CompletionTokens, + response.Value.Usage.TotalTokens, replyMessage); // Add the user message as a new Chat message entity @@ -255,7 +261,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib partitionKey: attribute.Id, messageIndex: ++chatState.Metadata.TotalMessages, content: replyMessage, - role: ChatRole.Assistant); + role: ChatMessageRole.Assistant); chatState.Messages.Add(replyFromAssistantEntity); // Add the reply from assistant chat message to the batch @@ -271,9 +277,8 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib chatState.Metadata.TotalTokens = response.Value.Usage.TotalTokens; // Check for function calls (which are described in the API as tools) - List functionCalls = response.Value.Choices - .SelectMany(c => c.Message.ToolCalls) - .OfType() + List functionCalls = response.Value.ToolCalls + .OfType() .ToList(); if (functionCalls.Count == 0) { @@ -301,14 +306,14 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib // Invoke the function calls and add the responses to the chat history. List> tasks = new(capacity: functionCalls.Count); - foreach (ChatCompletionsFunctionToolCall call in functionCalls) + foreach (FunctionToolDefinition call in functionCalls) { // CONSIDER: Call these in parallel this.logger.LogInformation( "[{Id}] Calling function '{Name}' with arguments: {Args}", attribute.Id, - call.Name, - call.Arguments); + call.FunctionName, + call.Parameters); string? functionResult; try @@ -320,7 +325,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib this.logger.LogInformation( "[{id}] Function '{Name}' returned the following content: {Content}", attribute.Id, - call.Name, + call.FunctionName, functionResult); } catch (Exception ex) @@ -329,7 +334,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib ex, "[{id}] Function '{Name}' failed with an unhandled exception", attribute.Id, - call.Name); + call.FunctionName); // CONSIDER: Automatic retries? functionResult = "The function call failed. Let the user know and ask if they'd like you to try again"; @@ -347,8 +352,8 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib partitionKey: attribute.Id, messageIndex: ++chatState.Metadata.TotalMessages, content: functionResult, - role: ChatRole.Function, - name: call.Name); + role: ChatMessageRole.Function, + name: call.FunctionName); chatState.Messages.Add(functionResultEntity); batch.Add(new TableTransactionAction(TableTransactionActionType.Add, functionResultEntity)); @@ -365,7 +370,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib // return the latest assistant message in the chat state List filteredChatMessages = chatState.Messages - .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatRole.Assistant) + .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatMessageRole.Assistant.ToString()) .ToList(); this.logger.LogInformation( @@ -381,7 +386,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib chatState.Metadata.LastUpdatedAt, chatState.Metadata.TotalMessages, chatState.Metadata.TotalTokens, - filteredChatMessages.Select(msg => new ChatMessage(msg.Content, msg.Role, msg.Name)).ToList()); + filteredChatMessages.Select(msg => new Models.ChatMessage(msg.Content, msg.Role, msg.Name)).ToList()); return state; } @@ -420,26 +425,26 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib return new InternalChatState(id, assistantStateEntity, chatMessageList); } - static IEnumerable ToOpenAIChatRequestMessages(IEnumerable entities) + static IEnumerable ToOpenAIChatRequestMessages(IEnumerable entities) { foreach (ChatMessageTableEntity entity in entities) { switch (entity.Role.ToLowerInvariant()) { case "user": - yield return new ChatRequestUserMessage(entity.Content); + yield return new UserChatMessage(entity.Content); break; case "assistant": - yield return new ChatRequestAssistantMessage(entity.Content); + yield return new AssistantChatMessage(entity.Content); break; case "system": - yield return new ChatRequestSystemMessage(entity.Content); + yield return new SystemChatMessage(entity.Content); break; case "function": - yield return new ChatRequestFunctionMessage(entity.Name, entity.Content); + yield return new FunctionChatMessage(entity.Name, entity.Content); break; case "tool": - yield return new ChatRequestToolMessage(entity.Content, toolCallId: entity.Name); + yield return new ToolChatMessage(toolCallId: entity.Name, entity.Content); break; default: throw new InvalidOperationException($"Unknown chat role '{entity.Role}'"); diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs b/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs index d3e539ba..43874583 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs @@ -4,17 +4,17 @@ using System.Reflection; using System.Runtime.ExceptionServices; using System.Text; -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Host.Executors; using Microsoft.Extensions.Logging; using Newtonsoft.Json; +using OpenAI.Assistants; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; public interface IAssistantSkillInvoker { - IList? GetFunctionsDefinitions(); - Task InvokeAsync(ChatCompletionsFunctionToolCall call, CancellationToken cancellationToken); + IList? GetFunctionsDefinitions(); + Task InvokeAsync(FunctionToolDefinition call, CancellationToken cancellationToken); } class SkillInvocationContext @@ -70,14 +70,14 @@ internal void UnregisterSkill(string name) this.skills.Remove(name); } - IList? IAssistantSkillInvoker.GetFunctionsDefinitions() + IList? IAssistantSkillInvoker.GetFunctionsDefinitions() { if (this.skills.Count == 0) { return null; } - List functions = new(capacity: this.skills.Count); + List functions = new(capacity: this.skills.Count); foreach (Skill skill in this.skills.Values) { // The parameters can be defined in the attribute JSON or can be inferred from @@ -85,9 +85,9 @@ internal void UnregisterSkill(string name) string parametersJson = skill.Attribute.ParameterDescriptionJson ?? JsonConvert.SerializeObject(GetParameterDefinition(skill)); - functions.Add(new ChatCompletionsFunctionToolDefinition + functions.Add(new FunctionToolDefinition { - Name = skill.Name, + FunctionName = skill.Name, Description = skill.Attribute.FunctionDescription, Parameters = BinaryData.FromBytes(Encoding.UTF8.GetBytes(parametersJson)), }); @@ -140,7 +140,7 @@ static Dictionary GetParameterDefinition(Skill skill) } async Task IAssistantSkillInvoker.InvokeAsync( - ChatCompletionsFunctionToolCall call, + FunctionToolDefinition call, CancellationToken cancellationToken) { if (call is null) @@ -148,14 +148,14 @@ static Dictionary GetParameterDefinition(Skill skill) throw new ArgumentNullException(nameof(call)); } - if (call.Name is null) + if (call.FunctionName is null) { throw new ArgumentException("The function call must have a name", nameof(call)); } - if (!this.skills.TryGetValue(call.Name, out Skill? skill)) + if (!this.skills.TryGetValue(call.FunctionName, out Skill? skill)) { - throw new InvalidOperationException($"No skill registered with name '{call.Name}'"); + throw new InvalidOperationException($"No skill registered with name '{call.FunctionName}'"); } SkillInvocationContext skillInvocationContext = new(call.Arguments); @@ -170,7 +170,7 @@ static Dictionary GetParameterDefinition(Skill skill) InvokeHandler = async userCodeInvoker => { // Invoke the function and attempt to get the result. - this.logger.LogInformation("Invoking user-code function '{Name}'", call.Name); + this.logger.LogInformation("Invoking user-code function '{Name}'", call.FunctionName); Task invokeTask = userCodeInvoker.Invoke(); if (invokeTask is Task invokeTaskWithResult) { diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs index 2035a189..2be6b3b8 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs @@ -47,7 +47,7 @@ async Task ConvertCoreAsync( EmbeddingsAttribute attribute, CancellationToken cancellationToken) { - OpenAISDK.EmbeddingsOptions request = await EmbeddingsHelper.BuildRequest(attribute.MaxOverlap, attribute.MaxChunkLength, attribute.Model, attribute.InputType, attribute.Input); + Embeddings request = await EmbeddingsHelper.BuildRequest(attribute.MaxOverlap, attribute.MaxChunkLength, attribute.Model, attribute.InputType, attribute.Input); this.logger.LogInformation("Sending OpenAI embeddings request: {request}", request); Response response = await this.openAIClient.GetEmbeddingsAsync(request, cancellationToken); this.logger.LogInformation("Received OpenAI embeddings response: {response}", response); diff --git a/src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs b/src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs deleted file mode 100644 index b52e0976..00000000 --- a/src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Newtonsoft.Json; - -namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; - -/// -/// Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable. -/// -[JsonObject(MemberSerialization.OptIn)] -public class ChatMessage -{ - /// - /// Initializes a new instance of the class. - /// - /// The content of the message. - /// The role of the chat agent. - public ChatMessage(string content, string role, string? name) - { - this.Content = content; - this.Role = role; - this.Name = name; - } - - /// - /// Gets or sets the content of the message. - /// - [JsonProperty("content")] - public string Content { get; set; } - - /// - /// Gets or sets the role of the chat agent. - /// - [JsonProperty("role")] - public string Role { get; set; } - - /// - /// Gets or sets the name of the calling function if applicable. - /// - [JsonProperty("name")] - public string? Name { get; set; } -} diff --git a/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs b/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs index 022350d9..5a453d45 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs @@ -4,6 +4,7 @@ using Azure; using Azure.AI.OpenAI; using Azure.Data.Tables; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; @@ -19,7 +20,7 @@ public ChatMessageTableEntity( string partitionKey, int messageIndex, string content, - ChatRole role, + ChatMessageRole role, string? name = null) { this.PartitionKey = partitionKey; diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs new file mode 100644 index 00000000..4d6f98e3 --- /dev/null +++ b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; +class OpenAIClientFactory +{ + public OpenAIClientFactory( + IConfiguration configuration, + ILogger logger) + { + + } +} diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs b/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs index 3b49603b..2b8bc633 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs @@ -31,18 +31,23 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) throw new ArgumentNullException(nameof(builder)); } + // Add AzureComponentFactory to the services + builder.Services.AddAzureClientsCore(); + // Register the client for Azure Open AI - Uri? azureOpenAIEndpoint = GetAzureOpenAIEndpoint(); string? openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); - string? azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY"); - if (azureOpenAIEndpoint != null && !string.IsNullOrEmpty(azureOpenAIKey)) - { - RegisterAzureOpenAIClient(builder.Services, azureOpenAIEndpoint, azureOpenAIKey); - } - else if (azureOpenAIEndpoint != null) + IConfigurationRoot configuration = new ConfigurationBuilder() + .AddEnvironmentVariables() + .Build(); + + IConfigurationSection azureOpenAIConfigSection = configuration.GetSection("AZURE_OPENAI"); + if (azureOpenAIConfigSection.Exists()) { - RegisterAzureOpenAIADAuthClient(builder.Services, azureOpenAIEndpoint); + builder.Services.AddAzureClients(clientBuilder => + { + clientBuilder.AddOpenAIClient(azureOpenAIConfigSection); + }); } else if (!string.IsNullOrEmpty(openAIKey)) { @@ -50,7 +55,7 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) } else { - throw new InvalidOperationException("Must set AZURE_OPENAI_ENDPOINT or OPENAI_API_KEY environment variables."); + throw new InvalidOperationException("Must set AZUREOPENAI configuration section (with Endpoint, Key or Credentials) or OPENAI_API_KEY environment variables."); } // Register the WebJobs extension, which enables the bindings. @@ -80,30 +85,6 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) return builder; } - static Uri? GetAzureOpenAIEndpoint() - { - if (Uri.TryCreate(Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"), UriKind.Absolute, out var uri)) - { - return uri; - } - - return null; - } - - static void RegisterAzureOpenAIClient(IServiceCollection services, Uri azureOpenAIEndpoint, string azureOpenAIKey) - { - services.AddAzureClients(clientBuilder => - { - clientBuilder.AddOpenAIClient(azureOpenAIEndpoint, new AzureKeyCredential(azureOpenAIKey)); - }); - } - - static void RegisterAzureOpenAIADAuthClient(IServiceCollection services, Uri azureOpenAIEndpoint) - { - var managedIdentityClient = new OpenAIClient(azureOpenAIEndpoint, new DefaultAzureCredential()); - services.AddSingleton(managedIdentityClient); - } - static void RegisterOpenAIClient(IServiceCollection services, string openAIKey) { var openAIClient = new OpenAIClient(openAIKey); diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs index a68d5e70..59a7bf45 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs @@ -8,7 +8,7 @@ using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAISDK = OpenAI; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -30,12 +30,10 @@ class SemanticSearchConverter : }; public SemanticSearchConverter( - OpenAISDK.OpenAIClient openAIClient, ILoggerFactory loggerFactory, IEnumerable searchProviders, IOptions openAiConfigOptions) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); openAiConfigOptions.Value.SearchProvider.TryGetValue("type", out object value); @@ -95,13 +93,13 @@ async Task ConvertHelperAsync( } // Call the chat API with the new combined prompt to get a response back - OpenAISDK.ChatCompletionsOptions chatCompletionsOptions = new() + OpenAISDK.Chat.ChatCompletionOptions chatCompletionsOptions = new() { DeploymentName = attribute.ChatModel, Messages = { - new OpenAISDK.ChatRequestSystemMessage(promptBuilder.ToString()), - new OpenAISDK.ChatRequestUserMessage(attribute.Query), + new OpenAISDK.Chat.SystemChatMessage(promptBuilder.ToString()), + new OpenAISDK.Chat.UserChatMessage(attribute.Query), } }; diff --git a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj index b32d3f60..02621e86 100644 --- a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj +++ b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj @@ -5,9 +5,8 @@ - + -