Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate identity auth implementation #101

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions samples/embeddings/python/.funcignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.git*
.vscode
__azurite_db*__.json
__blobstorage__
__queuestorage__
local.settings.json
test
.venv
1 change: 1 addition & 0 deletions samples/rag-aisearch/csharp-ooproc/host.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"openai": {
"searchProvider": {
"type": "azureAiSearch",
"aiSearchConnectionNamePrefix": "AISearch",
"isSemanticSearchEnabled": true,
"useSemanticCaptions": true,
"vectorSearchDimensions": 1536
Expand Down
8 changes: 8 additions & 0 deletions samples/rag-aisearch/python/.funcignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.git*
.vscode
__azurite_db*__.json
__blobstorage__
__queuestorage__
local.settings.json
test
.venv
8 changes: 8 additions & 0 deletions samples/rag-cosmosdb/python/.funcignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.git*
.vscode
__azurite_db*__.json
__blobstorage__
__queuestorage__
local.settings.json
test
.venv
8 changes: 8 additions & 0 deletions samples/rag-kusto/python/.funcignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.git*
.vscode
__azurite_db*__.json
__blobstorage__
__queuestorage__
local.settings.json
test
.venv
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ public class AzureAISearchConfigOptions

public int VectorSearchDimensions { get; set; } = 1536;

public string? SearchAPIKeySetting { get; set; }
public string? SearchConnectionNamePrefix { get; set; }
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Collections.Concurrent;
using Azure;
using Azure.Identity;
using Azure.Core;
using Azure.Search.Documents;
using Azure.Search.Documents.Indexes;
using Azure.Search.Documents.Indexes.Models;
using Azure.Search.Documents.Models;
using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search;
using Microsoft.Extensions.Azure;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
Expand All @@ -16,12 +18,17 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.AzureAISearch;

sealed class AzureAISearchProvider : ISearchProvider
{
readonly ConcurrentDictionary<string, (SearchClient, string, string)> searchClients = new(); // value is client, endpoint, indexName
readonly ConcurrentDictionary<string, (SearchIndexClient, string)> searchIndexClients = new(); // value is client, endpoint
readonly ConcurrentDictionary<string, TokenCredential> tokenCredentials = new(); // sectionNamePrefix as key and token credential as value

readonly IConfiguration configuration;
readonly ILogger logger;
readonly AzureComponentFactory azureComponentFactory;
readonly bool isSemanticSearchEnabled = false;
readonly bool useSemanticCaptions = false;
readonly int vectorSearchDimensions = 1536;
readonly string searchAPIKeySetting = "SearchAPIKey";
readonly string searchConnectionNamePrefix = "AISearch";
const string defaultSearchIndexName = "openai-index";
const string vectorSearchConfigName = "openai-vector-config";
const string vectorSearchProfile = "openai-vector-profile";
Expand All @@ -34,9 +41,10 @@ sealed class AzureAISearchProvider : ISearchProvider
/// <param name="configuration">The configuration.</param>
/// <param name="loggerFactory">The logger factory.</param>
/// <exception cref="ArgumentNullException">Throws ArgumentNullException if logger factory is null.</exception>
public AzureAISearchProvider(IConfiguration configuration, ILoggerFactory loggerFactory, IOptions<AzureAISearchConfigOptions> azureAiSearchConfigOptions)
public AzureAISearchProvider(IConfiguration configuration, ILoggerFactory loggerFactory, IOptions<AzureAISearchConfigOptions> azureAiSearchConfigOptions, AzureComponentFactory azureComponentFactory)
{
this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory));

if (loggerFactory == null)
{
Expand All @@ -45,7 +53,7 @@ public AzureAISearchProvider(IConfiguration configuration, ILoggerFactory logger

this.isSemanticSearchEnabled = azureAiSearchConfigOptions.Value.IsSemanticSearchEnabled;
this.useSemanticCaptions = azureAiSearchConfigOptions.Value.UseSemanticCaptions;
this.searchAPIKeySetting = azureAiSearchConfigOptions.Value.SearchAPIKeySetting ?? this.searchAPIKeySetting;
this.searchConnectionNamePrefix = azureAiSearchConfigOptions.Value.SearchConnectionNamePrefix ?? this.searchConnectionNamePrefix;
int value = azureAiSearchConfigOptions.Value.VectorSearchDimensions;
if (value < 2 || value > 3072)
{
Expand All @@ -69,10 +77,9 @@ public async Task AddDocumentAsync(SearchableDocument document, CancellationToke
{
throw new ArgumentNullException(nameof(document.ConnectionInfo));
}
string endpoint = this.configuration.GetValue<string>(document.ConnectionInfo.ConnectionName);

SearchIndexClient searchIndexClient = this.GetSearchIndexClient(endpoint);
SearchClient searchClient = this.GetSearchClient(endpoint, document.ConnectionInfo.CollectionName ?? defaultSearchIndexName);
SearchIndexClient searchIndexClient = this.GetSearchIndexClient(document.ConnectionInfo);
SearchClient searchClient = this.GetSearchClient(document.ConnectionInfo);

await this.CreateIndexIfDoesntExist(searchIndexClient, document.ConnectionInfo.CollectionName ?? defaultSearchIndexName, cancellationToken);

Expand All @@ -98,8 +105,7 @@ public async Task<SearchResponse> SearchAsync(SearchRequest request)
throw new ArgumentNullException(nameof(request.ConnectionInfo));
}

string endpoint = this.configuration.GetValue<string>(request.ConnectionInfo.ConnectionName);
SearchClient searchClient = this.GetSearchClient(endpoint, request.ConnectionInfo.CollectionName ?? defaultSearchIndexName);
SearchClient searchClient = this.GetSearchClient(request.ConnectionInfo);

SearchOptions searchOptions = this.isSemanticSearchEnabled
? new SearchOptions
Expand Down Expand Up @@ -269,32 +275,46 @@ async Task IndexDocumentsBatchAsync(SearchClient searchClient, IndexDocumentsBat
succeeded);
}

SearchIndexClient GetSearchIndexClient(string endpoint)
SearchIndexClient GetSearchIndexClient(ConnectionInfo connectionInfo)
{
string? key = this.configuration.GetValue<string>(this.searchAPIKeySetting);
if (string.IsNullOrEmpty(key))
{
return new SearchIndexClient(new Uri(endpoint), new DefaultAzureCredential());
}
else
{
return new SearchIndexClient(new Uri(endpoint), new AzureKeyCredential(key));
}
(SearchIndexClient searchIndexClient, string endpoint) =
this.searchIndexClients.GetOrAdd(
connectionInfo.ConnectionName,
name =>
{
string endpoint = this.configuration.GetValue<string>(connectionInfo.ConnectionName);
return (new SearchIndexClient(new Uri(endpoint), this.GetSearchTokenCredential()), endpoint);
});

return searchIndexClient;

}

SearchClient GetSearchClient(string endpoint, string searchIndexName)
SearchClient GetSearchClient(ConnectionInfo connectionInfo)
{
string? key = this.configuration.GetValue<string>(this.searchAPIKeySetting);
SearchClient searchClient;
if (string.IsNullOrEmpty(key))
{
searchClient = new SearchClient(new Uri(endpoint), searchIndexName, new DefaultAzureCredential());
}
else
{
searchClient = new SearchClient(new Uri(endpoint), searchIndexName, new AzureKeyCredential(key));
}
(SearchClient searchClient, string endpoint, string searchIndexName) =
this.searchClients.GetOrAdd(
connectionInfo.ConnectionName,
name =>
{
string endpoint = this.configuration.GetValue<string>(connectionInfo.ConnectionName);
string searchIndexName = connectionInfo.CollectionName ?? defaultSearchIndexName;
searchClient = new SearchClient(new Uri(endpoint), searchIndexName, this.GetSearchTokenCredential());
return (searchClient, endpoint, searchIndexName);
});

return searchClient;
}

TokenCredential GetSearchTokenCredential()
{
IConfigurationSection searchConnectionConfigSection = this.configuration.GetSection(this.searchConnectionNamePrefix);
TokenCredential tokenCredential = this.tokenCredentials.GetOrAdd(
this.searchConnectionNamePrefix,
name =>
{
return this.azureComponentFactory.CreateTokenCredential(searchConnectionConfigSection);
});
return tokenCredential;
}
}
7 changes: 7 additions & 0 deletions src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v0.4.0 - Unreleased

### Breaking

- Managed identity support and consistency established with other Azure Functions extensions
-

## v0.3.0 - 2024/10/08

### Changed
Expand Down
Loading
Loading