diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs new file mode 100644 index 000000000000..c8f69045082e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerServiceCollectionExtensions.cs @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.SqlServer; + +/// +/// Extension methods to register instances on an . +/// +public static class SqlServerServiceCollectionExtensions +{ + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The connection string provider. + /// Options provider to further configure the vector store. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddSqlServerVectorStore(this IServiceCollection services, + Func connectionStringProvider, + Func? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + Verify.NotNull(services); + Verify.NotNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), sp => + { + var connectionString = connectionStringProvider(sp); + var options = optionsProvider?.Invoke(sp); + return new SqlServerVectorStore(connectionString, options); + }, lifetime)); + + // We try to add the SqlServerVectorStore as an IVectorStore, + // but if it already exists, we don't override it. + // Sample scenario: one app using two different vector stores. + services.TryAdd(new ServiceDescriptor(typeof(IVectorStore), + static sp => sp.GetRequiredService(), lifetime)); + + return services; + } + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the vector store. + /// The connection string provider. + /// Options provider to further configure the vector store. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddKeyedSqlServerVectorStore(this IServiceCollection services, + object serviceKey, + Func connectionStringProvider, + Func? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + Verify.NotNull(services); + Verify.NotNull(serviceKey); + Verify.NotNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), serviceKey, (sp, _) => + { + var connectionString = connectionStringProvider(sp); + var options = optionsProvider?.Invoke(sp); + return new SqlServerVectorStore(connectionString, options); + }, lifetime)); + + services.TryAdd(new ServiceDescriptor(typeof(IVectorStore), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService(key), lifetime)); + + return services; + } + + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The name of the collection. + /// The connection string provider. + /// Options provider to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddSqlServerVectorStoreCollection(this IServiceCollection services, + string collectionName, + Func connectionStringProvider, + Func>? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : notnull + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection), sp => + { + var connectionString = connectionStringProvider(sp); + var options = optionsProvider?.Invoke(sp); + return new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options); + }, lifetime)); + + services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), + static sp => sp.GetRequiredService>(), lifetime)); + + return services; + } + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the collection. + /// The name of the collection. + /// The connection string provider. + /// Options provider to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddKeyedSqlServerVectorStoreCollection(this IServiceCollection services, + object serviceKey, + string collectionName, + Func connectionStringProvider, + Func>? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : notnull + { + Verify.NotNull(services); + Verify.NotNull(serviceKey); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection), serviceKey, (sp, _) => + { + var connectionString = connectionStringProvider(sp); + var options = optionsProvider?.Invoke(sp); + return new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options); + }, lifetime)); + + services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); + + return services; + } + + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The name of the collection. + /// The connection string. + /// Options to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddSqlServerVectorStoreCollection(this IServiceCollection services, + string collectionName, + string connectionString, + SqlServerVectorStoreRecordCollectionOptions? options = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : notnull + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNullOrWhiteSpace(connectionString); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection), + sp => new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options), lifetime)); + + services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), + static sp => sp.GetRequiredService>(), lifetime)); + + return services; + } + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the collection. + /// The name of the collection. + /// The connection string. + /// Options to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddKeyedSqlServerVectorStoreCollection(this IServiceCollection services, + object serviceKey, + string collectionName, + string connectionString, + SqlServerVectorStoreRecordCollectionOptions? options = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : notnull + { + Verify.NotNull(services); + Verify.NotNull(serviceKey); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNullOrWhiteSpace(connectionString); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection), serviceKey, + (sp, _) => new SqlServerVectorStoreRecordCollection(connectionString, collectionName, options), lifetime)); + + services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); + + return services; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs new file mode 100644 index 000000000000..399a5d134634 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/DependencyInjection/SqlServerDependencyInjectionTests.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.SemanticKernel.Connectors.SqlServer; +using VectorDataSpecificationTests.DependencyInjection; +using VectorDataSpecificationTests.Models; +using Xunit; + +namespace SqlServerIntegrationTests.DependencyInjection; + +public class SqlServerDependencyInjectionTests + : DependencyInjectionTests>, string, SimpleRecord> +{ + protected const string ConnectionString = "Server=localhost;Database=master;Integrated Security=True;"; + + protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null) + => configuration.AddInMemoryCollection( + [ + new(CreateConfigKey("SqlServer", serviceKey, "ConnectionString"), ConnectionString), + ]); + + protected override void RegisterVectorStore(IServiceCollection services, ServiceLifetime lifetime, object? serviceKey = null) + { + if (serviceKey is null) + { + services.AddSqlServerVectorStore( + sp => sp.GetRequiredService().GetRequiredSection("SqlServer:ConnectionString").Value!, + lifetime: lifetime); + } + else + { + services.AddKeyedSqlServerVectorStore( + serviceKey, + sp => sp.GetRequiredService().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!, + lifetime: lifetime); + } + } + + protected override void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null) + { + if (serviceKey is null) + { + services.AddSqlServerVectorStoreCollection>( + collectionName, + sp => sp.GetRequiredService().GetRequiredSection("SqlServer:ConnectionString").Value!, + lifetime: lifetime); + } + else + { + services.AddKeyedSqlServerVectorStoreCollection>( + serviceKey, + collectionName, + sp => sp.GetRequiredService().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!, + lifetime: lifetime); + } + } + + [Fact] + public void ConnectionStringProviderCantBeNull() + { + HostApplicationBuilder builder = this.CreateHostBuilder(); + + Assert.Throws(() => builder.Services.AddSqlServerVectorStore(connectionStringProvider: null!)); + Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStore(serviceKey: "notNull", connectionStringProvider: null!)); + Assert.Throws(() => builder.Services.AddSqlServerVectorStoreCollection>(collectionName: "notNull", connectionStringProvider: null!)); + Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStoreCollection>(serviceKey: "notNull", collectionName: "notNull", connectionStringProvider: null!)); + } + + [Fact] + public void ConnectionStringCantBeNullOrEmpty() + { + HostApplicationBuilder builder = this.CreateHostBuilder(); + + Assert.Throws(() => builder.Services.AddSqlServerVectorStoreCollection>( + collectionName: "notNull", connectionString: null!)); + Assert.Throws(() => builder.Services.AddSqlServerVectorStoreCollection>( + collectionName: "notNull", connectionString: "")); + Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStoreCollection>( + serviceKey: "notNull", collectionName: "notNull", connectionString: null!)); + Assert.Throws(() => builder.Services.AddKeyedSqlServerVectorStoreCollection>( + serviceKey: "notNull", collectionName: "notNull", connectionString: "")); + } +} + +public class SqlServerDependencyInjectionTests_ConnectionStrings : SqlServerDependencyInjectionTests +{ + protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null) + { + // do nothing, as in this scenario config should not be used at all + } + + protected override void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null) + { + if (serviceKey is null) + { + services.AddSqlServerVectorStoreCollection>( + collectionName, + connectionString: ConnectionString, + lifetime: lifetime); + } + else + { + services.AddKeyedSqlServerVectorStoreCollection>( + serviceKey, + collectionName, + connectionString: ConnectionString, + lifetime: lifetime); + } + } + + public override void CanRegisterVectorStore(ServiceLifetime lifetime, object? serviceKey) + { + // do nothing, we don't provide a test for this scenario with raw connection string + } + + public override void CanRegisterConcreteTypeVectorStoreAfterSomeAbstractionHasBeenRegistered(ServiceLifetime lifetime, object? serviceKey) + { + // do nothing, we don't provide a test for this scenario with raw connection string + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs new file mode 100644 index 000000000000..6933fc0389c0 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/DependencyInjection/DependencyInjectionTests.cs @@ -0,0 +1,231 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq.Expressions; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.VectorData; +using Xunit; + +namespace VectorDataSpecificationTests.DependencyInjection; + +public abstract class DependencyInjectionTests + where TVectorStore : class, IVectorStore + where TCollection : class, IVectorStoreRecordCollection + where TKey : notnull + where TRecord : notnull +{ + protected abstract void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null); + + protected abstract void RegisterVectorStore(IServiceCollection services, ServiceLifetime lifetime, object? serviceKey = null); + + protected abstract void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null); + + [Fact] + public void ServiceCollectionCantBeNull() + { + Assert.Throws(() => this.RegisterVectorStore(null!, ServiceLifetime.Singleton, serviceKey: null)); + Assert.Throws(() => this.RegisterVectorStore(null!, ServiceLifetime.Singleton, serviceKey: "notNull")); + Assert.Throws(() => this.RegisterCollection(null!, ServiceLifetime.Singleton, serviceKey: null)); + Assert.Throws(() => this.RegisterCollection(null!, ServiceLifetime.Singleton, serviceKey: "notNull")); + } + + [Fact] + public void CollectionNameCantBeNullOrEmpty() + { + HostApplicationBuilder builder = this.CreateHostBuilder(); + + Assert.Throws(() => this.RegisterCollection(builder.Services, ServiceLifetime.Singleton, collectionName: null!, serviceKey: null)); + Assert.Throws(() => this.RegisterCollection(builder.Services, ServiceLifetime.Singleton, collectionName: null!, serviceKey: "notNull")); + Assert.Throws(() => this.RegisterCollection(builder.Services, ServiceLifetime.Singleton, collectionName: "", serviceKey: null)); + Assert.Throws(() => this.RegisterCollection(builder.Services, ServiceLifetime.Singleton, collectionName: "", serviceKey: "notNull")); + } + +#pragma warning disable CA1000 // Do not declare static members on generic types + public static IEnumerable LiftetimesAndKeys() +#pragma warning restore CA1000 // Do not declare static members on generic types + { + foreach (ServiceLifetime lifetime in new ServiceLifetime[] { ServiceLifetime.Scoped, ServiceLifetime.Singleton, ServiceLifetime.Transient }) + { + yield return new object?[] { lifetime, null }; + yield return new object?[] { lifetime, "key" }; + yield return new object?[] { lifetime, 8 }; + } + } + + [Theory] + [MemberData(nameof(LiftetimesAndKeys))] + public virtual void CanRegisterVectorStore(ServiceLifetime lifetime, object? serviceKey) + { + HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); + + this.RegisterVectorStore(builder.Services, lifetime, serviceKey); + + using IHost host = builder.Build(); + // let's ensure that concrete types are registered + Verify(host, lifetime, serviceKey); + // and the abstraction too + Verify(host, lifetime, serviceKey); + } + + [Theory] + [MemberData(nameof(LiftetimesAndKeys))] + public void CanRegisterCollections(ServiceLifetime lifetime, object? serviceKey) + { + HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); + + this.RegisterCollection(builder.Services, lifetime, serviceKey: serviceKey); + + using IHost host = builder.Build(); + // let's ensure that concrete types are registered + Verify(host, lifetime, serviceKey); + // and the abstraction too + Verify>(host, lifetime, serviceKey); + } + + [Theory] + [MemberData(nameof(LiftetimesAndKeys))] + public virtual void CanRegisterConcreteTypeVectorStoreAfterSomeAbstractionHasBeenRegistered(ServiceLifetime lifetime, object? serviceKey) + { + HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); + + // Users may be willing to register more than one IVectorStore implementation. + if (serviceKey is null) + { + builder.Services.Add(new ServiceDescriptor(typeof(IVectorStore), sp => new FakeVectorStore(), lifetime)); + } + else + { + builder.Services.Add(new ServiceDescriptor(typeof(IVectorStore), serviceKey, (sp, key) => new FakeVectorStore(), lifetime)); + } + + this.RegisterVectorStore(builder.Services, lifetime, serviceKey); + + using IHost host = builder.Build(); + // let's ensure that concrete types are registered + Verify(host, lifetime, serviceKey); + } + + [Theory] + [MemberData(nameof(LiftetimesAndKeys))] + public void CanRegisterConcreteTypeCollectionsAfterSomeAbstractionHasBeenRegistered(ServiceLifetime lifetime, object? serviceKey) + { + HostApplicationBuilder builder = this.CreateHostBuilder(serviceKey); + + // Users may be willing to register more than one IVectorStoreRecordCollection implementation. + if (serviceKey is null) + { + builder.Services.Add(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), sp => new FakeVectorStoreRecordCollection(), lifetime)); + } + else + { + builder.Services.Add(new ServiceDescriptor(typeof(IVectorStoreRecordCollection), serviceKey, (sp, key) => new FakeVectorStoreRecordCollection(), lifetime)); + } + + this.RegisterCollection(builder.Services, lifetime, serviceKey: serviceKey); + + using IHost host = builder.Build(); + // let's ensure that concrete types are registered + Verify(host, lifetime, serviceKey); + } + + protected HostApplicationBuilder CreateHostBuilder(object? serviceKey = null) + { + HostApplicationBuilder builder = Host.CreateEmptyApplicationBuilder(settings: null); + + this.PopulateConfiguration(builder.Configuration, serviceKey); + + return builder; + } + + private static void Verify(IHost host, ServiceLifetime lifetime, object? serviceKey) + where TService : class + { + TService? serviceFromFirstScope, serviceFromSecondScope, secondServiceFromSecondScope; + + using (IServiceScope scope1 = host.Services.CreateScope()) + { + serviceFromFirstScope = Resolve(scope1.ServiceProvider, serviceKey); + } + + using (IServiceScope scope2 = host.Services.CreateScope()) + { + serviceFromSecondScope = Resolve(scope2.ServiceProvider, serviceKey); + + secondServiceFromSecondScope = Resolve(scope2.ServiceProvider, serviceKey); + } + + Assert.NotNull(serviceFromFirstScope); + Assert.NotNull(serviceFromSecondScope); + Assert.NotNull(secondServiceFromSecondScope); + + switch (lifetime) + { + case ServiceLifetime.Singleton: + Assert.Same(serviceFromFirstScope, serviceFromSecondScope); + Assert.Same(serviceFromSecondScope, secondServiceFromSecondScope); + break; + case ServiceLifetime.Scoped: + Assert.NotSame(serviceFromFirstScope, serviceFromSecondScope); + Assert.Same(serviceFromSecondScope, secondServiceFromSecondScope); + break; + case ServiceLifetime.Transient: + Assert.NotSame(serviceFromFirstScope, serviceFromSecondScope); + Assert.NotSame(serviceFromSecondScope, secondServiceFromSecondScope); + break; + } + } + + protected static string CreateConfigKey(string prefix, object? serviceKey, string suffix) + => serviceKey is null ? $"{prefix}:{suffix}" : $"{prefix}:{serviceKey}:{suffix}"; + + private static TService Resolve(IServiceProvider serviceProvider, object? serviceKey = null) where TService : notnull + => serviceKey is null + ? serviceProvider.GetRequiredService() + : serviceProvider.GetRequiredKeyedService(serviceKey); + + private sealed class FakeVectorStore : IVectorStore + { + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey2 : notnull + where TRecord2 : notnull + => throw new NotImplementedException(); + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + public object? GetService(Type serviceType, object? serviceKey = null) + => throw new NotImplementedException(); + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + } + + private sealed class FakeVectorStoreRecordCollection : IVectorStoreRecordCollection + { + public string Name => throw new NotImplementedException(); + + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public IAsyncEnumerable GetAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public IAsyncEnumerable GetAsync(Expression> filter, int top, GetFilteredRecordOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public object? GetService(Type serviceType, object? serviceKey = null) => throw new NotImplementedException(); + + public Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public Task> UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + public IAsyncEnumerable> VectorizedSearchAsync(TVector vector, int top, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj index 5b14dc1e41c1..3d146ba7cfa1 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj @@ -12,6 +12,8 @@ + +