diff --git a/src/Client/Core/DependencyInjection/DefaultDurableTaskClientProvider.cs b/src/Client/Core/DependencyInjection/DefaultDurableTaskClientProvider.cs index 8ac73bd1..d8105958 100644 --- a/src/Client/Core/DependencyInjection/DefaultDurableTaskClientProvider.cs +++ b/src/Client/Core/DependencyInjection/DefaultDurableTaskClientProvider.cs @@ -10,13 +10,13 @@ namespace Microsoft.DurableTask.Client; /// class DefaultDurableTaskClientProvider : IDurableTaskClientProvider { - readonly IEnumerable clients; + readonly IEnumerable clients; /// /// Initializes a new instance of the class. /// /// The set of clients. - public DefaultDurableTaskClientProvider(IEnumerable clients) + public DefaultDurableTaskClientProvider(IEnumerable clients) { this.clients = clients; } @@ -25,7 +25,7 @@ public DefaultDurableTaskClientProvider(IEnumerable clients) public DurableTaskClient GetClient(string? name = null) { name ??= Options.DefaultName; - DurableTaskClient? client = this.clients.FirstOrDefault( + ClientContainer? client = this.clients.FirstOrDefault( x => string.Equals(name, x.Name, StringComparison.Ordinal)); // options are case sensitive. if (client is null) @@ -35,6 +35,31 @@ public DurableTaskClient GetClient(string? name = null) nameof(name), name, $"The value of this argument must be in the set of available clients: [{names}]."); } - return client; + return client.Client; + } + + /// + /// Container for holding a client in memory. + /// + internal class ClientContainer + { + /// + /// Initializes a new instance of the class. + /// + /// The client. + public ClientContainer(DurableTaskClient client) + { + this.Client = Check.NotNull(client); + } + + /// + /// Gets the client name. + /// + public string Name => this.Client.Name; + + /// + /// Gets the client. + /// + public DurableTaskClient Client { get; } } } diff --git a/src/Client/Core/DependencyInjection/ServiceCollectionExtensions.cs b/src/Client/Core/DependencyInjection/ServiceCollectionExtensions.cs index 5a453ada..db1eca95 100644 --- a/src/Client/Core/DependencyInjection/ServiceCollectionExtensions.cs +++ b/src/Client/Core/DependencyInjection/ServiceCollectionExtensions.cs @@ -43,7 +43,16 @@ public static IServiceCollection AddDurableTaskClient( // The added toggle logic is because we cannot use TryAddEnumerable logic as // we would have to dynamically compile a lambda to have it work correctly. ConfigureDurableOptions(services, name); - services.AddSingleton(sp => builder.Build(sp)); + + // We do not want to register DurableTaskClient type directly so we can keep a max of 1 DurableTaskClients + // registered, allowing for direct-DI of the default client. + services.AddSingleton(sp => new DefaultDurableTaskClientProvider.ClientContainer(builder.Build(sp))); + + if (name == Options.DefaultName) + { + // If we have the default options name here, we will inject this client directly. + builder.RegisterDirectly(); + } } return services; diff --git a/test/Client/Core.Tests/DependencyInjection/DefaultDurableTaskClientProviderTests.cs b/test/Client/Core.Tests/DependencyInjection/DefaultDurableTaskClientProviderTests.cs index ba6a7250..7aec3aac 100644 --- a/test/Client/Core.Tests/DependencyInjection/DefaultDurableTaskClientProviderTests.cs +++ b/test/Client/Core.Tests/DependencyInjection/DefaultDurableTaskClientProviderTests.cs @@ -40,12 +40,12 @@ public void GetClient_Found_Returns(params string[] clients) client.Name.Should().Be("client1"); } - static List CreateClients(params string[] names) + static List CreateClients(params string[] names) { return names.Select(n => { Mock client = new(n, new DurableTaskClientOptions()); - return client.Object; + return new DefaultDurableTaskClientProvider.ClientContainer(client.Object); }).ToList(); } } diff --git a/test/Client/Core.Tests/DependencyInjection/ServiceCollectionExtensionsTests.cs b/test/Client/Core.Tests/DependencyInjection/ServiceCollectionExtensionsTests.cs index 54efd672..7347d61f 100644 --- a/test/Client/Core.Tests/DependencyInjection/ServiceCollectionExtensionsTests.cs +++ b/test/Client/Core.Tests/DependencyInjection/ServiceCollectionExtensionsTests.cs @@ -28,6 +28,19 @@ public void AddDurableTaskClient_HostedServiceAdded() services.AddDurableTaskClient(builder => { }); services.Should().ContainSingle( x => x.ServiceType == typeof(IDurableTaskClientProvider) && x.Lifetime == ServiceLifetime.Singleton); + services.Should().ContainSingle( + x => x.ServiceType == typeof(DurableTaskClient) && x.Lifetime == ServiceLifetime.Singleton); + } + + [Fact] + public void AddDurableTaskClient_Named_HostedServiceAdded() + { + ServiceCollection services = new(); + services.AddDurableTaskClient("named", builder => { }); + services.Should().ContainSingle( + x => x.ServiceType == typeof(IDurableTaskClientProvider) && x.Lifetime == ServiceLifetime.Singleton); + services.Should().NotContain( + x => x.ServiceType == typeof(DurableTaskClient) && x.Lifetime == ServiceLifetime.Singleton); } [Fact]