diff --git a/extensions/AzureAISearch/AzureAISearch.FunctionalTests/DefaultTests.cs b/extensions/AzureAISearch/AzureAISearch.FunctionalTests/DefaultTests.cs index c02bb50b3..bfc212328 100644 --- a/extensions/AzureAISearch/AzureAISearch.FunctionalTests/DefaultTests.cs +++ b/extensions/AzureAISearch/AzureAISearch.FunctionalTests/DefaultTests.cs @@ -15,14 +15,41 @@ public DefaultTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, ou { Assert.False(string.IsNullOrEmpty(this.AzureAiSearchConfig.Endpoint)); Assert.False(string.IsNullOrEmpty(this.AzureAiSearchConfig.APIKey)); - Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); - - this._memory = new KernelMemoryBuilder() - .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - .WithAzureAISearchMemoryDb(this.AzureAiSearchConfig.Endpoint, this.AzureAiSearchConfig.APIKey) - .Build(); + + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithAzureAISearchMemoryDb(this.AzureAiSearchConfig.Endpoint, this.AzureAiSearchConfig.APIKey) + .Build(); + } + else + { + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithAzureAISearchMemoryDb(this.AzureAiSearchConfig.Endpoint, this.AzureAiSearchConfig.APIKey) + .Build(); + } } [Fact] diff --git a/extensions/AzureAISearch/AzureAISearch.UnitTests/AzureAISearchFilteringTest.cs b/extensions/AzureAISearch/AzureAISearch.UnitTests/AzureAISearchFilteringTest.cs index 4e57b4e8e..4fd948972 100644 --- a/extensions/AzureAISearch/AzureAISearch.UnitTests/AzureAISearchFilteringTest.cs +++ b/extensions/AzureAISearch/AzureAISearch.UnitTests/AzureAISearchFilteringTest.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using Microsoft.KernelMemory; using Microsoft.KernelMemory.MemoryDb.AzureAISearch; @@ -225,7 +225,13 @@ public void ItHandlesEdgeCase3() // Assert Console.WriteLine($"Result: {result}"); - Assert.Equal("(tags/any(s: s eq 'color:blue') and tags/any(s: s eq 'color:blue'))", result); + + // Note: Before introducing Not filter the test expected the result + // (tags/any(s: s eq 'color:blue') and tags/any(s: s eq 'color:blue')) + // in my opinion it is better to have a more coincise result because the + // previous result contains two identical conditions with an and it is + // better to have a single condition. + Assert.Equal("(tags/any(s: s eq 'color:blue'))", result); } [Fact] diff --git a/extensions/AzureAISearch/AzureAISearch/Internals/AzureAISearchFiltering.cs b/extensions/AzureAISearch/AzureAISearch/Internals/AzureAISearchFiltering.cs index af2e1a679..bfcc78d0b 100644 --- a/extensions/AzureAISearch/AzureAISearch/Internals/AzureAISearchFiltering.cs +++ b/extensions/AzureAISearch/AzureAISearch/Internals/AzureAISearchFiltering.cs @@ -32,13 +32,14 @@ internal static string BuildSearchFilter(IEnumerable filters) var filtersForSearchInQuery = filterList // Filters with only one key, but not multiple values (i.e: excluding MemoryFilters.ByTag("department", "HR").ByTag("department", "Marketing") as here we want an `AND`) .Where(filter => !filter.IsEmpty() && filter.Keys.Count == 1 && filter.Values.First().Count == 1) - .SelectMany(filter => filter.Pairs) // Flattening to pairs + .SelectMany(filter => filter.GetFilters()) // Flattening to pairs .GroupBy(pair => pair.Key) // Grouping by the tag key .Where(g => g.Count() > 1) .Select(group => new { Key = group.Key, - Values = group.Select(pair => $"{pair.Key}:{pair.Value?.Replace("'", "''", StringComparison.Ordinal)}").ToList(), + EqualValues = group.OfType().Select(baseFilter => $"{baseFilter.Key}:{baseFilter.Value?.Replace("'", "''", StringComparison.Ordinal)}").ToList(), + NotEqualValues = group.OfType().Select(baseFilter => $"{baseFilter.Key}:{baseFilter.Value?.Replace("'", "''", StringComparison.Ordinal)}").ToList(), SearchInDelimiter = s_searchInDelimitersAvailable.FirstOrDefault(specialChar => !group.Any(pair => (pair.Value != null && pair.Value.Contains(specialChar, StringComparison.Ordinal)) || @@ -54,7 +55,15 @@ internal static string BuildSearchFilter(IEnumerable filters) // The default value of this parameter is ' ,' which means that any values with spaces and/or commas between them will be separated. // If you need to use separators other than spaces and commas because your values include those characters, // you can specify alternate delimiters such as '|' in this parameter. - conditions.Add($"tags/any(s: search.in(s, '{string.Join(filterGroup.SearchInDelimiter, filterGroup.Values)}', '{filterGroup.SearchInDelimiter}'))"); + if (filterGroup.EqualValues.Count != 0) + { + conditions.Add($"tags/any(s: search.in(s, '{string.Join(filterGroup.SearchInDelimiter, filterGroup.EqualValues)}', '{filterGroup.SearchInDelimiter}'))"); + } + + if (filterGroup.NotEqualValues.Count != 0) + { + conditions.Add($"not tags/any(s: search.in(s, '{string.Join(filterGroup.SearchInDelimiter, filterGroup.NotEqualValues)}', '{filterGroup.SearchInDelimiter}'))"); + } } //Exclude filters that were grouped before in the search.in process @@ -65,13 +74,25 @@ internal static string BuildSearchFilter(IEnumerable filters) // Note: empty filters would lead to a syntax error, so even if they are supposed // to be removed upstream, we check again and remove them here too. - foreach (var filter in remainingFilters.Where(f => !f.IsEmpty())) + foreach (var filter in remainingFilters) { var filterConditions = filter.GetFilters() - .Select(keyValue => + .Select(baseFilter => { - var fieldValue = keyValue.Value?.Replace("'", "''", StringComparison.Ordinal); - return $"tags/any(s: s eq '{keyValue.Key}{Constants.ReservedEqualsChar}{fieldValue}')"; + if (baseFilter is EqualFilter eq) + { + var fieldValue = eq.Value?.Replace("'", "''", StringComparison.Ordinal); + return $"tags/any(s: s eq '{baseFilter.Key}{Constants.ReservedEqualsChar}{fieldValue}')"; + } + else if (baseFilter is NotEqualFilter neq) + { + var fieldValue = neq.Value?.Replace("'", "''", StringComparison.Ordinal); + return $"not tags/any(s: s eq '{baseFilter.Key}{Constants.ReservedEqualsChar}{fieldValue}')"; + } + else + { + throw new AzureAISearchMemoryException($"Filter type {baseFilter.GetType().Name} is not supported."); + } }) .ToList(); diff --git a/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/IndexNameTests.cs b/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/IndexNameTests.cs index ff4e97e49..8f00471db 100644 --- a/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/IndexNameTests.cs +++ b/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/IndexNameTests.cs @@ -73,7 +73,7 @@ public void BadIndexNamesAreRejected(string indexName, int errorCount) $"" + $"The expected number of errors was {errorCount}."); - Assert.True(errorCount == exception.Errors.Count(), $"The number of errprs expected is different than the number of errors found."); + Assert.True(errorCount == exception.Errors.Count(), $"The number of errors expected is different than the number of errors found."); } [Fact] diff --git a/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/KernelMemoryTests.cs b/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/KernelMemoryTests.cs index 20b9bee1f..5c4735637 100644 --- a/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/KernelMemoryTests.cs +++ b/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/KernelMemoryTests.cs @@ -15,12 +15,40 @@ public class KernelMemoryTests : MemoryDbFunctionalTest public KernelMemoryTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - this.KernelMemory = new KernelMemoryBuilder() - .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - .WithElasticsearchMemoryDb(this.ElasticsearchConfig) - .Build(); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this.KernelMemory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithElasticsearchMemoryDb(this.ElasticsearchConfig) + .Build(); + } + else + { + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this.KernelMemory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithElasticsearchMemoryDb(this.ElasticsearchConfig) + .Build(); + } } public IKernelMemory KernelMemory { get; } diff --git a/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/MemoryDbFunctionalTest.cs b/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/MemoryDbFunctionalTest.cs index ceede208e..095845c89 100644 --- a/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/MemoryDbFunctionalTest.cs +++ b/extensions/Elasticsearch/Elasticsearch.FunctionalTests/Additional/MemoryDbFunctionalTest.cs @@ -2,6 +2,7 @@ using Elastic.Clients.Elasticsearch; using Microsoft.KernelMemory.AI; +using Microsoft.KernelMemory.AI.AzureOpenAI; using Microsoft.KernelMemory.AI.OpenAI; using Microsoft.KernelMemory.MemoryDb.Elasticsearch; using Microsoft.KernelMemory.MemoryDb.Elasticsearch.Internals; @@ -24,12 +25,21 @@ protected MemoryDbFunctionalTest(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { this.Output = output ?? throw new ArgumentNullException(nameof(output)); - #pragma warning disable KMEXP01 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - this.TextEmbeddingGenerator = new OpenAITextEmbeddingGenerator( - config: base.OpenAiConfig, - textTokenizer: default, - loggerFactory: default); + if (cfg.GetValue("UseAzureOpenAI")) + { + this.TextEmbeddingGenerator = new AzureOpenAITextEmbeddingGenerator( + config: base.AzureOpenAIEmbeddingConfiguration, + textTokenizer: default, + loggerFactory: default); + } + else + { + this.TextEmbeddingGenerator = new OpenAITextEmbeddingGenerator( + config: base.OpenAiConfig, + textTokenizer: default, + loggerFactory: default); + } #pragma warning restore KMEXP01 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. this.Client = new ElasticsearchClient(base.ElasticsearchConfig.ToElasticsearchClientSettings()); diff --git a/extensions/Elasticsearch/Elasticsearch.FunctionalTests/DefaultTests.cs b/extensions/Elasticsearch/Elasticsearch.FunctionalTests/DefaultTests.cs index 3bb9fd093..8199d0b63 100644 --- a/extensions/Elasticsearch/Elasticsearch.FunctionalTests/DefaultTests.cs +++ b/extensions/Elasticsearch/Elasticsearch.FunctionalTests/DefaultTests.cs @@ -15,18 +15,42 @@ public class DefaultTests : BaseFunctionalTestCase public DefaultTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); - this._esConfig = cfg.GetSection("KernelMemory:Services:Elasticsearch").Get()!; - this._memory = new KernelMemoryBuilder() - .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - // .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) - // .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) - .WithElasticsearchMemoryDb(this._esConfig) - .Build(); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithElasticsearchMemoryDb(this._esConfig) + .Build(); + } + else + { + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithElasticsearchMemoryDb(this._esConfig) + .Build(); + } } [Fact] diff --git a/extensions/Elasticsearch/Elasticsearch/ElasticsearchMemory.cs b/extensions/Elasticsearch/Elasticsearch/ElasticsearchMemory.cs index c1056949f..b8804a118 100644 --- a/extensions/Elasticsearch/Elasticsearch/ElasticsearchMemory.cs +++ b/extensions/Elasticsearch/Elasticsearch/ElasticsearchMemory.cs @@ -306,16 +306,8 @@ private QueryDescriptor ConvertTagFilters( QueryDescriptor qd, ICollection? filters = null) { - if ((filters == null) || (filters.Count == 0)) - { - qd.MatchAll(); - return qd; - } - - filters = filters.Where(f => f.Keys.Count > 0) - .ToList(); // Remove empty filters - - if (filters.Count == 0) + var hasOneNotEmptyFilter = filters != null && filters.Any(f => !f.IsEmpty()); + if (!hasOneNotEmptyFilter) { qd.MatchAll(); return qd; @@ -323,15 +315,15 @@ private QueryDescriptor ConvertTagFilters( List super = new(); - foreach (MemoryFilter filter in filters) + foreach (MemoryFilter filter in filters!) { List thisMust = new(); // Each filter is a list of key/value pairs. - foreach (var pair in filter.Pairs) + foreach (var baseFilter in filter.GetFilters()) { - Query newTagQuery = new TermQuery(ElasticsearchMemoryRecord.TagsName) { Value = pair.Key }; - Query termQuery = new TermQuery(ElasticsearchMemoryRecord.TagsValue) { Value = pair.Value ?? string.Empty }; + Query newTagQuery = new TermQuery(ElasticsearchMemoryRecord.TagsName) { Value = baseFilter.Key }; + Query termQuery = new TermQuery(ElasticsearchMemoryRecord.TagsValue) { Value = baseFilter.Value ?? string.Empty }; newTagQuery &= termQuery; @@ -339,7 +331,20 @@ private QueryDescriptor ConvertTagFilters( nestedQd.Path = ElasticsearchMemoryRecord.TagsField; nestedQd.Query = newTagQuery; - thisMust.Add(nestedQd); + if (baseFilter is EqualFilter eq) + { + thisMust.Add(nestedQd); + } + else if (baseFilter is NotEqualFilter neq) + { + var notQuery = new BoolQuery(); + notQuery.MustNot = [nestedQd]; + thisMust.Add(notQuery); + } + else + { + throw new ElasticsearchException($"Filter type {baseFilter.GetType().Name} is not supported."); + } } var filterQuery = new BoolQuery(); diff --git a/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/DefaultTests.cs b/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/DefaultTests.cs index d805cf08d..e3a72c86a 100644 --- a/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/DefaultTests.cs +++ b/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/DefaultTests.cs @@ -29,8 +29,6 @@ public abstract class DefaultTests : BaseFunctionalTestCase protected DefaultTests(IConfiguration cfg, ITestOutputHelper output, bool multiCollection) : base(cfg, output) { - Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey), "OpenAI API Key is empty"); - if (multiCollection) { // this._config = this.MongoDbAtlasConfig; @@ -62,13 +60,40 @@ protected DefaultTests(IConfiguration cfg, ITestOutputHelper output, bool multiC ash.DropDatabaseAsync().Wait(); } - this._memory = new KernelMemoryBuilder() - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - // .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) - // .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) - .WithMongoDbAtlasMemoryDb(this.MongoDbAtlasConfig) - .Build(); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithMongoDbAtlasMemoryDb(this.MongoDbAtlasConfig) + .Build(); + } + else + { + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithMongoDbAtlasMemoryDb(this.MongoDbAtlasConfig) + .Build(); + } } [Fact] diff --git a/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/MongoDbAtlas.FunctionalTests.csproj b/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/MongoDbAtlas.FunctionalTests.csproj index 2e450d8e9..bfdd6f1fe 100644 --- a/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/MongoDbAtlas.FunctionalTests.csproj +++ b/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/MongoDbAtlas.FunctionalTests.csproj @@ -1,4 +1,4 @@ - + Microsoft.MongoDbAtlas.FunctionalTests diff --git a/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/StorageTests.cs b/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/StorageTests.cs index d01a4251c..19d23a85d 100644 --- a/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/StorageTests.cs +++ b/extensions/MongoDbAtlas/MongoDbAtlas.FunctionalTests/StorageTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text; +using Microsoft.KernelMemory; using Microsoft.KernelMemory.DocumentStorage; using Microsoft.KernelMemory.MongoDbAtlas; using Microsoft.KM.TestHelpers; @@ -60,6 +61,13 @@ protected override void Dispose(bool disposing) } } + private async Task ReadFileContentAsync(StreamableFileContent streamableFileContent) + { + using var stream = await streamableFileContent.GetStreamAsync(); + using var reader = new StreamReader(stream); + return await reader.ReadToEndAsync(); + } + [Fact] [Trait("Category", "MongoDbAtlas")] public async Task SaveFilesHonorsId() @@ -73,7 +81,7 @@ public async Task SaveFilesHonorsId() // Assert var file = await this._sut.ReadFileAsync(this.IndexName, id, "filename.txt"); - var content = file.ToString(); + var content = await this.ReadFileContentAsync(file); Assert.Equal("Hello World 2", content); } @@ -96,12 +104,12 @@ public async Task SaveDifferentFiles(string extension, string content1, string c // Assert var file = await this._sut.ReadFileAsync(this.IndexName, id, fileName1); - var content = file.ToString(); + var content = await this.ReadFileContentAsync(file); Assert.Equal(content1, content); file = await this._sut.ReadFileAsync(this.IndexName, id, fileName2); - content = file.ToString(); + content = await this.ReadFileContentAsync(file); Assert.Equal(content2, content); } @@ -119,7 +127,7 @@ public async Task SaveFilesHonorsIdWithBinaryContent() // Assert var file = await this._sut.ReadFileAsync(this.IndexName, id, fileName); - var content = file.ToString(); + var content = await this.ReadFileContentAsync(file); Assert.Equal("Hello World 2", content); } @@ -139,12 +147,12 @@ public async Task SaveDifferentFilesWithBinaryContent() // Assert var file = await this._sut.ReadFileAsync(this.IndexName, id, fileName1); - var content = file.ToString(); + var content = await this.ReadFileContentAsync(file); Assert.Equal("Hello World", content); file = await this._sut.ReadFileAsync(this.IndexName, id, fileName2); - content = file.ToString(); + content = await this.ReadFileContentAsync(file); Assert.Equal("Hello World 2", content); } diff --git a/extensions/MongoDbAtlas/MongoDbAtlas/MongoDbAtlasMemory.cs b/extensions/MongoDbAtlas/MongoDbAtlas/MongoDbAtlasMemory.cs index 1183188d6..5aadb4965 100644 --- a/extensions/MongoDbAtlas/MongoDbAtlas/MongoDbAtlasMemory.cs +++ b/extensions/MongoDbAtlas/MongoDbAtlas/MongoDbAtlasMemory.cs @@ -222,11 +222,26 @@ public async Task UpsertAsync(string index, MemoryRecord record, Cancell List> filtersArray = new(); foreach (var singleFilter in thisFilter) { - var condition = Builders.Filter.And( - Builders.Filter.Eq("Tags.Key", singleFilter.Key), - Builders.Filter.Eq("Tags.Values", singleFilter.Value) - ); - filtersArray.Add(condition); + if (singleFilter is EqualFilter ef) + { + var condition = Builders.Filter.And( + Builders.Filter.Eq("Tags.Key", singleFilter.Key), + Builders.Filter.Eq("Tags.Values", singleFilter.Value) + ); + filtersArray.Add(condition); + } + else if (singleFilter is NotEqualFilter nef) + { + var condition = Builders.Filter.And( + Builders.Filter.Eq("Tags.Key", singleFilter.Key), + Builders.Filter.Ne("Tags.Values", singleFilter.Value) + ); + filtersArray.Add(condition); + } + else + { + throw new MongoDbAtlasException($"Filter {singleFilter.GetType().Name} is not supported"); + } } // if we have more than one condition, we need to compose all conditions with AND diff --git a/extensions/Postgres/Postgres.FunctionalTests/AdditionalFilteringTests.cs b/extensions/Postgres/Postgres.FunctionalTests/AdditionalFilteringTests.cs index 028802057..de464f6f5 100644 --- a/extensions/Postgres/Postgres.FunctionalTests/AdditionalFilteringTests.cs +++ b/extensions/Postgres/Postgres.FunctionalTests/AdditionalFilteringTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using Microsoft.KernelMemory; using Microsoft.KM.TestHelpers; @@ -11,13 +11,38 @@ public class AdditionalFilteringTests : BaseFunctionalTestCase public AdditionalFilteringTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - this._memory = new KernelMemoryBuilder() - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - //.WithOpenAI(this.OpenAiConfig) - .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) - .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) - .WithPostgresMemoryDb(this.PostgresConfig) - .Build(); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithPostgresMemoryDb(this.PostgresConfig) + .Build(); + } + else + { + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithPostgresMemoryDb(this.PostgresConfig) + .Build(); + } } [Fact] diff --git a/extensions/Postgres/Postgres.FunctionalTests/DefaultTests.cs b/extensions/Postgres/Postgres.FunctionalTests/DefaultTests.cs index 798229550..31196e5c3 100644 --- a/extensions/Postgres/Postgres.FunctionalTests/DefaultTests.cs +++ b/extensions/Postgres/Postgres.FunctionalTests/DefaultTests.cs @@ -12,16 +12,40 @@ public class DefaultTests : BaseFunctionalTestCase public DefaultTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); - - this._memory = new KernelMemoryBuilder() - .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - // .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) - // .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) - .WithPostgresMemoryDb(this.PostgresConfig) - .Build(); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithPostgresMemoryDb(this.PostgresConfig) + .Build(); + } + else + { + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithPostgresMemoryDb(this.PostgresConfig) + .Build(); + } } [Fact] diff --git a/extensions/Postgres/Postgres.FunctionalTests/appsettings.json b/extensions/Postgres/Postgres.FunctionalTests/appsettings.json index 28bca58af..debba6947 100644 --- a/extensions/Postgres/Postgres.FunctionalTests/appsettings.json +++ b/extensions/Postgres/Postgres.FunctionalTests/appsettings.json @@ -4,6 +4,7 @@ "Default": "Information" } }, + "UseAzureOpenAI": true, "KernelMemory": { "Services": { "Postgres": { diff --git a/extensions/Postgres/Postgres/PostgresMemory.cs b/extensions/Postgres/Postgres/PostgresMemory.cs index a3353788e..a9c765650 100644 --- a/extensions/Postgres/Postgres/PostgresMemory.cs +++ b/extensions/Postgres/Postgres/PostgresMemory.cs @@ -5,7 +5,6 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; -using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; @@ -216,9 +215,6 @@ public void Dispose() GC.SuppressFinalize(this); } - /// - /// Disposes the managed resources. - /// private void Dispose(bool disposing) { if (disposing) @@ -259,7 +255,10 @@ private static string NormalizeTableNamePrefix(string? name) var sql = ""; Dictionary unsafeSqlUserValues = new(); - if (filters is not { Count: > 0 }) + var nonEmptyFilters = filters?.Where(filters => !filters.IsEmpty()).ToArray() ?? Array.Empty(); + + //No query, we do not have filters. + if (nonEmptyFilters.Length == 0) { return (sql, unsafeSqlUserValues); } @@ -267,24 +266,28 @@ private static string NormalizeTableNamePrefix(string? name) var tagCounter = 0; var orConditions = new List(); - foreach (MemoryFilter filter in filters.Where(f => !f.IsEmpty())) + foreach (MemoryFilter filter in nonEmptyFilters) { - var andSql = new StringBuilder(); - andSql.AppendLine("("); - if (filter is PostgresMemoryFilter extendedFilter) { // use PostgresMemoryFilter filtering logic throw new NotImplementedException("PostgresMemoryFilter is not supported yet"); } - List requiredTags = filter.GetFilters().Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}").ToList(); + var allFilters = filter.GetFilters().ToArray(); + + List equalTags = allFilters + .OfType() + .Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}").ToList(); List safeSqlPlaceholders = new(); - if (requiredTags.Count > 0) + + List conditions = new(); + + if (equalTags.Count > 0) { var safeSqlPlaceholder = $"@placeholder{tagCounter++}"; safeSqlPlaceholders.Add(safeSqlPlaceholder); - unsafeSqlUserValues[safeSqlPlaceholder] = requiredTags; + unsafeSqlUserValues[safeSqlPlaceholder] = equalTags; // All tags are required // tags @> ARRAY['user:001', 'type:news', '__document_id:b405']::text[] <== all tags are required <=== we are using this @@ -293,11 +296,24 @@ private static string NormalizeTableNamePrefix(string? name) // $"{PostgresSchema.PlaceholdersTags} @> " + safeSqlPlaceholder // $"{PostgresSchema.PlaceholdersTags} @> " + safeSqlPlaceholder + "::text[]" // $"{PostgresSchema.PlaceholdersTags} @> ARRAY[" + safeSqlPlaceholder + "]::text[]" - andSql.AppendLine($"{PostgresSchema.PlaceholdersTags} @> " + safeSqlPlaceholder); + conditions.Add($"{PostgresSchema.PlaceholdersTags} @> " + safeSqlPlaceholder); + } + + List notEqualTags = allFilters + .OfType() + .Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}").ToList(); + if (notEqualTags.Count > 0) + { + var safeSqlPlaceholder = $"@placeholder{tagCounter++}"; + safeSqlPlaceholders.Add(safeSqlPlaceholder); + unsafeSqlUserValues[safeSqlPlaceholder] = notEqualTags; + + // Tags should not be present now we need to combine this query with the previous one using AND if we + // had equality tag. + conditions.Add($"NOT ({PostgresSchema.PlaceholdersTags} && " + safeSqlPlaceholder + ")"); } - andSql.AppendLine(")"); - orConditions.Add(andSql.ToString()); + orConditions.Add($"( {string.Join(" AND ", conditions)} )"); } sql = string.Join(" OR ", orConditions); diff --git a/extensions/Qdrant/Qdrant.FunctionalTests/DefaultTests.cs b/extensions/Qdrant/Qdrant.FunctionalTests/DefaultTests.cs index f16ed4345..dee9adad6 100644 --- a/extensions/Qdrant/Qdrant.FunctionalTests/DefaultTests.cs +++ b/extensions/Qdrant/Qdrant.FunctionalTests/DefaultTests.cs @@ -13,14 +13,42 @@ public class DefaultTests : BaseFunctionalTestCase public DefaultTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); - - this._memory = new KernelMemoryBuilder() - .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - .WithQdrantMemoryDb(this.QdrantConfig) - .Build(); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithQdrantMemoryDb(this.QdrantConfig) + .Build(); + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + else + { + //use standard openai + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithQdrantMemoryDb(this.QdrantConfig) + .Build(); + } } [Fact] diff --git a/extensions/Qdrant/Qdrant.FunctionalTests/Qdrant.FunctionalTests.csproj b/extensions/Qdrant/Qdrant.FunctionalTests/Qdrant.FunctionalTests.csproj index d2c86c0ed..4cc687dbb 100644 --- a/extensions/Qdrant/Qdrant.FunctionalTests/Qdrant.FunctionalTests.csproj +++ b/extensions/Qdrant/Qdrant.FunctionalTests/Qdrant.FunctionalTests.csproj @@ -1,4 +1,4 @@ - + Microsoft.Qdrant.FunctionalTests diff --git a/extensions/Qdrant/Qdrant.UnitTests/ScrollVectorsRequestTest.cs b/extensions/Qdrant/Qdrant.UnitTests/ScrollVectorsRequestTest.cs index 3550e6027..04c42f4da 100644 --- a/extensions/Qdrant/Qdrant.UnitTests/ScrollVectorsRequestTest.cs +++ b/extensions/Qdrant/Qdrant.UnitTests/ScrollVectorsRequestTest.cs @@ -2,6 +2,7 @@ using System.Text.Json; using Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; using Microsoft.KM.TestHelpers; using Xunit.Abstractions; @@ -92,11 +93,11 @@ public void FiltersAreRenderedToJson() // Arrange var request = ScrollVectorsRequest .Create("coll") - .HavingAllTags(["user:devis", "type:blog"]) + .HavingAllTags([new TagFilter("user:devis", TagFilterType.Equal), new TagFilter("type:blog", TagFilterType.Equal)]) .HavingSomeTags(new[] { - new[] { "month:january", "year:2000" }, - new[] { "month:july", "year:2003" }, + new[] { new TagFilter("month:january", TagFilterType.Equal), new TagFilter("year:2000", TagFilterType.Equal), }, + new[] { new TagFilter("month:july", TagFilterType.Equal), new TagFilter("year:2003", TagFilterType.Equal), }, }); // Act @@ -157,8 +158,8 @@ public void ItRendersOptimizedConditions() // Arrange var request = ScrollVectorsRequest .Create("coll") - .HavingAllTags(["user:devis", "type:blog"]) - .HavingSomeTags([new[] { "month:january", "year:2000" }]); + .HavingAllTags([new TagFilter("user:devis", TagFilterType.Equal), new TagFilter("type:blog", TagFilterType.Equal)]) + .HavingSomeTags([new[] { new TagFilter("month:january", TagFilterType.Equal), new TagFilter("year:2000", TagFilterType.Equal) }]); // Act var actual = JsonSerializer.Serialize(request); diff --git a/extensions/Qdrant/Qdrant/Internals/Http/Filter.cs b/extensions/Qdrant/Qdrant/Internals/Http/Filter.cs index d7df0ab82..fec4a5742 100644 --- a/extensions/Qdrant/Qdrant/Internals/Http/Filter.cs +++ b/extensions/Qdrant/Qdrant/Internals/Http/Filter.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Text.Json.Serialization; namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; @@ -96,6 +97,23 @@ internal void Validate() } } + internal sealed class MustNotClause + { + [JsonPropertyName("must_not")] + public List Clauses { get; set; } + + public MustNotClause(string key, object value) + { + this.Clauses = new(); + this.Clauses.Add(new MatchValueClause(key, value)); + } + + internal void Validate() + { + this.Clauses.Single().Validate(); + } + } + internal sealed class MatchValueClause { [JsonPropertyName("key")] diff --git a/extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs b/extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs index 1ba5be2b6..13c57d22b 100644 --- a/extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs +++ b/extensions/Qdrant/Qdrant/Internals/Http/ScrollVectorsRequest.cs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Text.Json.Serialization; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; +using static Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http.Filter; namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; @@ -38,42 +41,64 @@ public ScrollVectorsRequest HavingExternalId(string id) return this; } - public ScrollVectorsRequest HavingAllTags(IEnumerable? tags) + public ScrollVectorsRequest HavingAllTags(IEnumerable? tagFilters) { - if (tags == null) { return this; } + if (tagFilters == null) { return this; } - foreach (var tag in tags) + foreach (var tagFilter in tagFilters) { - if (!string.IsNullOrEmpty(tag)) + if (!string.IsNullOrEmpty(tagFilter.Tag)) { - this.Filters.AndValue(QdrantConstants.PayloadTagsField, tag); + if (tagFilter.FilterType == TagFilterType.NotEqual) + { + this.Filters.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag)); + } + else if (tagFilter.FilterType == TagFilterType.Equal) + { + this.Filters.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag); + } + else + { + throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory"); + } } } return this; } - public ScrollVectorsRequest HavingSomeTags(IEnumerable?>? tagGroups) + public ScrollVectorsRequest HavingSomeTags(IEnumerable>? tagFiltersGroups) { - if (tagGroups == null) { return this; } + if (tagFiltersGroups == null) { return this; } - var list = tagGroups.ToList(); + var list = tagFiltersGroups.ToList(); if (list.Count < 2) { return this.HavingAllTags(list.FirstOrDefault()); } var orFilter = new Filter.OrClause(); - foreach (var tags in list) + foreach (var tagFilters in list) { - if (tags == null) { continue; } + if (tagFilters == null) { continue; } var andFilter = new Filter.AndClause(); - foreach (var tag in tags) + foreach (var tagFilter in tagFilters) { - if (!string.IsNullOrEmpty(tag)) + if (!string.IsNullOrEmpty(tagFilter.Tag)) { - andFilter.AndValue(QdrantConstants.PayloadTagsField, tag); + if (tagFilter.FilterType == TagFilterType.NotEqual) + { + andFilter.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag)); + } + else if (tagFilter.FilterType == TagFilterType.Equal) + { + andFilter.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag); + } + else + { + throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory"); + } } } diff --git a/extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs b/extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs index f723f2e1c..3a2941be3 100644 --- a/extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs +++ b/extensions/Qdrant/Qdrant/Internals/Http/SearchVectorsRequest.cs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Text.Json.Serialization; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; +using static Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http.Filter; namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; @@ -56,42 +59,61 @@ public SearchVectorsRequest HavingExternalId(string externalId) return this; } - public SearchVectorsRequest HavingAllTags(IEnumerable? tags) + public SearchVectorsRequest HavingAllTags(IEnumerable? tagFilters) { - if (tags == null) { return this; } + if (tagFilters == null) { return this; } - foreach (var tag in tags) + foreach (var tagFilter in tagFilters) { - if (!string.IsNullOrEmpty(tag)) + if (!string.IsNullOrEmpty(tagFilter.Tag)) { - this.Filters.AndValue(QdrantConstants.PayloadTagsField, tag); + if (tagFilter.FilterType == TagFilterType.NotEqual) + { + this.Filters.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag)); + } + else if (tagFilter.FilterType == TagFilterType.Equal) + { + this.Filters.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag); + } + else + { + throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory"); + } } } return this; } - public SearchVectorsRequest HavingSomeTags(IEnumerable?>? tagGroups) + public SearchVectorsRequest HavingSomeTags(List>? tagFiltersGroup) { - if (tagGroups == null) { return this; } + if (tagFiltersGroup == null) { return this; } - var list = tagGroups.ToList(); + var list = tagFiltersGroup.ToList(); if (list.Count < 2) { return this.HavingAllTags(list.FirstOrDefault()); } var orFilter = new Filter.OrClause(); - foreach (var tags in list) + foreach (var tagFilters in list) { - if (tags == null) { continue; } + if (tagFilters == null) { continue; } var andFilter = new Filter.AndClause(); - foreach (var tag in tags) + foreach (var tagFilter in tagFilters.Where(t => !string.IsNullOrEmpty(t.Tag))) { - if (!string.IsNullOrEmpty(tag)) + if (tagFilter.FilterType == TagFilterType.NotEqual) + { + andFilter.And(new MustNotClause(QdrantConstants.PayloadTagsField, tagFilter.Tag)); + } + else if (tagFilter.FilterType == TagFilterType.Equal) + { + andFilter.AndValue(QdrantConstants.PayloadTagsField, tagFilter.Tag); + } + else { - andFilter.AndValue(QdrantConstants.PayloadTagsField, tag); + throw new NotSupportedException($"Filter type {tagFilter.FilterType} is not supported in QDrant Memory"); } } diff --git a/extensions/Qdrant/Qdrant/Internals/QdrantClient.cs b/extensions/Qdrant/Qdrant/Internals/QdrantClient.cs index 8abf979ad..d797bb590 100644 --- a/extensions/Qdrant/Qdrant/Internals/QdrantClient.cs +++ b/extensions/Qdrant/Qdrant/Internals/QdrantClient.cs @@ -12,6 +12,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; using Microsoft.KernelMemory.MemoryStorage; namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client; @@ -279,7 +280,7 @@ public async Task DeleteVectorsAsync(string collectionName, IList vectorId /// List of vectors public async Task>> GetListAsync( string collectionName, - IEnumerable?>? requiredTags = null, + List>? requiredTags = null, int offset = 0, int limit = 1, bool withVectors = false, @@ -339,7 +340,7 @@ public async Task>> GetListAsync( double scoreThreshold, int limit = 1, bool withVectors = false, - IEnumerable?>? requiredTags = null, + List>? requiredTags = null, CancellationToken cancellationToken = default) { this._log.LogTrace("Searching top {0} nearest vectors", limit); diff --git a/extensions/Qdrant/Qdrant/Internals/TagFilter.cs b/extensions/Qdrant/Qdrant/Internals/TagFilter.cs new file mode 100644 index 000000000..8fe604f1b --- /dev/null +++ b/extensions/Qdrant/Qdrant/Internals/TagFilter.cs @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; + +internal record TagFilter(string Tag, TagFilterType FilterType); diff --git a/extensions/Qdrant/Qdrant/Internals/TagFilterType.cs b/extensions/Qdrant/Qdrant/Internals/TagFilterType.cs new file mode 100644 index 000000000..3150b36de --- /dev/null +++ b/extensions/Qdrant/Qdrant/Internals/TagFilterType.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; + +internal enum TagFilterType +{ + Unknown = 0, + Equal = 1, + NotEqual = 2, +} diff --git a/extensions/Qdrant/Qdrant/QdrantMemory.cs b/extensions/Qdrant/Qdrant/QdrantMemory.cs index 820b20b92..e0540df52 100644 --- a/extensions/Qdrant/Qdrant/QdrantMemory.cs +++ b/extensions/Qdrant/Qdrant/QdrantMemory.cs @@ -12,6 +12,7 @@ using Microsoft.KernelMemory.AI; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.MemoryDb.Qdrant.Client; +using Microsoft.KernelMemory.MemoryDb.Qdrant.Internals; using Microsoft.KernelMemory.MemoryStorage; namespace Microsoft.KernelMemory.MemoryDb.Qdrant; @@ -160,14 +161,7 @@ public async IAsyncEnumerable UpsertBatchAsync(string index, IEnumerable index = NormalizeIndexName(index); if (limit <= 0) { limit = int.MaxValue; } - // Remove empty filters - filters = filters?.Where(f => !f.IsEmpty()).ToList(); - - var requiredTags = new List>(); - if (filters is { Count: > 0 }) - { - requiredTags.AddRange(filters.Select(filter => filter.GetFilters().Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}"))); - } + var requiredTags = CreateRequiredTagsFromMemoryFilters(filters); Embedding textEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false); @@ -178,9 +172,9 @@ public async IAsyncEnumerable UpsertBatchAsync(string index, IEnumerable collectionName: index, target: textEmbedding, scoreThreshold: minRelevance, - requiredTags: requiredTags, limit: limit, withVectors: withEmbeddings, + requiredTags: requiredTags, cancellationToken: cancellationToken).ConfigureAwait(false); } catch (IndexNotFoundException e) @@ -207,14 +201,7 @@ public async IAsyncEnumerable GetListAsync( index = NormalizeIndexName(index); if (limit <= 0) { limit = int.MaxValue; } - // Remove empty filters - filters = filters?.Where(f => !f.IsEmpty()).ToList(); - - var requiredTags = new List>(); - if (filters is { Count: > 0 }) - { - requiredTags.AddRange(filters.Select(filter => filter.GetFilters().Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}"))); - } + var requiredTags = CreateRequiredTagsFromMemoryFilters(filters); List> results; try @@ -282,5 +269,39 @@ private static string NormalizeIndexName(string index) return index.Trim(); } + private static List> CreateRequiredTagsFromMemoryFilters(ICollection? filters) + { + var requiredTags = new List>(); + // Check if we have at least one non-empty filter + var nonEmptyFilters = filters?.Where(filters => !filters.IsEmpty()).ToArray() ?? Array.Empty(); + if (nonEmptyFilters.Length > 0) + { + foreach (var memoryFilter in nonEmptyFilters) + { + var filtersList = memoryFilter.GetFilters(); + List stringFilters = new(); + foreach (var baseFilter in filtersList) + { + if (baseFilter is EqualFilter eq) + { + stringFilters.Add(new TagFilter($"{eq.Key}{Constants.ReservedEqualsChar}{eq.Value}", TagFilterType.Equal)); + } + else if (baseFilter is NotEqualFilter neq) + { + stringFilters.Add(new TagFilter($"{neq.Key}{Constants.ReservedEqualsChar}{neq.Value}", TagFilterType.NotEqual)); + } + else + { + throw new QdrantException($"Filter of type {baseFilter.GetType().Name} is not supported by redis"); + } + } + + requiredTags.Add(stringFilters); + } + } + + return requiredTags; + } + #endregion } diff --git a/extensions/Redis/Redis.FunctionalTests/AdditionalFilteringTests.cs b/extensions/Redis/Redis.FunctionalTests/AdditionalFilteringTests.cs index 4c5396d81..db1c01cbf 100644 --- a/extensions/Redis/Redis.FunctionalTests/AdditionalFilteringTests.cs +++ b/extensions/Redis/Redis.FunctionalTests/AdditionalFilteringTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using Microsoft.KernelMemory; using Microsoft.KernelMemory.DocumentStorage.DevTools; @@ -14,12 +14,40 @@ public class AdditionalFilteringTests : BaseFunctionalTestCase public AdditionalFilteringTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - this._memory = new KernelMemoryBuilder() - .WithRedisMemoryDb(this.RedisConfig) - .WithSimpleFileStorage(new SimpleFileStorageConfig { StorageType = FileSystemTypes.Volatile, Directory = "_files" }) - .WithOpenAITextGeneration(this.OpenAiConfig) - .WithOpenAITextEmbeddingGeneration(this.OpenAiConfig) - .Build(); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .WithRedisMemoryDb(this.RedisConfig) + .WithSimpleFileStorage(new SimpleFileStorageConfig { StorageType = FileSystemTypes.Volatile, Directory = "_files" }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .Build(); + } + else + { + //use standard openai + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .WithRedisMemoryDb(this.RedisConfig) + .WithSimpleFileStorage(new SimpleFileStorageConfig { StorageType = FileSystemTypes.Volatile, Directory = "_files" }) + .WithOpenAITextGeneration(this.OpenAiConfig) + .WithOpenAITextEmbeddingGeneration(this.OpenAiConfig) + .Build(); + } } [Fact] diff --git a/extensions/Redis/Redis.FunctionalTests/DefaultTests.cs b/extensions/Redis/Redis.FunctionalTests/DefaultTests.cs index f6bdc004b..905adc92f 100644 --- a/extensions/Redis/Redis.FunctionalTests/DefaultTests.cs +++ b/extensions/Redis/Redis.FunctionalTests/DefaultTests.cs @@ -13,14 +13,41 @@ public class DefaultTests : BaseFunctionalTestCase public DefaultTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); - - this._memory = new KernelMemoryBuilder() - .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - .WithRedisMemoryDb(this.RedisConfig) - .Build(); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithRedisMemoryDb(this.RedisConfig) + .Build(); + } + else + { + //use standard openai + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithRedisMemoryDb(this.RedisConfig) + .Build(); + } } [Fact] diff --git a/extensions/Redis/Redis.FunctionalTests/RedisSpecificTests.cs b/extensions/Redis/Redis.FunctionalTests/RedisSpecificTests.cs index 09dc3b645..31d0fc783 100644 --- a/extensions/Redis/Redis.FunctionalTests/RedisSpecificTests.cs +++ b/extensions/Redis/Redis.FunctionalTests/RedisSpecificTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using Microsoft.KernelMemory; using Microsoft.KM.TestHelpers; @@ -12,13 +12,39 @@ public class RedisSpecificTests : BaseFunctionalTestCase public RedisSpecificTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } - this._memory = new KernelMemoryBuilder() - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - .WithRedisMemoryDb(this.RedisConfig) - .Build(); + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + this._memory = new KernelMemoryBuilder() + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithRedisMemoryDb(this.RedisConfig) + .Build(); + } + else + { + //use standard openai + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + this._memory = new KernelMemoryBuilder() + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + .WithRedisMemoryDb(this.RedisConfig) + .Build(); + } } [Fact] diff --git a/extensions/Redis/Redis.FunctionalTests/appsettings.json b/extensions/Redis/Redis.FunctionalTests/appsettings.json index 263215611..045a6891a 100644 --- a/extensions/Redis/Redis.FunctionalTests/appsettings.json +++ b/extensions/Redis/Redis.FunctionalTests/appsettings.json @@ -4,6 +4,7 @@ "Default": "Information" } }, + "UseAzureOpenAi" : false, "KernelMemory": { "Services": { "Redis": { diff --git a/extensions/Redis/Redis/RedisMemory.cs b/extensions/Redis/Redis/RedisMemory.cs index 60bd72ccd..2b0d04ba7 100644 --- a/extensions/Redis/Redis/RedisMemory.cs +++ b/extensions/Redis/Redis/RedisMemory.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System.Diagnostics.CodeAnalysis; using System.Globalization; @@ -176,21 +176,30 @@ public async Task UpsertAsync(string index, MemoryRecord record, Cancell }; var sb = new StringBuilder(); - if (filters != null && filters.Any(x => x.Pairs.Any())) + if (filters?.Any(f => !f.IsEmpty()) == true) { sb.Append('('); - foreach (var filter in filters) + //now filtering for each filter that is not empty + foreach (var filter in filters.Where(f => !f.IsEmpty())) { sb.Append('('); - foreach ((string key, string? value) in filter.Pairs) + var validFilters = filter.GetFilters(); + foreach (BaseFilter baseFilter in validFilters) { - if (value is null) + if (baseFilter is EqualFilter eq) { - this._logger.LogError("Attempted to perform null check on tag field. This behavior is not supported by Redis"); - throw new RedisException("Attempted to perform null check on tag field. This behavior is not supported by Redis"); + sb.Append(CultureInfo.InvariantCulture, $"@{eq.Key}:{{{eq.Value}}} "); + } + else if (baseFilter is NotEqualFilter neq) + { + //use the -(tag:value) syntax for not equal to avoid returning ANY record where ONE of the + //tag is the one we are filtering on. + sb.Append(CultureInfo.InvariantCulture, $"-(@{neq.Key}:{{{neq.Value}}}) "); + } + else + { + throw new RedisException($"Filter of type {baseFilter.GetType().Name} is not supported by redis"); } - - sb.Append(CultureInfo.InvariantCulture, $"@{key}:{{{value}}} "); } sb.Replace(" ", ")|", sb.Length - 1, 1); diff --git a/extensions/SQLServer/SQLServer.FunctionalTests/DefaultTests.cs b/extensions/SQLServer/SQLServer.FunctionalTests/DefaultTests.cs index 41b8540bf..fac4a4d8f 100644 --- a/extensions/SQLServer/SQLServer.FunctionalTests/DefaultTests.cs +++ b/extensions/SQLServer/SQLServer.FunctionalTests/DefaultTests.cs @@ -15,23 +15,50 @@ public class DefaultTests : BaseFunctionalTestCase public DefaultTests(IConfiguration cfg, ITestOutputHelper output) : base(cfg, output) { - Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); - - SqlServerConfig sqlServerConfig = cfg.GetSection("KernelMemory:Services:SqlServer").Get()!; - - var builder = new KernelMemoryBuilder(); - - this._memory = builder - .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) - .Configure(kmb => kmb.Services.AddLogging(b => { b.AddConsole().SetMinimumLevel(LogLevel.Trace); })) - .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) - .WithOpenAI(this.OpenAiConfig) - // .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) - // .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) - .WithSqlServerMemoryDb(sqlServerConfig) - .Build(); + IKernelMemoryBuilder builder; + if (cfg.GetValue("UseAzureOpenAI")) + { + //ok in azure we can use managed identities so we need to check the configuration + if (this.AzureOpenAITextConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAITextConfiguration.APIKey)); + } + + if (this.AzureOpenAIEmbeddingConfiguration.Auth == AzureOpenAIConfig.AuthTypes.APIKey) + { + //verify that we really have an api key. + Assert.False(string.IsNullOrEmpty(this.AzureOpenAIEmbeddingConfiguration.APIKey)); + } + + SqlServerConfig sqlServerConfig = cfg.GetSection("KernelMemory:Services:SqlServer").Get()!; + + builder = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .Configure(kmb => kmb.Services.AddLogging(b => { b.AddConsole().SetMinimumLevel(LogLevel.Trace); })) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithSqlServerMemoryDb(sqlServerConfig); + } + else + { + Assert.False(string.IsNullOrEmpty(this.OpenAiConfig.APIKey)); + + SqlServerConfig sqlServerConfig = cfg.GetSection("KernelMemory:Services:SqlServer").Get()!; + + builder = new KernelMemoryBuilder() + .With(new KernelMemoryConfig { DefaultIndexName = "default4tests" }) + .Configure(kmb => kmb.Services.AddLogging(b => { b.AddConsole().SetMinimumLevel(LogLevel.Trace); })) + .WithSearchClientConfig(new SearchClientConfig { EmptyAnswer = NotFound }) + .WithOpenAI(this.OpenAiConfig) + // .WithAzureOpenAITextGeneration(this.AzureOpenAITextConfiguration) + // .WithAzureOpenAITextEmbeddingGeneration(this.AzureOpenAIEmbeddingConfiguration) + .WithSqlServerMemoryDb(sqlServerConfig); + } var serviceProvider = builder.Services.BuildServiceProvider(); + this._memory = builder.Build(); this._memoryDb = serviceProvider.GetService()!; } diff --git a/extensions/SQLServer/SQLServer/SqlServerMemory.cs b/extensions/SQLServer/SQLServer/SqlServerMemory.cs index 704c9290e..4890f2571 100644 --- a/extensions/SQLServer/SQLServer/SqlServerMemory.cs +++ b/extensions/SQLServer/SQLServer/SqlServerMemory.cs @@ -682,54 +682,92 @@ private string GenerateFilters( { var filterBuilder = new StringBuilder(); - if (filters is null || filters.Count <= 0 || filters.All(f => f.Count <= 0)) + if (filters?.Any(f => !f.IsEmpty()) == true) { - return string.Empty; - } - - filterBuilder.Append("AND ( "); - - for (int i = 0; i < filters.Count; i++) - { - var filter = filters.ElementAt(i); + filterBuilder.Append("AND ( "); - if (i > 0) + var nonEmptyFilter = filters.Where(f => !f.IsEmpty()).ToArray(); + for (int i = 0; i < nonEmptyFilter.Length; i++) { - filterBuilder.Append(" OR "); - } - - for (int j = 0; j < filter.Pairs.Count(); j++) - { - var value = filter.Pairs.ElementAt(j); + var filter = nonEmptyFilter[i]; - if (j > 0) + if (i > 0) { - filterBuilder.Append(" AND "); + filterBuilder.Append(" OR "); } - filterBuilder.Append(" ( "); - - filterBuilder.Append(CultureInfo.CurrentCulture, $@"EXISTS ( - SELECT - 1 - FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tags] - WHERE - [tags].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id] - AND [name] = @filter_{i}_{j}_name - AND [value] = @filter_{i}_{j}_value - ) - "); - - filterBuilder.Append(" ) "); - - parameters.AddWithValue($"@filter_{i}_{j}_name", value.Key); - parameters.AddWithValue($"@filter_{i}_{j}_value", value.Value); + var validFilters = filter.GetFilters().ToArray(); + for (int j = 0; j < validFilters.Length; j++) + { + var baseFilter = validFilters[j]; + + if (j > 0) + { + filterBuilder.Append(" AND "); + } + + filterBuilder.Append(" ( "); + + if (baseFilter is EqualFilter eq) + { + filterBuilder.Append(CultureInfo.CurrentCulture, $@" + EXISTS ( + SELECT + 1 + FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tags] + WHERE + [tags].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id] + AND [name] = @filter_{i}_{j}_name + AND [value] = @filter_{i}_{j}_value + ) +" + ); + + //now that I've appended a clause, I need to add the parameters + + parameters.AddWithValue($"@filter_{i}_{j}_name", eq.Key); + parameters.AddWithValue($"@filter_{i}_{j}_value", eq.Value); + } + else if (baseFilter is NotEqualFilter neq) + { + //We change only EXISTS to NOT EXISTS but we left the structure of the code with a sequence + //of else if for future expansino of the query syntax where each filter will have its own clause + filterBuilder.Append(CultureInfo.CurrentCulture, $@" + NOT EXISTS ( + SELECT + 1 + FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tags] + WHERE + [tags].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id] + AND [name] = @filter_{i}_{j}_name + AND [value] = @filter_{i}_{j}_value + ) +" + ); + + //now that I've appended a clause, I need to add the parameters + + parameters.AddWithValue($"@filter_{i}_{j}_name", neq.Key); + parameters.AddWithValue($"@filter_{i}_{j}_value", neq.Value); + } + else + { + throw new SqlServerMemoryException($"Filter of type {baseFilter.GetType().Name} is not supported by redis"); + } + + filterBuilder.Append(" ) "); + } } - } - filterBuilder.Append(" )"); + filterBuilder.Append(" )"); - return filterBuilder.ToString(); + return filterBuilder.ToString(); + } + else + { + //no active filter + return string.Empty; + } } private async Task ReadEntryAsync(SqlDataReader dataReader, bool withEmbedding, CancellationToken cancellationToken = default) diff --git a/service/Abstractions/Models/MemoryFilter.cs b/service/Abstractions/Models/MemoryFilter.cs index d3ee2a1ed..a10d122b8 100644 --- a/service/Abstractions/Models/MemoryFilter.cs +++ b/service/Abstractions/Models/MemoryFilter.cs @@ -1,14 +1,21 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; namespace Microsoft.KernelMemory; public class MemoryFilter : TagCollection { + /// + /// This collection of tags contains all the tags that are used to + /// negatively filter out memory records. + /// + private readonly TagCollection _notTags = new(); + public bool IsEmpty() { - return this.Count == 0; + return this.Count == 0 && this._notTags.Count == 0; } public MemoryFilter ByTag(string name, string value) @@ -17,18 +24,57 @@ public MemoryFilter ByTag(string name, string value) return this; } + public MemoryFilter ByNotTag(string name, string value) + { + this._notTags.Add(name, value); + return this; + } + public MemoryFilter ByDocument(string docId) { this.Add(Constants.ReservedDocumentIdTag, docId); return this; } - public IEnumerable> GetFilters() + /// + /// Get a composition of all filters, And and Not. + /// + /// + public IEnumerable GetFilters() { - return this.ToKeyValueList(); + var equalFilters = this.Pairs + .Where(f => !string.IsNullOrEmpty(f.Value)) + .Select(pair => (BaseFilter)new EqualFilter(pair.Key, pair.Value!)); + + var notEqualFilters = this._notTags.Pairs + .Where(f => !string.IsNullOrEmpty(f.Value)) + .Select(pair => (BaseFilter)new NotEqualFilter(pair.Key, pair.Value!)); + + return equalFilters.Union(notEqualFilters); } } +/// +/// This is the base filter, which is used to create different types of filters +/// +/// +/// +public record BaseFilter(string Key, string Value); + +/// +/// Filter for equality, tag named must have the value +/// +/// +/// +public record EqualFilter(string Key, string Value) : BaseFilter(Key, Value); + +/// +/// Filter for inequality, tag named must not have the value +/// +/// +/// +public record NotEqualFilter(string Key, string Value) : BaseFilter(Key, Value); + /// /// Factory for , to allow for a simpler syntax /// Instead of: new MemoryFilter().ByDocument(id).ByTag(k, v) @@ -41,6 +87,18 @@ public static MemoryFilter ByTag(string name, string value) return new MemoryFilter().ByTag(name, value); } + /// + /// Filter for all memory records that do not have the specified tag with that + /// specific value. + /// + /// + /// + /// + public static MemoryFilter ByNotTag(string name, string value) + { + return new MemoryFilter().ByNotTag(name, value); + } + public static MemoryFilter ByDocument(string docId) { return new MemoryFilter().ByDocument(docId); diff --git a/service/Abstractions/Pipeline/MimeTypes.cs b/service/Abstractions/Pipeline/MimeTypes.cs index 59cc8f5c6..9d8cf34d3 100644 --- a/service/Abstractions/Pipeline/MimeTypes.cs +++ b/service/Abstractions/Pipeline/MimeTypes.cs @@ -72,6 +72,7 @@ public static class MimeTypes public const string ArchiveZip = "application/zip"; public const string ArchiveRar = "application/vnd.rar"; public const string Archive7Zip = "application/x-7z-compressed"; + public const string OctetStream = "application/octet-stream"; } public static class FileExtensions @@ -136,6 +137,7 @@ public static class FileExtensions public const string ArchiveZip = ".zip"; public const string ArchiveRar = ".rar"; public const string Archive7Zip = ".7z"; + public const string Bin = ".bin"; } public interface IMimeTypeDetection @@ -209,6 +211,7 @@ public class MimeTypesDetection : IMimeTypeDetection { FileExtensions.ArchiveZip, MimeTypes.ArchiveZip }, { FileExtensions.ArchiveRar, MimeTypes.ArchiveRar }, { FileExtensions.Archive7Zip, MimeTypes.Archive7Zip }, + { FileExtensions.Bin, MimeTypes.OctetStream } }; public string GetFileType(string filename) diff --git a/service/Core/MemoryStorage/DevTools/SimpleTextDb.cs b/service/Core/MemoryStorage/DevTools/SimpleTextDb.cs index 07c05c4b3..b37dcbe10 100644 --- a/service/Core/MemoryStorage/DevTools/SimpleTextDb.cs +++ b/service/Core/MemoryStorage/DevTools/SimpleTextDb.cs @@ -204,12 +204,22 @@ private static bool TagsMatchFilters(TagCollection tags, ICollection> condition in filter) + var allFilters = filter.GetFilters(); + foreach (var baseFilter in allFilters) { - // Check if the tag name + value is present - for (int index = 0; match && index < condition.Value.Count; index++) + if (baseFilter is EqualFilter eq) { - match = match && (tags.ContainsKey(condition.Key) && tags[condition.Key].Contains(condition.Value[index])); + // Verify that it contains key and value required + match = match && tags.ContainsKey(eq.Key) && tags[eq.Key].Contains(eq.Value); + } + else if (baseFilter is NotEqualFilter neq) + { + // Verify that tag is not contained at all, or if is present it doesn't contain the value + match = match && (!tags.ContainsKey(neq.Key) || !tags[neq.Key].Contains(neq.Value)); + } + else + { + throw new ArgumentException($"Unknown filter type {baseFilter.GetType()}"); } } diff --git a/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs b/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs index 0d421dfcb..c529f8b0e 100644 --- a/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs +++ b/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs @@ -212,12 +212,22 @@ private static bool TagsMatchFilters(TagCollection tags, ICollection> condition in filter) + var allFilters = filter.GetFilters(); + foreach (var baseFilter in allFilters) { - // Check if the tag name + value is present - for (int index = 0; match && index < condition.Value.Count; index++) + if (baseFilter is EqualFilter eq) { - match = match && (tags.ContainsKey(condition.Key) && tags[condition.Key].Contains(condition.Value[index])); + // Verify that it contains key and value required + match = match && tags.ContainsKey(eq.Key) && tags[eq.Key].Contains(eq.Value); + } + else if (baseFilter is NotEqualFilter neq) + { + // Verity that tag is not contained at all, or if is present it doesn't contain the value + match = match && (!tags.ContainsKey(neq.Key) || !tags[neq.Key].Contains(neq.Value)); + } + else + { + throw new ArgumentException($"Unknown filter type {baseFilter.GetType()}"); } } diff --git a/service/tests/Core.FunctionalTests/DefaultTestCases/FilteringTest.cs b/service/tests/Core.FunctionalTests/DefaultTestCases/FilteringTest.cs index df24307ca..76ed89ec1 100644 --- a/service/tests/Core.FunctionalTests/DefaultTestCases/FilteringTest.cs +++ b/service/tests/Core.FunctionalTests/DefaultTestCases/FilteringTest.cs @@ -55,6 +55,21 @@ await memory.ImportDocumentAsync( log(answer.Result); Assert.Contains(Found, answer.Result, StringComparison.OrdinalIgnoreCase); + // Simple filter: NOT the news. + answer = await memory.AskAsync("What is Orion?", filter: MemoryFilters.ByNotTag("type", "news"), index: indexName); + log(answer.Result); + Assert.Contains(NotFound, answer.Result, StringComparison.OrdinalIgnoreCase); + + // Simple filter: the memory is of the user but we do not want to use memory of that user. + answer = await memory.AskAsync("What is Orion?", filter: MemoryFilters.ByNotTag("user", "owner"), index: indexName); + log(answer.Result); + Assert.Contains(NotFound, answer.Result, StringComparison.OrdinalIgnoreCase); + + // not equality on a field where we have two names + answer = await memory.AskAsync("What is Orion?", filter: MemoryFilters.ByTag("user", "owner").ByNotTag("type", "news"), index: indexName); + log(answer.Result); + Assert.Contains(NotFound, answer.Result, StringComparison.OrdinalIgnoreCase); + // Simple filter: test AND logic with correct values answer = await memory.AskAsync("What is Orion?", filter: MemoryFilters.ByTag("type", "news").ByTag("user", "owner"), index: indexName); log(answer.Result); @@ -100,6 +115,25 @@ await memory.ImportDocumentAsync( log(answer.Result); Assert.Contains(NotFound, answer.Result, StringComparison.OrdinalIgnoreCase); + // Multiple filters: unknown users can se the memory if it is not a type secured + answer = await memory.AskAsync("What is Orion?", filters: new List + { + MemoryFilters.ByTag("user", "someone1"), + MemoryFilters.ByTag("user", "someone2"), + MemoryFilters.ByNotTag("type", "securenews"), + }, index: indexName); + log(answer.Result); + Assert.Contains(Found, answer.Result, StringComparison.OrdinalIgnoreCase); + + //Multiple filters: exclude two user (OR not equality on the same filter). + answer = await memory.AskAsync("What is Orion?", filters: new List + { + MemoryFilters.ByNotTag("user", "someone1"), + MemoryFilters.ByNotTag("user", "someone2") + }, index: indexName); + log(answer.Result); + Assert.Contains(Found, answer.Result, StringComparison.OrdinalIgnoreCase); + // Multiple filters: unknown users cannot see the memory even if the type is correct (testing AND logic) answer = await memory.AskAsync("What is Orion?", filters: new List {