From 43bc48e1da6fcadf24f4a70fdcc1fcc5ee4620e1 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Thu, 23 Jan 2025 13:51:01 +0100 Subject: [PATCH 1/8] Implement LINQ-based vector search filtering Closes #10156 Does most of #10194 --- .github/_typos.toml | 1 + dotnet/Directory.Packages.props | 8 +- dotnet/SK-dotnet.sln | 104 +++++ dotnet/SK-dotnet.sln.DotSettings | 1 + ...extEmbeddingVectorStoreRecordCollection.cs | 6 +- .../MappingVectorStoreRecordCollection.cs | 268 ++++++------ .../QdrantFactory.cs | 33 +- ...torStore_VectorSearch_MultiStore_Common.cs | 3 +- .../Memory/VectorStore_VectorSearch_Simple.cs | 3 +- .../Concepts/Search/VectorStore_TextSearch.cs | 2 +- .../Step2_Vector_Search.cs | 2 + .../Step4_NonStringKey_VectorStore.cs | 389 +++++++++--------- ...VectorStoreCollectionSearchMappingTests.cs | 12 +- ...ISearchVectorStoreRecordCollectionTests.cs | 2 + ...VectorStoreCollectionSearchMappingTests.cs | 2 + ...MongoDBVectorStoreRecordCollectionTests.cs | 3 +- ...LVectorStoreCollectionQueryBuilderTests.cs | 14 +- ...DBNoSQLVectorStoreRecordCollectionTests.cs | 2 +- ...nMemoryVectorStoreRecordCollectionTests.cs | 12 +- .../AzureAISearchFilterTranslator.cs | 349 ++++++++++++++++ ...earchVectorStoreCollectionSearchMapping.cs | 6 +- ...zureAISearchVectorStoreRecordCollection.cs | 47 ++- .../AzureCosmosDBMongoDBFilterTranslator.cs | 258 ++++++++++++ ...ngoDBVectorStoreCollectionSearchMapping.cs | 2 + ...mosDBMongoDBVectorStoreRecordCollection.cs | 22 +- .../AzureCosmosDBNoSQLConstants.cs | 2 +- .../AzureCosmosDBNoSQLFilter.cs | 15 - ...BNoSQLVectorStoreCollectionQueryBuilder.cs | 80 ++-- ...osmosDBNoSQLVectorStoreRecordCollection.cs | 6 +- .../AzureCosmosDBNoSqlFilterTranslator.cs | 284 +++++++++++++ ...emoryVectorStoreCollectionSearchMapping.cs | 14 +- .../InMemoryVectorStoreRecordCollection.cs | 20 +- .../MongoDBFilterTranslator.cs | 258 ++++++++++++ ...ngoDBVectorStoreCollectionSearchMapping.cs | 8 +- .../MongoDBVectorStoreRecordCollection.cs | 20 +- ...econeVectorStoreCollectionSearchMapping.cs | 2 + .../PineconeVectorStoreRecordCollection.cs | 7 +- ...PostgresVectorStoreCollectionSqlBuilder.cs | 11 +- .../IPostgresVectorStoreDbClient.cs | 15 +- .../PostgresFilterTranslator.cs | 332 +++++++++++++++ ...PostgresVectorStoreCollectionSqlBuilder.cs | 134 +++--- .../PostgresVectorStoreDbClient.cs | 14 +- .../PostgresVectorStoreRecordCollection.cs | 11 +- ...ostgresVectorStoreRecordPropertyMapping.cs | 2 +- .../QdrantFilterTranslator.cs | 382 +++++++++++++++++ ...drantVectorStoreCollectionSearchMapping.cs | 10 +- .../QdrantVectorStoreRecordCollection.cs | 15 +- .../RedisFilterTranslator.cs | 229 +++++++++++ ...RedisHashSetVectorStoreRecordCollection.cs | 4 +- .../RedisJsonVectorStoreRecordCollection.cs | 4 +- ...RedisVectorStoreCollectionSearchMapping.cs | 26 +- .../SqliteFilterTranslator.cs | 360 ++++++++++++++++ ...liteVectorStoreCollectionCommandBuilder.cs | 24 +- .../SqliteVectorStoreRecordCollection.cs | 38 +- .../WeaviateFilterTranslator.cs | 259 ++++++++++++ .../WeaviateVectorStoreRecordCollection.cs | 7 +- ...VectorStoreRecordCollectionQueryBuilder.cs | 31 +- ...VectorStoreCollectionSearchMappingTests.cs | 36 +- ...MongoDBVectorStoreRecordCollectionTests.cs | 3 +- ...resVectorStoreCollectionSqlBuilderTests.cs | 53 --- ...VectorStoreCollectionSearchMappingTests.cs | 8 +- .../QdrantVectorStoreRecordCollectionTests.cs | 2 + ...HashSetVectorStoreRecordCollectionTests.cs | 2 + ...disJsonVectorStoreRecordCollectionTests.cs | 2 + ...VectorStoreCollectionSearchMappingTests.cs | 24 +- ...rStoreRecordCollectionQueryBuilderTests.cs | 15 +- ...eaviateVectorStoreRecordCollectionTests.cs | 7 +- .../AnyTagEqualToFilterClause.cs | 3 + .../FilterClauses/EqualToFilterClause.cs | 3 + .../FilterClauses/FilterClause.cs | 3 + .../VectorSearch/IVectorizableTextSearch.cs | 2 +- .../VectorSearch/IVectorizedSearch.cs | 2 +- .../VectorSearch/VectorSearchFilter.cs | 1 + .../VectorSearch/VectorSearchOptions.cs | 11 +- ...ISearchVectorStoreRecordCollectionTests.cs | 4 +- ...MongoDBVectorStoreRecordCollectionTests.cs | 2 + ...DBNoSQLVectorStoreRecordCollectionTests.cs | 2 + ...MongoDBVectorStoreRecordCollectionTests.cs | 4 +- ...ineconeVectorStoreRecordCollectionTests.cs | 4 +- ...ostgresVectorStoreRecordCollectionTests.cs | 2 + .../QdrantVectorStoreRecordCollectionTests.cs | 4 +- ...HashSetVectorStoreRecordCollectionTests.cs | 10 +- ...disJsonVectorStoreRecordCollectionTests.cs | 10 +- .../SqliteVectorStoreRecordCollectionTests.cs | 2 + ...eaviateVectorStoreRecordCollectionTests.cs | 2 + .../Data/BaseVectorStoreTextSearchTests.cs | 2 +- .../src/Diagnostics/UnreachableException.cs | 50 +++ .../src/System/IndexRange.cs | 288 +++++++++++++ .../Plugins.Web/Bing/BingTextSearch.cs | 3 + .../Plugins.Web/Google/GoogleTextSearch.cs | 2 + .../Data/TextSearch/TextSearchFilter.cs | 2 + .../Search/MockVectorizableTextSearch.cs | 2 +- .../Data/TextSearch/VectorStoreTextSearch.cs | 4 +- .../VolatileVectorStoreRecordCollection.cs | 9 +- .../Data/VectorStoreTextSearchTestBase.cs | 2 +- ...olatileVectorStoreRecordCollectionTests.cs | 10 +- .../AzureAISearchIntegrationTests.csproj | 31 ++ .../Filter/AzureAISearchBasicFilterTests.cs | 13 + .../Filter/AzureAISearchFilterFixture.cs | 18 + .../Properties/AssemblyAttributes.cs | 3 + .../Support/AzureAISearchTestEnvironment.cs | 28 ++ .../Support/AzureAISearchTestStore.cs | 45 ++ .../AzureAISearchUrlRequiredAttribute.cs | 19 + .../CosmosMongoDBIntegrationTests.csproj | 29 ++ .../Filter/CosmosMongoFilterFixture.cs | 15 + .../Filter/CosmosMongoFiltersNotSupported.cs | 24 ++ .../Properties/AssemblyAttributes.cs | 3 + ...CosmosConnectionStringRequiredAttribute.cs | 20 + .../Support/CosmosMongoDBTestEnvironment.cs | 25 ++ .../Support/CosmosMongoDBTestStore.cs | 45 ++ .../CosmosNoSQLIntegrationTests.csproj | 29 ++ .../Filter/CosmosNoSQLBasicFilterTests.cs | 8 + .../Filter/CosmosNoSQLFilterFixture.cs | 12 + .../Properties/AssemblyAttributes.cs | 3 + ...CosmosConnectionStringRequiredAttribute.cs | 19 + .../Support/CosmosNoSQLTestEnvironment.cs | 25 ++ .../Support/CosmosNoSQLTestStore.cs | 60 +++ .../Directory.Build.props | 15 + .../Filter/InMemoryBasicFilterTests.cs | 8 + .../Filter/InMemoryFilterFixture.cs | 12 + .../InMemoryIntegrationTests.csproj | 26 ++ .../Support/InMemoryTestStore.cs | 27 ++ .../Filter/MongoDBBasicFilterTests.cs | 59 +++ .../Filter/MongoDBFilterFixture.cs | 12 + .../MongoDBIntegrationTests.csproj | 27 ++ .../Support/MongoDBTestStore.cs | 51 +++ .../Filter/PostgresBasicFilterTests.cs | 32 ++ .../Filter/PostgresFilterFixture.cs | 12 + .../PostgresIntegrationTests.csproj | 27 ++ .../Support/PostgresTestStore.cs | 68 +++ .../Filter/QdrantBasicFilterTests.cs | 8 + .../Filter/QdrantFilterFixture.cs | 15 + .../QdrantIntegrationTests.csproj | 27 ++ .../Support/QdrantTestStore.cs | 40 ++ .../Support/TestContainer/QdrantBuilder.cs | 56 +++ .../TestContainer/QdrantConfiguration.cs | 53 +++ .../Support/TestContainer/QdrantContainer.cs | 7 + .../Filter/RedisBasicFilterTests.cs | 51 +++ .../Filter/RedisFilterFixture.cs | 20 + .../RedisIntegrationTests.csproj | 27 ++ .../Support/RedisTestStore.cs | 42 ++ .../Filter/SqliteBasicFilterTests.cs | 45 ++ .../Filter/SqliteFilterFixture.cs | 22 + .../Properties/AssemblyAttributes.cs | 3 + .../SqliteIntegrationTests.csproj | 26 ++ .../Support/SqliteTestEnvironment.cs | 56 +++ .../Support/SqliteTestStore.cs | 47 +++ .../Support/SqliteVecRequiredAttribute.cs | 19 + .../Filter/BasicFilterTestsBase.cs | 283 +++++++++++++ .../Filter/FilterFixtureBase.cs | 184 +++++++++ .../Support/TestStore.cs | 52 +++ .../VectorDataIntegrationTests.csproj | 21 + .../Xunit/ConditionalFactAttribute.cs | 10 + .../Xunit/ConditionalFactDiscoverer.cs | 23 ++ .../Xunit/ConditionalFactTestCase.cs | 39 ++ .../Xunit/ConditionalTheoryAttribute.cs | 10 + .../Xunit/ITestCondition.cs | 10 + .../Xunit/XunitTestCaseExtensions.cs | 51 +++ .../Filter/WeaviateBasicFilterTests.cs | 62 +++ .../Filter/WeaviateFilterFixture.cs | 14 + .../Support/TestContainer/WeaviateBuilder.cs | 48 +++ .../TestContainer/WeaviateConfiguration.cs | 53 +++ .../TestContainer/WeaviateContainer.cs | 7 + .../Support/WeaviateTestStore.cs | 37 ++ .../WeaviateIntegrationTests.csproj | 27 ++ 165 files changed, 6403 insertions(+), 714 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs delete mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLFilter.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs create mode 100644 dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs create mode 100644 dotnet/src/InternalUtilities/src/System/IndexRange.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Properties/AssemblyAttributes.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFiltersNotSupported.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Properties/AssemblyAttributes.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Properties/AssemblyAttributes.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestEnvironment.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/Directory.Build.props create mode 100644 dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantBuilder.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantConfiguration.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantContainer.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Properties/AssemblyAttributes.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestEnvironment.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteVecRequiredAttribute.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactAttribute.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactDiscoverer.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactTestCase.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ITestCondition.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/XunitTestCaseExtensions.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateFilterFixture.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateConfiguration.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateContainer.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj diff --git a/.github/_typos.toml b/.github/_typos.toml index 457e6bca4c2c..d9a2dcb7a2e4 100644 --- a/.github/_typos.toml +++ b/.github/_typos.toml @@ -39,6 +39,7 @@ prompty = "prompty" # prompty is a format name. ist = "ist" # German language dall = "dall" # OpenAI model name pn = "pn" # Kiota parameter +nin = "nin" # MongoDB "not in" operator [default.extend-identifiers] ags = "ags" # Azure Graph Service diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index e93dc3df49a2..fcad75436cb8 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -115,11 +115,15 @@ - + - + + + + + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 0a711f84f5f3..6b4dae547138 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -117,6 +117,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Diagnostics", "Diagnostics" src\InternalUtilities\src\Diagnostics\RequiresUnreferencedCodeAttribute.cs = src\InternalUtilities\src\Diagnostics\RequiresUnreferencedCodeAttribute.cs src\InternalUtilities\src\Diagnostics\UnconditionalSuppressMessageAttribute.cs = src\InternalUtilities\src\Diagnostics\UnconditionalSuppressMessageAttribute.cs src\InternalUtilities\src\Diagnostics\Verify.cs = src\InternalUtilities\src\Diagnostics\Verify.cs + src\InternalUtilities\src\Diagnostics\UnreachableException.cs = src\InternalUtilities\src\Diagnostics\UnreachableException.cs EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Linq", "Linq", "{B00AD427-0047-4850-BEF9-BA8237EA9D8B}" @@ -140,6 +141,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "System", "System", "{3CDE10 src\InternalUtilities\src\System\InternalTypeConverter.cs = src\InternalUtilities\src\System\InternalTypeConverter.cs src\InternalUtilities\src\System\NonNullCollection.cs = src\InternalUtilities\src\System\NonNullCollection.cs src\InternalUtilities\src\System\TypeConverterFactory.cs = src\InternalUtilities\src\System\TypeConverterFactory.cs + src\InternalUtilities\src\System\IndexRange.cs = src\InternalUtilities\src\System\IndexRange.cs EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Type", "Type", "{E85EA4D0-BB7E-4DFD-882F-A76EB8C0B8FF}" @@ -439,6 +441,30 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "sk-chatgpt-azure-function", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "kernel-functions-generator", "samples\Demos\CreateChatGptPlugin\MathPlugin\kernel-functions-generator\kernel-functions-generator.csproj", "{78785CB1-66CF-4895-D7E5-A440DD84BE86}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "VectorDataIntegrationTests", "VectorDataIntegrationTests", "{4F381919-F1BE-47D8-8558-3187ED04A84F}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "QdrantIntegrationTests", "src\VectorDataIntegrationTests\QdrantIntegrationTests\QdrantIntegrationTests.csproj", "{27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "VectorDataIntegrationTests", "src\VectorDataIntegrationTests\VectorDataIntegrationTests\VectorDataIntegrationTests.csproj", "{B29A972F-A774-4140-AECF-6B577C476627}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RedisIntegrationTests", "src\VectorDataIntegrationTests\RedisIntegrationTests\RedisIntegrationTests.csproj", "{F7EA82A4-A626-4316-AA47-EAC3A0E85870}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "PostgresIntegrationTests", "src\VectorDataIntegrationTests\PostgresIntegrationTests\PostgresIntegrationTests.csproj", "{3148FF01-38C7-4BEB-8CEC-9323EC7C593B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "InMemoryIntegrationTests", "src\VectorDataIntegrationTests\InMemoryIntegrationTests\InMemoryIntegrationTests.csproj", "{F5126690-0FD1-4777-9EDF-B3F5B7B3730B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CosmosNoSQLIntegrationTests", "src\VectorDataIntegrationTests\CosmosNoSQLIntegrationTests\CosmosNoSQLIntegrationTests.csproj", "{E200425C-E501-430C-8A8B-BC0088BD94DB}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SqliteIntegrationTests", "src\VectorDataIntegrationTests\SqliteIntegrationTests\SqliteIntegrationTests.csproj", "{709B3933-5286-4139-8D83-8C7AA5746FAE}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "WeaviateIntegrationTests", "src\VectorDataIntegrationTests\WeaviateIntegrationTests\WeaviateIntegrationTests.csproj", "{E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MongoDBIntegrationTests", "src\VectorDataIntegrationTests\MongoDBIntegrationTests\MongoDBIntegrationTests.csproj", "{A0E65043-6B00-4836-850F-000A52238914}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CosmosMongoDBIntegrationTests", "src\VectorDataIntegrationTests\CosmosMongoDBIntegrationTests\CosmosMongoDBIntegrationTests.csproj", "{11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AzureAISearchIntegrationTests", "src\VectorDataIntegrationTests\AzureAISearchIntegrationTests\AzureAISearchIntegrationTests.csproj", "{06181F0F-A375-43AE-B45F-73CBCFC30C14}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -1172,6 +1198,72 @@ Global {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Publish|Any CPU.Build.0 = Debug|Any CPU {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Release|Any CPU.ActiveCfg = Release|Any CPU {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Release|Any CPU.Build.0 = Release|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Debug|Any CPU.Build.0 = Debug|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Publish|Any CPU.Build.0 = Debug|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Release|Any CPU.ActiveCfg = Release|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Release|Any CPU.Build.0 = Release|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Publish|Any CPU.Build.0 = Debug|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Release|Any CPU.Build.0 = Release|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Publish|Any CPU.Build.0 = Debug|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Release|Any CPU.Build.0 = Release|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Publish|Any CPU.Build.0 = Debug|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Release|Any CPU.Build.0 = Release|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Publish|Any CPU.Build.0 = Debug|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Release|Any CPU.Build.0 = Release|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Publish|Any CPU.Build.0 = Debug|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Release|Any CPU.Build.0 = Release|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Publish|Any CPU.Build.0 = Debug|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Release|Any CPU.Build.0 = Release|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Publish|Any CPU.Build.0 = Debug|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Release|Any CPU.Build.0 = Release|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Publish|Any CPU.Build.0 = Debug|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Release|Any CPU.Build.0 = Release|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Publish|Any CPU.Build.0 = Debug|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Release|Any CPU.Build.0 = Release|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Debug|Any CPU.Build.0 = Debug|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Publish|Any CPU.Build.0 = Debug|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Release|Any CPU.ActiveCfg = Release|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1333,6 +1425,18 @@ Global {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {2EB6E4C2-606D-B638-2E08-49EA2061C428} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} {78785CB1-66CF-4895-D7E5-A440DD84BE86} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} + {4F381919-F1BE-47D8-8558-3187ED04A84F} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {B29A972F-A774-4140-AECF-6B577C476627} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {F7EA82A4-A626-4316-AA47-EAC3A0E85870} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {E200425C-E501-430C-8A8B-BC0088BD94DB} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {709B3933-5286-4139-8D83-8C7AA5746FAE} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {A0E65043-6B00-4836-850F-000A52238914} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {06181F0F-A375-43AE-B45F-73CBCFC30C14} = {4F381919-F1BE-47D8-8558-3187ED04A84F} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/SK-dotnet.sln.DotSettings b/dotnet/SK-dotnet.sln.DotSettings index d8964e230315..f5eec1700bcd 100644 --- a/dotnet/SK-dotnet.sln.DotSettings +++ b/dotnet/SK-dotnet.sln.DotSettings @@ -217,6 +217,7 @@ public void It$SOMENAME$() True True True + True True True True diff --git a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs b/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs index 5c1c4b05c56f..000cb1ebba07 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs @@ -8,7 +8,7 @@ namespace Memory.VectorStoreEmbeddingGeneration; /// -/// Decorator for a that generates embeddings for records on upsert and when using . +/// Decorator for a that generates embeddings for records on upsert and when using . /// /// /// This class is part of the sample. @@ -120,13 +120,13 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record } /// - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { return this._decoratedVectorStoreRecordCollection.VectorizedSearchAsync(vector, options, cancellationToken); } /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { var embeddingValue = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); return await this.VectorizedSearchAsync(embeddingValue, options, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs index 076be09c9ca5..5d9dca826e28 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs @@ -1,134 +1,138 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Runtime.CompilerServices; -using Microsoft.Extensions.VectorData; - -namespace Memory.VectorStoreLangchainInterop; - -/// -/// Decorator class that allows conversion of keys and records between public and internal representations. -/// -/// -/// This class is useful if a vector store implementation exposes keys or records in a way that is not -/// suitable for the user of the vector store. E.g. let's say that the vector store supports Guid keys -/// but you want to work with string keys that contain Guids. This class allows you to map between the -/// public string Guids and the internal Guids. -/// -/// The type of the key that the user of this class will use. -/// The type of the key that the internal collection exposes. -/// The type of the record that the user of this class will use. -/// The type of the record that the internal collection exposes. -internal sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection - where TPublicKey : notnull - where TInternalKey : notnull -{ - private readonly IVectorStoreRecordCollection _collection; - private readonly Func _publicToInternalKeyMapper; - private readonly Func _internalToPublicKeyMapper; - private readonly Func _publicToInternalRecordMapper; - private readonly Func _internalToPublicRecordMapper; - - public MappingVectorStoreRecordCollection( - IVectorStoreRecordCollection collection, - Func publicToInternalKeyMapper, - Func internalToPublicKeyMapper, - Func publicToInternalRecordMapper, - Func internalToPublicRecordMapper) - { - this._collection = collection; - this._publicToInternalKeyMapper = publicToInternalKeyMapper; - this._internalToPublicKeyMapper = internalToPublicKeyMapper; - this._publicToInternalRecordMapper = publicToInternalRecordMapper; - this._internalToPublicRecordMapper = internalToPublicRecordMapper; - } - - /// - public string CollectionName => this._collection.CollectionName; - - /// - public Task CollectionExistsAsync(CancellationToken cancellationToken = default) - { - return this._collection.CollectionExistsAsync(cancellationToken); - } - - /// - public Task CreateCollectionAsync(CancellationToken cancellationToken = default) - { - return this._collection.CreateCollectionAsync(cancellationToken); - } - - /// - public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) - { - return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); - } - - /// - public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) - { - return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); - } - - /// - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) - { - return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); - } - - /// - public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) - { - return this._collection.DeleteCollectionAsync(cancellationToken); - } - - /// - public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); - if (internalRecord == null) - { - return default; - } - - return this._internalToPublicRecordMapper(internalRecord); - } - - /// - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); - return internalRecords.Select(this._internalToPublicRecordMapper); - } - - /// - public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) - { - var internalRecord = this._publicToInternalRecordMapper(record); - var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); - return this._internalToPublicKeyMapper(internalKey); - } - - /// - public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - var internalRecords = records.Select(this._publicToInternalRecordMapper); - var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); - await foreach (var internalKey in internalKeys.ConfigureAwait(false)) - { - yield return this._internalToPublicKeyMapper(internalKey); - } - } - - /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) - { - var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); - var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); - - return new VectorSearchResults(publicResultRecords) - { - TotalCount = searchResults.TotalCount, - Metadata = searchResults.Metadata, - }; - } -} +// TODO: Commented out as part of implementing LINQ-based filtering, since MappingVectorStoreRecordCollection is no longer easy/feasible. +// TODO: The user provides an expression tree accepting a TPublicRecord, but we require an expression tree accepting a TInternalRecord. +// TODO: This is something that the user must provide, and is quite advanced. + +// using System.Runtime.CompilerServices; +// using Microsoft.Extensions.VectorData; +// +// namespace Memory.VectorStoreLangchainInterop; +// +// /// +// /// Decorator class that allows conversion of keys and records between public and internal representations. +// /// +// /// +// /// This class is useful if a vector store implementation exposes keys or records in a way that is not +// /// suitable for the user of the vector store. E.g. let's say that the vector store supports Guid keys +// /// but you want to work with string keys that contain Guids. This class allows you to map between the +// /// public string Guids and the internal Guids. +// /// +// /// The type of the key that the user of this class will use. +// /// The type of the key that the internal collection exposes. +// /// The type of the record that the user of this class will use. +// /// The type of the record that the internal collection exposes. +// internal sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection +// where TPublicKey : notnull +// where TInternalKey : notnull +// { +// private readonly IVectorStoreRecordCollection _collection; +// private readonly Func _publicToInternalKeyMapper; +// private readonly Func _internalToPublicKeyMapper; +// private readonly Func _publicToInternalRecordMapper; +// private readonly Func _internalToPublicRecordMapper; +// +// public MappingVectorStoreRecordCollection( +// IVectorStoreRecordCollection collection, +// Func publicToInternalKeyMapper, +// Func internalToPublicKeyMapper, +// Func publicToInternalRecordMapper, +// Func internalToPublicRecordMapper) +// { +// this._collection = collection; +// this._publicToInternalKeyMapper = publicToInternalKeyMapper; +// this._internalToPublicKeyMapper = internalToPublicKeyMapper; +// this._publicToInternalRecordMapper = publicToInternalRecordMapper; +// this._internalToPublicRecordMapper = internalToPublicRecordMapper; +// } +// +// /// +// public string CollectionName => this._collection.CollectionName; +// +// /// +// public Task CollectionExistsAsync(CancellationToken cancellationToken = default) +// { +// return this._collection.CollectionExistsAsync(cancellationToken); +// } +// +// /// +// public Task CreateCollectionAsync(CancellationToken cancellationToken = default) +// { +// return this._collection.CreateCollectionAsync(cancellationToken); +// } +// +// /// +// public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) +// { +// return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); +// } +// +// /// +// public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) +// { +// return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); +// } +// +// /// +// public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) +// { +// return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); +// } +// +// /// +// public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) +// { +// return this._collection.DeleteCollectionAsync(cancellationToken); +// } +// +// /// +// public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) +// { +// var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); +// if (internalRecord == null) +// { +// return default; +// } +// +// return this._internalToPublicRecordMapper(internalRecord); +// } +// +// /// +// public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) +// { +// var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); +// return internalRecords.Select(this._internalToPublicRecordMapper); +// } +// +// /// +// public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) +// { +// var internalRecord = this._publicToInternalRecordMapper(record); +// var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); +// return this._internalToPublicKeyMapper(internalKey); +// } +// +// /// +// public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) +// { +// var internalRecords = records.Select(this._publicToInternalRecordMapper); +// var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); +// await foreach (var internalKey in internalKeys.ConfigureAwait(false)) +// { +// yield return this._internalToPublicKeyMapper(internalKey); +// } +// } +// +// /// +// public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) +// { +// var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); +// var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); +// +// return new VectorSearchResults(publicResultRecords) +// { +// TotalCount = searchResults.TotalCount, +// Metadata = searchResults.Metadata, +// }; +// } +// } diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs index a21a2245a1c4..d0f63727b471 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs @@ -76,22 +76,23 @@ public IVectorStoreRecordCollection CreateVectorStoreRecordCollec return (collection as IVectorStoreRecordCollection)!; } - // If the user asked for a string key, we can add a decorator which converts back and forth between string and guid. - // The string that the user provides will still need to contain a valid guid, since the Langchain created collection - // uses guid keys. - // Supporting string keys like this is useful since it means you can work with the collection in the same way as with - // collections from other vector stores that support string keys. - if (typeof(TKey) == typeof(string) && typeof(TRecord) == typeof(LangchainDocument)) - { - var stringKeyCollection = new MappingVectorStoreRecordCollection, LangchainDocument>( - collection, - p => Guid.Parse(p), - i => i.ToString("D"), - p => new LangchainDocument { Key = Guid.Parse(p.Key), Content = p.Content, Source = p.Source, Embedding = p.Embedding }, - i => new LangchainDocument { Key = i.Key.ToString("D"), Content = i.Content, Source = i.Source, Embedding = i.Embedding }); - - return (stringKeyCollection as IVectorStoreRecordCollection)!; - } + // TODO: See note on MappingVectorStoreRecordCollection + // // If the user asked for a string key, we can add a decorator which converts back and forth between string and guid. + // // The string that the user provides will still need to contain a valid guid, since the Langchain created collection + // // uses guid keys. + // // Supporting string keys like this is useful since it means you can work with the collection in the same way as with + // // collections from other vector stores that support string keys. + // if (typeof(TKey) == typeof(string) && typeof(TRecord) == typeof(LangchainDocument)) + // { + // var stringKeyCollection = new MappingVectorStoreRecordCollection, LangchainDocument>( + // collection, + // p => Guid.Parse(p), + // i => i.ToString("D"), + // p => new LangchainDocument { Key = Guid.Parse(p.Key), Content = p.Content, Source = p.Source, Embedding = p.Embedding }, + // i => new LangchainDocument { Key = i.Key.ToString("D"), Content = i.Content, Source = i.Source, Embedding = i.Embedding }); + // + // return (stringKeyCollection as IVectorStoreRecordCollection)!; + // } throw new NotSupportedException("This VectorStore is only usable with Guid keys and LangchainDocument record types or string keys and LangchainDocument record types"); } diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs index c5160ac8739c..ff492ca58304 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs @@ -70,8 +70,7 @@ public async Task IngestDataAndSearchAsync(string collectionName, Func.Category), "External Definitions"); - searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, Filter = filter }); + searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, NewFilter = g => g.Category == "External Definitions" }); resultRecords = await searchResult.Results.ToListAsync(); output.WriteLine("Search string: " + searchString); diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs index a7eceb4046a9..5119881c3bda 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs @@ -70,8 +70,7 @@ public async Task ExampleAsync() // Search the collection using a vector search with pre-filtering. searchString = "What is Retrieval Augmented Generation"; searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - var filter = new VectorSearchFilter().EqualTo(nameof(Glossary.Category), "External Definitions"); - searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, Filter = filter }); + searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, NewFilter = g => g.Category == "External Definitions" }); resultRecords = await searchResult.Results.ToListAsync(); Console.WriteLine("Search string: " + searchString); diff --git a/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs b/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs index df52982104b8..f6a3d4ab6356 100644 --- a/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs +++ b/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs @@ -144,7 +144,7 @@ internal static async Task> CreateCo private sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch { /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs b/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs index 19c7cee676e8..7cf1363e3351 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs @@ -53,6 +53,7 @@ internal static async Task> SearchVectorStoreAsync( return searchResultItems.First(); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Do a more complex vector search with pre-filtering. /// @@ -79,6 +80,7 @@ public async Task SearchAnInMemoryVectorStoreWithFilteringAsync() Console.WriteLine(searchResultItems.First().Record.Definition); Console.WriteLine(searchResultItems.First().Score); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete private async Task> GetVectorStoreCollectionWithDataAsync() { diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs b/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs index 7303ddc9801a..9ca726f1fb97 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs @@ -1,195 +1,198 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Runtime.CompilerServices; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Qdrant; -using Qdrant.Client; - -namespace GettingStartedWithVectorStores; - -/// -/// Example that shows that you can switch between different vector stores with the same code, in this case -/// with a vector store that doesn't use string keys. -/// This sample demonstrates one possible approach, however it is also possible to use generics -/// in the common code to achieve code reuse. -/// -public class Step4_NonStringKey_VectorStore(ITestOutputHelper output, VectorStoresFixture fixture) : BaseTest(output), IClassFixture -{ - /// - /// Here we are going to use the same code that we used in and - /// but now with an . - /// Qdrant uses Guid or ulong as the key type, but the common code works with a string key. The string keys of the records created - /// in contain numbers though, so it's possible for us to convert them to ulong. - /// In this example, we'll demonstrate how to do that. - /// - /// This example requires a Qdrant server up and running. To run a Qdrant server in a Docker container, use the following command: - /// docker run -d --name qdrant -p 6333:6333 -p 6334:6334 qdrant/qdrant:latest - /// - [Fact] - public async Task UseAQdrantVectorStoreAsync() - { - // Construct a Qdrant vector store collection. - var collection = new QdrantVectorStoreRecordCollection(new QdrantClient("localhost"), "skglossary"); - - // Wrap the collection using a decorator that allows us to expose a version that uses string keys, but internally - // we convert to and from ulong. - var stringKeyCollection = new MappingVectorStoreRecordCollection( - collection, - p => ulong.Parse(p), - i => i.ToString(), - p => new UlongGlossary { Key = ulong.Parse(p.Key), Category = p.Category, Term = p.Term, Definition = p.Definition, DefinitionEmbedding = p.DefinitionEmbedding }, - i => new Glossary { Key = i.Key.ToString("D"), Category = i.Category, Term = i.Term, Definition = i.Definition, DefinitionEmbedding = i.DefinitionEmbedding }); - - // Ingest data into the collection using the same code as we used in Step1 with the InMemory Vector Store. - await Step1_Ingest_Data.IngestDataIntoVectorStoreAsync(stringKeyCollection, fixture.TextEmbeddingGenerationService); - - // Search the vector store using the same code as we used in Step2 with the InMemory Vector Store. - var searchResultItem = await Step2_Vector_Search.SearchVectorStoreAsync( - stringKeyCollection, - "What is an Application Programming Interface?", - fixture.TextEmbeddingGenerationService); - - // Write the search result with its score to the console. - Console.WriteLine(searchResultItem.Record.Definition); - Console.WriteLine(searchResultItem.Score); - } - - /// - /// Data model that uses a ulong as the key type instead of a string. - /// - private sealed class UlongGlossary - { - [VectorStoreRecordKey] - public ulong Key { get; set; } - - [VectorStoreRecordData(IsFilterable = true)] - public string Category { get; set; } - - [VectorStoreRecordData] - public string Term { get; set; } - - [VectorStoreRecordData] - public string Definition { get; set; } - - [VectorStoreRecordVector(Dimensions: 1536)] - public ReadOnlyMemory DefinitionEmbedding { get; set; } - } - - /// - /// Simple decorator class that allows conversion of keys and records from one type to another. - /// - private sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection - where TPublicKey : notnull - where TInternalKey : notnull - { - private readonly IVectorStoreRecordCollection _collection; - private readonly Func _publicToInternalKeyMapper; - private readonly Func _internalToPublicKeyMapper; - private readonly Func _publicToInternalRecordMapper; - private readonly Func _internalToPublicRecordMapper; - - public MappingVectorStoreRecordCollection( - IVectorStoreRecordCollection collection, - Func publicToInternalKeyMapper, - Func internalToPublicKeyMapper, - Func publicToInternalRecordMapper, - Func internalToPublicRecordMapper) - { - this._collection = collection; - this._publicToInternalKeyMapper = publicToInternalKeyMapper; - this._internalToPublicKeyMapper = internalToPublicKeyMapper; - this._publicToInternalRecordMapper = publicToInternalRecordMapper; - this._internalToPublicRecordMapper = internalToPublicRecordMapper; - } - - /// - public string CollectionName => this._collection.CollectionName; - - /// - public Task CollectionExistsAsync(CancellationToken cancellationToken = default) - { - return this._collection.CollectionExistsAsync(cancellationToken); - } - - /// - public Task CreateCollectionAsync(CancellationToken cancellationToken = default) - { - return this._collection.CreateCollectionAsync(cancellationToken); - } - - /// - public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) - { - return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); - } - - /// - public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) - { - return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); - } - - /// - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) - { - return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); - } - - /// - public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) - { - return this._collection.DeleteCollectionAsync(cancellationToken); - } - - /// - public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); - if (internalRecord == null) - { - return default; - } - - return this._internalToPublicRecordMapper(internalRecord); - } - - /// - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); - return internalRecords.Select(this._internalToPublicRecordMapper); - } - - /// - public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) - { - var internalRecord = this._publicToInternalRecordMapper(record); - var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); - return this._internalToPublicKeyMapper(internalKey); - } - - /// - public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - var internalRecords = records.Select(this._publicToInternalRecordMapper); - var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); - await foreach (var internalKey in internalKeys.ConfigureAwait(false)) - { - yield return this._internalToPublicKeyMapper(internalKey); - } - } - - /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) - { - var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); - var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); - - return new VectorSearchResults(publicResultRecords) - { - TotalCount = searchResults.TotalCount, - Metadata = searchResults.Metadata, - }; - } - } -} +// TODO: See note in MappingVectorStoreRecordCollection + +// using System.Runtime.CompilerServices; +// using Microsoft.Extensions.VectorData; +// using Microsoft.SemanticKernel.Connectors.Qdrant; +// using Qdrant.Client; +// +// namespace GettingStartedWithVectorStores; +// +// +// /// +// /// Example that shows that you can switch between different vector stores with the same code, in this case +// /// with a vector store that doesn't use string keys. +// /// This sample demonstrates one possible approach, however it is also possible to use generics +// /// in the common code to achieve code reuse. +// /// +// public class Step4_NonStringKey_VectorStore(ITestOutputHelper output, VectorStoresFixture fixture) : BaseTest(output), IClassFixture +// { +// /// +// /// Here we are going to use the same code that we used in and +// /// but now with an . +// /// Qdrant uses Guid or ulong as the key type, but the common code works with a string key. The string keys of the records created +// /// in contain numbers though, so it's possible for us to convert them to ulong. +// /// In this example, we'll demonstrate how to do that. +// /// +// /// This example requires a Qdrant server up and running. To run a Qdrant server in a Docker container, use the following command: +// /// docker run -d --name qdrant -p 6333:6333 -p 6334:6334 qdrant/qdrant:latest +// /// +// [Fact] +// public async Task UseAQdrantVectorStoreAsync() +// { +// // Construct a Qdrant vector store collection. +// var collection = new QdrantVectorStoreRecordCollection(new QdrantClient("localhost"), "skglossary"); +// +// // Wrap the collection using a decorator that allows us to expose a version that uses string keys, but internally +// // we convert to and from ulong. +// var stringKeyCollection = new MappingVectorStoreRecordCollection( +// collection, +// p => ulong.Parse(p), +// i => i.ToString(), +// p => new UlongGlossary { Key = ulong.Parse(p.Key), Category = p.Category, Term = p.Term, Definition = p.Definition, DefinitionEmbedding = p.DefinitionEmbedding }, +// i => new Glossary { Key = i.Key.ToString("D"), Category = i.Category, Term = i.Term, Definition = i.Definition, DefinitionEmbedding = i.DefinitionEmbedding }); +// +// // Ingest data into the collection using the same code as we used in Step1 with the InMemory Vector Store. +// await Step1_Ingest_Data.IngestDataIntoVectorStoreAsync(stringKeyCollection, fixture.TextEmbeddingGenerationService); +// +// // Search the vector store using the same code as we used in Step2 with the InMemory Vector Store. +// var searchResultItem = await Step2_Vector_Search.SearchVectorStoreAsync( +// stringKeyCollection, +// "What is an Application Programming Interface?", +// fixture.TextEmbeddingGenerationService); +// +// // Write the search result with its score to the console. +// Console.WriteLine(searchResultItem.Record.Definition); +// Console.WriteLine(searchResultItem.Score); +// } +// +// /// +// /// Data model that uses a ulong as the key type instead of a string. +// /// +// private sealed class UlongGlossary +// { +// [VectorStoreRecordKey] +// public ulong Key { get; set; } +// +// [VectorStoreRecordData(IsFilterable = true)] +// public string Category { get; set; } +// +// [VectorStoreRecordData] +// public string Term { get; set; } +// +// [VectorStoreRecordData] +// public string Definition { get; set; } +// +// [VectorStoreRecordVector(Dimensions: 1536)] +// public ReadOnlyMemory DefinitionEmbedding { get; set; } +// } +// +// /// +// /// Simple decorator class that allows conversion of keys and records from one type to another. +// /// +// private sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection +// where TPublicKey : notnull +// where TInternalKey : notnull +// { +// private readonly IVectorStoreRecordCollection _collection; +// private readonly Func _publicToInternalKeyMapper; +// private readonly Func _internalToPublicKeyMapper; +// private readonly Func _publicToInternalRecordMapper; +// private readonly Func _internalToPublicRecordMapper; +// +// public MappingVectorStoreRecordCollection( +// IVectorStoreRecordCollection collection, +// Func publicToInternalKeyMapper, +// Func internalToPublicKeyMapper, +// Func publicToInternalRecordMapper, +// Func internalToPublicRecordMapper) +// { +// this._collection = collection; +// this._publicToInternalKeyMapper = publicToInternalKeyMapper; +// this._internalToPublicKeyMapper = internalToPublicKeyMapper; +// this._publicToInternalRecordMapper = publicToInternalRecordMapper; +// this._internalToPublicRecordMapper = internalToPublicRecordMapper; +// } +// +// /// +// public string CollectionName => this._collection.CollectionName; +// +// /// +// public Task CollectionExistsAsync(CancellationToken cancellationToken = default) +// { +// return this._collection.CollectionExistsAsync(cancellationToken); +// } +// +// /// +// public Task CreateCollectionAsync(CancellationToken cancellationToken = default) +// { +// return this._collection.CreateCollectionAsync(cancellationToken); +// } +// +// /// +// public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) +// { +// return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); +// } +// +// /// +// public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) +// { +// return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); +// } +// +// /// +// public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) +// { +// return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); +// } +// +// /// +// public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) +// { +// return this._collection.DeleteCollectionAsync(cancellationToken); +// } +// +// /// +// public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) +// { +// var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); +// if (internalRecord == null) +// { +// return default; +// } +// +// return this._internalToPublicRecordMapper(internalRecord); +// } +// +// /// +// public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) +// { +// var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); +// return internalRecords.Select(this._internalToPublicRecordMapper); +// } +// +// /// +// public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) +// { +// var internalRecord = this._publicToInternalRecordMapper(record); +// var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); +// return this._internalToPublicKeyMapper(internalKey); +// } +// +// /// +// public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) +// { +// var internalRecords = records.Select(this._publicToInternalRecordMapper); +// var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); +// await foreach (var internalKey in internalKeys.ConfigureAwait(false)) +// { +// yield return this._internalToPublicKeyMapper(internalKey); +// } +// } +// +// /// +// public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) +// { +// var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); +// var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); +// +// return new VectorSearchResults(publicResultRecords) +// { +// TotalCount = searchResults.TotalCount, +// Metadata = searchResults.Metadata, +// }; +// } +// } +// } diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionSearchMappingTests.cs index ae121f93bd0e..13216b9ec8be 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionSearchMappingTests.cs @@ -8,6 +8,8 @@ namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -21,7 +23,7 @@ public void BuildFilterStringBuildsCorrectEqualityStringForEachFilterType(string var filter = new VectorSearchFilter().EqualTo(fieldName, fieldValue!); // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary { { fieldName, "storage_" + fieldName } }); + var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { fieldName, "storage_" + fieldName } }); // Assert. Assert.Equal(expected, actual); @@ -34,7 +36,7 @@ public void BuildFilterStringBuildsCorrectTagContainsString() var filter = new VectorSearchFilter().AnyTagEqualTo("Tags", "mytag"); // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary { { "Tags", "storage_tags" } }); + var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { "Tags", "storage_tags" } }); // Assert. Assert.Equal("storage_tags/any(t: t eq 'mytag')", actual); @@ -47,7 +49,7 @@ public void BuildFilterStringCombinesFilterOptions() var filter = new VectorSearchFilter().EqualTo("intField", 5).AnyTagEqualTo("Tags", "mytag"); // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary { { "Tags", "storage_tags" }, { "intField", "storage_intField" } }); + var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { "Tags", "storage_tags" }, { "intField", "storage_intField" } }); // Assert. Assert.Equal("storage_intField eq 5 and storage_tags/any(t: t eq 'mytag')", actual); @@ -57,8 +59,8 @@ public void BuildFilterStringCombinesFilterOptions() public void BuildFilterStringThrowsForUnknownPropertyName() { // Act and assert. - Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(new VectorSearchFilter().EqualTo("unknown", "value"), new Dictionary())); - Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(new VectorSearchFilter().AnyTagEqualTo("unknown", "value"), new Dictionary())); + Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(new VectorSearchFilter().EqualTo("unknown", "value"), new Dictionary())); + Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(new VectorSearchFilter().AnyTagEqualTo("unknown", "value"), new Dictionary())); } public static IEnumerable DataTypeMappingOptions() diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs index 467207b29ace..eb240f91d9aa 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs @@ -20,6 +20,8 @@ namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs index 6e061892d2b9..9dee844e61d2 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs @@ -9,6 +9,8 @@ namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Unit tests for class. /// diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs index 99815a1cee63..ab2fa157b212 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs @@ -13,6 +13,7 @@ using MongoDB.Driver; using Moq; using Xunit; +using MEVD = Microsoft.Extensions.VectorData; namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; @@ -643,7 +644,7 @@ public async Task VectorizedSearchThrowsExceptionWithNonExistentVectorPropertyNa this._mockMongoDatabase.Object, "collection"); - var options = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; + var options = new MEVD.VectorSearchOptions { VectorPropertyName = "non-existent-property" }; // Act & Assert await Assert.ThrowsAsync(async () => await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), options)).Results.FirstOrDefaultAsync()); diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs index 094028e516ab..f1ab2fc75f16 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs @@ -9,6 +9,8 @@ namespace SemanticKernel.Connectors.AzureCosmosDBNoSQL.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Unit tests for class. /// @@ -35,7 +37,7 @@ public void BuildSearchQueryByDefaultReturnsValidQueryDefinition() .EqualTo("TestProperty2", "test-value-2") .AnyTagEqualTo("TestProperty3", "test-value-3"); - var searchOptions = new VectorSearchOptions { Filter = filter, Skip = 5, Top = 10 }; + var searchOptions = new VectorSearchOptions { Filter = filter, Skip = 5, Top = 10 }; // Act var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( @@ -84,7 +86,7 @@ public void BuildSearchQueryWithoutOffsetReturnsQueryDefinitionWithTopParameter( .EqualTo("TestProperty2", "test-value-2") .AnyTagEqualTo("TestProperty3", "test-value-3"); - var searchOptions = new VectorSearchOptions { Filter = filter, Top = 10 }; + var searchOptions = new VectorSearchOptions { Filter = filter, Top = 10 }; // Act var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( @@ -129,7 +131,7 @@ public void BuildSearchQueryWithInvalidFilterThrowsException() var filter = new VectorSearchFilter().EqualTo("non-existent-property", "test-value-2"); - var searchOptions = new VectorSearchOptions { Filter = filter, Skip = 5, Top = 10 }; + var searchOptions = new VectorSearchOptions { Filter = filter, Skip = 5, Top = 10 }; // Act & Assert Assert.Throws(() => @@ -150,7 +152,7 @@ public void BuildSearchQueryWithoutFilterDoesNotContainWhereClause() var vectorPropertyName = "test_property_1"; var fields = this._storagePropertyNames.Values.ToList(); - var searchOptions = new VectorSearchOptions { Skip = 5, Top = 10 }; + var searchOptions = new VectorSearchOptions { Skip = 5, Top = 10 }; // Act var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( @@ -211,4 +213,8 @@ public void BuildSelectQueryByDefaultReturnsValidQueryDefinition() Assert.Equal("@pk0", queryParameters[1].Name); Assert.Equal("partition_key", queryParameters[1].Value); } + +#pragma warning disable CA1812 // An internal class that is apparently never instantiated. If so, remove the code from the assembly. + private sealed class DummyType; +#pragma warning restore CA1812 } diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs index d8718eb2f2b5..24e4a2083f0b 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs @@ -612,7 +612,7 @@ public async Task VectorizedSearchWithNonExistentVectorPropertyNameThrowsExcepti this._mockDatabase.Object, "collection"); - var searchOptions = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; + var searchOptions = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; // Act & Assert await Assert.ThrowsAsync(async () => diff --git a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs index 1cf974a77c84..bbf5c9611e32 100644 --- a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs @@ -293,7 +293,7 @@ public async Task CanSearchWithVectorAsync(bool useDefinition, TKey testKe // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true }, + new() { IncludeVectors = true }, this._testCancellationToken); // Assert @@ -309,6 +309,7 @@ public async Task CanSearchWithVectorAsync(bool useDefinition, TKey testKe Assert.Equal(-1, actualResults[1].Score); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete [Theory] [InlineData(true, TestRecordKey1, TestRecordKey2, "Equality")] [InlineData(true, TestRecordIntKey1, TestRecordIntKey2, "Equality")] @@ -337,7 +338,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, TK var filter = filterType == "Equality" ? new VectorSearchFilter().EqualTo("Data", $"data {testKey2}") : new VectorSearchFilter().AnyTagEqualTo("Tags", $"tag {testKey2}"); var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true, Filter = filter, IncludeTotalCount = true }, + new() { IncludeVectors = true, Filter = filter, IncludeTotalCount = true }, this._testCancellationToken); // Assert @@ -349,6 +350,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, TK Assert.Equal($"data {testKey2}", actualResults[0].Record.Data); Assert.Equal(-1, actualResults[0].Score); } +#pragma warning restore CS0618 // Type or member is obsolete [Theory] [InlineData(DistanceFunction.CosineSimilarity, 1, -1)] @@ -389,7 +391,7 @@ public async Task CanSearchWithDifferentDistanceFunctionsAsync(string distanceFu // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true }, + new() { IncludeVectors = true }, this._testCancellationToken); // Assert @@ -430,7 +432,7 @@ public async Task CanSearchManyRecordsAsync(bool useDefinition) // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, + new() { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, this._testCancellationToken); // Assert @@ -506,7 +508,7 @@ public async Task ItCanSearchUsingTheGenericDataModelAsync(TKey testKey1, // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory([1, 1, 1, 1]), - new VectorSearchOptions { IncludeVectors = true, VectorPropertyName = "Vector" }, + new() { IncludeVectors = true, VectorPropertyName = "Vector" }, this._testCancellationToken); // Assert diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs new file mode 100644 index 000000000000..b87183cce8c1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs @@ -0,0 +1,349 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal class AzureAISearchFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly StringBuilder _filter = new(); + + internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this._filter.Clear(); + this.Translate(lambdaExpression.Body); + return this._filter.ToString(); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + this._filter.Append('('); + this.Translate(binary.Left); + + this._filter.Append(binary.NodeType switch + { + ExpressionType.Equal => " eq ", + ExpressionType.NotEqual => " ne ", + + ExpressionType.GreaterThan => " gt ", + ExpressionType.GreaterThanOrEqual => " ge ", + ExpressionType.LessThan => " lt ", + ExpressionType.LessThanOrEqual => " le ", + + ExpressionType.AndAlso => " and ", + ExpressionType.OrElse => " or ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._filter.Append(')'); + } + + private void TranslateConstant(ConstantExpression constant) + => this.GenerateLiteral(constant.Value); + + private void GenerateLiteral(object? value) + { + // TODO: Nullable + switch (value) + { + case byte b: + this._filter.Append(b); + return; + case short s: + this._filter.Append(s); + return; + case int i: + this._filter.Append(i); + return; + case long l: + this._filter.Append(l); + return; + + case string s: + this._filter.Append('\'').Append(s.Replace("'", "''")).Append('\''); // TODO: escaping + return; + case bool b: + this._filter.Append(b ? "true" : "false"); + return; + case Guid g: + this._filter.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime: + case DateTimeOffset: + throw new NotImplementedException(); + + case Array: + throw new NotImplementedException(); + + case null: + this._filter.Append("null"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetField(memberExpression, out var column): + this._filter.Append(column); // TODO: Escape + return; + + // Identify captured lambda variables, inline them as constants + case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): + this.GenerateLiteral(capturedValue); + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array field (r => r.Strings.Contains("foo")) + case var _ when this.TryGetField(source, out _): + this.Translate(source); + this._filter.Append("/any(t: t eq "); + this.Translate(item); + this._filter.Append(')'); + return; + + // Contains over inline enumerable + case NewArrayExpression newArray: + var elements = new object?[newArray.Expressions.Count]; + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + { + throw new NotSupportedException("Invalid element in array"); + } + + elements[i] = elementValue; + } + + ProcessInlineEnumerable(elements, item); + return; + + // Contains over captured enumerable (we inline) + case var _ when TryGetConstant(source, out var constantEnumerable) + && constantEnumerable is IEnumerable enumerable and not string: + ProcessInlineEnumerable(enumerable, item); + return; + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + + void ProcessInlineEnumerable(IEnumerable elements, Expression item) + { + if (item.Type != typeof(string)) + { + throw new NotSupportedException("Contains over non-string arrays is not supported"); + } + + // The default delimiter for search.in() is comma or space. + // If any element contains a comma or space, we switch to using pipe as the delimiter. + // If any contains a pipe, we throw (for now). + var delimiter = ", "; + if (elements.Cast().Any(s => s.Contains(' ') || s.Contains(','))) + { + if (elements.Cast().Any(s => s.Contains('|'))) + { + throw new NotSupportedException(""); + } + + delimiter = "|"; + } + + this._filter.Append("search.in("); + this.Translate(item); + this._filter.Append(", '"); + + var isFirst = true; + foreach (var element in elements.Cast()) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._filter.Append(delimiter); + } + + this._filter.Append(element.Replace("'", "''")); + } + + this._filter.Append('\''); + + if (delimiter != ", ") + { + this._filter + .Append(", '") + .Append(delimiter) + .Append('\''); + } + + this._filter.Append(')'); + } + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._filter.Append("(not "); + this.Translate(unary.Operand); + this._filter.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetField(Expression expression, [NotNullWhen(true)] out string? field) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out field)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + field = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + capturedValue = fieldInfo.GetValue(constant.Value); + return true; + } + + capturedValue = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + case var _ when TryGetCapturedValue(expression, out var capturedValue): + constantValue = capturedValue; + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs index ced35f244c5e..732b6aeae42c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs @@ -12,6 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; /// internal static class AzureAISearchVectorStoreCollectionSearchMapping { +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Build an OData filter string from the provided . /// @@ -19,10 +20,10 @@ internal static class AzureAISearchVectorStoreCollectionSearchMapping /// A mapping of data model property names to the names under which they are stored. /// The OData filter string. /// Thrown when a provided filter value is not supported. - public static string BuildFilterString(VectorSearchFilter? basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) + public static string BuildLegacyFilterString(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) { var filterString = string.Empty; - if (basicVectorSearchFilter?.FilterClauses is not null) + if (basicVectorSearchFilter.FilterClauses is not null) { // Map Equality clauses. var filterStrings = basicVectorSearchFilter?.FilterClauses.OfType().Select(x => @@ -60,6 +61,7 @@ public static string BuildFilterString(VectorSearchFilter? basicVectorSearchFilt return filterString; } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// /// Gets the name of the name under which the property with the given name is stored. diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs index bdf25bd2b8a4..c3b338b816ad 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -14,7 +15,7 @@ using Azure.Search.Documents.Indexes.Models; using Azure.Search.Documents.Models; using Microsoft.Extensions.VectorData; -using VectorData = Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; @@ -66,7 +67,7 @@ public sealed class AzureAISearchVectorStoreRecordCollection : IVectorS ]; /// The default options for vector search. - private static readonly VectorData.VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. private readonly SearchIndexClient _searchIndexClient; @@ -314,7 +315,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable reco } /// - public Task> VectorizedSearchAsync(TVector vector, VectorData.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -335,7 +336,17 @@ public Task> VectorizedSearchAsync(TVector // Configure search settings. var vectorQueries = new List(); vectorQueries.Add(new VectorizedQuery(floatVector) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorFieldName } }); - var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter, this._propertyReader.JsonPropertyNamesMap); + +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + // Build filter object. + var filter = internalOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._propertyReader.JsonPropertyNamesMap), + { NewFilter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + _ => null + }; +#pragma warning restore CS0618 // Build search options. var searchOptions = new SearchOptions @@ -343,9 +354,14 @@ public Task> VectorizedSearchAsync(TVector VectorSearch = new(), Size = internalOptions.Top, Skip = internalOptions.Skip, - Filter = filterString, IncludeTotalCount = internalOptions.IncludeTotalCount, }; + + if (filter is not null) + { + searchOptions.Filter = filter; + } + searchOptions.VectorSearch.Queries.AddRange(vectorQueries); // Filter out vector fields if requested. @@ -359,7 +375,7 @@ public Task> VectorizedSearchAsync(TVector } /// - public Task> VectorizableTextSearchAsync(string searchText, VectorData.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(searchText); @@ -375,7 +391,17 @@ public Task> VectorizableTextSearchAsync(string sea // Configure search settings. var vectorQueries = new List(); vectorQueries.Add(new VectorizableTextQuery(searchText) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorFieldName } }); - var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter, this._propertyReader.JsonPropertyNamesMap); + +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + // Build filter object. + var filter = internalOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._propertyReader.JsonPropertyNamesMap), + { NewFilter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + _ => null + }; +#pragma warning restore CS0618 // Build search options. var searchOptions = new SearchOptions @@ -383,9 +409,14 @@ public Task> VectorizableTextSearchAsync(string sea VectorSearch = new(), Size = internalOptions.Top, Skip = internalOptions.Skip, - Filter = filterString, IncludeTotalCount = internalOptions.IncludeTotalCount, }; + + if (filter is not null) + { + searchOptions.Filter = filter; + } + searchOptions.VectorSearch.Queries.AddRange(vectorQueries); // Filter out vector fields if requested. diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs new file mode 100644 index 000000000000..6c0b4e44e23b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs @@ -0,0 +1,258 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +// MongoDB query reference: https://www.mongodb.com/docs/manual/reference/operator/query +// Information specific to vector search pre-filter: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter +internal class AzureCosmosDBMongoDBFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + internal BsonDocument Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + return this.Translate(lambdaExpression.Body); + } + + private BsonDocument Translate(Expression? node) + => node switch + { + BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } binary + => this.TranslateEqualityComparison(binary), + + BinaryExpression { NodeType: ExpressionType.AndAlso or ExpressionType.OrElse } andOr + => this.TranslateAndOr(andOr), + UnaryExpression { NodeType: ExpressionType.Not } not + => this.TranslateNot(not), + + // MemberExpression is generally handled within e.g. TranslateEqualityComparison; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) + => this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))), + + MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), + + _ => throw new NotSupportedException("The following NodeType is unsupported: " + node?.NodeType) + }; + + private BsonDocument TranslateEqualityComparison(BinaryExpression binary) + { + if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) + || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + { + if (value is null) + { + throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); + } + + // Short form of equality (instead of $eq) + if (binary.NodeType is ExpressionType.Equal) + { + return new BsonDocument { [storagePropertyName] = BsonValue.Create(value) }; + } + + var filterOperator = binary.NodeType switch + { + ExpressionType.NotEqual => "$ne", + ExpressionType.GreaterThan => "$gt", + ExpressionType.GreaterThanOrEqual => "$gte", + ExpressionType.LessThan => "$lt", + ExpressionType.LessThanOrEqual => "$lte", + + _ => throw new UnreachableException() + }; + + return new BsonDocument { [storagePropertyName] = new BsonDocument { [filterOperator] = BsonValue.Create(value) } }; + } + + throw new NotSupportedException("Invalid equality/comparison"); + } + + private BsonDocument TranslateAndOr(BinaryExpression andOr) + { + var mongoOperator = andOr.NodeType switch + { + ExpressionType.AndAlso => "$and", + ExpressionType.OrElse => "$or", + _ => throw new UnreachableException() + }; + + var (left, right) = (this.Translate(andOr.Left), this.Translate(andOr.Right)); + + var nestedLeft = left.ElementCount == 1 && left.Elements.First() is var leftElement && leftElement.Name == mongoOperator ? (BsonArray)leftElement.Value : null; + var nestedRight = right.ElementCount == 1 && right.Elements.First() is var rightElement && rightElement.Name == mongoOperator ? (BsonArray)rightElement.Value : null; + + switch ((nestedLeft, nestedRight)) + { + case (not null, not null): + nestedLeft.AddRange(nestedRight); + return left; + case (not null, null): + nestedLeft.Add(right); + return left; + case (null, not null): + nestedRight.Insert(0, left); + return right; + case (null, null): + return new BsonDocument { [mongoOperator] = new BsonArray([left, right]) }; + } + } + + private BsonDocument TranslateNot(UnaryExpression not) + { + switch (not.Operand) + { + // Special handling for !(a == b) and !(a != b) + case BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary: + return this.TranslateEqualityComparison( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + + // Not over bool field (Filter => r => !r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + return this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + } + + var operand = this.Translate(not.Operand); + + // Identify NOT over $in, transform to $nin (https://www.mongodb.com/docs/manual/reference/operator/query/nin/#mongodb-query-op.-nin) + if (operand.ElementCount == 1 && operand.Elements.First() is { Name: var fieldName, Value: BsonDocument nested } && + nested.ElementCount == 1 && nested.Elements.First() is { Name: "$in", Value: BsonArray values }) + { + return new BsonDocument { [fieldName] = new BsonDocument { ["$nin"] = values } }; + } + + throw new NotSupportedException("MongogDB does not support the NOT operator in vector search pre-filters"); + } + + private BsonDocument TranslateMethodCall(MethodCallExpression methodCall) + => methodCall switch + { + // Enumerable.Contains() + { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable) + => this.TranslateContains(source, item), + + // List.Contains() + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>) => this.TranslateContains(source, item), + + _ => throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}") + }; + + private BsonDocument TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryTranslateFieldAccess(source, out _): + throw new NotSupportedException("MongoDB does not support Contains within array fields ($elemMatch) in vector search pre-filters"); + + // Contains over inline enumerable + case NewArrayExpression newArray: + var elements = new object?[newArray.Expressions.Count]; + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + { + throw new NotSupportedException("Invalid element in array"); + } + + elements[i] = elementValue; + } + + return ProcessInlineEnumerable(elements, item); + + // Contains over captured enumerable (we inline) + case var _ when TryGetConstant(source, out var constantEnumerable) + && constantEnumerable is IEnumerable enumerable and not string: + return ProcessInlineEnumerable(enumerable, item); + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + + BsonDocument ProcessInlineEnumerable(IEnumerable elements, Expression item) + { + if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + { + throw new NotSupportedException("Unsupported item type in Contains"); + } + + return new BsonDocument + { + [storagePropertyName] = new BsonDocument + { + ["$in"] = new BsonArray(from object? element in elements select BsonValue.Create(element)) + } + }; + } + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs index 6e41eb7f3cb9..32377244112c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs @@ -20,6 +20,7 @@ internal static class AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping /// Returns distance function specified on vector property or default . public static string GetVectorPropertyDistanceFunction(string? distanceFunction) => !string.IsNullOrWhiteSpace(distanceFunction) ? distanceFunction! : MongoDBConstants.DefaultDistanceFunction; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Build Azure CosmosDB MongoDB filter from the provided . /// @@ -86,6 +87,7 @@ internal static class AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping return filter; } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// Returns search part of the search query for index kind. public static BsonDocument GetSearchQueryForHnswIndex( diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs index d54a184e5771..a5d355150da3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; using System.Threading; @@ -12,6 +13,7 @@ using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver; +using MEVD = Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -33,7 +35,7 @@ public sealed class AzureCosmosDBMongoDBVectorStoreRecordCollection : I private const string DocumentPropertyName = "document"; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly MEVD.VectorSearchOptions s_defaultVectorSearchOptions = new(); /// that can be used to manage the collections in Azure CosmosDB MongoDB. private readonly IMongoDatabase _mongoDatabase; @@ -244,7 +246,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable reco /// public async Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = null, + MEVD.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -270,9 +272,17 @@ public async Task> VectorizedSearchAsync( var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; - var filter = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter( - searchOptions.Filter, - this._storagePropertyNames); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + var filter = searchOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter( + legacyFilter, + this._storagePropertyNames), + { NewFilter: Expression> newFilter } => new AzureCosmosDBMongoDBFilterTranslator().Translate(newFilter, this._storagePropertyNames), + _ => null + }; +#pragma warning restore CS0618 // Constructing a query to fetch "skip + top" total items // to perform skip logic locally, since skip option is not part of API. @@ -371,7 +381,7 @@ private async Task> FindAsync(FilterDefinition> EnumerateAndMapSearchResultsAsync( IAsyncCursor cursor, - VectorSearchOptions searchOptions, + MEVD.VectorSearchOptions searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { const string OperationName = "Aggregate"; diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs index 87aeee36355e..6dbb0d440b45 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs @@ -13,5 +13,5 @@ internal static class AzureCosmosDBNoSQLConstants /// Variable name for table in Azure CosmosDB NoSQL queries. /// Can be any string. Example: "SELECT x.Name FROM x". /// - internal const string TableQueryVariableName = "x"; + internal const char ContainerAlias = 'x'; } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLFilter.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLFilter.cs deleted file mode 100644 index 8cf6636c73e7..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLFilter.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; - -namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; - -/// -/// Contains properties required to build query with filtering conditions. -/// -internal sealed class AzureCosmosDBNoSQLFilter -{ - public List? WhereClauseArguments { get; set; } - - public Dictionary? QueryParameters { get; set; } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs index a66eb5bfb719..1b0e7dcb8a7f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Text; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.VectorData; @@ -21,13 +22,13 @@ internal static class AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder /// /// Builds to get items from Azure CosmosDB NoSQL using vector search. /// - public static QueryDefinition BuildSearchQuery( + public static QueryDefinition BuildSearchQuery( TVector vector, List fields, Dictionary storagePropertyNames, string vectorPropertyName, string scorePropertyName, - VectorSearchOptions searchOptions) + VectorSearchOptions searchOptions) { Verify.NotNull(vector); @@ -36,7 +37,7 @@ public static QueryDefinition BuildSearchQuery( const string LimitVariableName = "@limit"; const string TopVariableName = "@top"; - var tableVariableName = AzureCosmosDBNoSQLConstants.TableQueryVariableName; + var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; var fieldsArgument = fields.Select(field => $"{tableVariableName}.{field}"); var vectorDistanceArgument = $"VectorDistance({tableVariableName}.{vectorPropertyName}, {VectorVariableName})"; @@ -44,19 +45,22 @@ public static QueryDefinition BuildSearchQuery( var selectClauseArguments = string.Join(SelectClauseDelimiter, [.. fieldsArgument, vectorDistanceArgumentWithAlias]); - var filter = BuildSearchFilter(searchOptions.Filter, storagePropertyNames); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + // Build filter object. + var (whereClause, filterParameters) = searchOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => BuildSearchFilter(legacyFilter, storagePropertyNames), + { NewFilter: Expression> newFilter } => new AzureCosmosDBNoSqlFilterTranslator().Translate(newFilter, storagePropertyNames), + _ => (null, []) + }; +#pragma warning restore CS0618 // VectorSearchFilter is obsolete - var filterQueryParameters = filter?.QueryParameters; - var filterWhereClauseArguments = filter?.WhereClauseArguments; - var queryParameters = new Dictionary + var queryParameters = new Dictionary { [VectorVariableName] = vector }; - var whereClause = filterWhereClauseArguments is { Count: > 0 } ? - $"WHERE {string.Join(AndConditionDelimiter, filterWhereClauseArguments)}" : - string.Empty; - // If Offset is not configured, use Top parameter instead of Limit/Offset // since it's more optimized. var topArgument = searchOptions.Skip == 0 ? $"TOP {TopVariableName} " : string.Empty; @@ -66,9 +70,9 @@ public static QueryDefinition BuildSearchQuery( builder.AppendLine($"SELECT {topArgument}{selectClauseArguments}"); builder.AppendLine($"FROM {tableVariableName}"); - if (filterWhereClauseArguments is { Count: > 0 }) + if (whereClause is not null) { - builder.AppendLine($"WHERE {string.Join(AndConditionDelimiter, filterWhereClauseArguments)}"); + builder.Append("WHERE ").AppendLine(whereClause); } builder.AppendLine($"ORDER BY {vectorDistanceArgument}"); @@ -86,9 +90,9 @@ public static QueryDefinition BuildSearchQuery( var queryDefinition = new QueryDefinition(builder.ToString()); - if (filterQueryParameters is { Count: > 0 }) + if (filterParameters is { Count: > 0 }) { - queryParameters = queryParameters.Union(filterQueryParameters).ToDictionary(k => k.Key, v => v.Value); + queryParameters = queryParameters.Union(filterParameters).ToDictionary(k => k.Key, v => v.Value); } foreach (var queryParameter in queryParameters) @@ -113,7 +117,7 @@ public static QueryDefinition BuildSelectQuery( const string RecordKeyVariableName = "@rk"; const string PartitionKeyVariableName = "@pk"; - var tableVariableName = AzureCosmosDBNoSQLConstants.TableQueryVariableName; + var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; var selectClauseArguments = string.Join(SelectClauseDelimiter, fields.Select(field => $"{tableVariableName}.{field}")); @@ -123,10 +127,11 @@ public static QueryDefinition BuildSelectQuery( $"({tableVariableName}.{keyStoragePropertyName} = {RecordKeyVariableName}{index} {AndConditionDelimiter} " + $"{tableVariableName}.{partitionKeyStoragePropertyName} = {PartitionKeyVariableName}{index})")); - var query = - $"SELECT {selectClauseArguments} " + - $"FROM {tableVariableName} " + - $"WHERE {whereClauseArguments} "; + var query = $""" + SELECT {selectClauseArguments} + FROM {tableVariableName} + WHERE {whereClauseArguments} + """; var queryDefinition = new QueryDefinition(query); @@ -147,44 +152,43 @@ public static QueryDefinition BuildSelectQuery( #region private - private static AzureCosmosDBNoSQLFilter? BuildSearchFilter( - VectorSearchFilter? filter, +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + private static (string WhereClause, Dictionary Parameters) BuildSearchFilter( + VectorSearchFilter filter, Dictionary storagePropertyNames) { const string EqualOperator = "="; const string ArrayContainsOperator = "ARRAY_CONTAINS"; const string ConditionValueVariableName = "@cv"; - var tableVariableName = AzureCosmosDBNoSQLConstants.TableQueryVariableName; + var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; - var filterClauses = filter?.FilterClauses.ToList(); - - if (filterClauses is not { Count: > 0 }) - { - return null; - } + var filterClauses = filter.FilterClauses.ToList(); - var whereClauseArguments = new List(); - var queryParameters = new Dictionary(); + var whereClauseBuilder = new StringBuilder(); + var queryParameters = new Dictionary(); for (var i = 0; i < filterClauses.Count; i++) { + if (i > 0) + { + whereClauseBuilder.Append(" AND "); + } var filterClause = filterClauses[i]; string queryParameterName = $"{ConditionValueVariableName}{i}"; object queryParameterValue; - string whereClauseArgument; if (filterClause is EqualToFilterClause equalToFilterClause) { var propertyName = GetStoragePropertyName(equalToFilterClause.FieldName, storagePropertyNames); - whereClauseArgument = $"{tableVariableName}.{propertyName} {EqualOperator} {queryParameterName}"; + whereClauseBuilder.Append($"{tableVariableName}.{propertyName} {EqualOperator} {queryParameterName}"); queryParameterValue = equalToFilterClause.Value; } else if (filterClause is AnyTagEqualToFilterClause anyTagEqualToFilterClause) { var propertyName = GetStoragePropertyName(anyTagEqualToFilterClause.FieldName, storagePropertyNames); - whereClauseArgument = $"{ArrayContainsOperator}({tableVariableName}.{propertyName}, {queryParameterName})"; + whereClauseBuilder.Append($"{ArrayContainsOperator}({tableVariableName}.{propertyName}, {queryParameterName})"); queryParameterValue = anyTagEqualToFilterClause.Value; } else @@ -196,16 +200,12 @@ public static QueryDefinition BuildSelectQuery( nameof(AnyTagEqualToFilterClause)])}"); } - whereClauseArguments.Add(whereClauseArgument); queryParameters.Add(queryParameterName, queryParameterValue); } - return new AzureCosmosDBNoSQLFilter - { - WhereClauseArguments = whereClauseArguments, - QueryParameters = queryParameters, - }; + return (whereClauseBuilder.ToString(), queryParameters); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete private static string GetStoragePropertyName(string propertyName, Dictionary storagePropertyNames) { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs index 6ab9222d2a14..53463cb943b4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs @@ -69,7 +69,7 @@ public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollection : ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// that can be used to manage the collections in Azure CosmosDB NoSQL. private readonly Database _database; @@ -355,7 +355,7 @@ async IAsyncEnumerable IVectorStoreRecordCollect /// public Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = null, + VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "VectorizedSearch"; @@ -679,7 +679,7 @@ private async IAsyncEnumerable> MapSearchResultsAsyn IAsyncEnumerable jsonObjects, string scorePropertyName, string operationName, - VectorSearchOptions searchOptions, + VectorSearchOptions searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { await foreach (var jsonObject in jsonObjects.ConfigureAwait(false)) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs new file mode 100644 index 000000000000..30019b97a1e1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs @@ -0,0 +1,284 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; + +internal class AzureCosmosDBNoSqlFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly Dictionary _parameters = new(); + + private readonly StringBuilder _sql = new(); + + internal (string WhereClause, Dictionary Parameters) Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + this._parameters.Clear(); + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this._sql.Clear(); + this.Translate(lambdaExpression.Body); + return (this._sql.ToString(), this._parameters); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case NewArrayExpression newArray: + this.TranslateNewArray(newArray); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + this._sql.Append('('); + this.Translate(binary.Left); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._sql.Append(')'); + } + + private void TranslateConstant(ConstantExpression constant) + { + // TODO: Nullable + switch (constant.Value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('"').Append(s.Replace(@"\", @"\\").Replace("\"", "\\\"")).Append('"'); + return; + case bool b: + this._sql.Append(b ? "true" : "false"); + return; + case Guid g: + this._sql.Append('"').Append(g.ToString()).Append('"'); + return; + + case DateTime: + case DateTimeOffset: + throw new NotImplementedException(); + + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("null"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetPropertyAccess(memberExpression, out var column): + this._sql.Append(AzureCosmosDBNoSQLConstants.ContainerAlias).Append("[\"").Append(column).Append("\"]"); + return; + + // Identify captured lambda variables, translate to Cosmos parameters (@foo, @bar...) + case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): + // Duplicate parameter name, create a new parameter with a different name + // TODO: Share the same parameter when it references the same captured value + if (this._parameters.ContainsKey(name)) + { + var baseName = name; + var i = 0; + do + { + name = baseName + (i++); + } while (this._parameters.ContainsKey(name)); + } + + name = '@' + name; + this._parameters.Add(name, value); + this._sql.Append(name); + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateNewArray(NewArrayExpression newArray) + { + this._sql.Append('['); + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (i > 0) + { + this._sql.Append(", "); + } + + this.Translate(newArray.Expressions[i]); + } + + this._sql.Append(']'); + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + this._sql.Append("ARRAY_CONTAINS("); + this.Translate(source); + this._sql.Append(", "); + this.Translate(item); + this._sql.Append(')'); + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + // Special handling for !(a == b) and !(a != b) + case ExpressionType.Not: + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetPropertyAccess(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + name = fieldInfo.Name; + value = fieldInfo.GetValue(constant.Value); + return true; + } + + name = null; + value = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs index 7ecea345cb85..6b33671cef9f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs @@ -88,6 +88,7 @@ public static float ConvertScore(float score, string? distanceFunction) } } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Filter the provided records using the provided filter definition. /// @@ -95,15 +96,15 @@ public static float ConvertScore(float score, string? distanceFunction) /// The records to filter. /// The filtered records. /// Thrown when an unsupported filter clause is encountered. - public static IEnumerable FilterRecords(VectorSearchFilter? filter, IEnumerable records) + public static IEnumerable FilterRecords(VectorSearchFilter filter, IEnumerable records) { - if (filter == null) - { - return records; - } - return records.Where(record => { + if (record is null) + { + return false; + } + var result = true; // Run each filter clause against the record, and AND the results together. @@ -197,6 +198,7 @@ private static bool CheckAnyTagEqualTo(object record, AnyTagEqualToFilterClause return false; } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// /// Get the property info for the provided property name on the record. diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs index a2fe21e0cfc6..03fe957cca07 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs @@ -4,6 +4,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -29,7 +30,7 @@ public sealed class InMemoryVectorStoreRecordCollection : IVector ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// Internal storage for all of the record collections. private readonly ConcurrentDictionary> _internalCollections; @@ -210,7 +211,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record /// #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - Need to satisfy the interface which returns IAsyncEnumerable - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) #pragma warning restore CS1998 { Verify.NotNull(vector); @@ -234,13 +235,22 @@ public async Task> VectorizedSearchAsync(T throw new InvalidOperationException($"The collection does not have a vector field named '{internalOptions.VectorPropertyName}', so vector search is not possible."); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete // Filter records using the provided filter before doing the vector comparison. - var filteredRecords = InMemoryVectorStoreCollectionSearchMapping.FilterRecords(internalOptions.Filter, this.GetCollectionDictionary().Values); + var allValues = this.GetCollectionDictionary().Values.Cast(); + var filteredRecords = internalOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => InMemoryVectorStoreCollectionSearchMapping.FilterRecords(legacyFilter, allValues), + { NewFilter: Expression> newFilter } => allValues.AsQueryable().Where(newFilter), + _ => allValues + }; +#pragma warning restore CS0618 // VectorSearchFilter is obsolete // Compare each vector in the filtered results with the provided vector. - var results = filteredRecords.Select((record) => + var results = filteredRecords.Select(record => { - var vectorObject = this._vectorResolver(vectorPropertyName!, (TRecord)record); + var vectorObject = this._vectorResolver(vectorPropertyName!, record); if (vectorObject is not ReadOnlyMemory dbVector) { return null; diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs new file mode 100644 index 000000000000..202908de1c0b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs @@ -0,0 +1,258 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +// MongoDB query reference: https://www.mongodb.com/docs/manual/reference/operator/query +// Information specific to vector search pre-filter: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter +internal class MongoDBFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + internal BsonDocument Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + return this.Translate(lambdaExpression.Body); + } + + private BsonDocument Translate(Expression? node) + => node switch + { + BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } binary + => this.TranslateEqualityComparison(binary), + + BinaryExpression { NodeType: ExpressionType.AndAlso or ExpressionType.OrElse } andOr + => this.TranslateAndOr(andOr), + UnaryExpression { NodeType: ExpressionType.Not } not + => this.TranslateNot(not), + + // MemberExpression is generally handled within e.g. TranslateEqualityComparison; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) + => this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))), + + MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), + + _ => throw new NotSupportedException("The following NodeType is unsupported: " + node?.NodeType) + }; + + private BsonDocument TranslateEqualityComparison(BinaryExpression binary) + { + if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) + || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + { + if (value is null) + { + throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); + } + + // Short form of equality (instead of $eq) + if (binary.NodeType is ExpressionType.Equal) + { + return new BsonDocument { [storagePropertyName] = BsonValue.Create(value) }; + } + + var filterOperator = binary.NodeType switch + { + ExpressionType.NotEqual => "$ne", + ExpressionType.GreaterThan => "$gt", + ExpressionType.GreaterThanOrEqual => "$gte", + ExpressionType.LessThan => "$lt", + ExpressionType.LessThanOrEqual => "$lte", + + _ => throw new UnreachableException() + }; + + return new BsonDocument { [storagePropertyName] = new BsonDocument { [filterOperator] = BsonValue.Create(value) } }; + } + + throw new NotSupportedException("Invalid equality/comparison"); + } + + private BsonDocument TranslateAndOr(BinaryExpression andOr) + { + var mongoOperator = andOr.NodeType switch + { + ExpressionType.AndAlso => "$and", + ExpressionType.OrElse => "$or", + _ => throw new UnreachableException() + }; + + var (left, right) = (this.Translate(andOr.Left), this.Translate(andOr.Right)); + + var nestedLeft = left.ElementCount == 1 && left.Elements.First() is var leftElement && leftElement.Name == mongoOperator ? (BsonArray)leftElement.Value : null; + var nestedRight = right.ElementCount == 1 && right.Elements.First() is var rightElement && rightElement.Name == mongoOperator ? (BsonArray)rightElement.Value : null; + + switch ((nestedLeft, nestedRight)) + { + case (not null, not null): + nestedLeft.AddRange(nestedRight); + return left; + case (not null, null): + nestedLeft.Add(right); + return left; + case (null, not null): + nestedRight.Insert(0, left); + return right; + case (null, null): + return new BsonDocument { [mongoOperator] = new BsonArray([left, right]) }; + } + } + + private BsonDocument TranslateNot(UnaryExpression not) + { + switch (not.Operand) + { + // Special handling for !(a == b) and !(a != b) + case BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary: + return this.TranslateEqualityComparison( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + + // Not over bool field (Filter => r => !r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + return this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + } + + var operand = this.Translate(not.Operand); + + // Identify NOT over $in, transform to $nin (https://www.mongodb.com/docs/manual/reference/operator/query/nin/#mongodb-query-op.-nin) + if (operand.ElementCount == 1 && operand.Elements.First() is { Name: var fieldName, Value: BsonDocument nested } && + nested.ElementCount == 1 && nested.Elements.First() is { Name: "$in", Value: BsonArray values }) + { + return new BsonDocument { [fieldName] = new BsonDocument { ["$nin"] = values } }; + } + + throw new NotSupportedException("MongogDB does not support the NOT operator in vector search pre-filters"); + } + + private BsonDocument TranslateMethodCall(MethodCallExpression methodCall) + => methodCall switch + { + // Enumerable.Contains() + { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable) + => this.TranslateContains(source, item), + + // List.Contains() + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>) => this.TranslateContains(source, item), + + _ => throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}") + }; + + private BsonDocument TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryTranslateFieldAccess(source, out _): + throw new NotSupportedException("MongoDB does not support Contains within array fields ($elemMatch) in vector search pre-filters"); + + // Contains over inline enumerable + case NewArrayExpression newArray: + var elements = new object?[newArray.Expressions.Count]; + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + { + throw new NotSupportedException("Invalid element in array"); + } + + elements[i] = elementValue; + } + + return ProcessInlineEnumerable(elements, item); + + // Contains over captured enumerable (we inline) + case var _ when TryGetConstant(source, out var constantEnumerable) + && constantEnumerable is IEnumerable enumerable and not string: + return ProcessInlineEnumerable(enumerable, item); + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + + BsonDocument ProcessInlineEnumerable(IEnumerable elements, Expression item) + { + if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + { + throw new NotSupportedException("Unsupported item type in Contains"); + } + + return new BsonDocument + { + [storagePropertyName] = new BsonDocument + { + ["$in"] = new BsonArray(from object? element in elements select BsonValue.Create(element)) + } + }; + } + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs index 931b668f535d..de47f6723b23 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs @@ -16,6 +16,7 @@ internal static class MongoDBVectorStoreCollectionSearchMapping /// Returns distance function specified on vector property or default . public static string GetVectorPropertyDistanceFunction(string? distanceFunction) => !string.IsNullOrWhiteSpace(distanceFunction) ? distanceFunction! : MongoDBConstants.DefaultDistanceFunction; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Build MongoDB filter from the provided . /// @@ -23,13 +24,13 @@ internal static class MongoDBVectorStoreCollectionSearchMapping /// A dictionary that maps from a property name to the storage name. /// Thrown when the provided filter type is unsupported. /// Thrown when property name specified in filter doesn't exist. - public static BsonDocument? BuildFilter( - VectorSearchFilter? vectorSearchFilter, + public static BsonDocument? BuildLegacyFilter( + VectorSearchFilter vectorSearchFilter, Dictionary storagePropertyNames) { const string EqualOperator = "$eq"; - var filterClauses = vectorSearchFilter?.FilterClauses.ToList(); + var filterClauses = vectorSearchFilter.FilterClauses.ToList(); if (filterClauses is not { Count: > 0 }) { @@ -82,6 +83,7 @@ internal static class MongoDBVectorStoreCollectionSearchMapping return filter; } +#pragma warning restore CS0618 /// Returns search part of the search query. public static BsonDocument GetSearchQuery( diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs index 353b3534dab9..25fc14e8196e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; using System.Threading; @@ -11,6 +12,7 @@ using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver; +using MEVD = Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.MongoDB; @@ -32,7 +34,7 @@ public sealed class MongoDBVectorStoreRecordCollection : IVectorStoreRe private const string DocumentPropertyName = "document"; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly MEVD.VectorSearchOptions s_defaultVectorSearchOptions = new(); /// that can be used to manage the collections in MongoDB. private readonly IMongoDatabase _mongoDatabase; @@ -247,7 +249,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable reco /// public async Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = null, + MEVD.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -273,9 +275,15 @@ public async Task> VectorizedSearchAsync( var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter( - searchOptions.Filter, - this._storagePropertyNames); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + var filter = searchOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(legacyFilter, this._storagePropertyNames), + { NewFilter: Expression> newFilter } => new MongoDBFilterTranslator().Translate(newFilter, this._storagePropertyNames), + _ => null + }; +#pragma warning restore CS0618 // Constructing a query to fetch "skip + top" total items // to perform skip logic locally, since skip option is not part of API. @@ -383,7 +391,7 @@ private async Task> FindAsync(FilterDefinition> EnumerateAndMapSearchResultsAsync( IAsyncCursor cursor, - VectorSearchOptions searchOptions, + MEVD.VectorSearchOptions searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { const string OperationName = "Aggregate"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs index e02e18807d9c..5b3d511c6b08 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs @@ -12,6 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// internal static class PineconeVectorStoreCollectionSearchMapping { +#pragma warning disable CS0618 // FilterClause is obsolete /// /// Build a Pinecone from a set of filter clauses. /// @@ -59,4 +60,5 @@ public static MetadataMap BuildSearchFilter(IEnumerable? filterCla return metadataMap; } +#pragma warning restore CS0618 // FilterClause is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs index 8a956f53f635..8e1e8cf7aaf1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs @@ -32,7 +32,7 @@ public sealed class PineconeVectorStoreRecordCollection : IVectorStoreR private const string GetOperationName = "Get"; private const string QueryOperationName = "Query"; - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); private readonly Sdk.PineconeClient _pineconeClient; private readonly PineconeVectorStoreRecordCollectionOptions _options; @@ -246,7 +246,7 @@ await this.RunOperationAsync( } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -259,9 +259,12 @@ public async Task> VectorizedSearchAsync(T // Resolve options and build filter clause. var internalOptions = options ?? s_defaultVectorSearchOptions; var mapperOptions = new StorageToDataModelMapperOptions { IncludeVectors = options?.IncludeVectors ?? false }; + +#pragma warning disable CS0618 // FilterClause is obsolete var filter = PineconeVectorStoreCollectionSearchMapping.BuildSearchFilter( internalOptions.Filter?.FilterClauses, this._propertyReader.StoragePropertyNamesMap); +#pragma warning restore CS0618 // Get the current index. var indexNamespace = this.GetIndexNamespace(); diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index d130d2f13b44..3c864cc6537f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Linq.Expressions; using Microsoft.Extensions.VectorData; using Pgvector; @@ -124,13 +126,16 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. - /// The properties of the table. + /// The property reader. /// The property which the vectors to compare are stored in. /// The vector to match. - /// The filter conditions for the query. + /// The filter conditions for the query. + /// The filter conditions for the query. /// The number of records to skip. /// Specifies whether to include vectors in the result. /// The maximum number of records to return. /// The built SQL command info. - PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? filter, int? skip, bool includeVectors, int limit); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? legacyFilter, Expression>? newFilter, int? skip, bool includeVectors, int limit); +#pragma warning restore CS0618 // VectorSearchFilter is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index 59aa9829c568..3fb62b667a92 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; @@ -118,15 +120,18 @@ internal interface IPostgresVectorStoreDbClient /// Gets the nearest matches to the . /// /// The name assigned to a table of entries. - /// The properties to retrieve. - /// The property which the vectors to compare are stored in. + /// The property reader. + /// The vector property. /// The to compare the table's vector with. /// The maximum number of similarity results to return. - /// Optional conditions to filter the results. + /// Optional conditions to filter the results. + /// Optional conditions to filter the results. /// The number of entries to skip. /// If true, the vectors will be returned in the entries. /// The to monitor for cancellation requests. The default is . /// An asynchronous stream of objects that the nearest matches to the . - IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, - VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? legacyFilter = default, Expression>? newFilter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); +#pragma warning restore CS0618 // VectorSearchFilter is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs new file mode 100644 index 000000000000..c1cf9f3633b9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs @@ -0,0 +1,332 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal class PostgresFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly List _parameterValues = new(); + private int _parameterIndex; + + private readonly StringBuilder _sql = new(); + + internal (string Clause, List Parameters) Translate( + IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression, + int startParamIndex) + { + this._storagePropertyNames = storagePropertyNames; + + this._parameterIndex = startParamIndex; + this._parameterValues.Clear(); + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this._sql.Clear(); + this._sql.Append("WHERE "); + this.Translate(lambdaExpression.Body); + return (this._sql.ToString(), this._parameterValues); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + // Special handling for null comparisons + switch (binary.NodeType) + { + case ExpressionType.Equal when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NOT NULL)"); + return; + + case ExpressionType.Equal when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NOT NULL)"); + return; + } + + this._sql.Append('('); + this.Translate(binary.Left); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._sql.Append(')'); + + static bool IsNull(Expression expression) + => expression is ConstantExpression { Value: null } + || (TryGetCapturedValue(expression, out var capturedValue) && capturedValue is null); + } + + private void TranslateConstant(ConstantExpression constant) + { + // TODO: Nullable + switch (constant.Value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); + return; + case bool b: + this._sql.Append(b ? "TRUE" : "FALSE"); + return; + case Guid g: + this._sql.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime: + case DateTimeOffset: + throw new NotImplementedException(); + + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("NULL"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetColumn(memberExpression, out var column): + this._sql.Append('"').Append(column).Append('"'); + return; + + // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) + case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (capturedValue is null) + { + this._sql.Append("NULL"); + } + else + { + this._parameterValues.Add(capturedValue); + this._sql.Append('$').Append(this._parameterIndex++); + } + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryGetColumn(source, out _): + this.Translate(source); + this._sql.Append(" @> ARRAY["); + this.Translate(item); + this._sql.Append(']'); + return; + + // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) + case NewArrayExpression newArray: + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in newArray.Expressions) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.Translate(element); + } + + this._sql.Append(')'); + return; + + // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) + case var _ when TryGetCapturedValue(source, out _): + this.Translate(item); + this._sql.Append(" = ANY ("); + this.Translate(source); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + capturedValue = fieldInfo.GetValue(constant.Value); + return true; + } + + capturedValue = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index d68412d31b7d..364c564703e4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Text; using Microsoft.Extensions.VectorData; using Npgsql; @@ -20,12 +21,13 @@ internal class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCol public PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) { return new PostgresSqlCommandInfo( - commandText: @" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = $1 - AND table_type = 'BASE TABLE' - AND table_name = $2", + commandText: """ +SELECT table_name +FROM information_schema.tables +WHERE table_schema = $1 + AND table_type = 'BASE TABLE' + AND table_name = $2 +""", parameters: [ new NpgsqlParameter() { Value = schema }, new NpgsqlParameter() { Value = tableName } @@ -37,11 +39,11 @@ FROM information_schema.tables public PostgresSqlCommandInfo BuildGetTablesCommand(string schema) { return new PostgresSqlCommandInfo( - commandText: @" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = $1 - AND table_type = 'BASE TABLE'", + commandText: """ +SELECT table_name +FROM information_schema.tables +WHERE table_schema = $1 AND table_type = 'BASE TABLE' +""", parameters: [new NpgsqlParameter() { Value = schema }] ); } @@ -167,11 +169,12 @@ public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName var valuesParams = string.Join(", ", columns.Select((k, i) => $"${i + 1}")); var columnsWithIndex = columns.Select((k, i) => (col: k, idx: i)); var updateColumnsWithParams = string.Join(", ", columnsWithIndex.Where(c => c.col != keyColumn).Select(c => $"\"{c.col}\"=${c.idx + 1}")); - var commandText = $@" - INSERT INTO {schema}.""{tableName}"" ({columnNames}) - VALUES({valuesParams}) - ON CONFLICT (""{keyColumn}"") - DO UPDATE SET {updateColumnsWithParams};"; + var commandText = $""" +INSERT INTO {schema}."{tableName}" ({columnNames}) +VALUES ({valuesParams}) +ON CONFLICT ("{keyColumn}") +DO UPDATE SET {updateColumnsWithParams}; +"""; return new PostgresSqlCommandInfo(commandText) { @@ -204,11 +207,12 @@ public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tabl var updateSetClause = string.Join(", ", columns.Where(c => c != keyColumn).Select(c => $"\"{c}\" = EXCLUDED.\"{c}\"")); // Generate the SQL command - var commandText = $@" - INSERT INTO {schema}.""{tableName}"" ({columnNames}) - VALUES {valuesRows} - ON CONFLICT (""{keyColumn}"") - DO UPDATE SET {updateSetClause}; "; + var commandText = $""" +INSERT INTO {schema}."{tableName}" ({columnNames}) +VALUES {valuesRows} +ON CONFLICT ("{keyColumn}") +DO UPDATE SET {updateSetClause}; +"""; // Generate the parameters var parameters = new List(); @@ -262,10 +266,11 @@ public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableN var queryColumnList = string.Join(", ", queryColumns); return new PostgresSqlCommandInfo( - commandText: $@" - SELECT {queryColumnList} - FROM {schema}.""{tableName}"" - WHERE ""{keyColumn}"" = ${1};", + commandText: $""" +SELECT {queryColumnList} +FROM {schema}."{tableName}" +WHERE "{keyColumn}" = ${1}; +""", parameters: [new NpgsqlParameter() { Value = key }] ); } @@ -294,10 +299,11 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t var keyParams = string.Join(", ", keys.Select((k, i) => $"${i + 1}")); // Generate the SQL command - var commandText = $@" - SELECT {columnNames} - FROM {schema}.""{tableName}"" - WHERE ""{keyColumn}"" = ANY($1);"; + var commandText = $""" +SELECT {columnNames} +FROM {schema}."{tableName}" +WHERE "{keyColumn}" = ANY($1); +"""; return new PostgresSqlCommandInfo(commandText) { @@ -309,9 +315,10 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t public PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key) { return new PostgresSqlCommandInfo( - commandText: $@" - DELETE FROM {schema}.""{tableName}"" - WHERE ""{keyColumn}"" = ${1};", + commandText: $""" +DELETE FROM {schema}."{tableName}" +WHERE "{keyColumn}" = ${1}; +""", parameters: [new NpgsqlParameter() { Value = key }] ); } @@ -333,9 +340,10 @@ public PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, strin } } - var commandText = $@" - DELETE FROM {schema}.""{tableName}"" - WHERE ""{keyColumn}"" = ANY($1);"; + var commandText = $""" +DELETE FROM {schema}."{tableName}" +WHERE "{keyColumn}" = ANY($1); +"""; return new PostgresSqlCommandInfo(commandText) { @@ -343,13 +351,14 @@ DELETE FROM {schema}.""{tableName}"" }; } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// - public PostgresSqlCommandInfo BuildGetNearestMatchCommand( - string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, - VectorSearchFilter? filter, int? skip, bool includeVectors, int limit) + public PostgresSqlCommandInfo BuildGetNearestMatchCommand( + string schema, string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, + VectorSearchFilter? legacyFilter, Expression>? newFilter, int? skip, bool includeVectors, int limit) { var columns = string.Join(" ,", - properties + propertyReader.RecordDefinition.Properties .Select(property => property.StoragePropertyName ?? property.DataModelPropertyName) .Select(column => $"\"{column}\"") ); @@ -367,14 +376,24 @@ public PostgresSqlCommandInfo BuildGetNearestMatchCommand( }; var vectorColumn = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + // Start where clause params at 2, vector takes param 1. - var where = GenerateWhereClause(schema, tableName, properties, filter, startParamIndex: 2); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + var (where, parameters) = (oldFilter: legacyFilter, newFilter) switch + { + (not null, not null) => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + (not null, null) => GenerateLegacyFilterWhereClause(schema, tableName, propertyReader.RecordDefinition.Properties, legacyFilter, startParamIndex: 2), + (null, not null) => new PostgresFilterTranslator().Translate(propertyReader.StoragePropertyNamesMap, newFilter, startParamIndex: 2), + _ => (Clause: string.Empty, Parameters: []) + }; +#pragma warning restore CS0618 // VectorSearchFilter is obsolete - var commandText = $@" - SELECT {columns}, ""{vectorColumn}"" {distanceOp} $1 AS ""{PostgresConstants.DistanceColumnName}"" - FROM {schema}.""{tableName}"" {where.Clause} - ORDER BY {PostgresConstants.DistanceColumnName} - LIMIT {limit}"; + var commandText = $""" +SELECT {columns}, "{vectorColumn}" {distanceOp} $1 AS "{PostgresConstants.DistanceColumnName}" +FROM {schema}."{tableName}" {where} +ORDER BY {PostgresConstants.DistanceColumnName} +LIMIT {limit} +"""; if (skip.HasValue) { commandText += $" OFFSET {skip.Value}"; } @@ -383,9 +402,10 @@ ORDER BY {PostgresConstants.DistanceColumnName} // Instead we'll wrap the query in a subquery and modify the distance in the outer query. if (vectorProperty.DistanceFunction == DistanceFunction.CosineSimilarity) { - commandText = $@" - SELECT {columns}, 1 - ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" - FROM ({commandText}) AS subquery"; + commandText = $""" +SELECT {columns}, 1 - "{PostgresConstants.DistanceColumnName}" AS "{PostgresConstants.DistanceColumnName}" +FROM ({commandText}) AS subquery +"""; } // For inner product, we need to take -1 * inner product. @@ -393,28 +413,27 @@ ORDER BY {PostgresConstants.DistanceColumnName} // Instead we'll wrap the query in a subquery and modify the distance in the outer query. if (vectorProperty.DistanceFunction == DistanceFunction.DotProductSimilarity) { - commandText = $@" - SELECT {columns}, -1 * ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" - FROM ({commandText}) AS subquery"; + commandText = $""" +SELECT {columns}, -1 * "{PostgresConstants.DistanceColumnName}" AS "{PostgresConstants.DistanceColumnName}" +FROM ({commandText}) AS subquery +"""; } return new PostgresSqlCommandInfo(commandText) { - Parameters = [new NpgsqlParameter() { Value = vectorValue }, .. where.Parameters.Select(p => new NpgsqlParameter() { Value = p })] + Parameters = [new NpgsqlParameter { Value = vectorValue }, .. parameters.Select(p => new NpgsqlParameter { Value = p })] }; } - - internal static (string Clause, List Parameters) GenerateWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter? filter, int startParamIndex) +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + internal static (string Clause, List Parameters) GenerateLegacyFilterWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter legacyFilter, int startParamIndex) { - if (filter == null) { return (string.Empty, new List()); } - var whereClause = new StringBuilder("WHERE "); var filterClauses = new List(); var parameters = new List(); var paramIndex = startParamIndex; - foreach (var filterClause in filter.FilterClauses) + foreach (var filterClause in legacyFilter.FilterClauses) { if (filterClause is EqualToFilterClause equalTo) { @@ -450,4 +469,5 @@ internal static (string Clause, List Parameters) GenerateWhereClause(str whereClause.Append(string.Join(" AND ", filterClauses)); return (whereClause.ToString(), parameters); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 5ef18cc88fdf..b97b24708b25 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -172,21 +174,23 @@ public async Task DeleteAsync(string tableName, string keyColumn, TKey key } /// - public async IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync( - string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, - VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + public async IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync( + string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? legacyFilter = default, Expression>? newFilter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) +#pragma warning restore CS0618 // VectorSearchFilter is obsolete { NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, properties, vectorProperty, vectorValue, filter, skip, includeVectors, limit); + var commandInfo = this._sqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, propertyReader, vectorProperty, vectorValue, legacyFilter, newFilter, skip, includeVectors, limit); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { var distance = dataReader.GetDouble(dataReader.GetOrdinal(PostgresConstants.DistanceColumnName)); - yield return (Row: this.GetRecord(dataReader, properties, includeVectors), Distance: distance); + yield return (Row: this.GetRecord(dataReader, propertyReader.RecordDefinition.Properties, includeVectors), Distance: distance); } } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index de4a432ea48c..fd85896a46d4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -37,7 +37,7 @@ public sealed class PostgresVectorStoreRecordCollection : IVector private readonly IVectorStoreRecordMapper> _mapper; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// /// Initializes a new instance of the class. @@ -250,7 +250,7 @@ public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellat } /// - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "VectorizedSearch"; @@ -261,7 +261,7 @@ public Task> VectorizedSearchAsync(TVector if (!PostgresConstants.SupportedVectorTypes.Contains(vectorType)) { throw new NotSupportedException( - $"The provided vector type {vectorType.FullName} is not supported by the SQLite connector. " + + $"The provided vector type {vectorType.FullName} is not supported by the PostgreSQL connector. " + $"Supported types are: {string.Join(", ", PostgresConstants.SupportedVectorTypes.Select(l => l.FullName))}"); } @@ -285,11 +285,14 @@ public Task> VectorizedSearchAsync(TVector { var results = this._client.GetNearestMatchesAsync( this.CollectionName, - this._propertyReader.RecordDefinition.Properties, + this._propertyReader, vectorProperty, pgVector, searchOptions.Top, +#pragma warning disable CS0618 // VectorSearchFilter is obsolete searchOptions.Filter, +#pragma warning restore CS0618 // VectorSearchFilter is obsolete + searchOptions.NewFilter, searchOptions.Skip, searchOptions.IncludeVectors, cancellationToken) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index 0b36f2003bf5..5e8509236e31 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -143,7 +143,7 @@ public static (string PgType, bool IsNullable) GetPostgresTypeName(Type property // Handle enumerables if (VectorStoreRecordPropertyVerification.IsSupportedEnumerableType(propertyType)) { - Type elementType = propertyType.GetGenericArguments()[0]; + Type elementType = propertyType.IsArray ? propertyType.GetElementType()! : propertyType.GetGenericArguments()[0]; var underlyingPgType = GetPostgresTypeName(elementType); return (underlyingPgType.PgType + "[]", true); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs new file mode 100644 index 000000000000..a918883aa054 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs @@ -0,0 +1,382 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using Google.Protobuf.Collections; +using Qdrant.Client.Grpc; +using Range = Qdrant.Client.Grpc.Range; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +internal class QdrantFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + internal Filter Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + return this.Translate(lambdaExpression.Body); + } + + private Filter Translate(Expression? node) + => node switch + { + BinaryExpression { NodeType: ExpressionType.Equal } equal => this.TranslateEqual(equal.Left, equal.Right), + BinaryExpression { NodeType: ExpressionType.NotEqual } notEqual => this.TranslateEqual(notEqual.Left, notEqual.Right, negated: true), + + BinaryExpression + { + NodeType: ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } comparison + => this.TranslateComparison(comparison), + + BinaryExpression { NodeType: ExpressionType.AndAlso } andAlso => this.TranslateAndAlso(andAlso.Left, andAlso.Right), + BinaryExpression { NodeType: ExpressionType.OrElse } orElse => this.TranslateOrElse(orElse.Left, orElse.Right), + UnaryExpression { NodeType: ExpressionType.Not } not => this.TranslateNot(not.Operand), + + // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) + => this.TranslateEqual(member, Expression.Constant(true)), + + MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), + + _ => throw new NotSupportedException("Qdrant does not support the following NodeType in filters: " + node?.NodeType) + }; + + private Filter TranslateEqual(Expression left, Expression right, bool negated = false) + { + return TryProcessEqual(left, right, out var result) + ? result + : TryProcessEqual(right, left, out result) + ? result + : throw new NotSupportedException("Equality expression not supported by Qdrant"); + + bool TryProcessEqual(Expression first, Expression second, [NotNullWhen(true)] out Filter? result) + { + // TODO: Nullable + if (this.TryTranslateFieldAccess(first, out var storagePropertyName) + && TryGetConstant(second, out var constantValue)) + { + var condition = constantValue is null + ? new Condition { IsNull = new() { Key = storagePropertyName } } + : new Condition + { + Field = new FieldCondition + { + Key = storagePropertyName, + Match = constantValue switch + { + string stringValue => new Match { Keyword = stringValue }, + int intValue => new Match { Integer = intValue }, + long longValue => new Match { Integer = longValue }, + bool boolValue => new Match { Boolean = boolValue }, + + _ => throw new InvalidOperationException($"Unsupported filter value type '{constantValue.GetType().Name}'.") + } + } + }; + + result = new Filter(); + if (negated) + { + result.MustNot.Add(condition); + } + else + { + result.Must.Add(condition); + } + return true; + } + + result = null; + return false; + } + } + + private Filter TranslateComparison(BinaryExpression comparison) + { + return TryProcessComparison(comparison.Left, comparison.Right, out var result) + ? result + : TryProcessComparison(comparison.Right, comparison.Left, out result) + ? result + : throw new NotSupportedException("Comparison expression not supported by Qdrant"); + + bool TryProcessComparison(Expression first, Expression second, [NotNullWhen(true)] out Filter? result) + { + // TODO: Nullable + if (this.TryTranslateFieldAccess(first, out var storagePropertyName) + && TryGetConstant(second, out var constantValue)) + { + double doubleConstantValue = constantValue switch + { + double d => d, + int i => i, + long l => l, + _ => throw new NotSupportedException($"Can't perform comparison on type '{constantValue?.GetType().Name}', which isn't convertible to double") + }; + + result = new Filter(); + result.Must.Add(new Condition + { + Field = new FieldCondition + { + Key = storagePropertyName, + Range = comparison.NodeType switch + { + ExpressionType.GreaterThan => new Range { Gt = doubleConstantValue }, + ExpressionType.GreaterThanOrEqual => new Range { Gte = doubleConstantValue }, + ExpressionType.LessThan => new Range { Lt = doubleConstantValue }, + ExpressionType.LessThanOrEqual => new Range { Lte = doubleConstantValue }, + + _ => throw new InvalidOperationException("Unreachable") + } + } + }); + return true; + } + + result = null; + return false; + } + } + + #region Logical operators + + private Filter TranslateAndAlso(Expression left, Expression right) + { + var leftFilter = this.Translate(left); + var rightFilter = this.Translate(right); + + // As long as there are only AND conditions (Must or MustNot), we can simply combine both filters into a single flat one. + // The moment there's a Should, things become a bit more complicated: + // 1. If a side contains both a Should and a Must/MustNot, it must be pushed down. + // 2. Otherwise, if the left's Should is empty, and the right side is only Should, we can just copy the right Should into the left's. + // 3. Finally, if both sides have a Should, we push down the right side and put the result in the left's Must. + if (leftFilter.Should.Count > 0 && (leftFilter.Must.Count > 0 || leftFilter.MustNot.Count > 0)) + { + leftFilter = new Filter { Must = { new Condition { Filter = leftFilter } } }; + } + + if (rightFilter.Should.Count > 0 && (rightFilter.Must.Count > 0 || rightFilter.MustNot.Count > 0)) + { + rightFilter = new Filter { Must = { new Condition { Filter = rightFilter } } }; + } + + if (rightFilter.Should.Count > 0) + { + if (leftFilter.Should.Count == 0) + { + leftFilter.Should.AddRange(rightFilter.Should); + } + else + { + rightFilter = new Filter { Must = { new Condition { Filter = rightFilter } } }; + } + } + + leftFilter.Must.AddRange(rightFilter.Must); + leftFilter.MustNot.AddRange(rightFilter.MustNot); + + return leftFilter; + } + + private Filter TranslateOrElse(Expression left, Expression right) + { + var leftFilter = this.Translate(left); + var rightFilter = this.Translate(right); + + var result = new Filter(); + result.Should.AddRange(GetShouldConditions(leftFilter)); + result.Should.AddRange(GetShouldConditions(rightFilter)); + return result; + + static RepeatedField GetShouldConditions(Filter filter) + => filter switch + { + { Must.Count: 0, MustNot.Count: 0 } => filter.Should, + { Must.Count: 1, MustNot.Count: 0, Should.Count: 0 } => [filter.Must[0]], + { Must.Count: 0, MustNot.Count: 1, Should.Count: 0 } => [filter.MustNot[0]], + + _ => [new Condition { Filter = filter }] + }; + } + + private Filter TranslateNot(Expression expression) + { + // Special handling for !(a == b) and !(a != b) + if (expression is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + return this.TranslateEqual(binary.Left, binary.Right, negated: binary.NodeType is ExpressionType.Equal); + } + + var filter = this.Translate(expression); + + switch (filter) + { + case { Must.Count: 1, MustNot.Count: 0, Should.Count: 0 }: + filter.MustNot.Add(filter.Must[0]); + filter.Must.RemoveAt(0); + return filter; + + case { Must.Count: 0, MustNot.Count: 1, Should.Count: 0 }: + filter.Must.Add(filter.MustNot[0]); + filter.MustNot.RemoveAt(0); + return filter; + + case { Must.Count: 0, MustNot.Count: 0, Should.Count: > 0 }: + filter.MustNot.AddRange(filter.Should); + filter.Should.Clear(); + return filter; + + default: + return new Filter { MustNot = { new Condition { Filter = filter } } }; + } + } + + #endregion Logical operators + + private Filter TranslateMethodCall(MethodCallExpression methodCall) + => methodCall switch + { + // Enumerable.Contains() + { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable) + => this.TranslateContains(source, item), + + // List.Contains() + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>) + => this.TranslateContains(source, item), + + _ => throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}") + }; + + private Filter TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over field enumerable + case var _ when this.TryTranslateFieldAccess(source, out _): + // Oddly, in Qdrant, tag list contains is handled using a Match condition, just like equality. + return this.TranslateEqual(source, item); + + // Contains over inline enumerable + case NewArrayExpression newArray: + var elements = new object?[newArray.Expressions.Count]; + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + { + throw new NotSupportedException("Invalid element in array"); + } + + elements[i] = elementValue; + } + + return ProcessInlineEnumerable(elements, item); + + // Contains over captured enumerable (we inline) + case var _ when TryGetConstant(source, out var constantEnumerable) + && constantEnumerable is IEnumerable enumerable and not string: + return ProcessInlineEnumerable(enumerable, item); + + default: + throw new NotSupportedException("Unsupported Contains"); + } + + Filter ProcessInlineEnumerable(IEnumerable elements, Expression item) + { + if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + { + throw new NotSupportedException("Unsupported item type in Contains"); + } + + if (item.Type == typeof(string)) + { + var strings = new RepeatedStrings(); + + foreach (var value in elements) + { + strings.Strings.Add(value is string or null + ? (string?)value + : throw new ArgumentException("Non-string element in string Contains array")); + } + + return new Filter { Must = { new Condition { Field = new FieldCondition { Key = storagePropertyName, Match = new Match { Keywords = strings } } } } }; + } + + if (item.Type == typeof(int)) + { + var ints = new RepeatedIntegers(); + + foreach (var value in elements) + { + ints.Integers.Add(value is int intValue + ? intValue + : throw new ArgumentException("Non-int element in string Contains array")); + } + + return new Filter { Must = { new Condition { Field = new FieldCondition { Key = storagePropertyName, Match = new Match { Integers = ints } } } } }; + } + + throw new NotSupportedException("Contains only supported over array of ints or strings"); + } + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs index f2b9c91179e9..ec14ef585dfb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs @@ -12,6 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// internal static class QdrantVectorStoreCollectionSearchMapping { +#pragma warning disable CS0618 // Type or member is obsolete /// /// Build a Qdrant from the provided . /// @@ -19,16 +20,10 @@ internal static class QdrantVectorStoreCollectionSearchMapping /// A mapping of data model property names to the names under which they are stored. /// The Qdrant . /// Thrown when the provided filter contains unsupported types, values or unknown properties. - public static Filter BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) + public static Filter BuildFromLegacyFilter(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) { var filter = new Filter(); - // Return an empty filter if no filter clauses are provided. - if (basicVectorSearchFilter?.FilterClauses is null) - { - return filter; - } - foreach (var filterClause in basicVectorSearchFilter.FilterClauses) { string fieldName; @@ -72,6 +67,7 @@ public static Filter BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IR return filter; } +#pragma warning restore CS0618 // Type or member is obsolete /// /// Map the given to a . diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs index 7dd77b76baff..e51ae549818a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -29,7 +30,7 @@ public sealed class QdrantVectorStoreRecordCollection : IVectorStoreRec ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// The name of this database for telemetry purposes. private const string DatabaseName = "Qdrant"; @@ -457,7 +458,7 @@ private async IAsyncEnumerable GetBatchByPointIdAsync( } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -473,8 +474,16 @@ public async Task> VectorizedSearchAsync(T var internalOptions = options ?? s_defaultVectorSearchOptions; +#pragma warning disable CS0618 // Type or member is obsolete // Build filter object. - var filter = QdrantVectorStoreCollectionSearchMapping.BuildFilter(internalOptions.Filter, this._propertyReader.StoragePropertyNamesMap); + var filter = internalOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(legacyFilter, this._propertyReader.StoragePropertyNamesMap), + { NewFilter: Expression> newFilter } => new QdrantFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + _ => new Filter() + }; +#pragma warning restore CS0618 // Type or member is obsolete // Specify the vector name if named vectors are used. string? vectorName = null; diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs new file mode 100644 index 000000000000..12a28b050c15 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +internal class RedisFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + private readonly StringBuilder _filter = new(); + + internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this._filter.Clear(); + this.Translate(lambdaExpression.Body); + return this._filter.ToString(); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } binary: + this.TranslateEqualityComparison(binary); + return; + + case BinaryExpression { NodeType: ExpressionType.AndAlso } andAlso: + // https://redis.io/docs/latest/develop/interact/search-and-query/query/combined/#and + this._filter.Append('('); + this.Translate(andAlso.Left); + this._filter.Append(' '); + this.Translate(andAlso.Right); + this._filter.Append(')'); + return; + + case BinaryExpression { NodeType: ExpressionType.OrElse } orElse: + // https://redis.io/docs/latest/develop/interact/search-and-query/query/combined/#or + this._filter.Append('('); + this.Translate(orElse.Left); + this._filter.Append(" | "); + this.Translate(orElse.Right); + this._filter.Append(')'); + return; + + case UnaryExpression { NodeType: ExpressionType.Not } not: + this.TranslateNot(not.Operand); + return; + + // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + { + this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))); + return; + } + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + default: + throw new NotSupportedException("Redis does not support the following NodeType in filters: " + node?.NodeType); + } + } + + private void TranslateEqualityComparison(BinaryExpression binary) + { + if (!TryProcessEqualityComparison(binary.Left, binary.Right) && !TryProcessEqualityComparison(binary.Right, binary.Left)) + { + throw new NotSupportedException("Binary expression not supported by Redis"); + } + + bool TryProcessEqualityComparison(Expression first, Expression second) + { + // TODO: Nullable + if (this.TryTranslateFieldAccess(first, out var storagePropertyName) + && TryGetConstant(second, out var constantValue)) + { + // Numeric negation has a special syntax (!=), for the rest we nest in a NOT + if (binary.NodeType is ExpressionType.NotEqual && constantValue is not int or long or float or double) + { + this.TranslateNot(Expression.Equal(first, second)); + return true; + } + + // https://redis.io/docs/latest/develop/interact/search-and-query/query/exact-match + this._filter.Append('@').Append(storagePropertyName); + + this._filter.Append( + binary.NodeType switch + { + ExpressionType.Equal when constantValue is int or long or float or double => $" == {constantValue}", + ExpressionType.Equal when constantValue is string stringValue +#if NETSTANDARD2_0 + => $$""":{"{{stringValue.Replace("\"", "\"\"")}}"}""", +#else + => $$""":{"{{stringValue.Replace("\"", "\\\"", StringComparison.Ordinal)}}"}""", +#endif + ExpressionType.Equal when constantValue is null => throw new NotSupportedException("Null value type not supported"), // TODO + + ExpressionType.NotEqual when constantValue is int or long or float or double => $" != {constantValue}", + ExpressionType.NotEqual => throw new InvalidOperationException("Unreachable"), // Handled above + + ExpressionType.GreaterThan => $" > {constantValue}", + ExpressionType.GreaterThanOrEqual => $" >= {constantValue}", + ExpressionType.LessThan => $" < {constantValue}", + ExpressionType.LessThanOrEqual => $" <= {constantValue}", + + _ => throw new InvalidOperationException("Unsupported equality/comparison") + }); + + return true; + } + + return false; + } + } + + private void TranslateNot(Expression expression) + { + // https://redis.io/docs/latest/develop/interact/search-and-query/query/combined/#not + this._filter.Append("(-"); + this.Translate(expression); + this._filter.Append(')'); + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + // Contains over tag field + if (this.TryTranslateFieldAccess(source, out var storagePropertyName) + && TryGetConstant(item, out var itemConstant) + && itemConstant is string stringConstant) + { + this._filter + .Append('@') + .Append(storagePropertyName) + .Append(":{") + .Append(stringConstant) + .Append('}'); + return; + } + + throw new NotSupportedException("Contains supported only over tag field"); + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs index 41971c5adb86..fb565cad17d8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs @@ -61,7 +61,7 @@ public sealed class RedisHashSetVectorStoreRecordCollection : IVectorSt ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// The Redis database to read/write records from. private readonly IDatabase _database; @@ -328,7 +328,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable reco } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs index f8afa3ed875e..0d5f74d0821a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs @@ -44,7 +44,7 @@ public sealed class RedisJsonVectorStoreRecordCollection : IVectorStore ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// The Redis database to read/write records from. private readonly IDatabase _database; @@ -374,7 +374,7 @@ await this.RunOperationAsync( } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs index d6603ca1634c..ea78a9e798c0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.InteropServices; using Microsoft.Extensions.VectorData; using NRedisStack.Search; @@ -50,14 +51,24 @@ public static byte[] ValidateVectorAndConvertToBytes(TVector vector, st /// The name of the first vector property in the data model. /// The set of fields to limit the results to. Null for all. /// The . - public static Query BuildQuery(byte[] vectorBytes, VectorSearchOptions options, IReadOnlyDictionary storagePropertyNames, string firstVectorPropertyName, string[]? selectFields) + public static Query BuildQuery(byte[] vectorBytes, VectorSearchOptions options, IReadOnlyDictionary storagePropertyNames, string firstVectorPropertyName, string[]? selectFields) { // Resolve options. var vectorPropertyName = ResolveVectorFieldName(options.VectorPropertyName, storagePropertyNames, firstVectorPropertyName); // Build search query. var redisLimit = options.Top + options.Skip; - var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(options.Filter, storagePropertyNames); + +#pragma warning disable CS0618 // Type or member is obsolete + var filter = options switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => BuildLegacyFilter(legacyFilter, storagePropertyNames), + { NewFilter: Expression> newFilter } => new RedisFilterTranslator().Translate(newFilter, storagePropertyNames), + _ => "*" + }; +#pragma warning restore CS0618 // Type or member is obsolete + var query = new Query($"{filter}=>[KNN {redisLimit} @{vectorPropertyName} $embedding AS vector_score]") .AddParam("embedding", vectorBytes) .SetSortBy("vector_score") @@ -80,13 +91,9 @@ public static Query BuildQuery(byte[] vectorBytes, VectorSearchOptions options, /// A mapping of data model property names to the names under which they are stored. /// The Redis filter string. /// Thrown when a provided filter value is not supported. - public static string BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) +#pragma warning disable CS0618 // Type or member is obsolete + public static string BuildLegacyFilter(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) { - if (basicVectorSearchFilter == null) - { - return "*"; - } - var filterClauses = basicVectorSearchFilter.FilterClauses.Select(clause => { if (clause is EqualToFilterClause equalityFilterClause) @@ -116,6 +123,7 @@ public static string BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IR return $"({string.Join(" ", filterClauses)})"; } +#pragma warning restore CS0618 // Type or member is obsolete /// /// Resolve the distance function to use for a search by checking the distance function of the vector property specified in options @@ -126,7 +134,7 @@ public static string BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IR /// The first vector property in the record. /// The distance function for the vector we want to search. /// Thrown when a user asked for a vector property that doesn't exist on the record. - public static string ResolveDistanceFunction(VectorSearchOptions options, IReadOnlyList vectorProperties, VectorStoreRecordVectorProperty firstVectorProperty) + public static string ResolveDistanceFunction(VectorSearchOptions options, IReadOnlyList vectorProperties, VectorStoreRecordVectorProperty firstVectorProperty) { if (options.VectorPropertyName == null || vectorProperties.Count == 1) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs new file mode 100644 index 000000000000..65e6e3d4dce2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.Sqlite; + +internal class SqliteFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly Dictionary _parameters = new(); + + private readonly StringBuilder _sql = new(); + + internal (string Clause, Dictionary) Translate(IReadOnlyDictionary storagePropertyNames, LambdaExpression lambdaExpression) + { + this._storagePropertyNames = storagePropertyNames; + + this._parameters.Clear(); + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this._sql.Clear(); + this.Translate(lambdaExpression.Body); + return (this._sql.ToString(), this._parameters); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + // Special handling for null comparisons + switch (binary.NodeType) + { + case ExpressionType.Equal when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NOT NULL)"); + return; + + case ExpressionType.Equal when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NOT NULL)"); + return; + } + + this._sql.Append('('); + this.Translate(binary.Left); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._sql.Append(')'); + + static bool IsNull(Expression expression) + => expression is ConstantExpression { Value: null } + || (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null); + } + + private void TranslateConstant(ConstantExpression constant) + => this.GenerateLiteral(constant.Value); + + private void GenerateLiteral(object? value) + { + // TODO: Nullable + switch (value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); + return; + case bool b: + this._sql.Append(b ? "TRUE" : "FALSE"); + return; + case Guid g: + this._sql.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime: + case DateTimeOffset: + throw new NotImplementedException(); + + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("NULL"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetColumn(memberExpression, out var column): + this._sql.Append('"').Append(column).Append('"'); + return; + + // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) + case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (value is null) + { + this._sql.Append("NULL"); + } + else + { + // Duplicate parameter name, create a new parameter with a different name + // TODO: Share the same parameter when it references the same captured value + if (this._parameters.ContainsKey(name)) + { + var baseName = name; + var i = 0; + do + { + name = baseName + (i++); + } while (this._parameters.ContainsKey(name)); + } + + this._parameters.Add(name, value); + this._sql.Append('@').Append(name); + } + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + switch (source) + { + // TODO: support Contains over array fields (#10343) + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryGetColumn(source, out _): + goto default; + + // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) + case NewArrayExpression newArray: + { + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in newArray.Expressions) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.Translate(element); + } + + this._sql.Append(')'); + return; + } + + // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) + case var _ when TryGetCapturedValue(source, out _, out var value) && value is IEnumerable elements: + { + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in elements) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.GenerateLiteral(element); + } + + this._sql.Append(')'); + return; + } + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + name = fieldInfo.Name; + value = fieldInfo.GetValue(constant.Value); + return true; + } + + name = null; + value = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs index 028a838487d1..802f468e15c3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs @@ -159,6 +159,8 @@ public DbCommand BuildSelectLeftJoinCommand( IReadOnlyList leftTablePropertyNames, IReadOnlyList rightTablePropertyNames, List conditions, + string? extraWhereFilter = null, + Dictionary? extraParameters = null, string? orderByPropertyName = null) { var builder = new StringBuilder(); @@ -169,7 +171,7 @@ .. leftTablePropertyNames.Select(property => $"{leftTable}.{property}"), .. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"), ]; - var (command, whereClause) = this.GetCommandWithWhereClause(conditions); + var (command, whereClause) = this.GetCommandWithWhereClause(conditions, extraWhereFilter, extraParameters); builder.AppendLine($"SELECT {string.Join(", ", propertyNames)}"); builder.AppendLine($"FROM {leftTable} "); @@ -238,7 +240,10 @@ private static string GetColumnDefinition(SqliteColumn column) return string.Join(" ", columnDefinitionParts); } - private (DbCommand Command, string WhereClause) GetCommandWithWhereClause(List conditions) + private (DbCommand Command, string WhereClause) GetCommandWithWhereClause( + List conditions, + string? extraWhereFilter = null, + Dictionary? extraParameters = null) { const string WhereClauseOperator = " AND "; @@ -263,6 +268,21 @@ private static string GetColumnDefinition(SqliteColumn column) var whereClause = string.Join(WhereClauseOperator, whereClauseParts); + if (extraWhereFilter is not null) + { + if (conditions.Count > 0) + { + whereClause += " AND "; + } + + whereClause += extraWhereFilter; + + foreach (var p in extraParameters!) + { + command.Parameters.Add(new SqliteParameter(p.Key, p.Value)); + } + } + return (command, whereClause); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index 08c976abf43f..8ae095dd3bf0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -34,7 +34,7 @@ public sealed class SqliteVectorStoreRecordCollection : private readonly IVectorStoreRecordMapper> _mapper; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// Command builder for queries in SQLite database. private readonly SqliteVectorStoreCollectionCommandBuilder _commandBuilder; @@ -154,7 +154,7 @@ public async Task DeleteCollectionAsync(CancellationToken cancellationToken = de } /// - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { const string LimitPropertyName = "k"; @@ -189,15 +189,35 @@ public Task> VectorizedSearchAsync(TVector new SqliteWhereEqualsCondition(LimitPropertyName, limit) }; - var filterConditions = this.GetFilterConditions(searchOptions.Filter, this._dataTableName); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + string? extraWhereFilter = null; + Dictionary? extraParameters = null; - if (filterConditions is { Count: > 0 }) + if (searchOptions.Filter is not null) { - conditions.AddRange(filterConditions); + if (searchOptions.NewFilter is not null) + { + throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"); + } + + // Old filter, we translate it to a list of SqliteWhereCondition, and merge these into the conditions we already have + var filterConditions = this.GetFilterConditions(searchOptions.Filter, this._dataTableName); + + if (filterConditions is { Count: > 0 }) + { + conditions.AddRange(filterConditions); + } + } + else if (searchOptions.NewFilter is not null) + { + (extraWhereFilter, extraParameters) = new SqliteFilterTranslator().Translate(this._propertyReader.StoragePropertyNamesMap, searchOptions.NewFilter); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete var vectorSearchResults = new VectorSearchResults(this.EnumerateAndMapSearchResultsAsync( conditions, + extraWhereFilter, + extraParameters, searchOptions, cancellationToken)); @@ -288,7 +308,9 @@ public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancell private async IAsyncEnumerable> EnumerateAndMapSearchResultsAsync( List conditions, - VectorSearchOptions searchOptions, + string? extraWhereFilter, + Dictionary? extraParameters, + VectorSearchOptions searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { const string OperationName = "VectorizedSearch"; @@ -311,6 +333,8 @@ private async IAsyncEnumerable> EnumerateAndMapSearc leftTableProperties, this._dataTableStoragePropertyNames.Value, conditions, + extraWhereFilter, + extraParameters, DistancePropertyName); using var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); @@ -670,6 +694,7 @@ private async Task RunOperationAsync(string operationName, Func> o return new SqliteVectorStoreRecordMapper(this._propertyReader); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete private List? GetFilterConditions(VectorSearchFilter? filter, string? tableName = null) { var filterClauses = filter?.FilterClauses.ToList(); @@ -706,6 +731,7 @@ private async Task RunOperationAsync(string operationName, Func> o return conditions; } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// /// Gets vector table name. diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs new file mode 100644 index 000000000000..8bd7780929b7 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs @@ -0,0 +1,259 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; + +namespace Microsoft.SemanticKernel.Connectors.Weaviate; + +// https://weaviate.io/developers/weaviate/api/graphql/filters#filter-structure +internal class WeaviateFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + private readonly StringBuilder _filter = new(); + + internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this._filter.Clear(); + this.Translate(lambdaExpression.Body); + return this._filter.ToString(); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } binary: + this.TranslateEqualityComparison(binary); + return; + + case BinaryExpression { NodeType: ExpressionType.AndAlso } andAlso: + this._filter.Append("{ operator: And, operands: ["); + this.Translate(andAlso.Left); + this._filter.Append(", "); + this.Translate(andAlso.Right); + this._filter.Append("] }"); + return; + + case BinaryExpression { NodeType: ExpressionType.OrElse } orElse: + this._filter.Append("{ operator: Or, operands: ["); + this.Translate(orElse.Left); + this._filter.Append(", "); + this.Translate(orElse.Right); + this._filter.Append("] }"); + return; + + case UnaryExpression { NodeType: ExpressionType.Not } not: + { + switch (not.Operand) + { + // Special handling for !(a == b) and !(a != b) + case BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary: + this.TranslateEqualityComparison( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + + // Not over bool field (Filter => r => !r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + return; + + default: + throw new NotSupportedException("Weaviate does not support the NOT operator (see https://github.com/weaviate/weaviate/issues/3683)"); + } + } + + // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + default: + throw new NotSupportedException("The following NodeType is unsupported: " + node?.NodeType); + } + } + + private void TranslateEqualityComparison(BinaryExpression binary) + { + if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) + || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + { + // { path: ["intPropName"], operator: Equal, ValueInt: 8 } + this._filter + .Append("{ path: [\"") + .Append(JsonEncodedText.Encode(storagePropertyName)) + .Append("\"], operator: "); + + // Special handling for null comparisons + if (value is null) + { + if (binary.NodeType is ExpressionType.Equal or ExpressionType.NotEqual) + { + this._filter + .Append("IsNull, valueBoolean: ") + .Append(binary.NodeType is ExpressionType.Equal ? "true" : "false") + .Append(" }"); + return; + } + + throw new NotSupportedException("null value supported only with equality/inequality checks"); + } + + // Operator + this._filter.Append(binary.NodeType switch + { + ExpressionType.Equal => "Equal", + ExpressionType.NotEqual => "NotEqual", + + ExpressionType.GreaterThan => "GreaterThan", + ExpressionType.GreaterThanOrEqual => "GreaterThanEqual", + ExpressionType.LessThan => "LessThan", + ExpressionType.LessThanOrEqual => "LessThanEqual", + + _ => throw new UnreachableException() + }); + + this._filter.Append(", "); + + // FieldType + var type = value.GetType(); + if (Nullable.GetUnderlyingType(type) is Type underlying) + { + type = underlying; + } + + this._filter.Append(value.GetType() switch + { + Type t when t == typeof(int) || t == typeof(long) || t == typeof(short) || t == typeof(byte) => "valueInt", + Type t when t == typeof(bool) => "valueBoolean", + Type t when t == typeof(string) || t == typeof(Guid) => "valueText", + Type t when t == typeof(float) || t == typeof(double) || t == typeof(decimal) => "valueNumber", + Type t when t == typeof(DateTimeOffset) => "valueDate", + + _ => throw new NotSupportedException($"Unsupported value type {type.FullName} in filter.") + }); + + this._filter.Append(": "); + + // Value + this._filter.Append(JsonSerializer.Serialize(value)); + + this._filter.Append('}'); + + return; + } + + throw new NotSupportedException("Invalid equality/comparison"); + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + // Contains over array + // { path: ["stringArrayPropName"], operator: ContainsAny, valueText: ["foo"] } + if (this.TryTranslateFieldAccess(source, out var storagePropertyName) + && TryGetConstant(item, out var itemConstant) + && itemConstant is string stringConstant) + { + this._filter + .Append("{ path: [\"") + .Append(JsonEncodedText.Encode(storagePropertyName)) + .Append("\"], operator: ContainsAny, valueText: [") + .Append(JsonEncodedText.Encode(stringConstant)) + .Append("]}"); + return; + } + + throw new NotSupportedException("Contains supported only over tag field"); + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs index a4ba633535a7..d03e2cf83a2e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs @@ -9,7 +9,6 @@ using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; -using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; @@ -75,7 +74,7 @@ public sealed class WeaviateVectorStoreRecordCollection : IVectorStoreR private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + // DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, Converters = { new WeaviateDateTimeOffsetConverter(), @@ -84,7 +83,7 @@ public sealed class WeaviateVectorStoreRecordCollection : IVectorStoreR }; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// that is used to interact with Weaviate API. private readonly HttpClient _httpClient; @@ -335,7 +334,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record /// public async Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = null, + VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "VectorSearch"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs index 397af63763a6..e665e7e85e08 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Text.Json; using Microsoft.Extensions.VectorData; @@ -17,13 +18,13 @@ internal static class WeaviateVectorStoreRecordCollectionQueryBuilder /// Builds Weaviate search query. /// More information here: . /// - public static string BuildSearchQuery( + public static string BuildSearchQuery( TVector vector, string collectionName, string vectorPropertyName, string keyPropertyName, JsonSerializerOptions jsonSerializerOptions, - VectorSearchOptions searchOptions, + VectorSearchOptions searchOptions, IReadOnlyDictionary storagePropertyNames, IReadOnlyList vectorPropertyStorageNames, IReadOnlyList dataPropertyStorageNames) @@ -32,11 +33,19 @@ public static string BuildSearchQuery( $"vectors {{ {string.Join(" ", vectorPropertyStorageNames)} }}" : string.Empty; - var filter = BuildFilter( - searchOptions.Filter, - jsonSerializerOptions, - keyPropertyName, - storagePropertyNames); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + var filter = searchOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => BuildLegacyFilter( + legacyFilter, + jsonSerializerOptions, + keyPropertyName, + storagePropertyNames), + { NewFilter: Expression> newFilter } => new WeaviateFilterTranslator().Translate(newFilter, storagePropertyNames), + _ => null + }; +#pragma warning restore CS0618 var vectorArray = JsonSerializer.Serialize(vector, jsonSerializerOptions); @@ -46,7 +55,7 @@ public static string BuildSearchQuery( {{collectionName}} ( limit: {{searchOptions.Top}} offset: {{searchOptions.Skip}} - {{filter}} + {{(filter is null ? "" : "where: " + filter)}} nearVector: { targetVectors: ["{{vectorPropertyName}}"] vector: {{vectorArray}} @@ -66,11 +75,12 @@ public static string BuildSearchQuery( #region private +#pragma warning disable CS0618 // Type or member is obsolete /// /// Builds filter for Weaviate search query. /// More information here: . /// - private static string BuildFilter( + private static string BuildLegacyFilter( VectorSearchFilter? vectorSearchFilter, JsonSerializerOptions jsonSerializerOptions, string keyPropertyName, @@ -134,8 +144,9 @@ private static string BuildFilter( operands.Add(operand); } - return $$"""where: { operator: And, operands: [{{string.Join(", ", operands)}}] }"""; + return $$"""{ operator: And, operands: [{{string.Join(", ", operands)}}] }"""; } +#pragma warning restore CS0618 // Type or member is obsolete /// /// Gets filter value type. diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs index 8242333ecea5..cea02dee086c 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs @@ -9,6 +9,8 @@ namespace SemanticKernel.Connectors.MongoDB.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Unit tests for class. /// @@ -20,32 +22,6 @@ public sealed class MongoDBVectorStoreCollectionSearchMappingTests ["Property2"] = "property_2", }; - [Fact] - public void BuildFilterWithNullVectorSearchFilterReturnsNull() - { - // Arrange - VectorSearchFilter? vectorSearchFilter = null; - - // Act - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); - - // Assert - Assert.Null(filter); - } - - [Fact] - public void BuildFilterWithoutFilterClausesReturnsNull() - { - // Arrange - VectorSearchFilter vectorSearchFilter = new(); - - // Act - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); - - // Assert - Assert.Null(filter); - } - [Fact] public void BuildFilterThrowsExceptionWithUnsupportedFilterClause() { @@ -53,7 +29,7 @@ public void BuildFilterThrowsExceptionWithUnsupportedFilterClause() var vectorSearchFilter = new VectorSearchFilter().AnyTagEqualTo("NonExistentProperty", "TestValue"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); } [Fact] @@ -63,7 +39,7 @@ public void BuildFilterThrowsExceptionWithNonExistentPropertyName() var vectorSearchFilter = new VectorSearchFilter().EqualTo("NonExistentProperty", "TestValue"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); } [Fact] @@ -75,7 +51,7 @@ public void BuildFilterThrowsExceptionWithMultipleFilterClausesOfSameType() .EqualTo("Property1", "TestValue2"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); } [Fact] @@ -86,7 +62,7 @@ public void BuilderFilterByDefaultReturnsValidFilter() var vectorSearchFilter = new VectorSearchFilter().EqualTo("Property1", "TestValue1"); // Act - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); + var filter = MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames); Assert.Equal(filter.ToJson(), expectedFilter.ToJson()); } diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs index 26a9b9fb00b7..7fa33bbd9967 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs @@ -13,6 +13,7 @@ using MongoDB.Driver; using Moq; using Xunit; +using MEVD = Microsoft.Extensions.VectorData; namespace SemanticKernel.Connectors.MongoDB.UnitTests; @@ -639,7 +640,7 @@ public async Task VectorizedSearchThrowsExceptionWithNonExistentVectorPropertyNa this._mockMongoDatabase.Object, "collection"); - var options = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; + var options = new MEVD.VectorSearchOptions { VectorPropertyName = "non-existent-property" }; // Act & Assert await Assert.ThrowsAsync(async () => await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), options)).Results.FirstOrDefaultAsync()); diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index 675843a78c18..e1958f934c5d 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -366,57 +366,4 @@ public void TestBuildDeleteBatchCommand() // Output this._output.WriteLine(cmdInfo.CommandText); } - - [Fact] - public void TestBuildGetNearestMatchCommand() - { - // Arrange - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - - var vectorProperty = new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) - { - Dimensions = 10, - IndexKind = "hnsw", - }; - - var recordDefinition = new VectorStoreRecordDefinition() - { - Properties = [ - new VectorStoreRecordKeyProperty("id", typeof(long)), - new VectorStoreRecordDataProperty("name", typeof(string)), - new VectorStoreRecordDataProperty("code", typeof(int)), - new VectorStoreRecordDataProperty("rating", typeof(float?)), - new VectorStoreRecordDataProperty("description", typeof(string)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), - new VectorStoreRecordDataProperty("tags", typeof(List)), - vectorProperty, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) - { - Dimensions = 10, - IndexKind = "hnsw", - } - ] - }; - - var vector = new Vector(s_vector); - - // Act - var cmdInfo = builder.BuildGetNearestMatchCommand("public", "testcollection", - properties: recordDefinition.Properties, - vectorProperty: vectorProperty, - vectorValue: vector, - filter: null, - skip: null, - includeVectors: true, - limit: 10); - - // Assert - Assert.Contains("SELECT", cmdInfo.CommandText); - Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); - Assert.Contains("ORDER BY", cmdInfo.CommandText); - Assert.Contains("LIMIT 10", cmdInfo.CommandText); - - // Output - this._output.WriteLine(cmdInfo.CommandText); - } } diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs index 623f997a4ed2..afd5e545030a 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs @@ -10,6 +10,8 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -35,7 +37,7 @@ public void BuildFilterMapsEqualityClause(string type) var filter = new VectorSearchFilter().EqualTo("FieldName", expected); // Act. - var actual = QdrantVectorStoreCollectionSearchMapping.BuildFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); + var actual = QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); // Assert. Assert.Single(actual.Must); @@ -69,7 +71,7 @@ public void BuildFilterMapsTagContainsClause() var filter = new VectorSearchFilter().AnyTagEqualTo("FieldName", "Value"); // Act. - var actual = QdrantVectorStoreCollectionSearchMapping.BuildFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); + var actual = QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); // Assert. Assert.Single(actual.Must); @@ -84,7 +86,7 @@ public void BuildFilterThrowsForUnknownFieldName() var filter = new VectorSearchFilter().EqualTo("FieldName", "Value"); // Act and Assert. - Assert.Throws(() => QdrantVectorStoreCollectionSearchMapping.BuildFilter(filter, new Dictionary())); + Assert.Throws(() => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary())); } [Fact] diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs index 1bb89a91344e..666efcc4647b 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs @@ -545,6 +545,7 @@ public void CanCreateCollectionWithMismatchedDefinitionAndType() new() { VectorStoreRecordDefinition = definition, PointStructCustomMapper = Mock.Of, PointStruct>>() }); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete [Theory] [MemberData(nameof(TestOptions))] public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool hasNamedVectors, TKey testRecordKey) @@ -593,6 +594,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bo Assert.Equal(new float[] { 1, 2, 3, 4 }, results.First().Record.Vector!.Value.ToArray()); Assert.Equal(0.5f, results.First().Score); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete private void SetupRetrieveMock(List retrievedPoints) { diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs index 5457582661ee..fb15d0031c2b 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -415,6 +415,7 @@ public async Task CanUpsertRecordWithCustomMapperAsync() Times.Once); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete [Theory] [InlineData(true, true)] [InlineData(true, false)] @@ -508,6 +509,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc Assert.False(results.First().Record.Vector.HasValue); } } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// /// Tests that the collection can be created even if the definition and the type do not match. diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs index 20d1b0da5831..6cfe1f17960e 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs @@ -16,6 +16,8 @@ namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs index 8253801a8cb7..1301ee6a7eb9 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs @@ -8,6 +8,8 @@ namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -70,7 +72,7 @@ public void BuildQueryBuildsRedisQueryWithDefaults() var firstVectorPropertyName = "storage_Vector"; // Act. - var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(byteArray, new VectorSearchOptions(), storagePropertyNames, firstVectorPropertyName, null); + var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(byteArray, new VectorSearchOptions(), storagePropertyNames, firstVectorPropertyName, null); // Assert. Assert.NotNull(query); @@ -86,7 +88,7 @@ public void BuildQueryBuildsRedisQueryWithCustomVectorName() // Arrange. var floatVector = new ReadOnlyMemory(new float[] { 1.0f, 2.0f, 3.0f }); var byteArray = MemoryMarshal.AsBytes(floatVector.Span).ToArray(); - var vectorSearchOptions = new VectorSearchOptions { Top = 5, Skip = 3, VectorPropertyName = "Vector" }; + var vectorSearchOptions = new VectorSearchOptions { Top = 5, Skip = 3, VectorPropertyName = "Vector" }; var storagePropertyNames = new Dictionary() { { "Vector", "storage_Vector" }, @@ -108,7 +110,7 @@ public void BuildQueryFailsForInvalidVectorName() // Arrange. var floatVector = new ReadOnlyMemory(new float[] { 1.0f, 2.0f, 3.0f }); var byteArray = MemoryMarshal.AsBytes(floatVector.Span).ToArray(); - var vectorSearchOptions = new VectorSearchOptions { VectorPropertyName = "UnknownVector" }; + var vectorSearchOptions = new VectorSearchOptions { VectorPropertyName = "UnknownVector" }; var storagePropertyNames = new Dictionary() { { "Vector", "storage_Vector" }, @@ -149,7 +151,7 @@ public void BuildFilterBuildsEqualityFilter(string filterType) }; // Act. - var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); // Assert. switch (filterType) @@ -184,7 +186,7 @@ public void BuildFilterThrowsForInvalidValueType() // Act & Assert. Assert.Throws(() => { - var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); }); } @@ -201,7 +203,7 @@ public void BuildFilterThrowsForUnknownFieldName() // Act & Assert. Assert.Throws(() => { - var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); }); } @@ -211,7 +213,7 @@ public void ResolveDistanceFunctionReturnsCosineSimilarityIfNoDistanceFunctionSp var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory)); // Act. - var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); // Assert. Assert.Equal(DistanceFunction.CosineSimilarity, resolvedDistanceFunction); @@ -223,7 +225,7 @@ public void ResolveDistanceFunctionReturnsDistanceFunctionFromFirstPropertyIfNoF var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.DotProductSimilarity }; // Act. - var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); // Assert. Assert.Equal(DistanceFunction.DotProductSimilarity, resolvedDistanceFunction); @@ -236,7 +238,7 @@ public void ResolveDistanceFunctionReturnsDistanceFunctionFromChosenPropertyIfFi var property2 = new VectorStoreRecordVectorProperty("Prop2", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.DotProductSimilarity }; // Act. - var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions() { VectorPropertyName = "Prop2" }, [property1, property2], property1); + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions { VectorPropertyName = "Prop2" }, [property1, property2], property1); // Assert. Assert.Equal(DistanceFunction.DotProductSimilarity, resolvedDistanceFunction); @@ -260,4 +262,8 @@ public void GetOutputScoreFromRedisScoreLeavesNonConsineSimilarityUntouched(stri // Act & Assert. Assert.Equal(score, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(score, distanceFunction)); } + +#pragma warning disable CA1812 // An internal class that is apparently never instantiated. If so, remove the code from the assembly. + private sealed class DummyType; +#pragma warning restore CA1812 } diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs index 6c4f8336654f..a0fa8b4f0ae0 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs @@ -10,6 +10,8 @@ namespace SemanticKernel.Connectors.Weaviate.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Unit tests for class. /// @@ -72,7 +74,7 @@ hotelName hotelCode } """; - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -102,7 +104,7 @@ hotelName hotelCode public void BuildSearchQueryWithIncludedVectorsReturnsValidQuery() { // Arrange - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -133,7 +135,7 @@ public void BuildSearchQueryWithFilterReturnsValidQuery() const string ExpectedFirstSubquery = """{ path: ["hotelName"], operator: Equal, valueText: "Test Name" }"""; const string ExpectedSecondSubquery = """{ path: ["tags"], operator: ContainsAny, valueText: ["t1"] }"""; - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -164,7 +166,7 @@ public void BuildSearchQueryWithFilterReturnsValidQuery() public void BuildSearchQueryWithInvalidFilterValueThrowsException() { // Arrange - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -189,7 +191,7 @@ public void BuildSearchQueryWithInvalidFilterValueThrowsException() public void BuildSearchQueryWithNonExistentPropertyInFilterThrowsException() { // Arrange - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -212,6 +214,9 @@ public void BuildSearchQueryWithNonExistentPropertyInFilterThrowsException() #region private +#pragma warning disable CA1812 // An internal class that is apparently never instantiated. If so, remove the code from the assembly. + private sealed class DummyType; +#pragma warning restore CA1812 private sealed class TestFilterValue; #endregion diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs index 0871c4978977..8f7ea996101d 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs @@ -530,11 +530,12 @@ public async Task VectorizedSearchWithNonExistentVectorPropertyNameThrowsExcepti // Arrange var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); - var searchOptions = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; - // Act & Assert await Assert.ThrowsAsync(async () => - await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), searchOptions)).Results.ToListAsync()); + await (await sut.VectorizedSearchAsync( + new ReadOnlyMemory([1f, 2f, 3f]), + new() { VectorPropertyName = "non-existent-property" })) + .Results.ToListAsync()); } public void Dispose() diff --git a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/AnyTagEqualToFilterClause.cs b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/AnyTagEqualToFilterClause.cs index 49ffce328e5e..0e001bc0cfae 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/AnyTagEqualToFilterClause.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/AnyTagEqualToFilterClause.cs @@ -1,10 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System; + namespace Microsoft.Extensions.VectorData; /// /// which filters by checking if a field consisting of a list of values contains a specific value. /// +[Obsolete("Use VectorSearchOptions.NewFilter instead of VectorSearchOptions.Filter")] public sealed class AnyTagEqualToFilterClause : FilterClause { /// diff --git a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/EqualToFilterClause.cs b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/EqualToFilterClause.cs index a0eb45c0fbe3..cef8a9670276 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/EqualToFilterClause.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/EqualToFilterClause.cs @@ -1,10 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System; + namespace Microsoft.Extensions.VectorData; /// /// which filters using equality of a field value. /// +[Obsolete("Use VectorSearchOptions.NewFilter instead of VectorSearchOptions.Filter")] public sealed class EqualToFilterClause : FilterClause { /// diff --git a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/FilterClause.cs b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/FilterClause.cs index 4392893f16e3..40c7b291fd10 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/FilterClause.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/FilterClause.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; + namespace Microsoft.Extensions.VectorData; /// @@ -9,6 +11,7 @@ namespace Microsoft.Extensions.VectorData; /// A is used to request that the underlying search service should /// filter search results based on the specified criteria. /// +[Obsolete("Use VectorSearchOptions.NewFilter instead of VectorSearchOptions.Filter")] public abstract class FilterClause { internal FilterClause() diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs index a0d5181b7668..5e39a541ef86 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs @@ -20,6 +20,6 @@ public interface IVectorizableTextSearch /// The records found by the vector search, including their result scores. Task> VectorizableTextSearchAsync( string searchText, - VectorSearchOptions? options = default, + VectorSearchOptions? options = default, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs index 9ac93383b18d..3286fafc15fc 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs @@ -21,6 +21,6 @@ public interface IVectorizedSearch /// The records found by the vector search, including their result scores. Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = default, + VectorSearchOptions? options = default, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchFilter.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchFilter.cs index a8b941776eff..9d167fcb160b 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchFilter.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchFilter.cs @@ -14,6 +14,7 @@ namespace Microsoft.Extensions.VectorData; /// to request that the underlying service filter the search results. /// All clauses are combined with and. /// +[Obsolete("Use VectorSearchOptions.NewFilter instead of VectorSearchOptions.Filter")] public sealed class VectorSearchFilter { /// The filter clauses to and together. diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs index a5773b0cc606..65d9c6e157c2 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs @@ -1,17 +1,26 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Linq.Expressions; + namespace Microsoft.Extensions.VectorData; /// /// Options for vector search. /// -public class VectorSearchOptions +public class VectorSearchOptions { /// /// Gets or sets a search filter to use before doing the vector search. /// + [Obsolete("Use NewFilter instead")] public VectorSearchFilter? Filter { get; init; } + /// + /// Gets or sets a search filter to use before doing the vector search. + /// + public Expression>? NewFilter { get; init; } + /// /// Gets or sets the name of the vector property to search on. /// Use the name of the vector property from your data model or as provided in the record definition. diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs index e3a420a789f4..f7fb10081c76 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs @@ -14,6 +14,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureAISearch; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Integration tests for class. /// Tests work with an Azure AI Search Instance. @@ -63,7 +65,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe var embedding = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"); var actual = await sut.VectorizedSearchAsync( embedding, - new VectorSearchOptions + new() { IncludeVectors = true, Filter = new VectorSearchFilter().EqualTo("HotelName", "MyHotel Upsert-1") diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs index c5929e0ecaa2..7f471405b8c9 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs @@ -12,6 +12,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBMongoDB; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("AzureCosmosDBMongoDBVectorStoreCollection")] public class AzureCosmosDBMongoDBVectorStoreRecordCollectionTests(AzureCosmosDBMongoDBVectorStoreFixture fixture) { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs index 6a0e249f4d7e..3864a48288ef 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs @@ -13,6 +13,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureCosmosDBNoSQL; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Integration tests for class. /// diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs index 11da55ba3329..fdf07a1acd43 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs @@ -12,13 +12,15 @@ namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("MongoDBVectorStoreCollection")] public class MongoDBVectorStoreRecordCollectionTests(MongoDBVectorStoreFixture fixture) { // If null, all tests will be enabled private const string? SkipReason = "The tests are for manual verification."; - [Theory(Skip = SkipReason)] + [Theory] [InlineData("sk-test-hotels", true)] [InlineData("nonexistentcollection", false)] public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs index e30b2f35fbae..7e19c73128d0 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs @@ -15,6 +15,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("PineconeVectorStoreTests")] [PineconeApiKeySetCondition] public class PineconeVectorStoreRecordCollectionTests(PineconeVectorStoreFixture fixture) : IClassFixture @@ -293,7 +295,7 @@ public async Task InsertGetModifyDeleteVectorAsync(bool collectionFromVectorStor // update await hotelRecordCollection.UpsertAsync(langriSha); - // this is not great but no vectors are added so we can't query status for number of vectors like we do for insert/delete + // this is not great but no vectors are added so we can't query status for number of vectors like we do for insert/delete await Task.Delay(2000); var updated = await hotelRecordCollection.GetAsync("langri-sha", new GetRecordOptions { IncludeVectors = true }); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 7e3ae3ad9392..6a479f0b10bf 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -10,6 +10,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("PostgresVectorStoreCollection")] public sealed class PostgresVectorStoreRecordCollectionTests(PostgresVectorStoreFixture fixture) { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs index 135d09d025aa..940687525238 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs @@ -15,6 +15,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -66,7 +68,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool hasNamedVec var vector = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"); var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions { Filter = new VectorSearchFilter().EqualTo("HotelCode", 30).AnyTagEqualTo("Tags", "t2") }); + new() { Filter = new VectorSearchFilter().EqualTo("HotelCode", 30).AnyTagEqualTo("Tags", "t2") }); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs index ef7ba087cf87..61018b2b7589 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -13,6 +13,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -65,7 +67,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe var actual = await sut .VectorizedSearchAsync( new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }), - new VectorSearchOptions { Filter = new VectorSearchFilter().EqualTo("HotelCode", 1), IncludeVectors = true }); + new() { Filter = new VectorSearchFilter().EqualTo("HotelCode", 1), IncludeVectors = true }); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); @@ -316,7 +318,7 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType, // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { IncludeVectors = includeVectors, Filter = filter @@ -360,7 +362,7 @@ public async Task ItCanSearchWithFloat32VectorAndTopSkipAsync() // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { Top = 3, Skip = 2 @@ -390,7 +392,7 @@ public async Task ItCanSearchWithFloat64VectorAsync(bool includeVectors) // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { IncludeVectors = includeVectors, Top = 1 diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs index 1e6c3d9aed0e..a12d710d9446 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs @@ -13,6 +13,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -64,7 +66,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe var getResult = await sut.GetAsync("Upsert-10", new GetRecordOptions { IncludeVectors = true }); var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }), - new VectorSearchOptions { Filter = new VectorSearchFilter().EqualTo("HotelCode", 10) }); + new() { Filter = new VectorSearchFilter().EqualTo("HotelCode", 10) }); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); @@ -346,7 +348,7 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType) // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions { IncludeVectors = true, Filter = filter }); + new() { IncludeVectors = true, Filter = filter }); // Assert var searchResults = await actual.Results.ToListAsync(); @@ -384,7 +386,7 @@ public async Task ItCanSearchWithFloat32VectorAndTopSkipAsync() // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { Top = 3, Skip = 2 @@ -414,7 +416,7 @@ public async Task ItCanSearchWithFloat64VectorAsync(bool includeVectors) // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { IncludeVectors = includeVectors, Top = 1 diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs index 214510438d59..c0dbb5fcf680 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs @@ -10,6 +10,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Integration tests for class. /// diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs index 9ffaf3172eec..bd6348932937 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs @@ -10,6 +10,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Weaviate; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("WeaviateVectorStoreCollection")] public sealed class WeaviateVectorStoreRecordCollectionTests(WeaviateVectorStoreFixture fixture) { diff --git a/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs b/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs index 90ce87f14482..143c61f69e5f 100644 --- a/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs +++ b/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs @@ -102,7 +102,7 @@ public Task>> GenerateEmbeddingsAsync(IList protected sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch { /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs b/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs new file mode 100644 index 000000000000..616073f54705 --- /dev/null +++ b/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NETSTANDARD2_0 + +// Polyfill for using UnreachableException with .NET Standard 2.0 + +namespace System.Diagnostics; + +#pragma warning disable CA1064 // Exceptions should be public +#pragma warning disable CA1812 // Internal class that is (sometimes) never instantiated. + +/// +/// Exception thrown when the program executes an instruction that was thought to be unreachable. +/// +internal sealed class UnreachableException : Exception +{ + private const string MessageText = "The program executed an instruction that was thought to be unreachable."; + + /// + /// Initializes a new instance of the class with the default error message. + /// + public UnreachableException() + : base(MessageText) + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message. + /// + /// The error message that explains the reason for the exception. + public UnreachableException(string? message) + : base(message ?? MessageText) + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message and a reference to the inner exception that is the cause of + /// this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. + public UnreachableException(string? message, Exception? innerException) + : base(message ?? MessageText, innerException) + { + } +} + +#endif diff --git a/dotnet/src/InternalUtilities/src/System/IndexRange.cs b/dotnet/src/InternalUtilities/src/System/IndexRange.cs new file mode 100644 index 000000000000..439e6e844fb6 --- /dev/null +++ b/dotnet/src/InternalUtilities/src/System/IndexRange.cs @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NETSTANDARD2_0 + +// Polyfill for using Index and Range with .NET Standard 2.0 (see https://www.meziantou.net/how-to-use-csharp-8-indices-and-ranges-in-dotnet-standard-2-0-and-dotn.htm) + +// https://github.com/dotnet/runtime/blob/419e949d258ecee4c40a460fb09c66d974229623/src/libraries/System.Private.CoreLib/src/System/Index.cs +// https://github.com/dotnet/runtime/blob/419e949d258ecee4c40a460fb09c66d974229623/src/libraries/System.Private.CoreLib/src/System/Range.cs + +#pragma warning disable RCS1168 +#pragma warning disable RCS1211 +#pragma warning disable IDE0009 +#pragma warning disable IDE0011 +#pragma warning disable IDE0090 + +using System.Runtime.CompilerServices; + +namespace System +{ + /// Represent a type can be used to index a collection either from the start or the end. + /// + /// Index is used by the C# compiler to support the new index syntax + /// + /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 } ; + /// int lastElement = someArray[^1]; // lastElement = 5 + /// + /// + internal readonly struct Index : IEquatable + { + private readonly int _value; + + /// Construct an Index using a value and indicating if the index is from the start or from the end. + /// The index value. it has to be zero or positive number. + /// Indicating if the index is from the start or from the end. + /// + /// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Index(int value, bool fromEnd = false) + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); + } + + if (fromEnd) + _value = ~value; + else + _value = value; + } + + // The following private constructors mainly created for perf reason to avoid the checks + private Index(int value) + { + _value = value; + } + + /// Create an Index pointing at first element. + public static Index Start => new Index(0); + + /// Create an Index pointing at beyond last element. + public static Index End => new Index(~0); + + /// Create an Index from the start at the position indicated by the value. + /// The index value from the start. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Index FromStart(int value) + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); + } + + return new Index(value); + } + + /// Create an Index from the end at the position indicated by the value. + /// The index value from the end. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Index FromEnd(int value) + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); + } + + return new Index(~value); + } + + /// Returns the index value. + public int Value + { + get + { + if (_value < 0) + { + return ~_value; + } + else + { + return _value; + } + } + } + + /// Indicates whether the index is from the start or the end. + public bool IsFromEnd => _value < 0; + + /// Calculate the offset from the start using the giving collection length. + /// The length of the collection that the Index will be used with. length has to be a positive value + /// + /// For performance reason, we don't validate the input length parameter and the returned offset value against negative values. + /// we don't validate either the returned offset is greater than the input length. + /// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and + /// then used to index a collection will get out of range exception which will be same affect as the validation. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int GetOffset(int length) + { + var offset = _value; + if (IsFromEnd) + { + // offset = length - (~value) + // offset = length + (~(~value) + 1) + // offset = length + value + 1 + + offset += length + 1; + } + return offset; + } + + /// Indicates whether the current Index object is equal to another object of the same type. + /// An object to compare with this object + public override bool Equals(object? value) => value is Index && _value == ((Index)value)._value; + + /// Indicates whether the current Index object is equal to another Index object. + /// An object to compare with this object + public bool Equals(Index other) => _value == other._value; + + /// Returns the hash code for this instance. + public override int GetHashCode() => _value; + + /// Converts integer number to an Index. + public static implicit operator Index(int value) => FromStart(value); + + /// Converts the value of the current Index object to its equivalent string representation. + public override string ToString() + { + if (IsFromEnd) + return "^" + ((uint)Value).ToString(); + + return ((uint)Value).ToString(); + } + } + + /// Represent a range has start and end indexes. + /// + /// Range is used by the C# compiler to support the range syntax. + /// + /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 }; + /// int[] subArray1 = someArray[0..2]; // { 1, 2 } + /// int[] subArray2 = someArray[1..^0]; // { 2, 3, 4, 5 } + /// + /// + internal readonly struct Range : IEquatable + { + /// Represent the inclusive start index of the Range. + public Index Start { get; } + + /// Represent the exclusive end index of the Range. + public Index End { get; } + + /// Construct a Range object using the start and end indexes. + /// Represent the inclusive start index of the range. + /// Represent the exclusive end index of the range. + public Range(Index start, Index end) + { + Start = start; + End = end; + } + + /// Indicates whether the current Range object is equal to another object of the same type. + /// An object to compare with this object + public override bool Equals(object? value) => + value is Range r && + r.Start.Equals(Start) && + r.End.Equals(End); + + /// Indicates whether the current Range object is equal to another Range object. + /// An object to compare with this object + public bool Equals(Range other) => other.Start.Equals(Start) && other.End.Equals(End); + + /// Returns the hash code for this instance. + public override int GetHashCode() + { + return Start.GetHashCode() * 31 + End.GetHashCode(); + } + + /// Converts the value of the current Range object to its equivalent string representation. + public override string ToString() + { + return Start + ".." + End; + } + + /// Create a Range object starting from start index to the end of the collection. + public static Range StartAt(Index start) => new Range(start, Index.End); + + /// Create a Range object starting from first element in the collection to the end Index. + public static Range EndAt(Index end) => new Range(Index.Start, end); + + /// Create a Range object starting from first element to the end. + public static Range All => new Range(Index.Start, Index.End); + + /// Calculate the start offset and length of range object using a collection length. + /// The length of the collection that the range will be used with. length has to be a positive value. + /// + /// For performance reason, we don't validate the input length parameter against negative values. + /// It is expected Range will be used with collections which always have non negative length/count. + /// We validate the range is inside the length scope though. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public (int Offset, int Length) GetOffsetAndLength(int length) + { + int start; + var startIndex = Start; + if (startIndex.IsFromEnd) + start = length - startIndex.Value; + else + start = startIndex.Value; + + int end; + var endIndex = End; + if (endIndex.IsFromEnd) + end = length - endIndex.Value; + else + end = endIndex.Value; + + if ((uint)end > (uint)length || (uint)start > (uint)end) + { + throw new ArgumentOutOfRangeException(nameof(length)); + } + + return (start, end - start); + } + } +} + +namespace System.Runtime.CompilerServices +{ + internal static class RuntimeHelpers + { + /// + /// Slices the specified array using the specified range. + /// + public static T[] GetSubArray(T[] array, Range range) + { + if (array == null) + { + throw new ArgumentNullException(nameof(array)); + } + + (int offset, int length) = range.GetOffsetAndLength(array.Length); + + if (default(T) != null || typeof(T[]) == array.GetType()) + { + // We know the type of the array to be exactly T[]. + + if (length == 0) + { + return Array.Empty(); + } + + var dest = new T[length]; + Array.Copy(array, offset, dest, 0, length); + return dest; + } + else + { + // The array is actually a U[] where U:T. + var dest = (T[])Array.CreateInstance(array.GetType().GetElementType(), length); + Array.Copy(array, offset, dest, 0, length); + return dest; + } + } + } +} + +#endif diff --git a/dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs b/dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs index 97526f388b17..556e04f148d3 100644 --- a/dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs +++ b/dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs @@ -241,6 +241,7 @@ public TextSearchResult MapFromResultToTextSearchResult(object result) } } +#pragma warning disable CS0618 // FilterClause is obsolete /// /// Build a query string from the /// @@ -280,5 +281,7 @@ private static string BuildQuery(string query, TextSearchOptions searchOptions) return fullQuery.ToString(); } +#pragma warning restore CS0618 // FilterClause is obsolete + #endregion } diff --git a/dotnet/src/Plugins/Plugins.Web/Google/GoogleTextSearch.cs b/dotnet/src/Plugins/Plugins.Web/Google/GoogleTextSearch.cs index a42500fa7c4e..c4165a2edadc 100644 --- a/dotnet/src/Plugins/Plugins.Web/Google/GoogleTextSearch.cs +++ b/dotnet/src/Plugins/Plugins.Web/Google/GoogleTextSearch.cs @@ -160,6 +160,7 @@ public void Dispose() return await search.ExecuteAsync(cancellationToken).ConfigureAwait(false); } +#pragma warning disable CS0618 // FilterClause is obsolete /// /// Add basic filters to the Google search metadata. /// @@ -192,6 +193,7 @@ private void AddFilters(CseResource.ListRequest search, TextSearchOptions search } } } +#pragma warning restore CS0618 // FilterClause is obsolete /// /// Return the search results as instances of . diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/TextSearch/TextSearchFilter.cs b/dotnet/src/SemanticKernel.Abstractions/Data/TextSearch/TextSearchFilter.cs index d964fb1ecba1..bb679eb7573b 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Data/TextSearch/TextSearchFilter.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Data/TextSearch/TextSearchFilter.cs @@ -6,6 +6,8 @@ namespace Microsoft.SemanticKernel.Data; +#pragma warning disable CS0618 // FilterClause is obsolete + /// /// Used to provide filtering when using . /// diff --git a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs index 454d82ace013..b39976adbebf 100644 --- a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs +++ b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs @@ -13,7 +13,7 @@ public MockVectorizableTextSearch(IEnumerable> searc this._searchResults = ToAsyncEnumerable(searchResults); } - public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { return Task.FromResult(new VectorSearchResults(this._searchResults)); } diff --git a/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs b/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs index 6970294723ef..3cf8528ea169 100644 --- a/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs +++ b/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs @@ -188,6 +188,7 @@ private TextSearchStringMapper CreateTextSearchStringMapper() }); } +#pragma warning disable CS0618 // FilterClause is obsolete /// /// Execute a vector search and return the results. /// @@ -197,7 +198,7 @@ private TextSearchStringMapper CreateTextSearchStringMapper() private async Task> ExecuteVectorSearchAsync(string query, TextSearchOptions? searchOptions, CancellationToken cancellationToken) { searchOptions ??= new TextSearchOptions(); - var vectorSearchOptions = new VectorSearchOptions + var vectorSearchOptions = new VectorSearchOptions { Filter = searchOptions.Filter?.FilterClauses is not null ? new VectorSearchFilter(searchOptions.Filter.FilterClauses) : null, Skip = searchOptions.Skip, @@ -213,6 +214,7 @@ private async Task> ExecuteVectorSearchAsync(string return await this._vectorizableTextSearch!.VectorizableTextSearchAsync(query, vectorSearchOptions, cancellationToken).ConfigureAwait(false); } +#pragma warning restore CS0618 // FilterClause is obsolete /// /// Return the search results as instances of TRecord. diff --git a/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs index da062934cfbb..e94f321eed4a 100644 --- a/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs +++ b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs @@ -31,7 +31,7 @@ public sealed class VolatileVectorStoreRecordCollection : IVector ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// Internal storage for all of the record collections. private readonly ConcurrentDictionary> _internalCollections; @@ -213,7 +213,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record /// #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - Need to satisfy the interface which returns IAsyncEnumerable - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) #pragma warning restore CS1998 { Verify.NotNull(vector); @@ -238,6 +238,11 @@ public async Task> VectorizedSearchAsync(T } // Filter records using the provided filter before doing the vector comparison. + if (internalOptions.NewFilter is not null) + { + throw new NotSupportedException("LINQ-based filtering is not supported with VolatileVectorStore, use Microsoft.SemanticKernel.Connectors.InMemory instead"); + } + var filteredRecords = VolatileVectorStoreCollectionSearchMapping.FilterRecords(internalOptions.Filter, this.GetCollectionDictionary().Values); // Compare each vector in the filtered results with the provided vector. diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs index 262289c567d0..c01fe06eddf4 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs @@ -126,7 +126,7 @@ public Task>> GenerateEmbeddingsAsync(IList public sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch { /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); return await vectorizedSearch.VectorizedSearchAsync(vectorizedQuery, options, cancellationToken); diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs index 9530c48fe574..edd169a725ff 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs @@ -294,7 +294,7 @@ public async Task CanSearchWithVectorAsync(bool useDefinition, TKey testKe // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true }, + new() { IncludeVectors = true }, this._testCancellationToken); // Assert @@ -338,7 +338,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, TK var filter = filterType == "Equality" ? new VectorSearchFilter().EqualTo("Data", $"data {testKey2}") : new VectorSearchFilter().AnyTagEqualTo("Tags", $"tag {testKey2}"); var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true, Filter = filter, IncludeTotalCount = true }, + new() { IncludeVectors = true, Filter = filter, IncludeTotalCount = true }, this._testCancellationToken); // Assert @@ -390,7 +390,7 @@ public async Task CanSearchWithDifferentDistanceFunctionsAsync(string distanceFu // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true }, + new() { IncludeVectors = true }, this._testCancellationToken); // Assert @@ -431,7 +431,7 @@ public async Task CanSearchManyRecordsAsync(bool useDefinition) // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, + new() { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, this._testCancellationToken); // Assert @@ -507,7 +507,7 @@ public async Task ItCanSearchUsingTheGenericDataModelAsync(TKey testKey1, // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory([1, 1, 1, 1]), - new VectorSearchOptions { IncludeVectors = true, VectorPropertyName = "Vector" }, + new() { IncludeVectors = true, VectorPropertyName = "Vector" }, this._testCancellationToken); // Assert diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj new file mode 100644 index 000000000000..c575ad645e31 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj @@ -0,0 +1,31 @@ + + + + net8.0 + enable + enable + true + false + AzureAISearchIntegrationTests + b7762d10-e29b-4bb1-8b74-b6d69a667dd4 + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs new file mode 100644 index 000000000000..9683543d3e98 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace AzureAISearchIntegrationTests.Filter; + +public class AzureAISearchBasicFilterTests(AzureAISearchFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + // Azure AI Search only supports search.in() over strings + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs new file mode 100644 index 000000000000..0dfc4ce4238e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using AzureAISearchIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace AzureAISearchIntegrationTests.Filter; + +public class AzureAISearchFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => AzureAISearchTestStore.Instance; + + // Azure AI search only supports lowercase letters, digits or dashes. + protected override string StoreName => "filter-tests"; + + public override async Task DisposeAsync() + => await base.DisposeAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..786c2742c2b3 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: AzureAISearchIntegrationTests.Support.AzureAISearchUrlRequiredAttribute] diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs new file mode 100644 index 000000000000..27e905656870 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; + +namespace AzureAISearchIntegrationTests.Support; + +#pragma warning disable CA1810 // Initialize all static fields when those fields are declared + +internal static class AzureAISearchTestEnvironment +{ + public static readonly string? ServiceUrl, ApiKey; + + public static bool IsConnectionInfoDefined => ServiceUrl is not null && ApiKey is not null; + + static AzureAISearchTestEnvironment() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + var azureAISearchSection = configuration.GetSection("AzureAISearch"); + ServiceUrl = azureAISearchSection?["ServiceUrl"]; + ApiKey = azureAISearchSection?["ApiKey"]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs new file mode 100644 index 000000000000..791005d55c9a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Azure; +using Azure.Search.Documents.Indexes; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using VectorDataSpecificationTests.Support; + +namespace AzureAISearchIntegrationTests.Support; + +internal sealed class AzureAISearchTestStore : TestStore +{ + public static AzureAISearchTestStore Instance { get; } = new(); + + private SearchIndexClient? _client; + private AzureAISearchVectorStore? _defaultVectorStore; + + public SearchIndexClient Client + => this._client ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public override IVectorStore DefaultVectorStore + => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public AzureAISearchVectorStore GetVectorStore(AzureAISearchVectorStoreOptions options) + => new(this.Client, options); + + private AzureAISearchTestStore() + { + } + + protected override Task StartAsync() + { + (string? serviceUrl, string? apiKey) = (AzureAISearchTestEnvironment.ServiceUrl, AzureAISearchTestEnvironment.ApiKey); + + if (string.IsNullOrWhiteSpace(serviceUrl) || string.IsNullOrWhiteSpace(apiKey)) + { + throw new InvalidOperationException("Service URL and API key are not configured, set AzureAISearch:ServiceUrl and AzureAISearch:ApiKey"); + } + + this._client = new SearchIndexClient(new Uri(serviceUrl), new AzureKeyCredential(apiKey)); + this._defaultVectorStore = new(this._client); + + return Task.CompletedTask; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs new file mode 100644 index 000000000000..1b30639bc1be --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Xunit; + +namespace AzureAISearchIntegrationTests.Support; + +/// +/// Checks whether the sqlite_vec extension is properly installed, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class AzureAISearchUrlRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(AzureAISearchTestEnvironment.IsConnectionInfoDefined); + + public string Skip { get; set; } = "Service URL and API key are not configured, set AzureAISearch:ServiceUrl and AzureAISearch:ApiKey."; + + public string SkipReason + => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj new file mode 100644 index 000000000000..dbd8e1d18a24 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj @@ -0,0 +1,29 @@ + + + + net8.0 + enable + enable + true + false + MongoDBIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFilterFixture.cs new file mode 100644 index 000000000000..129c7b0cc337 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFilterFixture.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Filter; + +public class CosmosMongoFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => CosmosMongoDBTestStore.Instance; + + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.IvfFlat; + protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFiltersNotSupported.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFiltersNotSupported.cs new file mode 100644 index 000000000000..47b6c4720cf4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFiltersNotSupported.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDB.Driver; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace MongoDBIntegrationTests.Filter; + +public class CosmosMongoFiltersNotSupported(CosmosMongoFilterFixture fixture) : IClassFixture +{ + [ConditionalFact] + public virtual async Task Equal_with_int() + { + // Cosmos MongoDB vCore doesn't yet support filters with vector search: + // Command aggregate failed: $filter is not supported for vector search yet.. + await Assert.ThrowsAsync(() => fixture.Collection.VectorizedSearchAsync( + new ReadOnlyMemory([1, 2, 3]), + new() + { + NewFilter = r => r.Int == 8, + Top = fixture.TestData.Count + })); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..4e8438d68759 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: CosmosIntegrationTests.Support.CosmosConnectionStringRequired] diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs new file mode 100644 index 000000000000..c944d36eb78c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using VectorDataSpecificationTests.Xunit; + +namespace CosmosIntegrationTests.Support; + +/// +/// Checks whether the sqlite_vec extension is properly installed, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class CosmosConnectionStringRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(CosmosMongoDBTestEnvironment.IsConnectionStringDefined); + + public string Skip { get; set; } = "The Cosmos connection string hasn't been configured (AzureCosmosDBMongoDB:ConnectionString)."; + + public string SkipReason + => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs new file mode 100644 index 000000000000..1adcb225e66d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; + +namespace CosmosNoSQLIntegrationTests.Support; + +#pragma warning disable CA1810 // Initialize all static fields when those fields are declared + +public static class CosmosMongoDBTestEnvironment +{ + public static readonly string? ConnectionString; + + public static bool IsConnectionStringDefined => ConnectionString is not null; + + static CosmosMongoDBTestEnvironment() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .Build(); + + ConnectionString = configuration["AzureCosmosDBMongoDB:ConnectionString"]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs new file mode 100644 index 000000000000..8432901efb73 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +using MongoDB.Driver; +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Support; + +public sealed class CosmosMongoDBTestStore : TestStore +{ + public static CosmosMongoDBTestStore Instance { get; } = new(); + + public MongoClient? _client { get; private set; } + public IMongoDatabase? _database { get; private set; } + private AzureCosmosDBMongoDBVectorStore? _defaultVectorStore; + + public MongoClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); + public IMongoDatabase Database => this._database ?? throw new InvalidOperationException("Not initialized"); + + public override IVectorStore DefaultVectorStore + => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public AzureCosmosDBMongoDBVectorStore GetVectorStore(AzureCosmosDBMongoDBVectorStoreOptions options) + => new(this.Database, options); + + private CosmosMongoDBTestStore() + { + } + + protected override Task StartAsync() + { + if (string.IsNullOrWhiteSpace(CosmosMongoDBTestEnvironment.ConnectionString)) + { + throw new InvalidOperationException("Connection string is not configured, set the AzureCosmosDBMongoDB:ConnectionString environment variable"); + } + + this._client = new MongoClient(CosmosMongoDBTestEnvironment.ConnectionString); + this._database = this._client.GetDatabase("VectorSearchTests"); + this._defaultVectorStore = new(this._database); + + return Task.CompletedTask; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj new file mode 100644 index 000000000000..782d7e328e44 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj @@ -0,0 +1,29 @@ + + + + net8.0 + enable + enable + true + false + CosmosNoSQLIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs new file mode 100644 index 000000000000..b67141d82e6c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace CosmosNoSQLIntegrationTests.Filter; + +public class CosmosNoSQLBasicFilterTests(CosmosNoSQLFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture; diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLFilterFixture.cs new file mode 100644 index 000000000000..8aaf6b86d4f9 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace CosmosNoSQLIntegrationTests.Filter; + +public class CosmosNoSQLFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => CosmosNoSqlTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..183a8a7c926c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: CosmosNoSQLIntegrationTests.Support.CosmosConnectionStringRequired] diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs new file mode 100644 index 000000000000..2183f166d3ec --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Xunit; + +namespace CosmosNoSQLIntegrationTests.Support; + +/// +/// Checks whether the sqlite_vec extension is properly installed, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class CosmosConnectionStringRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(CosmosNoSQLTestEnvironment.IsConnectionStringDefined); + + public string Skip { get; set; } = "The Cosmos connection string hasn't been configured (AzureCosmosDBNoSQL:ConnectionString)."; + + public string SkipReason + => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestEnvironment.cs new file mode 100644 index 000000000000..bd2848a2cb8f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestEnvironment.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; + +namespace CosmosNoSQLIntegrationTests.Support; + +#pragma warning disable CA1810 // Initialize all static fields when those fields are declared + +internal static class CosmosNoSQLTestEnvironment +{ + public static readonly string? ConnectionString; + + public static bool IsConnectionStringDefined => ConnectionString is not null; + + static CosmosNoSQLTestEnvironment() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .Build(); + + ConnectionString = configuration["AzureCosmosDBNoSQL:ConnectionString"]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs new file mode 100644 index 000000000000..18f58717d461 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; +using VectorDataSpecificationTests.Support; + +namespace CosmosNoSQLIntegrationTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields (_connection) but is not disposable + +internal sealed class CosmosNoSqlTestStore : TestStore +{ + public static CosmosNoSqlTestStore Instance { get; } = new(); + + private CosmosClient? _client; + private Database? _database; + private AzureCosmosDBNoSQLVectorStore? _defaultVectorStore; + + public CosmosClient Client + => this._client ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public Database Database + => this._database ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public override IVectorStore DefaultVectorStore + => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public AzureCosmosDBNoSQLVectorStore GetVectorStore(AzureCosmosDBNoSQLVectorStoreOptions options) + => new(this.Database, options); + + private CosmosNoSqlTestStore() + { + } + +#pragma warning disable CA5400 // HttpClient may be created without enabling CheckCertificateRevocationList + protected override async Task StartAsync() + { + var connectionString = CosmosNoSQLTestEnvironment.ConnectionString; + + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new InvalidOperationException("Connection string is not configured, set the AzureCosmosDBNoSQL:ConnectionString environment variable"); + } + + var options = new CosmosClientOptions + { + UseSystemTextJsonSerializerWithOptions = JsonSerializerOptions.Default, + ConnectionMode = ConnectionMode.Gateway, + HttpClientFactory = () => new HttpClient(new HttpClientHandler { ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator }) + }; + + this._client = new CosmosClient(connectionString, options); + this._database = this._client.GetDatabase("VectorDataIntegrationTests"); + await this._client.CreateDatabaseIfNotExistsAsync("VectorDataIntegrationTests"); + this._defaultVectorStore = new(this._database); + } +#pragma warning restore CA5400 +} diff --git a/dotnet/src/VectorDataIntegrationTests/Directory.Build.props b/dotnet/src/VectorDataIntegrationTests/Directory.Build.props new file mode 100644 index 000000000000..1cfeef0ee289 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/Directory.Build.props @@ -0,0 +1,15 @@ + + + + + $(NoWarn);CA1515 + $(NoWarn);CA1707 + $(NoWarn);CA1716 + $(NoWarn);CA1720 + $(NoWarn);CA1861 + $(NoWarn);CA2007;VSTHRD111 + $(NoWarn);CS1591 + $(NoWarn);IDE1006 + + + diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs new file mode 100644 index 000000000000..32adf75e9017 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace PostgresIntegrationTests.Filter; + +public class InMemoryBasicFilterTests(InMemoryFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture; diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryFilterFixture.cs new file mode 100644 index 000000000000..7952d1dffad3 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace PostgresIntegrationTests.Filter; + +public class InMemoryFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => InMemoryTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj new file mode 100644 index 000000000000..4c6988e72e3d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj @@ -0,0 +1,26 @@ + + + + net8.0 + enable + enable + true + false + InMemoryIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs new file mode 100644 index 000000000000..246d5166c831 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.InMemory; +using VectorDataSpecificationTests.Support; + +namespace InMemoryIntegrationTests.Support; + +internal sealed class InMemoryTestStore : TestStore +{ + public static InMemoryTestStore Instance { get; } = new(); + + private InMemoryVectorStore _vectorStore = new(); + + public override IVectorStore DefaultVectorStore => this._vectorStore; + + private InMemoryTestStore() + { + } + + protected override Task StartAsync() + { + this._vectorStore = new InMemoryVectorStore(); + + return Task.CompletedTask; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs new file mode 100644 index 000000000000..f1a37114e6c6 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace MongoDBIntegrationTests.Filter; + +public class MongoDBBasicFilterTests(MongoDBFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + // Specialized MongoDB syntax for NOT over Contains ($nin) + [ConditionalFact] + public virtual Task Not_over_Contains() + => this.TestFilterAsync(r => !new[] { 8, 10 }.Contains(r.Int)); + + #region Null checking + + // MongoDB currently doesn't support null checking ({ "Foo" : null }) in vector search pre-filters + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_referenceType() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_referenceType()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + #endregion + + #region Not + + // MongoDB currently doesn't support NOT in vector search pre-filters + // (https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + #endregion + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + // AnyTagEqualTo not (currently) supported on SQLite + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_List() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_List()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBFilterFixture.cs new file mode 100644 index 000000000000..8774018ffabf --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Filter; + +public class MongoDBFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => MongoDBTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj new file mode 100644 index 000000000000..17d39c913623 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0 + enable + enable + true + false + MongoDBIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs new file mode 100644 index 000000000000..7ea67c46d0c7 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Connectors.MongoDB; +using MongoDB.Driver; +using Testcontainers.MongoDb; +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Support; + +internal sealed class MongoDBTestStore : TestStore +{ + public static MongoDBTestStore Instance { get; } = new(); + + private readonly MongoDbContainer _container = new MongoDbBuilder() + .WithImage("mongodb/mongodb-atlas-local:7.0.6") + .Build(); + + public MongoClient? _client { get; private set; } + public IMongoDatabase? _database { get; private set; } + private MongoDBVectorStore? _defaultVectorStore; + + public MongoClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); + public IMongoDatabase Database => this._database ?? throw new InvalidOperationException("Not initialized"); + + public override MongoDBVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public MongoDBVectorStore GetVectorStore(MongoDBVectorStoreOptions options) + => new(this.Database, options); + + private MongoDBTestStore() + { + } + + protected override async Task StartAsync() + { + await this._container.StartAsync(); + + this._client = new MongoClient(new MongoClientSettings + { + Server = new MongoServerAddress(this._container.Hostname, this._container.GetMappedPublicPort(MongoDbBuilder.MongoDbPort)), + DirectConnection = true, + // ReadConcern = ReadConcern.Linearizable, + // WriteConcern = WriteConcern.WMajority + }); + this._database = this._client.GetDatabase("VectorSearchTests"); + this._defaultVectorStore = new(this._database); + } + + protected override Task StopAsync() + => this._container.StopAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs new file mode 100644 index 000000000000..4fad76458700 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; +using Xunit.Sdk; + +namespace PostgresIntegrationTests.Filter; + +public class PostgresBasicFilterTests(PostgresFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + } + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresFilterFixture.cs new file mode 100644 index 000000000000..c65b37177003 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace PostgresIntegrationTests.Filter; + +public class PostgresFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => PostgresTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj new file mode 100644 index 000000000000..316ddd7225c1 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0 + enable + enable + true + false + PostgresIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs new file mode 100644 index 000000000000..9bd067b74e6e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; +using Testcontainers.PostgreSql; +using VectorDataSpecificationTests.Support; + +namespace PostgresIntegrationTests.Support; + +#pragma warning disable SKEXP0020 + +internal sealed class PostgresTestStore : TestStore +{ + public static PostgresTestStore Instance { get; } = new(); + + private static readonly PostgreSqlContainer s_container = new PostgreSqlBuilder() + .WithImage("pgvector/pgvector:pg16") + .Build(); + + private NpgsqlDataSource? _dataSource; + private PostgresVectorStore? _defaultVectorStore; + + public NpgsqlDataSource DataSource => this._dataSource ?? throw new InvalidOperationException("Not initialized"); + + public override PostgresVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public PostgresVectorStore GetVectorStore(PostgresVectorStoreOptions options) + => new(this.DataSource, options); + + private PostgresTestStore() + { + } + + protected override async Task StartAsync() + { + await s_container.StartAsync(); + + var dataSourceBuilder = new NpgsqlDataSourceBuilder + { + ConnectionStringBuilder = + { + Host = s_container.Hostname, + Port = s_container.GetMappedPublicPort(5432), + Username = PostgreSqlBuilder.DefaultUsername, + Password = PostgreSqlBuilder.DefaultPassword, + Database = PostgreSqlBuilder.DefaultDatabase + } + }; + + dataSourceBuilder.UseVector(); + + this._dataSource = dataSourceBuilder.Build(); + + await using var connection = this._dataSource.CreateConnection(); + await connection.OpenAsync(); + await using var command = new NpgsqlCommand("CREATE EXTENSION IF NOT EXISTS vector", connection); + await command.ExecuteNonQueryAsync(); + await connection.ReloadTypesAsync(); + + this._defaultVectorStore = new(this._dataSource); + } + + protected override async Task StopAsync() + { + await this._dataSource!.DisposeAsync(); + await s_container.StopAsync(); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs new file mode 100644 index 000000000000..11593833dddf --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace QdrantIntegrationTests.Filter; + +public class QdrantBasicFilterTests(QdrantFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture; diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantFilterFixture.cs new file mode 100644 index 000000000000..8c8a6528b4f8 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantFilterFixture.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using QdrantIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace QdrantIntegrationTests.Filter; + +public class QdrantFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => QdrantTestStore.Instance; + + // Qdrant doesn't support the default Flat index kind + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj new file mode 100644 index 000000000000..8eb6273cbee3 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0 + enable + enable + true + false + QdrantIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs new file mode 100644 index 000000000000..52736ac02681 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Qdrant.Client; +using QdrantIntegrationTests.Support.TestContainer; +using VectorDataSpecificationTests.Support; + +namespace QdrantIntegrationTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields but is not disposable + +internal sealed class QdrantTestStore : TestStore +{ + public static QdrantTestStore Instance { get; } = new(); + + private readonly QdrantContainer _container = new QdrantBuilder().Build(); + private QdrantClient? _client; + private QdrantVectorStore? _defaultVectorStore; + + public QdrantClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); + + public override QdrantVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public QdrantVectorStore GetVectorStore(QdrantVectorStoreOptions options) + => new(this.Client, options); + + private QdrantTestStore() + { + } + + protected override async Task StartAsync() + { + await this._container.StartAsync(); + this._client = new QdrantClient(this._container.Hostname, this._container.GetMappedPublicPort(QdrantBuilder.QdrantGrpcPort)); + this._defaultVectorStore = new(this._client); + } + + protected override Task StopAsync() + => this._container.StopAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantBuilder.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantBuilder.cs new file mode 100644 index 000000000000..a3444a9f0ee5 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantBuilder.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet.Models; +using DotNet.Testcontainers.Builders; +using DotNet.Testcontainers.Configurations; +using Qdrant.Client.Grpc; + +namespace QdrantIntegrationTests.Support.TestContainer; + +public sealed class QdrantBuilder : ContainerBuilder +{ + public const string QdrantImage = "qdrant/qdrant:" + QdrantGrpcClient.QdrantVersion; + + public const ushort QdrantHttpPort = 6333; + + public const ushort QdrantGrpcPort = 6334; + + public QdrantBuilder() : this(new QdrantConfiguration()) => this.DockerResourceConfiguration = this.Init().DockerResourceConfiguration; + + private QdrantBuilder(QdrantConfiguration dockerResourceConfiguration) : base(dockerResourceConfiguration) + => this.DockerResourceConfiguration = dockerResourceConfiguration; + + public QdrantBuilder WithConfigFile(string configPath) + => this.Merge(this.DockerResourceConfiguration, new QdrantConfiguration()) + .WithBindMount(configPath, "/qdrant/config/custom_config.yaml"); + + public QdrantBuilder WithCertificate(string certPath, string keyPath) + => this.Merge(this.DockerResourceConfiguration, new QdrantConfiguration()) + .WithBindMount(certPath, "/qdrant/tls/cert.pem") + .WithBindMount(keyPath, "/qdrant/tls/key.pem"); + + public override QdrantContainer Build() + { + this.Validate(); + return new QdrantContainer(this.DockerResourceConfiguration); + } + + protected override QdrantBuilder Init() + => base.Init() + .WithImage(QdrantImage) + .WithPortBinding(QdrantHttpPort, true) + .WithPortBinding(QdrantGrpcPort, true) + .WithWaitStrategy(Wait.ForUnixContainer() + .UntilMessageIsLogged(".*Actix runtime found; starting in Actix runtime.*")); + + protected override QdrantBuilder Clone(IResourceConfiguration resourceConfiguration) + => this.Merge(this.DockerResourceConfiguration, new QdrantConfiguration(resourceConfiguration)); + + protected override QdrantBuilder Merge(QdrantConfiguration oldValue, QdrantConfiguration newValue) + => new(new QdrantConfiguration(oldValue, newValue)); + + protected override QdrantConfiguration DockerResourceConfiguration { get; } + + protected override QdrantBuilder Clone(IContainerConfiguration resourceConfiguration) + => this.Merge(this.DockerResourceConfiguration, new QdrantConfiguration(resourceConfiguration)); +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantConfiguration.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantConfiguration.cs new file mode 100644 index 000000000000..219e4030c581 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantConfiguration.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet.Models; +using DotNet.Testcontainers.Configurations; + +namespace QdrantIntegrationTests.Support.TestContainer; + +public sealed class QdrantConfiguration : ContainerConfiguration +{ + /// + /// Initializes a new instance of the class. + /// + public QdrantConfiguration() + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public QdrantConfiguration(IResourceConfiguration resourceConfiguration) + : base(resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public QdrantConfiguration(IContainerConfiguration resourceConfiguration) + : base(resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public QdrantConfiguration(QdrantConfiguration resourceConfiguration) + : this(new QdrantConfiguration(), resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The old Docker resource configuration. + /// The new Docker resource configuration. + public QdrantConfiguration(QdrantConfiguration oldValue, QdrantConfiguration newValue) + : base(oldValue, newValue) + { + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantContainer.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantContainer.cs new file mode 100644 index 000000000000..f9c1ab05f1cc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantContainer.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +using DotNet.Testcontainers.Containers; + +namespace QdrantIntegrationTests.Support.TestContainer; + +public class QdrantContainer(QdrantConfiguration configuration) : DockerContainer(configuration); diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs new file mode 100644 index 000000000000..2d0bc17f179a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace RedisIntegrationTests.Filter; + +public class RedisBasicFilterTests(RedisFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + #region Equality with null + + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_referenceType() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + #endregion + + #region Bool + + public override Task Bool() + => Assert.ThrowsAsync(() => base.Bool()); + + public override Task Not_over_bool() + => Assert.ThrowsAsync(() => base.Not_over_bool()); + + #endregion + + #region Contains + + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array()); + + public override Task Contains_over_inline_string_array_with_weird_chars() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array_with_weird_chars()); + + public override Task Contains_over_captured_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_captured_string_array()); + + #endregion +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs new file mode 100644 index 000000000000..e450381f91e4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace RedisIntegrationTests.Filter; + +public class RedisFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => RedisTestStore.Instance; + + // Override to remove the bool property, which isn't (currently) supported on Redis + protected override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(bool)).ToList() + }; +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj new file mode 100644 index 000000000000..b3661a620b41 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0 + enable + enable + true + false + RedisIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs new file mode 100644 index 000000000000..c0e2336ed9e5 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Connectors.Redis; +using StackExchange.Redis; +using Testcontainers.Redis; +using VectorDataSpecificationTests.Support; + +namespace RedisIntegrationTests.Support; + +internal sealed class RedisTestStore : TestStore +{ + public static RedisTestStore Instance { get; } = new(); + + private readonly RedisContainer _container = new RedisBuilder() + .WithImage("redis/redis-stack") + .Build(); + + private IDatabase? _database; + private RedisVectorStore? _defaultVectorStore; + + public IDatabase Database => this._database ?? throw new InvalidOperationException("Not initialized"); + + public override RedisVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public RedisVectorStore GetVectorStore(RedisVectorStoreOptions options) + => new(this.Database, options); + + private RedisTestStore() + { + } + + protected override async Task StartAsync() + { + await this._container.StartAsync(); + var redis = await ConnectionMultiplexer.ConnectAsync($"{this._container.Hostname}:{this._container.GetMappedPublicPort(6379)},connectTimeout=60000,connectRetry=5"); + this._database = redis.GetDatabase(); + this._defaultVectorStore = new(this._database); + } + + protected override Task StopAsync() + => this._container.StopAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs new file mode 100644 index 000000000000..9ca7878a414e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; +using Xunit.Sdk; + +namespace SqliteIntegrationTests.Filter; + +public class SqliteBasicFilterTests(SqliteFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + } + + // Array fields not (currently) supported on SQLite (see #10343) + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + // List fields not (currently) supported on SQLite (see #10343) + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + // AnyTagEqualTo not (currently) supported on SQLite + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_List() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_List()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteFilterFixture.cs new file mode 100644 index 000000000000..3dc9a0d10dad --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteFilterFixture.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace SqliteIntegrationTests.Filter; + +public class SqliteFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => SqliteTestStore.Instance; + + protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; + + // Override to remove the string array property, which isn't (currently) supported on SQLite + protected override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(string[]) && p.PropertyType != typeof(List)).ToList() + }; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..89ee1c5e6025 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: SqliteIntegrationTests.Support.SqliteVecRequired] diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj new file mode 100644 index 000000000000..3b1a18810500 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj @@ -0,0 +1,26 @@ + + + + net8.0 + enable + enable + true + false + SqliteIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestEnvironment.cs new file mode 100644 index 000000000000..e7dd76fb76fc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestEnvironment.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Data; +using Microsoft.Data.Sqlite; + +namespace SqliteIntegrationTests.Support; + +internal static class SqliteTestEnvironment +{ + /// + /// SQLite extension name for vector search. + /// More information here: . + /// + private const string VectorSearchExtensionName = "vec0"; + + private static bool? s_isSqliteVecInstalled; + + internal static bool TryLoadSqliteVec(SqliteConnection connection) + { + if (!s_isSqliteVecInstalled.HasValue) + { + if (connection.State != ConnectionState.Open) + { + throw new ArgumentException("Connection must be open"); + } + + try + { + connection.LoadExtension(VectorSearchExtensionName); + s_isSqliteVecInstalled = true; + } + catch (SqliteException) + { + s_isSqliteVecInstalled = false; + } + } + + return s_isSqliteVecInstalled.Value; + } + + internal static bool IsSqliteVecInstalled + { + get + { + if (!s_isSqliteVecInstalled.HasValue) + { + using var connection = new SqliteConnection("Data Source=:memory:;"); + connection.Open(); + + s_isSqliteVecInstalled = TryLoadSqliteVec(connection); + } + + return s_isSqliteVecInstalled.Value; + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs new file mode 100644 index 000000000000..48b884b08ade --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Data.Sqlite; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Sqlite; +using VectorDataSpecificationTests.Support; + +namespace SqliteIntegrationTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields (_connection) but is not disposable + +internal sealed class SqliteTestStore : TestStore +{ + public static SqliteTestStore Instance { get; } = new(); + + private SqliteConnection? _connection; + public SqliteConnection Connection + => this._connection ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + private SqliteVectorStore? _defaultVectorStore; + public override IVectorStore DefaultVectorStore + => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + private SqliteTestStore() + { + } + + protected override async Task StartAsync() + { + this._connection = new SqliteConnection("Data Source=:memory:"); + + await this.Connection.OpenAsync(); + + if (!SqliteTestEnvironment.TryLoadSqliteVec(this.Connection)) + { + this.Connection.Dispose(); + + // Note that we ignore sqlite_vec loading failures; the tests are decorated with [SqliteVecRequired], which causes + // them to be skipped if sqlite_vec isn't installed (better than an exception triggering failure here) + } + + this._defaultVectorStore = new SqliteVectorStore(this.Connection); + } + + protected override async Task StopAsync() + => await this.Connection.DisposeAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteVecRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteVecRequiredAttribute.cs new file mode 100644 index 000000000000..9351fd679171 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteVecRequiredAttribute.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Xunit; + +namespace SqliteIntegrationTests.Support; + +/// +/// Checks whether the sqlite_vec extension is properly installed, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class SqliteVecRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(SqliteTestEnvironment.IsSqliteVecInstalled); + + public string Skip { get; set; } = "The sqlite_vec extension is not installed."; + + public string SkipReason + => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs new file mode 100644 index 000000000000..b637035eea53 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs @@ -0,0 +1,283 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq.Expressions; +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.Filter; + +public abstract class BasicFilterTestsBase(FilterFixtureBase fixture) + where TKey : notnull +{ + #region Equality + + [ConditionalFact] + public virtual Task Equal_with_int() + => this.TestFilterAsync(r => r.Int == 8); + + [ConditionalFact] + public virtual Task Equal_with_string() + => this.TestFilterAsync(r => r.String == "foo"); + + [ConditionalFact] + public virtual Task Equal_with_string_containing_special_characters() + => this.TestFilterAsync(r => r.String == """with some special"characters'and\stuff"""); + + [ConditionalFact] + public virtual Task Equal_with_string_is_not_Contains() + => this.TestFilterAsync(r => r.String == "some", expectZeroResults: true); + + [ConditionalFact] + public virtual Task Equal_reversed() + => this.TestFilterAsync(r => 8 == r.Int); + + [ConditionalFact] + public virtual Task Equal_with_null_reference_type() + => this.TestFilterAsync(r => r.String == null); + + [ConditionalFact] + public virtual Task Equal_with_null_captured() + { + string? s = null; + + return this.TestFilterAsync(r => r.String == s); + } + + [ConditionalFact] + public virtual Task NotEqual_with_int() + => this.TestFilterAsync(r => r.Int != 8); + + [ConditionalFact] + public virtual Task NotEqual_with_string() + => this.TestFilterAsync(r => r.String != "foo"); + + [ConditionalFact] + public virtual Task NotEqual_reversed() + => this.TestFilterAsync(r => r.Int != 8); + + [ConditionalFact] + public virtual Task NotEqual_with_null_referenceType() + => this.TestFilterAsync(r => r.String != null); + + [ConditionalFact] + public virtual Task NotEqual_with_null_captured() + { + string? s = null; + + return this.TestFilterAsync(r => r.String != s); + } + + [ConditionalFact] + public virtual Task Bool() + => this.TestFilterAsync(r => r.Bool); + + #endregion Equality + + #region Comparison + + [ConditionalFact] + public virtual Task GreaterThan_with_int() + => this.TestFilterAsync(r => r.Int > 9); + + [ConditionalFact] + public virtual Task GreaterThanOrEqual_with_int() + => this.TestFilterAsync(r => r.Int >= 9); + + [ConditionalFact] + public virtual Task LessThan_with_int() + => this.TestFilterAsync(r => r.Int < 10); + + [ConditionalFact] + public virtual Task LessThanOrEqual_with_int() + => this.TestFilterAsync(r => r.Int <= 10); + + #endregion Comparison + + #region Logical operators + + [ConditionalFact] + public virtual Task And() + => this.TestFilterAsync(r => r.Int == 8 && r.String == "foo"); + + [ConditionalFact] + public virtual Task Or() + => this.TestFilterAsync(r => r.Int == 8 || r.String == "foo"); + + [ConditionalFact] + public virtual Task And_within_And() + => this.TestFilterAsync(r => (r.Int == 8 && r.String == "foo") && r.Int2 == 80); + + [ConditionalFact] + public virtual Task And_within_Or() + => this.TestFilterAsync(r => (r.Int == 8 && r.String == "foo") || r.Int2 == 100); + + [ConditionalFact] + public virtual Task Or_within_And() + => this.TestFilterAsync(r => (r.Int == 8 || r.Int == 9) && r.String == "foo"); + + [ConditionalFact] + public virtual Task Not_over_Equal() + // ReSharper disable once NegativeEqualityExpression + => this.TestFilterAsync(r => !(r.Int == 8)); + + [ConditionalFact] + public virtual Task Not_over_NotEqual() + // ReSharper disable once NegativeEqualityExpression + => this.TestFilterAsync(r => !(r.Int != 8)); + + [ConditionalFact] + public virtual Task Not_over_And() + => this.TestFilterAsync(r => !(r.Int == 8 && r.String == "foo")); + + [ConditionalFact] + public virtual Task Not_over_Or() + => this.TestFilterAsync(r => !(r.Int == 8 || r.String == "foo")); + + [ConditionalFact] + public virtual Task Not_over_bool() + => this.TestFilterAsync(r => !r.Bool); + + #endregion Logical operators + + #region Contains + + [ConditionalFact] + public virtual Task Contains_over_field_string_array() + => this.TestFilterAsync(r => r.StringArray.Contains("x")); + + [ConditionalFact] + public virtual Task Contains_over_field_string_List() + => this.TestFilterAsync(r => r.StringList.Contains("x")); + + [ConditionalFact] + public virtual Task Contains_over_inline_int_array() + => this.TestFilterAsync(r => new[] { 8, 10 }.Contains(r.Int)); + + [ConditionalFact] + public virtual Task Contains_over_inline_string_array() + => this.TestFilterAsync(r => new[] { "foo", "baz", "unknown" }.Contains(r.String)); + + [ConditionalFact] + public virtual Task Contains_over_inline_string_array_with_weird_chars() + => this.TestFilterAsync(r => new[] { "foo", "baz", "un , ' \"" }.Contains(r.String)); + + [ConditionalFact] + public virtual Task Contains_over_captured_string_array() + { + var array = new[] { "foo", "baz", "unknown" }; + + return this.TestFilterAsync(r => array.Contains(r.String)); + } + + #endregion Contains + + [ConditionalFact] + public virtual Task Captured_variable() + { + // ReSharper disable once ConvertToConstant.Local + var i = 8; + + return this.TestFilterAsync(r => r.Int == i); + } + + #region Legacy filter support + + [ConditionalFact] + [Obsolete("Legacy filter support")] + public virtual Task Legacy_equality() + => this.TestLegacyFilterAsync( + new VectorSearchFilter().EqualTo("Int", 8), + r => r.Int == 8); + + [ConditionalFact] + [Obsolete("Legacy filter support")] + public virtual Task Legacy_And() + => this.TestLegacyFilterAsync( + new VectorSearchFilter().EqualTo("Int", 8).EqualTo("String", "foo"), + r => r.Int == 8); + + [ConditionalFact] + [Obsolete("Legacy filter support")] + public virtual Task Legacy_AnyTagEqualTo_array() + => this.TestLegacyFilterAsync( + new VectorSearchFilter().AnyTagEqualTo("StringArray", "x"), + r => r.StringArray.Contains("x")); + + [ConditionalFact] + [Obsolete("Legacy filter support")] + public virtual Task Legacy_AnyTagEqualTo_List() + => this.TestLegacyFilterAsync( + new VectorSearchFilter().AnyTagEqualTo("StringList", "x"), + r => r.StringArray.Contains("x")); + + #endregion Legacy filter support + + protected virtual async Task TestFilterAsync( + Expression, bool>> filter, + bool expectZeroResults = false, + bool expectAllResults = false) + { + var expected = fixture.TestData.AsQueryable().Where(filter).OrderBy(r => r.Key).ToList(); + + if (expected.Count == 0 && !expectZeroResults) + { + Assert.Fail("The test returns zero results, and so is unreliable"); + } + + if (expected.Count == fixture.TestData.Count && !expectAllResults) + { + Assert.Fail("The test returns all results, and so is unreliable"); + } + + var results = await fixture.Collection.VectorizedSearchAsync( + new ReadOnlyMemory([1, 2, 3]), + new() + { + NewFilter = filter, + Top = fixture.TestData.Count + }); + + var actual = await results.Results.Select(r => r.Record).OrderBy(r => r.Key).ToListAsync(); + + Assert.Equal(expected, actual, (e, a) => + e.Int == a.Int && + e.String == a.String && + e.Int2 == a.Int2); + } + + [Obsolete("Legacy filter support")] + protected virtual async Task TestLegacyFilterAsync( + VectorSearchFilter legacyFilter, + Expression, bool>> expectedFilter, + bool expectZeroResults = false, + bool expectAllResults = false) + { + var expected = fixture.TestData.AsQueryable().Where(expectedFilter).OrderBy(r => r.Key).ToList(); + + if (expected.Count == 0 && !expectZeroResults) + { + Assert.Fail("The test returns zero results, and so is unreliable"); + } + + if (expected.Count == fixture.TestData.Count && !expectAllResults) + { + Assert.Fail("The test returns all results, and so is unreliable"); + } + + var results = await fixture.Collection.VectorizedSearchAsync( + new ReadOnlyMemory([1, 2, 3]), + new() + { + Filter = legacyFilter, + Top = fixture.TestData.Count + }); + + var actual = await results.Results.Select(r => r.Record).OrderBy(r => r.Key).ToListAsync(); + + Assert.Equal(expected, actual, (e, a) => + e.Int == a.Int && + e.String == a.String && + e.Int2 == a.Int2); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs new file mode 100644 index 000000000000..9af82882a7fb --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Globalization; +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace VectorDataSpecificationTests.Filter; + +public abstract class FilterFixtureBase : IAsyncLifetime + where TKey : notnull +{ + private int _nextKeyValue = 1; + private List>? _testData; + + protected virtual string StoreName => "FilterTests"; + + protected abstract TestStore TestStore { get; } + + protected virtual string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineSimilarity; + protected virtual string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Flat; + + public virtual async Task InitializeAsync() + { + await this.TestStore.ReferenceCountingStartAsync(); + + this.Collection = this.TestStore.DefaultVectorStore.GetCollection>(this.StoreName, this.GetRecordDefinition()); + + if (await this.Collection.CollectionExistsAsync()) + { + await this.Collection.DeleteCollectionAsync(); + } + + await this.Collection.CreateCollectionAsync(); + await this.SeedAsync(); + + // Some databases upsert asynchronously, meaning that our seed data may not be visible immediately to tests. + // Check and loop until it is. + for (var i = 0; i < 20; i++) + { + var results = await this.Collection.VectorizedSearchAsync( + new ReadOnlyMemory([1, 2, 3]), + new() { Top = this.TestData.Count }); + var count = await results.Results.CountAsync(); + if (count == this.TestData.Count) + { + break; + } + + await Task.Delay(TimeSpan.FromMilliseconds(100)); + } + } + + protected virtual VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(FilterRecord.Key), typeof(TKey)), + new VectorStoreRecordVectorProperty(nameof(FilterRecord.Vector), typeof(ReadOnlyMemory?)) + { + Dimensions = 3, + DistanceFunction = this.DistanceFunction, + IndexKind = this.IndexKind + }, + + new VectorStoreRecordDataProperty(nameof(FilterRecord.Int), typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.String), typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.Bool), typeof(bool)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.Int2), typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.StringArray), typeof(string[])) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.StringList), typeof(List)) { IsFilterable = true } + ] + }; + + public virtual IVectorStoreRecordCollection> Collection { get; private set; } = null!; + + public List> TestData => this._testData ??= this.BuildTestData(); + + protected virtual List> BuildTestData() + { + // All records have the same vector - this fixture is about testing criteria filtering only + var vector = new ReadOnlyMemory([1, 2, 3]); + + return + [ + new() + { + Key = this.GenerateNextKey(), + Int = 8, + String = "foo", + Bool = true, + Int2 = 80, + StringArray = ["x", "y"], + StringList = ["x", "y"], + Vector = vector + }, + new() + { + Key = this.GenerateNextKey(), + Int = 9, + String = "bar", + Bool = false, + Int2 = 90, + StringArray = ["a", "b"], + StringList = ["a", "b"], + Vector = vector + }, + new() + { + Key = this.GenerateNextKey(), + Int = 9, + String = "foo", + Bool = true, + Int2 = 9, + StringArray = ["x"], + StringList = ["x"], + Vector = vector + }, + new() + { + Key = this.GenerateNextKey(), + Int = 10, + String = null, + Bool = false, + Int2 = 100, + StringArray = ["x", "y", "z"], + StringList = ["x", "y", "z"], + Vector = vector + }, + new() + { + Key = this.GenerateNextKey(), + Int = 11, + Bool = true, + String = """with some special"characters'and\stuff""", + Int2 = 101, + StringArray = ["y", "z"], + StringList = ["y", "z"], + Vector = vector + } + ]; + } + + protected virtual async Task SeedAsync() + { + // TODO: UpsertBatchAsync returns IAsyncEnumerable (to support server-generated keys?), but this makes it quite hard to use: + await foreach (var _ in this.Collection.UpsertBatchAsync(this.TestData)) + { + } + } + + protected virtual TKey GenerateNextKey() + => typeof(TKey) switch + { + _ when typeof(TKey) == typeof(int) => (TKey)(object)this._nextKeyValue++, + _ when typeof(TKey) == typeof(long) => (TKey)(object)(long)this._nextKeyValue++, + _ when typeof(TKey) == typeof(ulong) => (TKey)(object)(ulong)this._nextKeyValue++, + _ when typeof(TKey) == typeof(string) => (TKey)(object)(this._nextKeyValue++).ToString(CultureInfo.InvariantCulture), + _ when typeof(TKey) == typeof(Guid) => (TKey)(object)new Guid($"00000000-0000-0000-0000-00{this._nextKeyValue++:0000000000}"), + + _ => throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}', override {nameof(this.GenerateNextKey)}") + }; + + public virtual Task DisposeAsync() + => this.TestStore.ReferenceCountingStopAsync(); +} + +#pragma warning disable CS1819 // Properties should not return arrays +#pragma warning disable CA1819 // Properties should not return arrays +public class FilterRecord +{ + public TKey Key { get; init; } = default!; + public ReadOnlyMemory? Vector { get; set; } + + public int Int { get; set; } + public string? String { get; set; } + public bool Bool { get; set; } + public int Int2 { get; set; } + public string[] StringArray { get; set; } = null!; + public List StringList { get; set; } = null!; +} +#pragma warning restore CA1819 // Properties should not return arrays +#pragma warning restore CS1819 diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs new file mode 100644 index 000000000000..de7c0d252062 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; + +namespace VectorDataSpecificationTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields but is not disposable + +public abstract class TestStore +{ + private readonly SemaphoreSlim _lock = new(1, 1); + private int _referenceCount; + + protected abstract Task StartAsync(); + + protected virtual Task StopAsync() + => Task.CompletedTask; + + public virtual async Task ReferenceCountingStartAsync() + { + await this._lock.WaitAsync(); + try + { + if (this._referenceCount++ == 0) + { + await this.StartAsync(); + } + } + finally + { + this._lock.Release(); + } + } + + public virtual async Task ReferenceCountingStopAsync() + { + await this._lock.WaitAsync(); + try + { + if (--this._referenceCount == 0) + { + await this.StopAsync(); + } + } + finally + { + this._lock.Release(); + } + } + + public abstract IVectorStore DefaultVectorStore { get; } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj new file mode 100644 index 000000000000..9949700b9fce --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj @@ -0,0 +1,21 @@ + + + + net8.0 + enable + enable + true + false + VectorDataSpecificationTests + + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactAttribute.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactAttribute.cs new file mode 100644 index 000000000000..d4d93c8b5035 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactAttribute.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +[AttributeUsage(AttributeTargets.Method)] +[XunitTestCaseDiscoverer("VectorDataSpecificationTests.Xunit.ConditionalFactDiscoverer", "VectorDataIntegrationTests")] +public sealed class ConditionalFactAttribute : FactAttribute; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactDiscoverer.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactDiscoverer.cs new file mode 100644 index 000000000000..1fbeafd3dd1c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactDiscoverer.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +/// +/// Used dynamically from . +/// Make sure to update that class if you move this type. +/// +public class ConditionalFactDiscoverer(IMessageSink messageSink) : FactDiscoverer(messageSink) +{ + protected override IXunitTestCase CreateTestCase( + ITestFrameworkDiscoveryOptions discoveryOptions, + ITestMethod testMethod, + IAttributeInfo factAttribute) + => new ConditionalFactTestCase( + this.DiagnosticMessageSink, + discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), + testMethod); +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactTestCase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactTestCase.cs new file mode 100644 index 000000000000..3dea216a1084 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactTestCase.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +public sealed class ConditionalFactTestCase : XunitTestCase +{ + [Obsolete("Called by the de-serializer; should only be called by deriving classes for de-serialization purposes")] + public ConditionalFactTestCase() + { + } + + public ConditionalFactTestCase( + IMessageSink diagnosticMessageSink, + TestMethodDisplay defaultMethodDisplay, + TestMethodDisplayOptions defaultMethodDisplayOptions, + ITestMethod testMethod, + object[]? testMethodArguments = null) + : base(diagnosticMessageSink, defaultMethodDisplay, defaultMethodDisplayOptions, testMethod, testMethodArguments) + { + } + + public override async Task RunAsync( + IMessageSink diagnosticMessageSink, + IMessageBus messageBus, + object[] constructorArguments, + ExceptionAggregator aggregator, + CancellationTokenSource cancellationTokenSource) + => await XunitTestCaseExtensions.TrySkipAsync(this, messageBus) + ? new RunSummary { Total = 1, Skipped = 1 } + : await base.RunAsync( + diagnosticMessageSink, + messageBus, + constructorArguments, + aggregator, + cancellationTokenSource); +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs new file mode 100644 index 000000000000..529f42ef1310 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +[AttributeUsage(AttributeTargets.Method)] +[XunitTestCaseDiscoverer("VectorDataSpecificationTests.Xunit.VectorStoreFactDiscoverer", "VectorDataIntegrationTests")] +public sealed class ConditionalTheoryAttribute : TheoryAttribute; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ITestCondition.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ITestCondition.cs new file mode 100644 index 000000000000..deca7716fb1a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ITestCondition.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace VectorDataSpecificationTests.Xunit; + +public interface ITestCondition +{ + ValueTask IsMetAsync(); + + string SkipReason { get; } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/XunitTestCaseExtensions.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/XunitTestCaseExtensions.cs new file mode 100644 index 000000000000..2cf37205ead4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/XunitTestCaseExtensions.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Concurrent; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +public static class XunitTestCaseExtensions +{ + private static readonly ConcurrentDictionary> s_typeAttributes = new(); + private static readonly ConcurrentDictionary> s_assemblyAttributes = new(); + + public static async ValueTask TrySkipAsync(XunitTestCase testCase, IMessageBus messageBus) + { + var method = testCase.Method; + var type = testCase.TestMethod.TestClass.Class; + var assembly = type.Assembly; + + var skipReasons = new List(); + var attributes = + s_assemblyAttributes.GetOrAdd( + assembly.Name, + a => assembly.GetCustomAttributes(typeof(ITestCondition)).ToList()) + .Concat( + s_typeAttributes.GetOrAdd( + type.Name, + t => type.GetCustomAttributes(typeof(ITestCondition)).ToList())) + .Concat(method.GetCustomAttributes(typeof(ITestCondition))) + .OfType() + .Select(attributeInfo => (ITestCondition)attributeInfo.Attribute); + + foreach (var attribute in attributes) + { + if (!await attribute.IsMetAsync()) + { + skipReasons.Add(attribute.SkipReason); + } + } + + if (skipReasons.Count > 0) + { + messageBus.QueueMessage( + new TestSkipped(new XunitTest(testCase, testCase.DisplayName), string.Join(Environment.NewLine, skipReasons))); + + return true; + } + + return false; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs new file mode 100644 index 000000000000..941053dd98a3 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Filter; +using Xunit; +using Xunit.Sdk; + +namespace WeaviateIntegrationTests.Filter; + +public class WeaviateBasicFilterTests(WeaviateFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + #region Filter by null + + // Null-state indexing needs to be set up, but that's not supported yet (#10358). + // We could interact with Weaviate directly (not via the abstraction) to do this. + + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + public override Task NotEqual_with_null_referenceType() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_referenceType()); + + #endregion + + #region Not + + // Weaviate currently doesn't support NOT (https://github.com/weaviate/weaviate/issues/3683) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + #endregion + + #region Unsupported Contains scenarios + + public override Task Contains_over_captured_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_captured_string_array()); + + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array_with_weird_chars() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array_with_weird_chars()); + + #endregion + + // In Weaviate, string equality on multi-word textual properties depends on tokenization + // (https://weaviate.io/developers/weaviate/api/graphql/filters#multi-word-queries-in-equal-filters) + public override Task Equal_with_string_is_not_Contains() + => Assert.ThrowsAsync(() => base.Equal_with_string_is_not_Contains()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateFilterFixture.cs new file mode 100644 index 000000000000..f00b884780c2 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateFilterFixture.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using WeaviateIntegrationTests.Support; + +namespace WeaviateIntegrationTests.Filter; + +public class WeaviateFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => WeaviateTestStore.Instance; + + protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs new file mode 100644 index 000000000000..1745a902a348 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet.Models; +using DotNet.Testcontainers.Builders; +using DotNet.Testcontainers.Configurations; + +namespace WeaviateIntegrationTests.Support.TestContainer; + +public sealed class WeaviateBuilder : ContainerBuilder +{ + public const string WeaviateImage = "semitechnologies/weaviate:1.26.4"; + public const ushort WeaviateHttpPort = 8080; + public const ushort WeaviateGrpcPort = 50051; + + public WeaviateBuilder() : this(new WeaviateConfiguration()) => this.DockerResourceConfiguration = this.Init().DockerResourceConfiguration; + + private WeaviateBuilder(WeaviateConfiguration dockerResourceConfiguration) : base(dockerResourceConfiguration) + => this.DockerResourceConfiguration = dockerResourceConfiguration; + + public override WeaviateContainer Build() + { + this.Validate(); + return new WeaviateContainer(this.DockerResourceConfiguration); + } + + protected override WeaviateBuilder Init() + => base.Init() + .WithImage(WeaviateImage) + .WithPortBinding(WeaviateHttpPort, true) + .WithPortBinding(WeaviateGrpcPort, true) + .WithEnvironment("AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED", "true") + .WithEnvironment("PERSISTENCE_DATA_PATH", "/var/lib/weaviate") + .WithWaitStrategy(Wait.ForUnixContainer() + .UntilPortIsAvailable(WeaviateHttpPort) + .UntilPortIsAvailable(WeaviateGrpcPort) + .UntilHttpRequestIsSucceeded(r => r.ForPath("/v1/.well-known/ready").ForPort(WeaviateHttpPort))); + + protected override WeaviateBuilder Clone(IResourceConfiguration resourceConfiguration) + => this.Merge(this.DockerResourceConfiguration, new WeaviateConfiguration(resourceConfiguration)); + + protected override WeaviateBuilder Merge(WeaviateConfiguration oldValue, WeaviateConfiguration newValue) + => new(new WeaviateConfiguration(oldValue, newValue)); + + protected override WeaviateConfiguration DockerResourceConfiguration { get; } + + protected override WeaviateBuilder Clone(IContainerConfiguration resourceConfiguration) + => this.Merge(this.DockerResourceConfiguration, new WeaviateConfiguration(resourceConfiguration)); +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateConfiguration.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateConfiguration.cs new file mode 100644 index 000000000000..56ea40b242e7 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateConfiguration.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet.Models; +using DotNet.Testcontainers.Configurations; + +namespace WeaviateIntegrationTests.Support.TestContainer; + +public sealed class WeaviateConfiguration : ContainerConfiguration +{ + /// + /// Initializes a new instance of the class. + /// + public WeaviateConfiguration() + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public WeaviateConfiguration(IResourceConfiguration resourceConfiguration) + : base(resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public WeaviateConfiguration(IContainerConfiguration resourceConfiguration) + : base(resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public WeaviateConfiguration(WeaviateConfiguration resourceConfiguration) + : this(new WeaviateConfiguration(), resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The old Docker resource configuration. + /// The new Docker resource configuration. + public WeaviateConfiguration(WeaviateConfiguration oldValue, WeaviateConfiguration newValue) + : base(oldValue, newValue) + { + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateContainer.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateContainer.cs new file mode 100644 index 000000000000..c209d662a4d4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateContainer.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +using DotNet.Testcontainers.Containers; + +namespace WeaviateIntegrationTests.Support.TestContainer; + +public class WeaviateContainer(WeaviateConfiguration configuration) : DockerContainer(configuration); diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs new file mode 100644 index 000000000000..21e079421674 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Connectors.Weaviate; +using VectorDataSpecificationTests.Support; +using WeaviateIntegrationTests.Support.TestContainer; + +namespace WeaviateIntegrationTests.Support; + +public sealed class WeaviateTestStore : TestStore +{ + public static WeaviateTestStore Instance { get; } = new(); + + private readonly WeaviateContainer _container = new WeaviateBuilder().Build(); + public HttpClient? _httpClient { get; private set; } + private WeaviateVectorStore? _defaultVectorStore; + + public HttpClient Client => this._httpClient ?? throw new InvalidOperationException("Not initialized"); + + public override WeaviateVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public WeaviateVectorStore GetVectorStore(WeaviateVectorStoreOptions options) + => new(this.Client, options); + + private WeaviateTestStore() + { + } + + protected override async Task StartAsync() + { + await this._container.StartAsync(); + this._httpClient = new HttpClient { BaseAddress = new Uri($"http://localhost:{this._container.GetMappedPublicPort(WeaviateBuilder.WeaviateHttpPort)}/v1/") }; + this._defaultVectorStore = new(this._httpClient); + } + + protected override Task StopAsync() + => this._container.StopAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj new file mode 100644 index 000000000000..d99bde7558c3 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0 + enable + enable + true + false + WeaviateIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + From 283d581154f09dec3c825ddea968b1b49f3c4b6d Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Sun, 9 Feb 2025 15:39:19 -0800 Subject: [PATCH 2/8] Generate suppression files --- .../CompatibilitySuppressions.xml | 107 +++++++++++++++++- .../CompatibilitySuppressions.xml | 16 ++- 2 files changed, 121 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml b/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml index 0860b81e7585..cd9bfbaa3ca7 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml +++ b/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml @@ -1,5 +1,5 @@  - + CP0001 @@ -15,6 +15,13 @@ lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0001 + T:Microsoft.Extensions.VectorData.VectorSearchOptions + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0001 T:Microsoft.Extensions.VectorData.DeleteRecordOptions @@ -29,6 +36,13 @@ lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0001 + T:Microsoft.Extensions.VectorData.VectorSearchOptions + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0001 T:Microsoft.Extensions.VectorData.DeleteRecordOptions @@ -43,6 +57,27 @@ lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0001 + T:Microsoft.Extensions.VectorData.VectorSearchOptions + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0002 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) @@ -71,6 +106,20 @@ lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0002 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) @@ -99,6 +148,20 @@ lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0002 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) @@ -127,6 +190,20 @@ lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0006 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) @@ -155,6 +232,20 @@ lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0006 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) @@ -183,6 +274,20 @@ lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0006 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) diff --git a/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml b/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml index de2e33319a56..6c9084abb2ce 100644 --- a/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml +++ b/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml @@ -1,5 +1,5 @@  - + CP0002 @@ -29,6 +29,13 @@ lib/net8.0/Microsoft.SemanticKernel.Core.dll true + + CP0002 + M:Microsoft.SemanticKernel.Data.VolatileVectorStoreRecordCollection`2.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net8.0/Microsoft.SemanticKernel.Core.dll + lib/net8.0/Microsoft.SemanticKernel.Core.dll + true + CP0002 M:Microsoft.SemanticKernel.Data.VolatileVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) @@ -57,4 +64,11 @@ lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll true + + CP0002 + M:Microsoft.SemanticKernel.Data.VolatileVectorStoreRecordCollection`2.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + true + \ No newline at end of file From 1c893c6182b36465ebe414aebc846b296dcc1cdf Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 10 Feb 2025 11:18:45 +0100 Subject: [PATCH 3/8] Implement Cosmos Mongo filtering --- .../Filter/CosmosMongoBasicFilterTests.cs | 59 +++++++++++++++++++ .../Filter/CosmosMongoFiltersNotSupported.cs | 24 -------- 2 files changed, 59 insertions(+), 24 deletions(-) create mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs delete mode 100644 dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFiltersNotSupported.cs diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs new file mode 100644 index 000000000000..deed1197b728 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace MongoDBIntegrationTests.Filter; + +public class CosmosMongoBasicFilterTests(CosmosMongoFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + // Specialized MongoDB syntax for NOT over Contains ($nin) + [ConditionalFact] + public virtual Task Not_over_Contains() + => this.TestFilterAsync(r => !new[] { 8, 10 }.Contains(r.Int)); + + #region Null checking + + // MongoDB currently doesn't support null checking ({ "Foo" : null }) in vector search pre-filters + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_referenceType() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_referenceType()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + #endregion + + #region Not + + // MongoDB currently doesn't support NOT in vector search pre-filters + // (https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + #endregion + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + // AnyTagEqualTo not (currently) supported on SQLite + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_List() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_List()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFiltersNotSupported.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFiltersNotSupported.cs deleted file mode 100644 index 47b6c4720cf4..000000000000 --- a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFiltersNotSupported.cs +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using MongoDB.Driver; -using VectorDataSpecificationTests.Xunit; -using Xunit; - -namespace MongoDBIntegrationTests.Filter; - -public class CosmosMongoFiltersNotSupported(CosmosMongoFilterFixture fixture) : IClassFixture -{ - [ConditionalFact] - public virtual async Task Equal_with_int() - { - // Cosmos MongoDB vCore doesn't yet support filters with vector search: - // Command aggregate failed: $filter is not supported for vector search yet.. - await Assert.ThrowsAsync(() => fixture.Collection.VectorizedSearchAsync( - new ReadOnlyMemory([1, 2, 3]), - new() - { - NewFilter = r => r.Int == 8, - Top = fixture.TestData.Count - })); - } -} From 64fd77b62351cec485b272ae04f0157d657245ad Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 11 Feb 2025 09:33:02 +0100 Subject: [PATCH 4/8] Address review comments --- .../MappingVectorStoreRecordCollection.cs | 268 ++++++------ .../QdrantFactory.cs | 35 +- .../Step2_Vector_Search.cs | 4 +- .../Step4_NonStringKey_VectorStore.cs | 394 +++++++++--------- ...LVectorStoreCollectionQueryBuilderTests.cs | 13 +- .../Connectors.AzureOpenAI.csproj | 1 + .../AzureAISearchFilterTranslator.cs | 51 ++- ...zureAISearchVectorStoreRecordCollection.cs | 1 - .../AzureCosmosDBNoSqlFilterTranslator.cs | 6 +- .../PostgresFilterTranslator.cs | 4 +- .../RedisFilterTranslator.cs | 3 +- .../SqliteFilterTranslator.cs | 5 +- ...liteVectorStoreCollectionCommandBuilder.cs | 4 +- .../WeaviateFilterTranslator.cs | 3 +- .../WeaviateVectorStoreRecordCollection.cs | 1 - .../Connectors.OpenAI.csproj | 6 + ...ectorStoreCollectionCommandBuilderTests.cs | 2 + .../AnyTagEqualToFilterClause.cs | 3 - .../FilterClauses/EqualToFilterClause.cs | 3 - .../FilterClauses/FilterClause.cs | 3 - .../Functions.OpenApi.csproj | 4 + ...MongoDBVectorStoreRecordCollectionTests.cs | 2 +- .../Data/TextSearch/TextSearchFilter.cs | 2 - .../SemanticKernel.Abstractions.csproj | 3 + .../Data/TextSearch/VectorStoreTextSearch.cs | 4 +- .../SemanticKernel.Core.csproj | 3 +- .../Filter/AzureAISearchFilterFixture.cs | 3 - .../Filter/FilterFixtureBase.cs | 6 +- 28 files changed, 435 insertions(+), 402 deletions(-) diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs index 5d9dca826e28..1951f3a6dbee 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs @@ -4,135 +4,139 @@ // TODO: The user provides an expression tree accepting a TPublicRecord, but we require an expression tree accepting a TInternalRecord. // TODO: This is something that the user must provide, and is quite advanced. -// using System.Runtime.CompilerServices; -// using Microsoft.Extensions.VectorData; -// -// namespace Memory.VectorStoreLangchainInterop; -// -// /// -// /// Decorator class that allows conversion of keys and records between public and internal representations. -// /// -// /// -// /// This class is useful if a vector store implementation exposes keys or records in a way that is not -// /// suitable for the user of the vector store. E.g. let's say that the vector store supports Guid keys -// /// but you want to work with string keys that contain Guids. This class allows you to map between the -// /// public string Guids and the internal Guids. -// /// -// /// The type of the key that the user of this class will use. -// /// The type of the key that the internal collection exposes. -// /// The type of the record that the user of this class will use. -// /// The type of the record that the internal collection exposes. -// internal sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection -// where TPublicKey : notnull -// where TInternalKey : notnull -// { -// private readonly IVectorStoreRecordCollection _collection; -// private readonly Func _publicToInternalKeyMapper; -// private readonly Func _internalToPublicKeyMapper; -// private readonly Func _publicToInternalRecordMapper; -// private readonly Func _internalToPublicRecordMapper; -// -// public MappingVectorStoreRecordCollection( -// IVectorStoreRecordCollection collection, -// Func publicToInternalKeyMapper, -// Func internalToPublicKeyMapper, -// Func publicToInternalRecordMapper, -// Func internalToPublicRecordMapper) -// { -// this._collection = collection; -// this._publicToInternalKeyMapper = publicToInternalKeyMapper; -// this._internalToPublicKeyMapper = internalToPublicKeyMapper; -// this._publicToInternalRecordMapper = publicToInternalRecordMapper; -// this._internalToPublicRecordMapper = internalToPublicRecordMapper; -// } -// -// /// -// public string CollectionName => this._collection.CollectionName; -// -// /// -// public Task CollectionExistsAsync(CancellationToken cancellationToken = default) -// { -// return this._collection.CollectionExistsAsync(cancellationToken); -// } -// -// /// -// public Task CreateCollectionAsync(CancellationToken cancellationToken = default) -// { -// return this._collection.CreateCollectionAsync(cancellationToken); -// } -// -// /// -// public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) -// { -// return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); -// } -// -// /// -// public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) -// { -// return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); -// } -// -// /// -// public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) -// { -// return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); -// } -// -// /// -// public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) -// { -// return this._collection.DeleteCollectionAsync(cancellationToken); -// } -// -// /// -// public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) -// { -// var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); -// if (internalRecord == null) -// { -// return default; -// } -// -// return this._internalToPublicRecordMapper(internalRecord); -// } -// -// /// -// public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) -// { -// var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); -// return internalRecords.Select(this._internalToPublicRecordMapper); -// } -// -// /// -// public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) -// { -// var internalRecord = this._publicToInternalRecordMapper(record); -// var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); -// return this._internalToPublicKeyMapper(internalKey); -// } -// -// /// -// public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) -// { -// var internalRecords = records.Select(this._publicToInternalRecordMapper); -// var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); -// await foreach (var internalKey in internalKeys.ConfigureAwait(false)) -// { -// yield return this._internalToPublicKeyMapper(internalKey); -// } -// } -// -// /// -// public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) -// { -// var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); -// var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); -// -// return new VectorSearchResults(publicResultRecords) -// { -// TotalCount = searchResults.TotalCount, -// Metadata = searchResults.Metadata, -// }; -// } -// } +#if DISABLED + +using System.Runtime.CompilerServices; +using Microsoft.Extensions.VectorData; + +namespace Memory.VectorStoreLangchainInterop; + +/// +/// Decorator class that allows conversion of keys and records between public and internal representations. +/// +/// +/// This class is useful if a vector store implementation exposes keys or records in a way that is not +/// suitable for the user of the vector store. E.g. let's say that the vector store supports Guid keys +/// but you want to work with string keys that contain Guids. This class allows you to map between the +/// public string Guids and the internal Guids. +/// +/// The type of the key that the user of this class will use. +/// The type of the key that the internal collection exposes. +/// The type of the record that the user of this class will use. +/// The type of the record that the internal collection exposes. +internal sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection + where TPublicKey : notnull + where TInternalKey : notnull +{ + private readonly IVectorStoreRecordCollection _collection; + private readonly Func _publicToInternalKeyMapper; + private readonly Func _internalToPublicKeyMapper; + private readonly Func _publicToInternalRecordMapper; + private readonly Func _internalToPublicRecordMapper; + + public MappingVectorStoreRecordCollection( + IVectorStoreRecordCollection collection, + Func publicToInternalKeyMapper, + Func internalToPublicKeyMapper, + Func publicToInternalRecordMapper, + Func internalToPublicRecordMapper) + { + this._collection = collection; + this._publicToInternalKeyMapper = publicToInternalKeyMapper; + this._internalToPublicKeyMapper = internalToPublicKeyMapper; + this._publicToInternalRecordMapper = publicToInternalRecordMapper; + this._internalToPublicRecordMapper = internalToPublicRecordMapper; + } + + /// + public string CollectionName => this._collection.CollectionName; + + /// + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + return this._collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + return this._collection.CreateCollectionAsync(cancellationToken); + } + + /// + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); + } + + /// + public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) + { + return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + { + return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + return this._collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); + if (internalRecord == null) + { + return default; + } + + return this._internalToPublicRecordMapper(internalRecord); + } + + /// + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); + return internalRecords.Select(this._internalToPublicRecordMapper); + } + + /// + public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) + { + var internalRecord = this._publicToInternalRecordMapper(record); + var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); + return this._internalToPublicKeyMapper(internalKey); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var internalRecords = records.Select(this._publicToInternalRecordMapper); + var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); + await foreach (var internalKey in internalKeys.ConfigureAwait(false)) + { + yield return this._internalToPublicKeyMapper(internalKey); + } + } + + /// + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + { + var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); + var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); + + return new VectorSearchResults(publicResultRecords) + { + TotalCount = searchResults.TotalCount, + Metadata = searchResults.Metadata, + }; + } +} + +#endif diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs index d0f63727b471..f34fdc72e812 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/QdrantFactory.cs @@ -76,23 +76,24 @@ public IVectorStoreRecordCollection CreateVectorStoreRecordCollec return (collection as IVectorStoreRecordCollection)!; } - // TODO: See note on MappingVectorStoreRecordCollection - // // If the user asked for a string key, we can add a decorator which converts back and forth between string and guid. - // // The string that the user provides will still need to contain a valid guid, since the Langchain created collection - // // uses guid keys. - // // Supporting string keys like this is useful since it means you can work with the collection in the same way as with - // // collections from other vector stores that support string keys. - // if (typeof(TKey) == typeof(string) && typeof(TRecord) == typeof(LangchainDocument)) - // { - // var stringKeyCollection = new MappingVectorStoreRecordCollection, LangchainDocument>( - // collection, - // p => Guid.Parse(p), - // i => i.ToString("D"), - // p => new LangchainDocument { Key = Guid.Parse(p.Key), Content = p.Content, Source = p.Source, Embedding = p.Embedding }, - // i => new LangchainDocument { Key = i.Key.ToString("D"), Content = i.Content, Source = i.Source, Embedding = i.Embedding }); - // - // return (stringKeyCollection as IVectorStoreRecordCollection)!; - // } +#if DISABLED_FOR_NOW // TODO: See note on MappingVectorStoreRecordCollection + // If the user asked for a string key, we can add a decorator which converts back and forth between string and guid. + // The string that the user provides will still need to contain a valid guid, since the Langchain created collection + // uses guid keys. + // Supporting string keys like this is useful since it means you can work with the collection in the same way as with + // collections from other vector stores that support string keys. + if (typeof(TKey) == typeof(string) && typeof(TRecord) == typeof(LangchainDocument)) + { + var stringKeyCollection = new MappingVectorStoreRecordCollection, LangchainDocument>( + collection, + p => Guid.Parse(p), + i => i.ToString("D"), + p => new LangchainDocument { Key = Guid.Parse(p.Key), Content = p.Content, Source = p.Source, Embedding = p.Embedding }, + i => new LangchainDocument { Key = i.Key.ToString("D"), Content = i.Content, Source = i.Source, Embedding = i.Embedding }); + + return (stringKeyCollection as IVectorStoreRecordCollection)!; + } +#endif throw new NotSupportedException("This VectorStore is only usable with Guid keys and LangchainDocument record types or string keys and LangchainDocument record types"); } diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs b/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs index 7cf1363e3351..9b7e889b25dd 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs @@ -53,7 +53,6 @@ internal static async Task> SearchVectorStoreAsync( return searchResultItems.First(); } -#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Do a more complex vector search with pre-filtering. /// @@ -72,7 +71,7 @@ public async Task SearchAnInMemoryVectorStoreWithFilteringAsync() new() { Top = 1, - Filter = new VectorSearchFilter().EqualTo(nameof(Glossary.Category), "AI") + NewFilter = g => g.Category == "AI" }); var searchResultItems = await searchResult.Results.ToListAsync(); @@ -80,7 +79,6 @@ public async Task SearchAnInMemoryVectorStoreWithFilteringAsync() Console.WriteLine(searchResultItems.First().Record.Definition); Console.WriteLine(searchResultItems.First().Score); } -#pragma warning restore CS0618 // VectorSearchFilter is obsolete private async Task> GetVectorStoreCollectionWithDataAsync() { diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs b/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs index 9ca726f1fb97..35ca4822a824 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs @@ -1,198 +1,200 @@ // Copyright (c) Microsoft. All rights reserved. -// TODO: See note in MappingVectorStoreRecordCollection - -// using System.Runtime.CompilerServices; -// using Microsoft.Extensions.VectorData; -// using Microsoft.SemanticKernel.Connectors.Qdrant; -// using Qdrant.Client; -// -// namespace GettingStartedWithVectorStores; -// -// -// /// -// /// Example that shows that you can switch between different vector stores with the same code, in this case -// /// with a vector store that doesn't use string keys. -// /// This sample demonstrates one possible approach, however it is also possible to use generics -// /// in the common code to achieve code reuse. -// /// -// public class Step4_NonStringKey_VectorStore(ITestOutputHelper output, VectorStoresFixture fixture) : BaseTest(output), IClassFixture -// { -// /// -// /// Here we are going to use the same code that we used in and -// /// but now with an . -// /// Qdrant uses Guid or ulong as the key type, but the common code works with a string key. The string keys of the records created -// /// in contain numbers though, so it's possible for us to convert them to ulong. -// /// In this example, we'll demonstrate how to do that. -// /// -// /// This example requires a Qdrant server up and running. To run a Qdrant server in a Docker container, use the following command: -// /// docker run -d --name qdrant -p 6333:6333 -p 6334:6334 qdrant/qdrant:latest -// /// -// [Fact] -// public async Task UseAQdrantVectorStoreAsync() -// { -// // Construct a Qdrant vector store collection. -// var collection = new QdrantVectorStoreRecordCollection(new QdrantClient("localhost"), "skglossary"); -// -// // Wrap the collection using a decorator that allows us to expose a version that uses string keys, but internally -// // we convert to and from ulong. -// var stringKeyCollection = new MappingVectorStoreRecordCollection( -// collection, -// p => ulong.Parse(p), -// i => i.ToString(), -// p => new UlongGlossary { Key = ulong.Parse(p.Key), Category = p.Category, Term = p.Term, Definition = p.Definition, DefinitionEmbedding = p.DefinitionEmbedding }, -// i => new Glossary { Key = i.Key.ToString("D"), Category = i.Category, Term = i.Term, Definition = i.Definition, DefinitionEmbedding = i.DefinitionEmbedding }); -// -// // Ingest data into the collection using the same code as we used in Step1 with the InMemory Vector Store. -// await Step1_Ingest_Data.IngestDataIntoVectorStoreAsync(stringKeyCollection, fixture.TextEmbeddingGenerationService); -// -// // Search the vector store using the same code as we used in Step2 with the InMemory Vector Store. -// var searchResultItem = await Step2_Vector_Search.SearchVectorStoreAsync( -// stringKeyCollection, -// "What is an Application Programming Interface?", -// fixture.TextEmbeddingGenerationService); -// -// // Write the search result with its score to the console. -// Console.WriteLine(searchResultItem.Record.Definition); -// Console.WriteLine(searchResultItem.Score); -// } -// -// /// -// /// Data model that uses a ulong as the key type instead of a string. -// /// -// private sealed class UlongGlossary -// { -// [VectorStoreRecordKey] -// public ulong Key { get; set; } -// -// [VectorStoreRecordData(IsFilterable = true)] -// public string Category { get; set; } -// -// [VectorStoreRecordData] -// public string Term { get; set; } -// -// [VectorStoreRecordData] -// public string Definition { get; set; } -// -// [VectorStoreRecordVector(Dimensions: 1536)] -// public ReadOnlyMemory DefinitionEmbedding { get; set; } -// } -// -// /// -// /// Simple decorator class that allows conversion of keys and records from one type to another. -// /// -// private sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection -// where TPublicKey : notnull -// where TInternalKey : notnull -// { -// private readonly IVectorStoreRecordCollection _collection; -// private readonly Func _publicToInternalKeyMapper; -// private readonly Func _internalToPublicKeyMapper; -// private readonly Func _publicToInternalRecordMapper; -// private readonly Func _internalToPublicRecordMapper; -// -// public MappingVectorStoreRecordCollection( -// IVectorStoreRecordCollection collection, -// Func publicToInternalKeyMapper, -// Func internalToPublicKeyMapper, -// Func publicToInternalRecordMapper, -// Func internalToPublicRecordMapper) -// { -// this._collection = collection; -// this._publicToInternalKeyMapper = publicToInternalKeyMapper; -// this._internalToPublicKeyMapper = internalToPublicKeyMapper; -// this._publicToInternalRecordMapper = publicToInternalRecordMapper; -// this._internalToPublicRecordMapper = internalToPublicRecordMapper; -// } -// -// /// -// public string CollectionName => this._collection.CollectionName; -// -// /// -// public Task CollectionExistsAsync(CancellationToken cancellationToken = default) -// { -// return this._collection.CollectionExistsAsync(cancellationToken); -// } -// -// /// -// public Task CreateCollectionAsync(CancellationToken cancellationToken = default) -// { -// return this._collection.CreateCollectionAsync(cancellationToken); -// } -// -// /// -// public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) -// { -// return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); -// } -// -// /// -// public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) -// { -// return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); -// } -// -// /// -// public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) -// { -// return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); -// } -// -// /// -// public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) -// { -// return this._collection.DeleteCollectionAsync(cancellationToken); -// } -// -// /// -// public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) -// { -// var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); -// if (internalRecord == null) -// { -// return default; -// } -// -// return this._internalToPublicRecordMapper(internalRecord); -// } -// -// /// -// public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) -// { -// var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); -// return internalRecords.Select(this._internalToPublicRecordMapper); -// } -// -// /// -// public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) -// { -// var internalRecord = this._publicToInternalRecordMapper(record); -// var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); -// return this._internalToPublicKeyMapper(internalKey); -// } -// -// /// -// public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) -// { -// var internalRecords = records.Select(this._publicToInternalRecordMapper); -// var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); -// await foreach (var internalKey in internalKeys.ConfigureAwait(false)) -// { -// yield return this._internalToPublicKeyMapper(internalKey); -// } -// } -// -// /// -// public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) -// { -// var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); -// var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); -// -// return new VectorSearchResults(publicResultRecords) -// { -// TotalCount = searchResults.TotalCount, -// Metadata = searchResults.Metadata, -// }; -// } -// } -// } +#if DISABLED_FOR_NOW // TODO: See note in MappingVectorStoreRecordCollection + +using System.Runtime.CompilerServices; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Qdrant.Client; + +namespace GettingStartedWithVectorStores; + + +/// +/// Example that shows that you can switch between different vector stores with the same code, in this case +/// with a vector store that doesn't use string keys. +/// This sample demonstrates one possible approach, however it is also possible to use generics +/// in the common code to achieve code reuse. +/// +public class Step4_NonStringKey_VectorStore(ITestOutputHelper output, VectorStoresFixture fixture) : BaseTest(output), IClassFixture +{ + /// + /// Here we are going to use the same code that we used in and + /// but now with an . + /// Qdrant uses Guid or ulong as the key type, but the common code works with a string key. The string keys of the records created + /// in contain numbers though, so it's possible for us to convert them to ulong. + /// In this example, we'll demonstrate how to do that. + /// + /// This example requires a Qdrant server up and running. To run a Qdrant server in a Docker container, use the following command: + /// docker run -d --name qdrant -p 6333:6333 -p 6334:6334 qdrant/qdrant:latest + /// + [Fact] + public async Task UseAQdrantVectorStoreAsync() + { + // Construct a Qdrant vector store collection. + var collection = new QdrantVectorStoreRecordCollection(new QdrantClient("localhost"), "skglossary"); + + // Wrap the collection using a decorator that allows us to expose a version that uses string keys, but internally + // we convert to and from ulong. + var stringKeyCollection = new MappingVectorStoreRecordCollection( + collection, + p => ulong.Parse(p), + i => i.ToString(), + p => new UlongGlossary { Key = ulong.Parse(p.Key), Category = p.Category, Term = p.Term, Definition = p.Definition, DefinitionEmbedding = p.DefinitionEmbedding }, + i => new Glossary { Key = i.Key.ToString("D"), Category = i.Category, Term = i.Term, Definition = i.Definition, DefinitionEmbedding = i.DefinitionEmbedding }); + + // Ingest data into the collection using the same code as we used in Step1 with the InMemory Vector Store. + await Step1_Ingest_Data.IngestDataIntoVectorStoreAsync(stringKeyCollection, fixture.TextEmbeddingGenerationService); + + // Search the vector store using the same code as we used in Step2 with the InMemory Vector Store. + var searchResultItem = await Step2_Vector_Search.SearchVectorStoreAsync( + stringKeyCollection, + "What is an Application Programming Interface?", + fixture.TextEmbeddingGenerationService); + + // Write the search result with its score to the console. + Console.WriteLine(searchResultItem.Record.Definition); + Console.WriteLine(searchResultItem.Score); + } + + /// + /// Data model that uses a ulong as the key type instead of a string. + /// + private sealed class UlongGlossary + { + [VectorStoreRecordKey] + public ulong Key { get; set; } + + [VectorStoreRecordData(IsFilterable = true)] + public string Category { get; set; } + + [VectorStoreRecordData] + public string Term { get; set; } + + [VectorStoreRecordData] + public string Definition { get; set; } + + [VectorStoreRecordVector(Dimensions: 1536)] + public ReadOnlyMemory DefinitionEmbedding { get; set; } + } + + /// + /// Simple decorator class that allows conversion of keys and records from one type to another. + /// + private sealed class MappingVectorStoreRecordCollection : IVectorStoreRecordCollection + where TPublicKey : notnull + where TInternalKey : notnull + { + private readonly IVectorStoreRecordCollection _collection; + private readonly Func _publicToInternalKeyMapper; + private readonly Func _internalToPublicKeyMapper; + private readonly Func _publicToInternalRecordMapper; + private readonly Func _internalToPublicRecordMapper; + + public MappingVectorStoreRecordCollection( + IVectorStoreRecordCollection collection, + Func publicToInternalKeyMapper, + Func internalToPublicKeyMapper, + Func publicToInternalRecordMapper, + Func internalToPublicRecordMapper) + { + this._collection = collection; + this._publicToInternalKeyMapper = publicToInternalKeyMapper; + this._internalToPublicKeyMapper = internalToPublicKeyMapper; + this._publicToInternalRecordMapper = publicToInternalRecordMapper; + this._internalToPublicRecordMapper = internalToPublicRecordMapper; + } + + /// + public string CollectionName => this._collection.CollectionName; + + /// + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + return this._collection.CollectionExistsAsync(cancellationToken); + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + return this._collection.CreateCollectionAsync(cancellationToken); + } + + /// + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + return this._collection.CreateCollectionIfNotExistsAsync(cancellationToken); + } + + /// + public Task DeleteAsync(TPublicKey key, CancellationToken cancellationToken = default) + { + return this._collection.DeleteAsync(this._publicToInternalKeyMapper(key), cancellationToken); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + { + return this._collection.DeleteBatchAsync(keys.Select(this._publicToInternalKeyMapper), cancellationToken); + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + return this._collection.DeleteCollectionAsync(cancellationToken); + } + + /// + public async Task GetAsync(TPublicKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var internalRecord = await this._collection.GetAsync(this._publicToInternalKeyMapper(key), options, cancellationToken).ConfigureAwait(false); + if (internalRecord == null) + { + return default; + } + + return this._internalToPublicRecordMapper(internalRecord); + } + + /// + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var internalRecords = this._collection.GetBatchAsync(keys.Select(this._publicToInternalKeyMapper), options, cancellationToken); + return internalRecords.Select(this._internalToPublicRecordMapper); + } + + /// + public async Task UpsertAsync(TPublicRecord record, CancellationToken cancellationToken = default) + { + var internalRecord = this._publicToInternalRecordMapper(record); + var internalKey = await this._collection.UpsertAsync(internalRecord, cancellationToken).ConfigureAwait(false); + return this._internalToPublicKeyMapper(internalKey); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var internalRecords = records.Select(this._publicToInternalRecordMapper); + var internalKeys = this._collection.UpsertBatchAsync(internalRecords, cancellationToken); + await foreach (var internalKey in internalKeys.ConfigureAwait(false)) + { + yield return this._internalToPublicKeyMapper(internalKey); + } + } + + /// + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + { + var searchResults = await this._collection.VectorizedSearchAsync(vector, options, cancellationToken).ConfigureAwait(false); + var publicResultRecords = searchResults.Results.Select(result => new VectorSearchResult(this._internalToPublicRecordMapper(result.Record), result.Score)); + + return new VectorSearchResults(publicResultRecords) + { + TotalCount = searchResults.TotalCount, + Metadata = searchResults.Metadata, + }; + } + } +} + +#endif diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs index f1ab2fc75f16..4c77dc161414 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs @@ -183,10 +183,15 @@ public void BuildSearchQueryWithoutFilterDoesNotContainWhereClause() public void BuildSelectQueryByDefaultReturnsValidQueryDefinition() { // Arrange - const string ExpectedQueryText = "" + - "SELECT x.key,x.property_1,x.property_2 " + - "FROM x " + - "WHERE (x.key_property = @rk0 AND x.partition_key_property = @pk0) "; + // const string ExpectedQueryText = "" + + // "SELECT x.key,x.property_1,x.property_2 " + + // "FROM x " + + // "WHERE (x.key_property = @rk0 AND x.partition_key_property = @pk0) "; + const string ExpectedQueryText = """ + SELECT x.key,x.property_1,x.property_2 + FROM x + WHERE (x.key_property = @rk0 AND x.partition_key_property = @pk0) + """; const string KeyStoragePropertyName = "key_property"; const string PartitionKeyPropertyName = "partition_key_property"; diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Connectors.AzureOpenAI.csproj b/dotnet/src/Connectors/Connectors.AzureOpenAI/Connectors.AzureOpenAI.csproj index 15d88496159b..9fcbdecf530e 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI/Connectors.AzureOpenAI.csproj +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Connectors.AzureOpenAI.csproj @@ -35,4 +35,5 @@ + \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs index b87183cce8c1..16164c2a3eca 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs @@ -11,7 +11,7 @@ using System.Runtime.CompilerServices; using System.Text; -namespace Microsoft.SemanticKernel.Connectors.Postgres; +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; internal class AzureAISearchFilterTranslator { @@ -20,14 +20,17 @@ internal class AzureAISearchFilterTranslator private readonly StringBuilder _filter = new(); + private static readonly char[] s_searchInDefaultDelimiter = [' ', ',']; + internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) { + Debug.Assert(this._filter.Length == 0); + this._storagePropertyNames = storagePropertyNames; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this._filter.Clear(); this.Translate(lambdaExpression.Body); return this._filter.ToString(); } @@ -226,26 +229,16 @@ void ProcessInlineEnumerable(IEnumerable elements, Expression item) throw new NotSupportedException("Contains over non-string arrays is not supported"); } - // The default delimiter for search.in() is comma or space. - // If any element contains a comma or space, we switch to using pipe as the delimiter. - // If any contains a pipe, we throw (for now). - var delimiter = ", "; - if (elements.Cast().Any(s => s.Contains(' ') || s.Contains(','))) - { - if (elements.Cast().Any(s => s.Contains('|'))) - { - throw new NotSupportedException(""); - } - - delimiter = "|"; - } - this._filter.Append("search.in("); this.Translate(item); this._filter.Append(", '"); + string delimiter = ", "; + var startingPosition = this._filter.Length; + +RestartLoop: var isFirst = true; - foreach (var element in elements.Cast()) + foreach (string element in elements) { if (isFirst) { @@ -256,6 +249,30 @@ void ProcessInlineEnumerable(IEnumerable elements, Expression item) this._filter.Append(delimiter); } + // The default delimiter for search.in() is comma or space. + // If any element contains a comma or space, we switch to using pipe as the delimiter. + // If any contains a pipe, we throw (for now). + switch (delimiter) + { + case ", ": + if (element.IndexOfAny(s_searchInDefaultDelimiter) > -1) + { + delimiter = "|"; + this._filter.Length = startingPosition; + goto RestartLoop; + } + + break; + + case "|": + if (element.Contains('|')) + { + throw new NotSupportedException("Some elements contain both commas/spaces and pipes, cannot translate Contains"); + } + + break; + } + this._filter.Append(element.Replace("'", "''")); } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs index c3b338b816ad..9e92f5bbb722 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs @@ -15,7 +15,6 @@ using Azure.Search.Documents.Indexes.Models; using Azure.Search.Documents.Models; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Postgres; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs index 30019b97a1e1..e18f176c2ea7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs @@ -18,19 +18,17 @@ internal class AzureCosmosDBNoSqlFilterTranslator private ParameterExpression _recordParameter = null!; private readonly Dictionary _parameters = new(); - private readonly StringBuilder _sql = new(); internal (string WhereClause, Dictionary Parameters) Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) { - this._storagePropertyNames = storagePropertyNames; + Debug.Assert(this._sql.Length == 0); - this._parameters.Clear(); + this._storagePropertyNames = storagePropertyNames; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this._sql.Clear(); this.Translate(lambdaExpression.Body); return (this._sql.ToString(), this._parameters); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs index c1cf9f3633b9..6c68527da5c1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs @@ -27,15 +27,15 @@ internal class PostgresFilterTranslator LambdaExpression lambdaExpression, int startParamIndex) { + Debug.Assert(this._sql.Length == 0); + this._storagePropertyNames = storagePropertyNames; this._parameterIndex = startParamIndex; - this._parameterValues.Clear(); Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this._sql.Clear(); this._sql.Append("WHERE "); this.Translate(lambdaExpression.Body); return (this._sql.ToString(), this._parameterValues); diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs index 12a28b050c15..ec5bcd73514f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs @@ -20,12 +20,13 @@ internal class RedisFilterTranslator internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) { + Debug.Assert(this._filter.Length == 0); + this._storagePropertyNames = storagePropertyNames; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this._filter.Clear(); this.Translate(lambdaExpression.Body); return this._filter.ToString(); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs index 65e6e3d4dce2..2cb6b16fc8cd 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs @@ -24,14 +24,13 @@ internal class SqliteFilterTranslator internal (string Clause, Dictionary) Translate(IReadOnlyDictionary storagePropertyNames, LambdaExpression lambdaExpression) { - this._storagePropertyNames = storagePropertyNames; + Debug.Assert(this._sql.Length == 0); - this._parameters.Clear(); + this._storagePropertyNames = storagePropertyNames; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this._sql.Clear(); this.Translate(lambdaExpression.Body); return (this._sql.ToString(), this._parameters); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs index 802f468e15c3..837e3044ddc7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Data.Common; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text; @@ -277,7 +278,8 @@ private static string GetColumnDefinition(SqliteColumn column) whereClause += extraWhereFilter; - foreach (var p in extraParameters!) + Debug.Assert(extraParameters is not null, "extraParameters must be provided when extraWhereFilter is provided."); + foreach (var p in extraParameters) { command.Parameters.Add(new SqliteParameter(p.Key, p.Value)); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs index 8bd7780929b7..2e4be5391159 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs @@ -22,12 +22,13 @@ internal class WeaviateFilterTranslator internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) { + Debug.Assert(this._filter.Length == 0); + this._storagePropertyNames = storagePropertyNames; Debug.Assert(lambdaExpression.Parameters.Count == 1); this._recordParameter = lambdaExpression.Parameters[0]; - this._filter.Clear(); this.Translate(lambdaExpression.Body); return this._filter.ToString(); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs index d03e2cf83a2e..fe8e965f67e3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs @@ -74,7 +74,6 @@ public sealed class WeaviateVectorStoreRecordCollection : IVectorStoreR private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - // DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, Converters = { new WeaviateDateTimeOffsetConverter(), diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj index 68fbec524a28..0f884f0df59c 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj @@ -39,4 +39,10 @@ + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs index 9d79fd640a33..370756cb4344 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs @@ -233,6 +233,8 @@ public void ItBuildsSelectLeftJoinCommand(string? orderByPropertyName) leftTablePropertyNames, rightTablePropertyNames, conditions, + extraWhereFilter: null, + extraParameters: null, orderByPropertyName); // Assert diff --git a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/AnyTagEqualToFilterClause.cs b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/AnyTagEqualToFilterClause.cs index 0e001bc0cfae..49ffce328e5e 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/AnyTagEqualToFilterClause.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/AnyTagEqualToFilterClause.cs @@ -1,13 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. -using System; - namespace Microsoft.Extensions.VectorData; /// /// which filters by checking if a field consisting of a list of values contains a specific value. /// -[Obsolete("Use VectorSearchOptions.NewFilter instead of VectorSearchOptions.Filter")] public sealed class AnyTagEqualToFilterClause : FilterClause { /// diff --git a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/EqualToFilterClause.cs b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/EqualToFilterClause.cs index cef8a9670276..a0eb45c0fbe3 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/EqualToFilterClause.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/EqualToFilterClause.cs @@ -1,13 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. -using System; - namespace Microsoft.Extensions.VectorData; /// /// which filters using equality of a field value. /// -[Obsolete("Use VectorSearchOptions.NewFilter instead of VectorSearchOptions.Filter")] public sealed class EqualToFilterClause : FilterClause { /// diff --git a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/FilterClause.cs b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/FilterClause.cs index 40c7b291fd10..4392893f16e3 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/FilterClause.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/FilterClauses/FilterClause.cs @@ -1,7 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; - namespace Microsoft.Extensions.VectorData; /// @@ -11,7 +9,6 @@ namespace Microsoft.Extensions.VectorData; /// A is used to request that the underlying search service should /// filter search results based on the specified criteria. /// -[Obsolete("Use VectorSearchOptions.NewFilter instead of VectorSearchOptions.Filter")] public abstract class FilterClause { internal FilterClause() diff --git a/dotnet/src/Functions/Functions.OpenApi/Functions.OpenApi.csproj b/dotnet/src/Functions/Functions.OpenApi/Functions.OpenApi.csproj index 2d37b88dca4a..1d72c971fcba 100644 --- a/dotnet/src/Functions/Functions.OpenApi/Functions.OpenApi.csproj +++ b/dotnet/src/Functions/Functions.OpenApi/Functions.OpenApi.csproj @@ -29,4 +29,8 @@ + + + + \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs index fdf07a1acd43..3f88b10eef4b 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs @@ -20,7 +20,7 @@ public class MongoDBVectorStoreRecordCollectionTests(MongoDBVectorStoreFixture f // If null, all tests will be enabled private const string? SkipReason = "The tests are for manual verification."; - [Theory] + [Theory(Skip = SkipReason)] [InlineData("sk-test-hotels", true)] [InlineData("nonexistentcollection", false)] public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/TextSearch/TextSearchFilter.cs b/dotnet/src/SemanticKernel.Abstractions/Data/TextSearch/TextSearchFilter.cs index bb679eb7573b..d964fb1ecba1 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Data/TextSearch/TextSearchFilter.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Data/TextSearch/TextSearchFilter.cs @@ -6,8 +6,6 @@ namespace Microsoft.SemanticKernel.Data; -#pragma warning disable CS0618 // FilterClause is obsolete - /// /// Used to provide filtering when using . /// diff --git a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj index 235c08e4d52b..47043cbe1df8 100644 --- a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj +++ b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj @@ -57,6 +57,9 @@ + + + diff --git a/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs b/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs index 3cf8528ea169..42781b1c5483 100644 --- a/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs +++ b/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs @@ -188,7 +188,6 @@ private TextSearchStringMapper CreateTextSearchStringMapper() }); } -#pragma warning disable CS0618 // FilterClause is obsolete /// /// Execute a vector search and return the results. /// @@ -200,7 +199,9 @@ private async Task> ExecuteVectorSearchAsync(string searchOptions ??= new TextSearchOptions(); var vectorSearchOptions = new VectorSearchOptions { +#pragma warning disable CS0618 // VectorSearchFilter is obsolete Filter = searchOptions.Filter?.FilterClauses is not null ? new VectorSearchFilter(searchOptions.Filter.FilterClauses) : null, +#pragma warning restore CS0618 // VectorSearchFilter is obsolete Skip = searchOptions.Skip, Top = searchOptions.Top, }; @@ -214,7 +215,6 @@ private async Task> ExecuteVectorSearchAsync(string return await this._vectorizableTextSearch!.VectorizableTextSearchAsync(query, vectorSearchOptions, cancellationToken).ConfigureAwait(false); } -#pragma warning restore CS0618 // FilterClause is obsolete /// /// Return the search results as instances of TRecord. diff --git a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj index c4c4956a3fa8..14aac96c6b73 100644 --- a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj +++ b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj @@ -8,6 +8,7 @@ true true $(NoWarn);SKEXP0001,SKEXP0120 + true true @@ -54,5 +55,5 @@ - + \ No newline at end of file diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs index 0dfc4ce4238e..a5ec5df341dd 100644 --- a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs @@ -12,7 +12,4 @@ public class AzureAISearchFilterFixture : FilterFixtureBase // Azure AI search only supports lowercase letters, digits or dashes. protected override string StoreName => "filter-tests"; - - public override async Task DisposeAsync() - => await base.DisposeAsync(); } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs index 9af82882a7fb..9b1641e5799a 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs @@ -40,7 +40,11 @@ public virtual async Task InitializeAsync() { var results = await this.Collection.VectorizedSearchAsync( new ReadOnlyMemory([1, 2, 3]), - new() { Top = this.TestData.Count }); + new() + { + Top = this.TestData.Count, + NewFilter = r => r.Int > 0 + }); var count = await results.Results.CountAsync(); if (count == this.TestData.Count) { From d4dc9698db0472c78f54e76cec26a2e7b4a9bd81 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 11 Feb 2025 11:41:37 +0100 Subject: [PATCH 5/8] Target .NET Framework --- dotnet/Directory.Build.props | 5 +++++ .../AzureAISearchIntegrationTests.csproj | 2 +- .../CosmosMongoDBIntegrationTests.csproj | 2 +- .../CosmosNoSQLIntegrationTests.csproj | 2 +- .../Support/CosmosNoSQLTestStore.cs | 1 + .../src/VectorDataIntegrationTests/Directory.Build.props | 5 +++++ .../InMemoryIntegrationTests.csproj | 2 +- .../MongoDBIntegrationTests.csproj | 2 +- .../MongoDBIntegrationTests/Support/MongoDBTestStore.cs | 3 ++- .../PostgresIntegrationTests.csproj | 2 +- .../PostgresIntegrationTests/Support/PostgresTestStore.cs | 5 +++-- .../QdrantIntegrationTests/QdrantIntegrationTests.csproj | 2 +- .../QdrantIntegrationTests/Support/QdrantTestStore.cs | 3 ++- .../RedisIntegrationTests/RedisIntegrationTests.csproj | 2 +- .../RedisIntegrationTests/Support/RedisTestStore.cs | 3 ++- .../SqliteIntegrationTests/SqliteIntegrationTests.csproj | 2 +- .../SqliteIntegrationTests/Support/SqliteTestStore.cs | 8 ++++++++ .../Filter/FilterFixtureBase.cs | 2 +- .../VectorDataIntegrationTests.csproj | 6 +++++- .../WeaviateIntegrationTests/Support/WeaviateTestStore.cs | 4 +++- .../WeaviateIntegrationTests.csproj | 2 +- 21 files changed, 47 insertions(+), 18 deletions(-) diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props index 94d748c78057..c15c377086ab 100644 --- a/dotnet/Directory.Build.props +++ b/dotnet/Directory.Build.props @@ -25,6 +25,11 @@ True + + + $(NoWarn);CS8604;CS8602 + + $([System.IO.Path]::GetDirectoryName($([MSBuild]::GetPathOfFileAbove('.gitignore', '$(MSBuildThisFileDirectory)')))) diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj index c575ad645e31..0fcc13f45809 100644 --- a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj index dbd8e1d18a24..aaf0dcf8160b 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj index 782d7e328e44..dd8e3f7a9ba0 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs index 18f58717d461..392924be8f78 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Net.Http; using System.Text.Json; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.VectorData; diff --git a/dotnet/src/VectorDataIntegrationTests/Directory.Build.props b/dotnet/src/VectorDataIntegrationTests/Directory.Build.props index 1cfeef0ee289..f5d133b5fd9f 100644 --- a/dotnet/src/VectorDataIntegrationTests/Directory.Build.props +++ b/dotnet/src/VectorDataIntegrationTests/Directory.Build.props @@ -12,4 +12,9 @@ $(NoWarn);IDE1006 + + + $(NoWarn);CS8604;CS8602 + + diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj index 4c6988e72e3d..f77fff8de939 100644 --- a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj index 17d39c913623..6aa9923ffaa2 100644 --- a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs index 7ea67c46d0c7..10ee96b890b6 100644 --- a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Driver; using Testcontainers.MongoDb; @@ -22,7 +23,7 @@ internal sealed class MongoDBTestStore : TestStore public MongoClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); public IMongoDatabase Database => this._database ?? throw new InvalidOperationException("Not initialized"); - public override MongoDBVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); public MongoDBVectorStore GetVectorStore(MongoDBVectorStoreOptions options) => new(this.Database, options); diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj index 316ddd7225c1..0a039793dc49 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs index 9bd067b74e6e..1d4c540c216a 100644 --- a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Postgres; using Npgsql; using Testcontainers.PostgreSql; @@ -22,7 +23,7 @@ internal sealed class PostgresTestStore : TestStore public NpgsqlDataSource DataSource => this._dataSource ?? throw new InvalidOperationException("Not initialized"); - public override PostgresVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); public PostgresVectorStore GetVectorStore(PostgresVectorStoreOptions options) => new(this.DataSource, options); @@ -53,7 +54,7 @@ protected override async Task StartAsync() await using var connection = this._dataSource.CreateConnection(); await connection.OpenAsync(); - await using var command = new NpgsqlCommand("CREATE EXTENSION IF NOT EXISTS vector", connection); + using var command = new NpgsqlCommand("CREATE EXTENSION IF NOT EXISTS vector", connection); await command.ExecuteNonQueryAsync(); await connection.ReloadTypesAsync(); diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj index 8eb6273cbee3..0ea8db51c21d 100644 --- a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs index 52736ac02681..3537cf8c64e9 100644 --- a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Qdrant; using Qdrant.Client; using QdrantIntegrationTests.Support.TestContainer; @@ -19,7 +20,7 @@ internal sealed class QdrantTestStore : TestStore public QdrantClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); - public override QdrantVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); public QdrantVectorStore GetVectorStore(QdrantVectorStoreOptions options) => new(this.Client, options); diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj index b3661a620b41..5727b3b2650a 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs index c0e2336ed9e5..a1dd2f02c0bc 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Redis; using StackExchange.Redis; using Testcontainers.Redis; @@ -20,7 +21,7 @@ internal sealed class RedisTestStore : TestStore public IDatabase Database => this._database ?? throw new InvalidOperationException("Not initialized"); - public override RedisVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); public RedisVectorStore GetVectorStore(RedisVectorStoreOptions options) => new(this.Database, options); diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj index 3b1a18810500..a47480e526cd 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs index 48b884b08ade..526eeac3b2d8 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs @@ -42,6 +42,14 @@ protected override async Task StartAsync() this._defaultVectorStore = new SqliteVectorStore(this.Connection); } +#if NET8_0_OR_GREATER protected override async Task StopAsync() => await this.Connection.DisposeAsync(); +#else + protected override Task StopAsync() + { + this.Connection.Dispose(); + return Task.CompletedTask; + } +#endif } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs index 9b1641e5799a..68274beeedd1 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs @@ -174,7 +174,7 @@ public virtual Task DisposeAsync() #pragma warning disable CA1819 // Properties should not return arrays public class FilterRecord { - public TKey Key { get; init; } = default!; + public TKey Key { get; set; } = default!; public ReadOnlyMemory? Vector { get; set; } public int Int { get; set; } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj index 9949700b9fce..bb7aa1bf3497 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true @@ -18,4 +18,8 @@ + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs index 21e079421674..8fed1f0dc042 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Net.Http; +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Weaviate; using VectorDataSpecificationTests.Support; using WeaviateIntegrationTests.Support.TestContainer; @@ -16,7 +18,7 @@ public sealed class WeaviateTestStore : TestStore public HttpClient Client => this._httpClient ?? throw new InvalidOperationException("Not initialized"); - public override WeaviateVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); public WeaviateVectorStore GetVectorStore(WeaviateVectorStoreOptions options) => new(this.Client, options); diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj index d99bde7558c3..eb98407f35ee 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net8.0;net472 enable enable true From 161ee585b7ac97c38f8ce0f5f67881cf013b8580 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 11 Feb 2025 14:24:33 +0100 Subject: [PATCH 6/8] Final formatting fixes --- .../Support/CosmosMongoDBTestStore.cs | 4 ++-- .../Support/CosmosNoSQLTestStore.cs | 2 ++ .../VectorDataIntegrationTests.csproj | 3 +-- .../WeaviateIntegrationTests/Support/WeaviateTestStore.cs | 2 ++ 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs index 8432901efb73..b0d4c379ecf4 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs @@ -12,8 +12,8 @@ public sealed class CosmosMongoDBTestStore : TestStore { public static CosmosMongoDBTestStore Instance { get; } = new(); - public MongoClient? _client { get; private set; } - public IMongoDatabase? _database { get; private set; } + private MongoClient? _client; + private IMongoDatabase? _database; private AzureCosmosDBMongoDBVectorStore? _defaultVectorStore; public MongoClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs index 392924be8f78..7e3269ba2a27 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +#if NET472 using System.Net.Http; +#endif using System.Text.Json; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.VectorData; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj index bb7aa1bf3497..de93af95f360 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj @@ -4,8 +4,7 @@ net8.0;net472 enable enable - true - false + true VectorDataSpecificationTests diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs index 8fed1f0dc042..d112a2abfe49 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +#if NET472 using System.Net.Http; +#endif using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Weaviate; using VectorDataSpecificationTests.Support; From eb07d55581af838cf08a149418d0239785656a95 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 11 Feb 2025 17:30:20 +0100 Subject: [PATCH 7/8] Make non-package again --- .../VectorDataIntegrationTests.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj index de93af95f360..77fc8e90dbb2 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj @@ -4,7 +4,7 @@ net8.0;net472 enable enable - true + false VectorDataSpecificationTests From de13cffb92e7cccba3a4d4d5a53907c8e023289e Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 11 Feb 2025 23:40:53 +0100 Subject: [PATCH 8/8] Address review comments, implement Redis HashSet tests --- dotnet/Directory.Build.props | 2 +- dotnet/SK-dotnet.sln | 4 +- ...LVectorStoreCollectionQueryBuilderTests.cs | 4 -- ...RedisHashSetVectorStoreRecordCollection.cs | 1 + .../SemanticKernel.Core.csproj | 1 - .../Filter/CosmosMongoBasicFilterTests.cs | 4 +- .../Filter/MongoDBBasicFilterTests.cs | 4 +- .../Filter/RedisBasicFilterTests.cs | 38 +++++++++++++- .../Filter/RedisFilterFixture.cs | 49 ++++++++++++++++++- .../Filter/BasicFilterTestsBase.cs | 2 +- .../Filter/FilterFixtureBase.cs | 5 +- .../Filter/WeaviateBasicFilterTests.cs | 4 +- 12 files changed, 98 insertions(+), 20 deletions(-) diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props index c15c377086ab..13e279d799d6 100644 --- a/dotnet/Directory.Build.props +++ b/dotnet/Directory.Build.props @@ -25,7 +25,7 @@ True - + $(NoWarn);CS8604;CS8602 diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 6b4dae547138..e1953ea0bf7e 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -1206,8 +1206,8 @@ Global {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Release|Any CPU.Build.0 = Release|Any CPU {B29A972F-A774-4140-AECF-6B577C476627}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {B29A972F-A774-4140-AECF-6B577C476627}.Debug|Any CPU.Build.0 = Debug|Any CPU - {B29A972F-A774-4140-AECF-6B577C476627}.Publish|Any CPU.ActiveCfg = Debug|Any CPU - {B29A972F-A774-4140-AECF-6B577C476627}.Publish|Any CPU.Build.0 = Debug|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Publish|Any CPU.Build.0 = Publish|Any CPU {B29A972F-A774-4140-AECF-6B577C476627}.Release|Any CPU.ActiveCfg = Release|Any CPU {B29A972F-A774-4140-AECF-6B577C476627}.Release|Any CPU.Build.0 = Release|Any CPU {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Debug|Any CPU.ActiveCfg = Debug|Any CPU diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs index 4c77dc161414..37aa005777d5 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs @@ -183,10 +183,6 @@ public void BuildSearchQueryWithoutFilterDoesNotContainWhereClause() public void BuildSelectQueryByDefaultReturnsValidQueryDefinition() { // Arrange - // const string ExpectedQueryText = "" + - // "SELECT x.key,x.property_1,x.property_2 " + - // "FROM x " + - // "WHERE (x.key_property = @rk0 AND x.partition_key_property = @pk0) "; const string ExpectedQueryText = """ SELECT x.key,x.property_1,x.property_2 FROM x diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs index fb565cad17d8..2a5d324e0171 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs @@ -300,6 +300,7 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancella // Upsert. var maybePrefixedKey = this.PrefixKeyIfNeeded(redisHashSetRecord.Key); + await this.RunOperationAsync( "HSET", () => this._database diff --git a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj index 14aac96c6b73..268c2e470314 100644 --- a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj +++ b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj @@ -8,7 +8,6 @@ true true $(NoWarn);SKEXP0001,SKEXP0120 - true true diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs index deed1197b728..33d14908f537 100644 --- a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs @@ -22,8 +22,8 @@ public override Task Equal_with_null_reference_type() public override Task Equal_with_null_captured() => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); - public override Task NotEqual_with_null_referenceType() - => Assert.ThrowsAsync(() => base.NotEqual_with_null_referenceType()); + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); public override Task NotEqual_with_null_captured() => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs index f1a37114e6c6..a6ad4378f7a1 100644 --- a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs @@ -22,8 +22,8 @@ public override Task Equal_with_null_reference_type() public override Task Equal_with_null_captured() => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); - public override Task NotEqual_with_null_referenceType() - => Assert.ThrowsAsync(() => base.NotEqual_with_null_referenceType()); + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); public override Task NotEqual_with_null_captured() => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs index 2d0bc17f179a..d0017e3a510c 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs @@ -2,10 +2,11 @@ using VectorDataSpecificationTests.Filter; using Xunit; +using Xunit.Sdk; namespace RedisIntegrationTests.Filter; -public class RedisBasicFilterTests(RedisFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +public abstract class RedisBasicFilterTests(FilterFixtureBase fixture) : BasicFilterTestsBase(fixture) { #region Equality with null @@ -15,7 +16,7 @@ public override Task Equal_with_null_reference_type() public override Task Equal_with_null_captured() => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); - public override Task NotEqual_with_null_referenceType() + public override Task NotEqual_with_null_reference_type() => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); public override Task NotEqual_with_null_captured() @@ -49,3 +50,36 @@ public override Task Contains_over_captured_string_array() #endregion } + +public class RedisJsonCollectionBasicFilterTests(RedisJsonCollectionFilterFixture fixture) : RedisBasicFilterTests(fixture), IClassFixture; + +public class RedisHashSetCollectionBasicFilterTests(RedisHashSetCollectionFilterFixture fixture) : RedisBasicFilterTests(fixture), IClassFixture +{ + // Null values are not supported in Redis HashSet + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + // Array fields not supported on Redis HashSet + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_List() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_List()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs index e450381f91e4..de751f36ca4e 100644 --- a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs @@ -1,20 +1,65 @@ // Copyright (c) Microsoft. All rights reserved. using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Redis; using RedisIntegrationTests.Support; using VectorDataSpecificationTests.Filter; using VectorDataSpecificationTests.Support; namespace RedisIntegrationTests.Filter; -public class RedisFilterFixture : FilterFixtureBase +public class RedisJsonCollectionFilterFixture : FilterFixtureBase { protected override TestStore TestStore => RedisTestStore.Instance; - // Override to remove the bool property, which isn't (currently) supported on Redis + protected override string StoreName => "JsonCollectionFilterTests"; + + // Override to remove the bool property, which isn't (currently) supported on Redis/JSON protected override VectorStoreRecordDefinition GetRecordDefinition() => new() { Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(bool)).ToList() }; + + protected override IVectorStoreRecordCollection> CreateCollection() + => new RedisJsonVectorStoreRecordCollection>( + RedisTestStore.Instance.Database, + this.StoreName, + new() { VectorStoreRecordDefinition = this.GetRecordDefinition() }); +} + +public class RedisHashSetCollectionFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => RedisTestStore.Instance; + + protected override string StoreName => "HashSetCollectionFilterTests"; + + // Override to remove the bool property, which isn't (currently) supported on Redis + protected override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => + p.PropertyType != typeof(bool) && + p.PropertyType != typeof(string[]) && + p.PropertyType != typeof(List)).ToList() + }; + + protected override IVectorStoreRecordCollection> CreateCollection() + => new RedisHashSetVectorStoreRecordCollection>( + RedisTestStore.Instance.Database, + this.StoreName, + new() { VectorStoreRecordDefinition = this.GetRecordDefinition() }); + + protected override List> BuildTestData() + { + var testData = base.BuildTestData(); + + foreach (var record in testData) + { + // Null values are not supported in Redis hashsets + record.String ??= string.Empty; + } + + return testData; + } } diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs index b637035eea53..f2022a2e7c60 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs @@ -57,7 +57,7 @@ public virtual Task NotEqual_reversed() => this.TestFilterAsync(r => r.Int != 8); [ConditionalFact] - public virtual Task NotEqual_with_null_referenceType() + public virtual Task NotEqual_with_null_reference_type() => this.TestFilterAsync(r => r.String != null); [ConditionalFact] diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs index 68274beeedd1..436d1453d552 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs @@ -20,11 +20,14 @@ public abstract class FilterFixtureBase : IAsyncLifetime protected virtual string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineSimilarity; protected virtual string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Flat; + protected virtual IVectorStoreRecordCollection> CreateCollection() + => this.TestStore.DefaultVectorStore.GetCollection>(this.StoreName, this.GetRecordDefinition()); + public virtual async Task InitializeAsync() { await this.TestStore.ReferenceCountingStartAsync(); - this.Collection = this.TestStore.DefaultVectorStore.GetCollection>(this.StoreName, this.GetRecordDefinition()); + this.Collection = this.CreateCollection(); if (await this.Collection.CollectionExistsAsync()) { diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs index 941053dd98a3..2880d1b93859 100644 --- a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs @@ -23,8 +23,8 @@ public override Task Equal_with_null_captured() public override Task NotEqual_with_null_captured() => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); - public override Task NotEqual_with_null_referenceType() - => Assert.ThrowsAsync(() => base.NotEqual_with_null_referenceType()); + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); #endregion