From ccf2bfdce35c34f858f2aa966dcebb5e60f95736 Mon Sep 17 00:00:00 2001 From: Whit Waldo Date: Wed, 11 Dec 2024 13:42:23 -0600 Subject: [PATCH] Conversation builder consistency changes (#1423) * Corrected several unit tests Signed-off-by: Whit Waldo * Updated extension name for consistency Signed-off-by: Whit Waldo * Updated registration name for consistency Signed-off-by: Whit Waldo --------- Signed-off-by: Whit Waldo --- examples/AI/ConversationalAI/Program.cs | 2 +- .../DaprAiConversationBuilderExtensions.cs | 2 +- ...DaprAiConversationBuilderExtensionsTest.cs | 121 +++++++++++++++++- ...aprJobsServiceCollectionExtensionsTests.cs | 10 +- 4 files changed, 121 insertions(+), 14 deletions(-) diff --git a/examples/AI/ConversationalAI/Program.cs b/examples/AI/ConversationalAI/Program.cs index bd3dc906a..6315db87a 100644 --- a/examples/AI/ConversationalAI/Program.cs +++ b/examples/AI/ConversationalAI/Program.cs @@ -3,7 +3,7 @@ var builder = WebApplication.CreateBuilder(args); -builder.Services.AddDaprAiConversation(); +builder.Services.AddDaprConversationClient(); var app = builder.Build(); diff --git a/src/Dapr.AI/Conversation/Extensions/DaprAiConversationBuilderExtensions.cs b/src/Dapr.AI/Conversation/Extensions/DaprAiConversationBuilderExtensions.cs index 902fd82a3..2f049a906 100644 --- a/src/Dapr.AI/Conversation/Extensions/DaprAiConversationBuilderExtensions.cs +++ b/src/Dapr.AI/Conversation/Extensions/DaprAiConversationBuilderExtensions.cs @@ -26,7 +26,7 @@ public static class DaprAiConversationBuilderExtensions /// Registers the necessary functionality for the Dapr AI conversation functionality. /// /// - public static IDaprAiConversationBuilder AddDaprAiConversation(this IServiceCollection services, Action? configure = null, ServiceLifetime lifetime = ServiceLifetime.Singleton) + public static IDaprAiConversationBuilder AddDaprConversationClient(this IServiceCollection services, Action? configure = null, ServiceLifetime lifetime = ServiceLifetime.Singleton) { ArgumentNullException.ThrowIfNull(services, nameof(services)); diff --git a/test/Dapr.AI.Test/Conversation/Extensions/DaprAiConversationBuilderExtensionsTest.cs b/test/Dapr.AI.Test/Conversation/Extensions/DaprAiConversationBuilderExtensionsTest.cs index 95a8e1e8c..2ee321895 100644 --- a/test/Dapr.AI.Test/Conversation/Extensions/DaprAiConversationBuilderExtensionsTest.cs +++ b/test/Dapr.AI.Test/Conversation/Extensions/DaprAiConversationBuilderExtensionsTest.cs @@ -13,7 +13,9 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Net.Http; +using System.Threading.Tasks; using Dapr.AI.Conversation; using Dapr.AI.Conversation.Extensions; using Microsoft.Extensions.Configuration; @@ -34,7 +36,7 @@ public void AddDaprConversationClient_FromIConfiguration() var services = new ServiceCollection(); services.AddSingleton(configuration); - services.AddDaprAiConversation(); + services.AddDaprConversationClient(); var app = services.BuildServiceProvider(); @@ -45,18 +47,66 @@ public void AddDaprConversationClient_FromIConfiguration() } [Fact] - public void AddDaprAiConversation_WithoutConfigure_ShouldAddServices() + public void AddDaprConversationClient_RegistersDaprClientOnlyOnce() { var services = new ServiceCollection(); - var builder = services.AddDaprAiConversation(); + + var clientBuilder = new Action((sp, builder) => + { + builder.UseDaprApiToken("abc"); + }); + + services.AddDaprConversationClient(); //Sets a default API token value of an empty string + services.AddDaprConversationClient(clientBuilder); //Sets the API token value + + var serviceProvider = services.BuildServiceProvider(); + var daprConversationClient = serviceProvider.GetService(); + + Assert.NotNull(daprConversationClient!.HttpClient); + Assert.False(daprConversationClient.HttpClient.DefaultRequestHeaders.TryGetValues("dapr-api-token", out var _)); + } + + [Fact] + public void AddDaprConversationClient_RegistersUsingDependencyFromIServiceProvider() + { + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddDaprConversationClient((provider, builder) => + { + var configProvider = provider.GetRequiredService(); + var apiToken = configProvider.GetApiTokenValue(); + builder.UseDaprApiToken(apiToken); + }); + + var serviceProvider = services.BuildServiceProvider(); + var client = serviceProvider.GetRequiredService(); + + //Validate it's set on the GrpcClient - note that it doesn't get set on the HttpClient + Assert.NotNull(client); + Assert.NotNull(client.DaprApiToken); + Assert.Equal("abcdef", client.DaprApiToken); + Assert.NotNull(client.HttpClient); + + if (!client.HttpClient.DefaultRequestHeaders.TryGetValues("dapr-api-token", out var daprApiToken)) + { + Assert.Fail(); + } + Assert.Equal("abcdef", daprApiToken.FirstOrDefault()); + } + + [Fact] + public void AddDaprConversationClient_WithoutConfigure_ShouldAddServices() + { + var services = new ServiceCollection(); + var builder = services.AddDaprConversationClient(); Assert.NotNull(builder); } [Fact] - public void AddDaprAiConversation_RegistersIHttpClientFactory() + public void AddDaprConversationClient_RegistersIHttpClientFactory() { var services = new ServiceCollection(); - services.AddDaprAiConversation(); + services.AddDaprConversationClient(); var serviceProvider = services.BuildServiceProvider(); var httpClientFactory = serviceProvider.GetService(); @@ -67,9 +117,66 @@ public void AddDaprAiConversation_RegistersIHttpClientFactory() } [Fact] - public void AddDaprAiConversation_NullServices_ShouldThrowException() + public void AddDaprConversationClient_NullServices_ShouldThrowException() { IServiceCollection services = null; - Assert.Throws(() => services.AddDaprAiConversation()); + Assert.Throws(() => services.AddDaprConversationClient()); + } + + [Fact] + public void AddDaprConversationClient_ShouldRegisterSingleton_WhenLifetimeIsSingleton() + { + var services = new ServiceCollection(); + + services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Singleton); + var serviceProvider = services.BuildServiceProvider(); + + var daprConversationClient1 = serviceProvider.GetService(); + var daprConversationClient2 = serviceProvider.GetService(); + + Assert.NotNull(daprConversationClient1); + Assert.NotNull(daprConversationClient2); + + Assert.Same(daprConversationClient1, daprConversationClient2); + } + + [Fact] + public async Task AddDaprConversationClient_ShouldRegisterScoped_WhenLifetimeIsScoped() + { + var services = new ServiceCollection(); + + services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Scoped); + var serviceProvider = services.BuildServiceProvider(); + + await using var scope1 = serviceProvider.CreateAsyncScope(); + var daprConversationClient1 = scope1.ServiceProvider.GetService(); + + await using var scope2 = serviceProvider.CreateAsyncScope(); + var daprConversationClient2 = scope2.ServiceProvider.GetService(); + + Assert.NotNull(daprConversationClient1); + Assert.NotNull(daprConversationClient2); + Assert.NotSame(daprConversationClient1, daprConversationClient2); + } + + [Fact] + public void AddDaprConversationClient_ShouldRegisterTransient_WhenLifetimeIsTransient() + { + var services = new ServiceCollection(); + + services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Transient); + var serviceProvider = services.BuildServiceProvider(); + + var daprConversationClient1 = serviceProvider.GetService(); + var daprConversationClient2 = serviceProvider.GetService(); + + Assert.NotNull(daprConversationClient1); + Assert.NotNull(daprConversationClient2); + Assert.NotSame(daprConversationClient1, daprConversationClient2); + } + + private class TestSecretRetriever + { + public string GetApiTokenValue() => "abcdef"; } } diff --git a/test/Dapr.Jobs.Test/Extensions/DaprJobsServiceCollectionExtensionsTests.cs b/test/Dapr.Jobs.Test/Extensions/DaprJobsServiceCollectionExtensionsTests.cs index bd5e4acd0..3b2c5f990 100644 --- a/test/Dapr.Jobs.Test/Extensions/DaprJobsServiceCollectionExtensionsTests.cs +++ b/test/Dapr.Jobs.Test/Extensions/DaprJobsServiceCollectionExtensionsTests.cs @@ -89,7 +89,7 @@ public void AddDaprJobsClient_RegistersUsingDependencyFromIServiceProvider() services.AddDaprJobsClient((provider, builder) => { var configProvider = provider.GetRequiredService(); - var apiToken = TestSecretRetriever.GetApiTokenValue(); + var apiToken = configProvider.GetApiTokenValue(); builder.UseDaprApiToken(apiToken); }); @@ -114,7 +114,7 @@ public void RegisterJobsClient_ShouldRegisterSingleton_WhenLifetimeIsSingleton() { var services = new ServiceCollection(); - services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Singleton); + services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Singleton); var serviceProvider = services.BuildServiceProvider(); var daprJobsClient1 = serviceProvider.GetService(); @@ -131,7 +131,7 @@ public async Task RegisterJobsClient_ShouldRegisterScoped_WhenLifetimeIsScoped() { var services = new ServiceCollection(); - services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Scoped); + services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Scoped); var serviceProvider = services.BuildServiceProvider(); await using var scope1 = serviceProvider.CreateAsyncScope(); @@ -150,7 +150,7 @@ public void RegisterJobsClient_ShouldRegisterTransient_WhenLifetimeIsTransient() { var services = new ServiceCollection(); - services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Transient); + services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Transient); var serviceProvider = services.BuildServiceProvider(); var daprJobsClient1 = serviceProvider.GetService(); @@ -163,6 +163,6 @@ public void RegisterJobsClient_ShouldRegisterTransient_WhenLifetimeIsTransient() private class TestSecretRetriever { - public static string GetApiTokenValue() => "abcdef"; + public string GetApiTokenValue() => "abcdef"; } }