Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.VectorData;

namespace Microsoft.SemanticKernel.Connectors.SqlServer;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's common practice to put DI registration methods in a very general namespace so that they light up with an extra using, e.g. Microsoft.Extensions.VectorData (though I seem to remember that this was slightly controversial).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other connectors went one step further and have defined their extension methods in Microsoft.SemanticKernel:

@westey-m @roji is that desired?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe another approach is to have them in the namespace of the thing that is being registered on, in this case that of ServiceCollection, which does make sense to me too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


/// <summary>
/// Extension methods to register <see cref="SqlServerVectorStore"/> instances on an <see cref="IServiceCollection"/>.
/// </summary>
public static class SqlServerServiceCollectionExtensions
{
/// <summary>
/// Registers a <see cref="SqlServerVectorStore"/> as <see cref="IVectorStore"/>, with the specified connection string and service lifetime.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStore"/> on.</param>
/// <param name="connectionStringProvider">The connection string provider.</param>
/// <param name="optionsProvider">Options provider to further configure the vector store.</param>
/// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddSqlServerVectorStore(this IServiceCollection services,
Func<IServiceProvider, string> connectionStringProvider,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but... I think the basic expectation any user would have is to be able to pass the connection string directly (without a function). I get the logic of accepting a Func (to construct a connection string based on config that's in DI, e.g. tenant ID as database name), but it's definitely a bit jarring/unexpected for the minimal demo to look like this: AddSqlServerVectorStore(_ => "blablabla").

FWIW IIRC EF only accepts a string directly; AddNpgsqlDataSource is the same (though there's an issue for adding an overload that accepts an IServiceProvider Func. I'm not sure what they did in Aspire - might be worth checking.

The easy answer here is obviously to just have two overloads - one with a string, one with a Func. The only concern here is that we risk going into DI overload explosion: we have:

  1. String vs. Func. For PG there's also NpgsqlDataSource (and this is also planned for SQL Server, Sqlite...).
  2. Regular vs. Keyed
  3. IVectorStore vs. IVectorStoreRecordCollection
  4. We may end up needing POCO vs. dynamic mapping - the former will be not safe for trimming. After Build we'll also introduce another POCO one which accepts source-generated stuff, for trimming-safe POCO mapping (so 3 overall)

So I'm not sure where this will go - let's maybe discuss this in parking lot soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roji, is you concern around having many overloads born from complexity of use, or from maintaining so many overloads?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roji @westey-m I've moved the conversation to the issue: #10549 (comment) PTAL

Func<IServiceProvider, SqlServerVectorStoreOptions>? optionsProvider = null,
ServiceLifetime lifetime = ServiceLifetime.Singleton)
{
Verify.NotNull(services);
Verify.NotNull(connectionStringProvider);

services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), sp =>
{
var connectionString = connectionStringProvider(sp);
var options = optionsProvider?.Invoke(sp);
return new SqlServerVectorStore(connectionString, options);
}, lifetime));

// We try to add the SqlServerVectorStore as an IVectorStore,
// but if it already exists, we don't override it.
// Sample scenario: one app using two different vector stores.
services.TryAdd(new ServiceDescriptor(typeof(IVectorStore),
static sp => sp.GetRequiredService<SqlServerVectorStore>(), lifetime));

return services;
}

/// <summary>
/// Registers a keyed <see cref="SqlServerVectorStore"/> as <see cref="IVectorStore"/>, with the specified connection string and service lifetime.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStore"/> on.</param>
/// <param name="serviceKey">The key with which to associate the vector store.</param>
/// <param name="connectionStringProvider">The connection string provider.</param>
/// <param name="optionsProvider">Options provider to further configure the vector store.</param>
/// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddKeyedSqlServerVectorStore(this IServiceCollection services,
object serviceKey,
Func<IServiceProvider, string> connectionStringProvider,
Func<IServiceProvider, SqlServerVectorStoreOptions>? optionsProvider = null,
ServiceLifetime lifetime = ServiceLifetime.Singleton)
{
Verify.NotNull(services);
Verify.NotNull(serviceKey);
Verify.NotNull(connectionStringProvider);

services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), serviceKey, (sp, _) =>
{
var connectionString = connectionStringProvider(sp);
var options = optionsProvider?.Invoke(sp);
return new SqlServerVectorStore(connectionString, options);
}, lifetime));

services.TryAdd(new ServiceDescriptor(typeof(IVectorStore), serviceKey,
static (sp, key) => sp.GetRequiredKeyedService<SqlServerVectorStore>(key), lifetime));

return services;
}

/// <summary>
/// Registers a <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> as <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>, with the specified connection string and service lifetime.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> on.</param>
/// <param name="collectionName">The name of the collection.</param>
/// <param name="connectionStringProvider">The connection string provider.</param>
/// <param name="optionsProvider">Options provider to further configure the collection.</param>
/// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddSqlServerVectorStoreCollection<TKey, TRecord>(this IServiceCollection services,
string collectionName,
Func<IServiceProvider, string> connectionStringProvider,
Func<IServiceProvider, SqlServerVectorStoreRecordCollectionOptions<TRecord>>? optionsProvider = null,
ServiceLifetime lifetime = ServiceLifetime.Singleton)
where TKey : notnull
where TRecord : notnull
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(collectionName);
Verify.NotNull(connectionStringProvider);

services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection<TKey, TRecord>), sp =>
{
var connectionString = connectionStringProvider(sp);
var options = optionsProvider?.Invoke(sp);
return new SqlServerVectorStoreRecordCollection<TKey, TRecord>(connectionString, collectionName, options);
}, lifetime));

services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection<TKey, TRecord>),
static sp => sp.GetRequiredService<SqlServerVectorStoreRecordCollection<TKey, TRecord>>(), lifetime));

return services;
}

/// <summary>
/// Registers a keyed <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> as <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>, with the specified connection string and service lifetime.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> on.</param>
/// <param name="serviceKey">The key with which to associate the collection.</param>
/// <param name="collectionName">The name of the collection.</param>
/// <param name="connectionStringProvider">The connection string provider.</param>
/// <param name="optionsProvider">Options provider to further configure the collection.</param>
/// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddKeyedSqlServerVectorStoreCollection<TKey, TRecord>(this IServiceCollection services,
object serviceKey,
string collectionName,
Func<IServiceProvider, string> connectionStringProvider,
Func<IServiceProvider, SqlServerVectorStoreRecordCollectionOptions<TRecord>>? optionsProvider = null,
ServiceLifetime lifetime = ServiceLifetime.Singleton)
where TKey : notnull
where TRecord : notnull
{
Verify.NotNull(services);
Verify.NotNull(serviceKey);
Verify.NotNullOrWhiteSpace(collectionName);
Verify.NotNull(connectionStringProvider);

services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection<TKey, TRecord>), serviceKey, (sp, _) =>
{
var connectionString = connectionStringProvider(sp);
var options = optionsProvider?.Invoke(sp);
return new SqlServerVectorStoreRecordCollection<TKey, TRecord>(connectionString, collectionName, options);
}, lifetime));

services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection<TKey, TRecord>), serviceKey,
static (sp, key) => sp.GetRequiredKeyedService<SqlServerVectorStoreRecordCollection<TKey, TRecord>>(key), lifetime));

return services;
}

/// <summary>
/// Registers a <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> as <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>, with the specified connection string and service lifetime.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> on.</param>
/// <param name="collectionName">The name of the collection.</param>
/// <param name="connectionString">The connection string.</param>
/// <param name="options">Options to further configure the collection.</param>
/// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddSqlServerVectorStoreCollection<TKey, TRecord>(this IServiceCollection services,
string collectionName,
string connectionString,
SqlServerVectorStoreRecordCollectionOptions<TRecord>? options = null,
ServiceLifetime lifetime = ServiceLifetime.Singleton)
where TKey : notnull
where TRecord : notnull
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(collectionName);
Verify.NotNullOrWhiteSpace(connectionString);

services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection<TKey, TRecord>),
sp => new SqlServerVectorStoreRecordCollection<TKey, TRecord>(connectionString, collectionName, options), lifetime));

services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection<TKey, TRecord>),
static sp => sp.GetRequiredService<SqlServerVectorStoreRecordCollection<TKey, TRecord>>(), lifetime));

return services;
}

/// <summary>
/// Registers a keyed <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> as <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>, with the specified connection string and service lifetime.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to register the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> on.</param>
/// <param name="serviceKey">The key with which to associate the collection.</param>
/// <param name="collectionName">The name of the collection.</param>
/// <param name="connectionString">The connection string.</param>
/// <param name="options">Options to further configure the collection.</param>
/// <param name="lifetime">The service lifetime for the store. Defaults to <see cref="ServiceLifetime.Singleton"/>.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddKeyedSqlServerVectorStoreCollection<TKey, TRecord>(this IServiceCollection services,
object serviceKey,
string collectionName,
string connectionString,
SqlServerVectorStoreRecordCollectionOptions<TRecord>? options = null,
ServiceLifetime lifetime = ServiceLifetime.Singleton)
where TKey : notnull
where TRecord : notnull
{
Verify.NotNull(services);
Verify.NotNull(serviceKey);
Verify.NotNullOrWhiteSpace(collectionName);
Verify.NotNullOrWhiteSpace(connectionString);

services.Add(new ServiceDescriptor(typeof(SqlServerVectorStoreRecordCollection<TKey, TRecord>), serviceKey,
(sp, _) => new SqlServerVectorStoreRecordCollection<TKey, TRecord>(connectionString, collectionName, options), lifetime));

services.TryAdd(new ServiceDescriptor(typeof(IVectorStoreRecordCollection<TKey, TRecord>), serviceKey,
static (sp, key) => sp.GetRequiredKeyedService<SqlServerVectorStoreRecordCollection<TKey, TRecord>>(key), lifetime));

return services;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.SemanticKernel.Connectors.SqlServer;
using VectorDataSpecificationTests.DependencyInjection;
using VectorDataSpecificationTests.Models;
using Xunit;

namespace SqlServerIntegrationTests.DependencyInjection;

public class SqlServerDependencyInjectionTests
: DependencyInjectionTests<SqlServerVectorStore, SqlServerVectorStoreRecordCollection<string, SimpleRecord<string>>, string, SimpleRecord<string>>
{
protected const string ConnectionString = "Server=localhost;Database=master;Integrated Security=True;";

protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null)
=> configuration.AddInMemoryCollection(
[
new(CreateConfigKey("SqlServer", serviceKey, "ConnectionString"), ConnectionString),
]);

protected override void RegisterVectorStore(IServiceCollection services, ServiceLifetime lifetime, object? serviceKey = null)
{
if (serviceKey is null)
{
services.AddSqlServerVectorStore(
sp => sp.GetRequiredService<IConfiguration>().GetRequiredSection("SqlServer:ConnectionString").Value!,
lifetime: lifetime);
}
else
{
services.AddKeyedSqlServerVectorStore(
serviceKey,
sp => sp.GetRequiredService<IConfiguration>().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!,
lifetime: lifetime);
}
}

protected override void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null)
{
if (serviceKey is null)
{
services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>(
collectionName,
sp => sp.GetRequiredService<IConfiguration>().GetRequiredSection("SqlServer:ConnectionString").Value!,
lifetime: lifetime);
}
else
{
services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>(
serviceKey,
collectionName,
sp => sp.GetRequiredService<IConfiguration>().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!,
lifetime: lifetime);
}
}

[Fact]
public void ConnectionStringProviderCantBeNull()
{
HostApplicationBuilder builder = this.CreateHostBuilder();

Assert.Throws<ArgumentNullException>(() => builder.Services.AddSqlServerVectorStore(connectionStringProvider: null!));
Assert.Throws<ArgumentNullException>(() => builder.Services.AddKeyedSqlServerVectorStore(serviceKey: "notNull", connectionStringProvider: null!));
Assert.Throws<ArgumentNullException>(() => builder.Services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>(collectionName: "notNull", connectionStringProvider: null!));
Assert.Throws<ArgumentNullException>(() => builder.Services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>(serviceKey: "notNull", collectionName: "notNull", connectionStringProvider: null!));
}

[Fact]
public void ConnectionStringCantBeNullOrEmpty()
{
HostApplicationBuilder builder = this.CreateHostBuilder();

Assert.Throws<ArgumentNullException>(() => builder.Services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>(
collectionName: "notNull", connectionString: null!));
Assert.Throws<ArgumentException>(() => builder.Services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>(
collectionName: "notNull", connectionString: ""));
Assert.Throws<ArgumentNullException>(() => builder.Services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>(
serviceKey: "notNull", collectionName: "notNull", connectionString: null!));
Assert.Throws<ArgumentException>(() => builder.Services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>(
serviceKey: "notNull", collectionName: "notNull", connectionString: ""));
}
}

public class SqlServerDependencyInjectionTests_ConnectionStrings : SqlServerDependencyInjectionTests
{
protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null)
{
// do nothing, as in this scenario config should not be used at all
}

protected override void RegisterCollection(IServiceCollection services, ServiceLifetime lifetime, string collectionName = "name", object? serviceKey = null)
{
if (serviceKey is null)
{
services.AddSqlServerVectorStoreCollection<string, SimpleRecord<string>>(
collectionName,
connectionString: ConnectionString,
lifetime: lifetime);
}
else
{
services.AddKeyedSqlServerVectorStoreCollection<string, SimpleRecord<string>>(
serviceKey,
collectionName,
connectionString: ConnectionString,
lifetime: lifetime);
}
}

public override void CanRegisterVectorStore(ServiceLifetime lifetime, object? serviceKey)
{
// do nothing, we don't provide a test for this scenario with raw connection string
}

public override void CanRegisterConcreteTypeVectorStoreAfterSomeAbstractionHasBeenRegistered(ServiceLifetime lifetime, object? serviceKey)
{
// do nothing, we don't provide a test for this scenario with raw connection string
}
}
Loading
Loading