From b3dedff5d1a3bb712060b3e341059953675c5402 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 16 Apr 2025 15:35:18 +0200 Subject: [PATCH 1/2] Dependency Injection for SqlServer --- .../SqlServerServiceCollectionExtensions.cs | 133 ++++++++++++++++++ .../SqlServerDependencyInjectionTests.cs | 67 +++++++++ .../DependencyInjectionTests.cs | 131 +++++++++++++++++ .../VectorDataIntegrationTests.csproj | 2 + 4 files changed, 333 insertions(+) create mode 100644 dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs new file mode 100644 index 000000000000..127c9cba5bf8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// Extension methods to register instances on an . +/// +public static class SqlServerServiceCollectionExtensions +{ + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The connection string provider. + /// Options provider to further configure the vector store. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddSqlServerVectorStore(this IServiceCollection services, + Func connectionStringProvider, + Func? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + Verify.NotNull(services); + Verify.NotNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(IVectorStore), sp => + { + var connectionString = connectionStringProvider(sp); + var options = optionsProvider?.Invoke(sp); + return new SqlServerVectorStore(connectionString, options); + }, lifetime)); + + return services; + } + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the vector store. + /// The connection string provider. + /// Options provider to further configure the vector store. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddKeyedSqlServerVectorStore(this IServiceCollection services, + object serviceKey, + Func connectionStringProvider, + Func? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + Verify.NotNull(services); + Verify.NotNull(serviceKey); + Verify.NotNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(IVectorStore), serviceKey, (sp, _) => + { + var connectionString = connectionStringProvider(sp); + var options = optionsProvider?.Invoke(sp); + return new SqlServerVectorStore(connectionString, options); + }, lifetime)); + + return services; + } + + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The name of the collection. + /// The connection string provider. + /// Options provider to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddSqlServerVectorStoreCollection(this IServiceCollection services, + string collectionName, + Func connectionStringProvider, + Func>? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : notnull + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), sp => + { + var connectionString = connectionStringProvider(sp); + var options = optionsProvider?.Invoke(sp); + return new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options); + }, lifetime)); + + return services; + } + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the collection. + /// The name of the collection. + /// The connection string provider. + /// Options provider to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddKeyedSqlServerVectorStoreCollection(this IServiceCollection services, + object serviceKey, + string collectionName, + Func connectionStringProvider, + Func>? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : notnull + { + Verify.NotNull(services); + Verify.NotNull(serviceKey); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), serviceKey, (sp, _) => + { + var connectionString = connectionStringProvider(sp); + var options = optionsProvider?.Invoke(sp); + return new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options); + }, lifetime)); + + return services; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs new file mode 100644 index 000000000000..9f827ec41c20 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.SemanticKernel.Connectors.SqlServer; +using VectorDataSpecificationTests.DependencyInjection; +using VectorDataSpecificationTests.Models; +using Xunit; + +namespace SqlServerIntegrationTests.DependencyInjection; + +public class SqlServerDependencyInjectionTests : DependencyInjectionTests> +{ + protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null) + => configuration.AddInMemoryCollection( + [ + new(CreateConfigKey("SqlServer", serviceKey, "ConnectionString"), "Server=localhost;Database=master;Integrated Security=True;"), + ]); + + protected override void RegisterVectorStore(IServiceCollection services, ServiceLifetime lifetime, object? serviceKey = null) + { + if (serviceKey is null) + { + services.AddSqlServerVectorStore( + sp => sp.GetRequiredService().GetRequiredSection("SqlServer:ConnectionString").Value!, + lifetime: lifetime); + } + else + { + services.AddKeyedSqlServerVectorStore( + serviceKey, + sp => sp.GetRequiredService().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!, + lifetime: lifetime); + } + } + + protected override void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null) + { + if (serviceKey is null) + { + services.AddSqlServerVectorStoreCollection>( + collectionName, + sp => sp.GetRequiredService().GetRequiredSection("SqlServer:ConnectionString").Value!, + lifetime: lifetime); + } + else + { + services.AddKeyedSqlServerVectorStoreCollection>( + serviceKey, + collectionName, + sp => sp.GetRequiredService().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!, + lifetime: lifetime); + } + } + + [Fact] + public void ConnectionStringProviderCantBeNull() + { + HostApplicationBuilder builder = this.CreateHostBuilder(); + + Assert.Throws(() => builder.Services.AddSqlServerVectorStore(connectionStringProvider: null!)); + Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStore(serviceKey: "notNull", connectionStringProvider: null!)); + Assert.Throws(() => builder.Services.AddSqlServerVectorStoreCollection>(collectionName: "notNull", connectionStringProvider: null!)); + Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStoreCollection>(serviceKey: "notNull", collectionName: "notNull", connectionStringProvider: null!)); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs new file mode 100644 index 000000000000..38d147fde9c0 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.VectorData; +using Xunit; + +namespace VectorDataSpecificationTests.DependencyInjection; + +public abstract class DependencyInjectionTests + where TKey : notnull + where TRecord : notnull +{ + protected abstract void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null); + + protected abstract void RegisterVectorStore(IServiceCollection services, ServiceLifetime lifetime, object? serviceKey = null); + + protected abstract void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null); + + [Fact] + public void ServiceCollectionCantBeNull() + { + Assert.Throws(() => this.RegisterVectorStore(null!, ServiceLifetime.Singleton, serviceKey: null)); + Assert.Throws(() => this.RegisterVectorStore(null!, ServiceLifetime.Singleton, serviceKey: "notNull")); + Assert.Throws(() => this.RegisterCollection(null!, ServiceLifetime.Singleton, serviceKey: null)); + Assert.Throws(() => this.RegisterCollection(null!, ServiceLifetime.Singleton, serviceKey: "notNull")); + } + + [Fact] + public void CollectionNameCantBeNullOrEmpty() + { + HostApplicationBuilder builder = this.CreateHostBuilder(); + + Assert.Throws(() => this.RegisterCollection(builder.Services, ServiceLifetime.Singleton, collectionName: null!, serviceKey: null)); + Assert.Throws(() => this.RegisterCollection(builder.Services, ServiceLifetime.Singleton, collectionName: null!, serviceKey: "notNull")); + Assert.Throws(() => this.RegisterCollection(builder.Services, ServiceLifetime.Singleton, collectionName: "", serviceKey: null)); + Assert.Throws(() => this.RegisterCollection(builder.Services, ServiceLifetime.Singleton, collectionName: "", serviceKey: "notNull")); + } + +#pragma warning disable CA1000 // Do not declare static members on generic types + public static IEnumerable LiftetimesAndKeys() +#pragma warning restore CA1000 // Do not declare static members on generic types + { + foreach (ServiceLifetime lifetime in new ServiceLifetime[] { ServiceLifetime.Scoped, ServiceLifetime.Singleton, ServiceLifetime.Transient }) + { + yield return new object?[] { lifetime, null }; + yield return new object?[] { lifetime, "key" }; + yield return new object?[] { lifetime, 8 }; + } + } + + [Theory] + [MemberData(nameof(LiftetimesAndKeys))] + public void CanRegisterVectorStore(ServiceLifetime lifetime, object? serviceKey) + { + HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); + + this.RegisterVectorStore(builder.Services, lifetime, serviceKey); + + using IHost host = builder.Build(); + Verify(host, lifetime, serviceKey); + } + + [Theory] + [MemberData(nameof(LiftetimesAndKeys))] + public void CanRegisterCollections(ServiceLifetime lifetime, object? serviceKey) + { + HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); + + this.RegisterCollection(builder.Services, lifetime, serviceKey: serviceKey); + + using IHost host = builder.Build(); + Verify>(host, lifetime, serviceKey); + } + + protected HostApplicationBuilder CreateHostBuilder(object? serviceKey = null) + { + HostApplicationBuilder builder = Host.CreateEmptyApplicationBuilder(settings: null); + + this.PopulateConfiguration(builder.Configuration, serviceKey); + + return builder; + } + + private static void Verify(IHost host, ServiceLifetime lifetime, object? serviceKey) + where TService : class + { + TService? serviceFromFirstScope, serviceFromSecondScope, secondServiceFromSecondScope; + + using (IServiceScope scope1 = host.Services.CreateScope()) + { + serviceFromFirstScope = Resolve(scope1.ServiceProvider, serviceKey); + } + + using (IServiceScope scope2 = host.Services.CreateScope()) + { + serviceFromSecondScope = Resolve(scope2.ServiceProvider, serviceKey); + + secondServiceFromSecondScope = Resolve(scope2.ServiceProvider, serviceKey); + } + + Assert.NotNull(serviceFromFirstScope); + Assert.NotNull(serviceFromSecondScope); + Assert.NotNull(secondServiceFromSecondScope); + + switch (lifetime) + { + case ServiceLifetime.Singleton: + Assert.Same(serviceFromFirstScope, serviceFromSecondScope); + Assert.Same(serviceFromSecondScope, secondServiceFromSecondScope); + break; + case ServiceLifetime.Scoped: + Assert.NotSame(serviceFromFirstScope, serviceFromSecondScope); + Assert.Same(serviceFromSecondScope, secondServiceFromSecondScope); + break; + case ServiceLifetime.Transient: + Assert.NotSame(serviceFromFirstScope, serviceFromSecondScope); + Assert.NotSame(serviceFromSecondScope, secondServiceFromSecondScope); + break; + } + } + + protected static string CreateConfigKey(string prefix, object? serviceKey, string suffix) + => serviceKey is null ? $"{prefix}:{suffix}" : $"{prefix}:{serviceKey}:{suffix}"; + + private static TService Resolve(IServiceProvider serviceProvider, object? serviceKey = null) where TService : notnull + => serviceKey is null + ? serviceProvider.GetRequiredService() + : serviceProvider.GetRequiredKeyedService(serviceKey); +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj index 5b14dc1e41c1..3d146ba7cfa1 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj @@ -12,6 +12,8 @@ + + From 1b6f134cbfec7b0a5af71282568871353851f63f Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 24 Apr 2025 17:04:49 +0200 Subject: [PATCH 2/2] apply the changes we have agreed to: - register concrete, provider-specific types, not just the abstractions - have overloads that accept raw connection string (and just options, not an option provider) --- .../SqlServerServiceCollectionExtensions.cs | 87 ++++++++++++++- .../SqlServerDependencyInjectionTests.cs | 59 +++++++++- .../DependencyInjectionTests.cs | 102 +++++++++++++++++- 3 files changed, 240 insertions(+), 8 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs index 127c9cba5bf8..c8f69045082e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs @@ -2,6 +2,7 @@ using System; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.SqlServer; @@ -27,13 +28,19 @@ public static IServiceCollection AddSqlServerVectorStore(this IServiceCollection Verify.NotNull(services); Verify.NotNull(connectionStringProvider); - services.Add(new ServiceDescriptor(typeof(IVectorStore), sp => + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), sp => { var connectionString = connectionStringProvider(sp); var options = optionsProvider?.Invoke(sp); return new SqlServerVectorStore(connectionString, options); }, lifetime)); + // We try to add the SqlServerVectorStore as an IVectorStore, + // but if it already exists, we don't override it. + // Sample scenario: one app using two different vector stores. + services.TryAdd(new ServiceDescriptor(typeof(IVectorStore), + static sp => sp.GetRequiredService(), lifetime)); + return services; } @@ -56,13 +63,16 @@ public static IServiceCollection AddKeyedSqlServerVectorStore(this IServiceColle Verify.NotNull(serviceKey); Verify.NotNull(connectionStringProvider); - services.Add(new ServiceDescriptor(typeof(IVectorStore), serviceKey, (sp, _) => + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), serviceKey, (sp, _) => { var connectionString = connectionStringProvider(sp); var options = optionsProvider?.Invoke(sp); return new SqlServerVectorStore(connectionString, options); }, lifetime)); + services.TryAdd(new ServiceDescriptor(typeof(IVectorStore), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService(key), lifetime)); + return services; } @@ -87,13 +97,16 @@ public static IServiceCollection AddSqlServerVectorStoreCollection), sp => + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection), sp => { var connectionString = connectionStringProvider(sp); var options = optionsProvider?.Invoke(sp); return new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options); }, lifetime)); + services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), + static sp => sp.GetRequiredService>(), lifetime)); + return services; } @@ -121,13 +134,79 @@ public static IServiceCollection AddKeyedSqlServerVectorStoreCollection), serviceKey, (sp, _) => + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection), serviceKey, (sp, _) => { var connectionString = connectionStringProvider(sp); var options = optionsProvider?.Invoke(sp); return new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options); }, lifetime)); + services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); + + return services; + } + + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The name of the collection. + /// The connection string. + /// Options to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddSqlServerVectorStoreCollection(this IServiceCollection services, + string collectionName, + string connectionString, + SqlServerVectorStoreRecordCollectionOptions? options = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : notnull + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNullOrWhiteSpace(connectionString); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection), + sp => new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options), lifetime)); + + services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), + static sp => sp.GetRequiredService>(), lifetime)); + + return services; + } + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the collection. + /// The name of the collection. + /// The connection string. + /// Options to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddKeyedSqlServerVectorStoreCollection(this IServiceCollection services, + object serviceKey, + string collectionName, + string connectionString, + SqlServerVectorStoreRecordCollectionOptions? options = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : notnull + { + Verify.NotNull(services); + Verify.NotNull(serviceKey); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNullOrWhiteSpace(connectionString); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection), serviceKey, + (sp, _) => new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options), lifetime)); + + services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); + return services; } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs index 9f827ec41c20..399a5d134634 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs @@ -10,12 +10,15 @@ namespace SqlServerIntegrationTests.DependencyInjection; -public class SqlServerDependencyInjectionTests : DependencyInjectionTests> +public class SqlServerDependencyInjectionTests + : DependencyInjectionTests>, string, SimpleRecord> { + protected const string ConnectionString = "Server=localhost;Database=master;Integrated Security=True;"; + protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null) => configuration.AddInMemoryCollection( [ - new(CreateConfigKey("SqlServer", serviceKey, "ConnectionString"), "Server=localhost;Database=master;Integrated Security=True;"), + new(CreateConfigKey("SqlServer", serviceKey, "ConnectionString"), ConnectionString), ]); protected override void RegisterVectorStore(IServiceCollection services, ServiceLifetime lifetime, object? serviceKey = null) @@ -64,4 +67,56 @@ public void ConnectionStringProviderCantBeNull() Assert.Throws(() => builder.Services.AddSqlServerVectorStoreCollection>(collectionName: "notNull", connectionStringProvider: null!)); Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStoreCollection>(serviceKey: "notNull", collectionName: "notNull", connectionStringProvider: null!)); } + + [Fact] + public void ConnectionStringCantBeNullOrEmpty() + { + HostApplicationBuilder builder = this.CreateHostBuilder(); + + Assert.Throws(() => builder.Services.AddSqlServerVectorStoreCollection>( + collectionName: "notNull", connectionString: null!)); + Assert.Throws(() => builder.Services.AddSqlServerVectorStoreCollection>( + collectionName: "notNull", connectionString: "")); + Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStoreCollection>( + serviceKey: "notNull", collectionName: "notNull", connectionString: null!)); + Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStoreCollection>( + serviceKey: "notNull", collectionName: "notNull", connectionString: "")); + } +} + +public class SqlServerDependencyInjectionTests_ConnectionStrings : SqlServerDependencyInjectionTests +{ + protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null) + { + // do nothing, as in this scenario config should not be used at all + } + + protected override void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null) + { + if (serviceKey is null) + { + services.AddSqlServerVectorStoreCollection>( + collectionName, + connectionString: ConnectionString, + lifetime: lifetime); + } + else + { + services.AddKeyedSqlServerVectorStoreCollection>( + serviceKey, + collectionName, + connectionString: ConnectionString, + lifetime: lifetime); + } + } + + public override void CanRegisterVectorStore(ServiceLifetime lifetime, object? serviceKey) + { + // do nothing, we don't provide a test for this scenario with raw connection string + } + + public override void CanRegisterConcreteTypeVectorStoreAfterSomeAbstractionHasBeenRegistered(ServiceLifetime lifetime, object? serviceKey) + { + // do nothing, we don't provide a test for this scenario with raw connection string + } } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs index 38d147fde9c0..6f6563fd731a 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Linq.Expressions; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -8,7 +9,9 @@ namespace VectorDataSpecificationTests.DependencyInjection; -public abstract class DependencyInjectionTests +public abstract class DependencyInjectionTests + where TVectorStore : class, IVectorStore + where TCollection : class, IVectorStoreRecordCollection where TKey : notnull where TRecord : notnull { @@ -52,13 +55,16 @@ public void CollectionNameCantBeNullOrEmpty() [Theory] [MemberData(nameof(LiftetimesAndKeys))] - public void CanRegisterVectorStore(ServiceLifetime lifetime, object? serviceKey) + public virtual void CanRegisterVectorStore(ServiceLifetime lifetime, object? serviceKey) { HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); this.RegisterVectorStore(builder.Services, lifetime, serviceKey); using IHost host = builder.Build(); + // let's ensure that concrete types are registered + Verify(host, lifetime, serviceKey); + // and the abstraction too Verify(host, lifetime, serviceKey); } @@ -71,9 +77,58 @@ public void CanRegisterCollections(ServiceLifetime lifetime, object? serviceKey) this.RegisterCollection(builder.Services, lifetime, serviceKey: serviceKey); using IHost host = builder.Build(); + // let's ensure that concrete types are registered + Verify(host, lifetime, serviceKey); + // and the abstraction too Verify>(host, lifetime, serviceKey); } + [Theory] + [MemberData(nameof(LiftetimesAndKeys))] + public virtual void CanRegisterConcreteTypeVectorStoreAfterSomeAbstractionHasBeenRegistered(ServiceLifetime lifetime, object? serviceKey) + { + HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); + + // Users may be willing to register more than one IVectorStore implementation. + if (serviceKey is null) + { + builder.Services.Add(new ServiceDescriptor(typeof(IVectorStore), sp => new FakeVectorStore(), lifetime)); + } + else + { + builder.Services.Add(new ServiceDescriptor(typeof(IVectorStore), serviceKey, (sp, key) => new FakeVectorStore(), lifetime)); + } + + this.RegisterVectorStore(builder.Services, lifetime, serviceKey); + + using IHost host = builder.Build(); + // let's ensure that concrete types are registered + Verify(host, lifetime, serviceKey); + } + + [Theory] + [MemberData(nameof(LiftetimesAndKeys))] + public void CanRegisterConcreteTypeCollectionsAfterSomeAbstractionHasBeenRegistered(ServiceLifetime lifetime, object? serviceKey) + { + HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); + + // Users may be willing to register more than one IVectorStoreRecordCollection implementation. + if (serviceKey is null) + { + builder.Services.Add(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), sp => new FakeVectorStoreRecordCollection(), lifetime)); + } + else + { + builder.Services.Add(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), serviceKey, (sp, key) => new FakeVectorStoreRecordCollection(), lifetime)); + } + + this.RegisterCollection(builder.Services, lifetime, serviceKey: serviceKey); + + using IHost host = builder.Build(); + // let's ensure that concrete types are registered + Verify(host, lifetime, serviceKey); + } + protected HostApplicationBuilder CreateHostBuilder(object? serviceKey = null) { HostApplicationBuilder builder = Host.CreateEmptyApplicationBuilder(settings: null); @@ -128,4 +183,47 @@ private static TService Resolve(IServiceProvider serviceProvider, obje => serviceKey is null ? serviceProvider.GetRequiredService() : serviceProvider.GetRequiredKeyedService(serviceKey); + + private sealed class FakeVectorStore : IVectorStore + { + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey2 : notnull + where TRecord2 : notnull + => throw new NotImplementedException(); + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + public object? GetService(Type serviceType, object? serviceKey = null) + => throw new NotImplementedException(); + } + + private sealed class FakeVectorStoreRecordCollection : IVectorStoreRecordCollection + { + public string CollectionName => throw new NotImplementedException(); + + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public IAsyncEnumerable GetAsync(Expression> filter, int top, GetFilteredRecordOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public object? GetService(Type serviceType, object? serviceKey = null) => throw new NotImplementedException(); + + public Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + } }