-
Notifications
You must be signed in to change notification settings - Fork 4.4k
.Net MEVD: Dependency Injection for SqlServer connector #11594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| // Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| using System; | ||
| using Microsoft.Extensions.DependencyInjection; | ||
| using Microsoft.Extensions.DependencyInjection.Extensions; | ||
| using Microsoft.Extensions.VectorData; | ||
|
|
||
| namespace Microsoft.SemanticKernel.Connectors.SqlServer; | ||
|
|
||
| /// <summary> | ||
| /// Extension methods to register <see cref="SqlServerVectorStore"/> instances on an <see cref="IServiceCollection"/>. | ||
| /// </summary> | ||
| public static class SqlServerServiceCollectionExtensions | ||
| { | ||
| /// <summary> | ||
| /// Registers a <see cref="SqlServerVectorStore"/> as <see cref="IVectorStore"/>, with the specified connection string and service lifetime. | ||
| /// </summary> | ||
| /// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStore"/> on.</param> | ||
| /// <param name="connectionStringProvider">The connection string provider.</param> | ||
| /// <param name="optionsProvider">Options provider to further configure the vector store.</param> | ||
| /// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param> | ||
| /// <returns>The service collection.</returns> | ||
| public static IServiceCollection AddSqlServerVectorStore(this IServiceCollection services, | ||
| Func<IServiceProvider, string> connectionStringProvider, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure, but... I think the basic expectation any user would have is to be able to pass the connection string directly (without a function). I get the logic of accepting a Func (to construct a connection string based on config that's in DI, e.g. tenant ID as database name), but it's definitely a bit jarring/unexpected for the minimal demo to look like this: FWIW IIRC EF only accepts a string directly; AddNpgsqlDataSource is the same (though there's an issue for adding an overload that accepts an IServiceProvider Func. I'm not sure what they did in Aspire - might be worth checking. The easy answer here is obviously to just have two overloads - one with a string, one with a Func. The only concern here is that we risk going into DI overload explosion: we have:
So I'm not sure where this will go - let's maybe discuss this in parking lot soon.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @roji, is you concern around having many overloads born from complexity of use, or from maintaining so many overloads?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @roji @westey-m I've moved the conversation to the issue: #10549 (comment) PTAL |
||
| Func<IServiceProvider, SqlServerVectorStoreOptions>? optionsProvider = null, | ||
| ServiceLifetime lifetime = ServiceLifetime.Singleton) | ||
| { | ||
| Verify.NotNull(services); | ||
| Verify.NotNull(connectionStringProvider); | ||
|
|
||
| services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), sp => | ||
| { | ||
| var connectionString = connectionStringProvider(sp); | ||
| var options = optionsProvider?.Invoke(sp); | ||
| return new SqlServerVectorStore(connectionString, options); | ||
| }, lifetime)); | ||
|
|
||
| // We try to add the SqlServerVectorStore as an IVectorStore, | ||
| // but if it already exists, we don't override it. | ||
| // Sample scenario: one app using two different vector stores. | ||
| services.TryAdd(new ServiceDescriptor(typeof(IVectorStore), | ||
| static sp => sp.GetRequiredService<SqlServerVectorStore>(), lifetime)); | ||
|
|
||
| return services; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Registers a keyed <see cref="SqlServerVectorStore"/> as <see cref="IVectorStore"/>, with the specified connection string and service lifetime. | ||
| /// </summary> | ||
| /// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStore"/> on.</param> | ||
| /// <param name="serviceKey">The key with which to associate the vector store.</param> | ||
| /// <param name="connectionStringProvider">The connection string provider.</param> | ||
| /// <param name="optionsProvider">Options provider to further configure the vector store.</param> | ||
| /// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param> | ||
| /// <returns>The service collection.</returns> | ||
| public static IServiceCollection AddKeyedSqlServerVectorStore(this IServiceCollection services, | ||
| object serviceKey, | ||
| Func<IServiceProvider, string> connectionStringProvider, | ||
| Func<IServiceProvider, SqlServerVectorStoreOptions>? optionsProvider = null, | ||
| ServiceLifetime lifetime = ServiceLifetime.Singleton) | ||
| { | ||
| Verify.NotNull(services); | ||
| Verify.NotNull(serviceKey); | ||
| Verify.NotNull(connectionStringProvider); | ||
|
|
||
| services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), serviceKey, (sp, _) => | ||
| { | ||
| var connectionString = connectionStringProvider(sp); | ||
| var options = optionsProvider?.Invoke(sp); | ||
| return new SqlServerVectorStore(connectionString, options); | ||
| }, lifetime)); | ||
|
|
||
| services.TryAdd(new ServiceDescriptor(typeof(IVectorStore), serviceKey, | ||
| static (sp, key) => sp.GetRequiredKeyedService<SqlServerVectorStore>(key), lifetime)); | ||
|
|
||
| return services; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Registers a <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> as <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>, with the specified connection string and service lifetime. | ||
| /// </summary> | ||
| /// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> on.</param> | ||
| /// <param name="collectionName">The name of the collection.</param> | ||
| /// <param name="connectionStringProvider">The connection string provider.</param> | ||
| /// <param name="optionsProvider">Options provider to further configure the collection.</param> | ||
| /// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param> | ||
| /// <returns>The service collection.</returns> | ||
| public static IServiceCollection AddSqlServerVectorStoreCollection<TKey, TRecord>(this IServiceCollection services, | ||
| string collectionName, | ||
| Func<IServiceProvider, string> connectionStringProvider, | ||
| Func<IServiceProvider, SqlServerVectorStoreRecordCollectionOptions<TRecord>>? optionsProvider = null, | ||
| ServiceLifetime lifetime = ServiceLifetime.Singleton) | ||
| where TKey : notnull | ||
| where TRecord : notnull | ||
| { | ||
| Verify.NotNull(services); | ||
| Verify.NotNullOrWhiteSpace(collectionName); | ||
| Verify.NotNull(connectionStringProvider); | ||
|
|
||
| services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection<TKey, TRecord>), sp => | ||
| { | ||
| var connectionString = connectionStringProvider(sp); | ||
| var options = optionsProvider?.Invoke(sp); | ||
| return new SqlServerVectorStoreRecordCollection<TKey, TRecord>(connectionString, collectionName, options); | ||
| }, lifetime)); | ||
|
|
||
| services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection<TKey, TRecord>), | ||
| static sp => sp.GetRequiredService<SqlServerVectorStoreRecordCollection<TKey, TRecord>>(), lifetime)); | ||
|
|
||
| return services; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Registers a keyed <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> as <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>, with the specified connection string and service lifetime. | ||
| /// </summary> | ||
| /// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> on.</param> | ||
| /// <param name="serviceKey">The key with which to associate the collection.</param> | ||
| /// <param name="collectionName">The name of the collection.</param> | ||
| /// <param name="connectionStringProvider">The connection string provider.</param> | ||
| /// <param name="optionsProvider">Options provider to further configure the collection.</param> | ||
| /// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param> | ||
| /// <returns>The service collection.</returns> | ||
| public static IServiceCollection AddKeyedSqlServerVectorStoreCollection<TKey, TRecord>(this IServiceCollection services, | ||
| object serviceKey, | ||
| string collectionName, | ||
| Func<IServiceProvider, string> connectionStringProvider, | ||
| Func<IServiceProvider, SqlServerVectorStoreRecordCollectionOptions<TRecord>>? optionsProvider = null, | ||
| ServiceLifetime lifetime = ServiceLifetime.Singleton) | ||
| where TKey : notnull | ||
| where TRecord : notnull | ||
| { | ||
| Verify.NotNull(services); | ||
| Verify.NotNull(serviceKey); | ||
| Verify.NotNullOrWhiteSpace(collectionName); | ||
| Verify.NotNull(connectionStringProvider); | ||
|
|
||
| services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection<TKey, TRecord>), serviceKey, (sp, _) => | ||
| { | ||
| var connectionString = connectionStringProvider(sp); | ||
| var options = optionsProvider?.Invoke(sp); | ||
| return new SqlServerVectorStoreRecordCollection<TKey, TRecord>(connectionString, collectionName, options); | ||
| }, lifetime)); | ||
|
|
||
| services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection<TKey, TRecord>), serviceKey, | ||
| static (sp, key) => sp.GetRequiredKeyedService<SqlServerVectorStoreRecordCollection<TKey, TRecord>>(key), lifetime)); | ||
|
|
||
| return services; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Registers a <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> as <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>, with the specified connection string and service lifetime. | ||
| /// </summary> | ||
| /// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> on.</param> | ||
| /// <param name="collectionName">The name of the collection.</param> | ||
| /// <param name="connectionString">The connection string.</param> | ||
| /// <param name="options">Options to further configure the collection.</param> | ||
| /// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param> | ||
| /// <returns>The service collection.</returns> | ||
| public static IServiceCollection AddSqlServerVectorStoreCollection<TKey, TRecord>(this IServiceCollection services, | ||
| string collectionName, | ||
| string connectionString, | ||
| SqlServerVectorStoreRecordCollectionOptions<TRecord>? options = null, | ||
| ServiceLifetime lifetime = ServiceLifetime.Singleton) | ||
| where TKey : notnull | ||
| where TRecord : notnull | ||
| { | ||
| Verify.NotNull(services); | ||
| Verify.NotNullOrWhiteSpace(collectionName); | ||
| Verify.NotNullOrWhiteSpace(connectionString); | ||
|
|
||
| services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection<TKey, TRecord>), | ||
| sp => new SqlServerVectorStoreRecordCollection<TKey, TRecord>(connectionString, collectionName, options), lifetime)); | ||
|
|
||
| services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection<TKey, TRecord>), | ||
| static sp => sp.GetRequiredService<SqlServerVectorStoreRecordCollection<TKey, TRecord>>(), lifetime)); | ||
|
|
||
| return services; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Registers a keyed <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> as <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>, with the specified connection string and service lifetime. | ||
| /// </summary> | ||
| /// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> on.</param> | ||
| /// <param name="serviceKey">The key with which to associate the collection.</param> | ||
| /// <param name="collectionName">The name of the collection.</param> | ||
| /// <param name="connectionString">The connection string.</param> | ||
| /// <param name="options">Options to further configure the collection.</param> | ||
| /// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param> | ||
| /// <returns>The service collection.</returns> | ||
| public static IServiceCollection AddKeyedSqlServerVectorStoreCollection<TKey, TRecord>(this IServiceCollection services, | ||
| object serviceKey, | ||
| string collectionName, | ||
| string connectionString, | ||
| SqlServerVectorStoreRecordCollectionOptions<TRecord>? options = null, | ||
| ServiceLifetime lifetime = ServiceLifetime.Singleton) | ||
| where TKey : notnull | ||
| where TRecord : notnull | ||
| { | ||
| Verify.NotNull(services); | ||
| Verify.NotNull(serviceKey); | ||
| Verify.NotNullOrWhiteSpace(collectionName); | ||
| Verify.NotNullOrWhiteSpace(connectionString); | ||
|
|
||
| services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection<TKey, TRecord>), serviceKey, | ||
| (sp, _) => new SqlServerVectorStoreRecordCollection<TKey, TRecord>(connectionString, collectionName, options), lifetime)); | ||
|
|
||
| services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection<TKey, TRecord>), serviceKey, | ||
| static (sp, key) => sp.GetRequiredKeyedService<SqlServerVectorStoreRecordCollection<TKey, TRecord>>(key), lifetime)); | ||
|
|
||
| return services; | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| // Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| using Microsoft.Extensions.Configuration; | ||
| using Microsoft.Extensions.DependencyInjection; | ||
| using Microsoft.Extensions.Hosting; | ||
| using Microsoft.SemanticKernel.Connectors.SqlServer; | ||
| using VectorDataSpecificationTests.DependencyInjection; | ||
| using VectorDataSpecificationTests.Models; | ||
| using Xunit; | ||
|
|
||
| namespace SqlServerIntegrationTests.DependencyInjection; | ||
|
|
||
| public class SqlServerDependencyInjectionTests | ||
| : DependencyInjectionTests<SqlServerVectorStore, SqlServerVectorStoreRecordCollection<string, SimpleRecord<string>>, string, SimpleRecord<string>> | ||
| { | ||
| protected const string ConnectionString = "Server=localhost;Database=master;Integrated Security=True;"; | ||
|
|
||
| protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null) | ||
| => configuration.AddInMemoryCollection( | ||
| [ | ||
| new(CreateConfigKey("SqlServer", serviceKey, "ConnectionString"), ConnectionString), | ||
| ]); | ||
|
|
||
| protected override void RegisterVectorStore(IServiceCollection services, ServiceLifetime lifetime, object? serviceKey = null) | ||
| { | ||
| if (serviceKey is null) | ||
| { | ||
| services.AddSqlServerVectorStore( | ||
| sp => sp.GetRequiredService<IConfiguration>().GetRequiredSection("SqlServer:ConnectionString").Value!, | ||
| lifetime: lifetime); | ||
| } | ||
| else | ||
| { | ||
| services.AddKeyedSqlServerVectorStore( | ||
| serviceKey, | ||
| sp => sp.GetRequiredService<IConfiguration>().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!, | ||
| lifetime: lifetime); | ||
| } | ||
| } | ||
|
|
||
| protected override void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null) | ||
| { | ||
| if (serviceKey is null) | ||
| { | ||
| services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>( | ||
| collectionName, | ||
| sp => sp.GetRequiredService<IConfiguration>().GetRequiredSection("SqlServer:ConnectionString").Value!, | ||
| lifetime: lifetime); | ||
| } | ||
| else | ||
| { | ||
| services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>( | ||
| serviceKey, | ||
| collectionName, | ||
| sp => sp.GetRequiredService<IConfiguration>().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!, | ||
| lifetime: lifetime); | ||
| } | ||
| } | ||
|
|
||
| [Fact] | ||
| public void ConnectionStringProviderCantBeNull() | ||
| { | ||
| HostApplicationBuilder builder = this.CreateHostBuilder(); | ||
|
|
||
| Assert.Throws<ArgumentNullException>(() => builder.Services.AddSqlServerVectorStore(connectionStringProvider: null!)); | ||
| Assert.Throws<ArgumentNullException>(() => builder.Services.AddKeyedSqlServerVectorStore(serviceKey: "notNull", connectionStringProvider: null!)); | ||
| Assert.Throws<ArgumentNullException>(() => builder.Services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>(collectionName: "notNull", connectionStringProvider: null!)); | ||
| Assert.Throws<ArgumentNullException>(() => builder.Services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>(serviceKey: "notNull", collectionName: "notNull", connectionStringProvider: null!)); | ||
| } | ||
|
|
||
| [Fact] | ||
| public void ConnectionStringCantBeNullOrEmpty() | ||
| { | ||
| HostApplicationBuilder builder = this.CreateHostBuilder(); | ||
|
|
||
| Assert.Throws<ArgumentNullException>(() => builder.Services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>( | ||
| collectionName: "notNull", connectionString: null!)); | ||
| Assert.Throws<ArgumentException>(() => builder.Services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>( | ||
| collectionName: "notNull", connectionString: "")); | ||
| Assert.Throws<ArgumentNullException>(() => builder.Services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>( | ||
| serviceKey: "notNull", collectionName: "notNull", connectionString: null!)); | ||
| Assert.Throws<ArgumentException>(() => builder.Services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>( | ||
| serviceKey: "notNull", collectionName: "notNull", connectionString: "")); | ||
| } | ||
| } | ||
|
|
||
| public class SqlServerDependencyInjectionTests_ConnectionStrings : SqlServerDependencyInjectionTests | ||
| { | ||
| protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null) | ||
| { | ||
| // do nothing, as in this scenario config should not be used at all | ||
| } | ||
|
|
||
| protected override void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null) | ||
| { | ||
| if (serviceKey is null) | ||
| { | ||
| services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>( | ||
| collectionName, | ||
| connectionString: ConnectionString, | ||
| lifetime: lifetime); | ||
| } | ||
| else | ||
| { | ||
| services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>( | ||
| serviceKey, | ||
| collectionName, | ||
| connectionString: ConnectionString, | ||
| lifetime: lifetime); | ||
| } | ||
| } | ||
|
|
||
| public override void CanRegisterVectorStore(ServiceLifetime lifetime, object? serviceKey) | ||
| { | ||
| // do nothing, we don't provide a test for this scenario with raw connection string | ||
| } | ||
|
|
||
| public override void CanRegisterConcreteTypeVectorStoreAfterSomeAbstractionHasBeenRegistered(ServiceLifetime lifetime, object? serviceKey) | ||
| { | ||
| // do nothing, we don't provide a test for this scenario with raw connection string | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's common practice to put DI registration methods in a very general namespace so that they light up with an extra using, e.g. Microsoft.Extensions.VectorData (though I seem to remember that this was slightly controversial).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other connectors went one step further and have defined their extension methods in
Microsoft.SemanticKernel:semantic-kernel/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchServiceCollectionExtensions.cs
Line 15 in 9c9aec0
@westey-m @roji is that desired?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe another approach is to have them in the namespace of the thing that is being registered on, in this case that of ServiceCollection, which does make sense to me too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@westey-m good point. I can see that this is also what Aspire is doing:
https://github.com/dotnet/aspire/blob/20d79aedee6c99495e50de4d1cef5240362983be/src/Components/Aspire.Azure.Storage.Blobs/AspireBlobStorageExtensions.cs#L15
I am going to apply this suggestion.