diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ResponseCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ResponseCachingChatClient.cs index 79baf3be88c..bc49d76e1be 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ResponseCachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ResponseCachingChatClient.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; @@ -13,7 +12,6 @@ namespace Microsoft.Extensions.AI.Evaluation.Reporting; internal sealed class ResponseCachingChatClient : DistributedCachingChatClient { - private readonly IReadOnlyList _cachingKeys; private readonly ChatDetails _chatDetails; private readonly ConcurrentDictionary _stopWatches; @@ -24,7 +22,7 @@ internal ResponseCachingChatClient( ChatDetails chatDetails) : base(originalChatClient, cache) { - _cachingKeys = [.. cachingKeys]; + CacheKeyAdditionalValues = [.. cachingKeys]; _chatDetails = chatDetails; _stopWatches = new ConcurrentDictionary(); } @@ -124,7 +122,4 @@ protected override async Task WriteCacheStreamingAsync( cacheHit: false)); } } - - protected override string GetCacheKey(IEnumerable messages, ChatOptions? options, params ReadOnlySpan additionalValues) - => base.GetCacheKey(messages, options, [.. additionalValues, .. _cachingKeys]); } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs index 9fb586f5b79..afaa12235ec 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs @@ -2,7 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -34,9 +36,16 @@ namespace Microsoft.Extensions.AI; /// public class DistributedCachingChatClient : CachingChatClient { + /// Boxed cache version. + /// Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way. + private static readonly object _cacheVersion = 2; + /// The instance that will be used as the backing store for the cache. private readonly IDistributedCache _storage; + /// Additional values used to inform the cache key employed for storing state. + private object[]? _cacheKeyAdditionalValues; + /// The to use when serializing cache data. private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions; @@ -56,6 +65,14 @@ public JsonSerializerOptions JsonSerializerOptions set => _jsonSerializerOptions = Throw.IfNull(value); } + /// Gets or sets additional values used to inform the cache key employed for storing state. + /// Any values set in this list will augment the other values used to inform the cache key. + public IReadOnlyList? CacheKeyAdditionalValues + { + get => _cacheKeyAdditionalValues; + set => _cacheKeyAdditionalValues = value?.ToArray(); + } + /// protected override async Task ReadCacheAsync(string key, CancellationToken cancellationToken) { @@ -122,9 +139,26 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList /// protected override string GetCacheKey(IEnumerable messages, ChatOptions? options, params ReadOnlySpan additionalValues) { - // Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way. - const int CacheVersion = 2; + const int FixedValuesCount = 3; + + object[] clientValues = _cacheKeyAdditionalValues ?? Array.Empty(); + int length = FixedValuesCount + additionalValues.Length + clientValues.Length; - return AIJsonUtilities.HashDataToString([CacheVersion, messages, options, .. additionalValues], _jsonSerializerOptions); + object?[] arr = ArrayPool.Shared.Rent(length); + try + { + arr[0] = _cacheVersion; + arr[1] = messages; + arr[2] = options; + additionalValues.CopyTo(arr.AsSpan(FixedValuesCount)); + clientValues.CopyTo(arr, FixedValuesCount + additionalValues.Length); + + return AIJsonUtilities.HashDataToString(arr.AsSpan(0, length), _jsonSerializerOptions); + } + finally + { + Array.Clear(arr, 0, length); + ArrayPool.Shared.Return(arr); + } } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs index 2c880d7a22c..926378ad517 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs @@ -54,6 +54,11 @@ public override async Task> GenerateAsync( Throw.InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); } + if (generated[0] is null) + { + Throw.InvalidOperationException("Generator produced null embedding."); + } + await WriteCacheAsync(cacheKey, generated[0], cancellationToken); return generated; } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs index cd26879d040..7da9671554b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Text.Json.Serialization.Metadata; using System.Threading; @@ -24,7 +27,17 @@ namespace Microsoft.Extensions.AI; public class DistributedCachingEmbeddingGenerator : CachingEmbeddingGenerator where TEmbedding : Embedding { + /// Boxed cache version. + /// Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way. + private static readonly object _cacheVersion = 2; + + /// The instance that will be used as the backing store for the cache. private readonly IDistributedCache _storage; + + /// Additional values used to inform the cache key employed for storing state. + private object[]? _cacheKeyAdditionalValues; + + /// Additional cache key values used to inform the key employed for storing state. private JsonSerializerOptions _jsonSerializerOptions; /// Initializes a new instance of the class. @@ -51,6 +64,14 @@ public JsonSerializerOptions JsonSerializerOptions } } + /// Gets or sets additional values used to inform the cache key employed for storing state. + /// Any values set in this list will augment the other values used to inform the cache key. + public IReadOnlyList? CacheKeyAdditionalValues + { + get => _cacheKeyAdditionalValues; + set => _cacheKeyAdditionalValues = value?.ToArray(); + } + /// protected override async Task ReadCacheAsync(string key, CancellationToken cancellationToken) { @@ -87,6 +108,26 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc /// The generated cache key is not guaranteed to be stable across releases of the library. /// /// - protected override string GetCacheKey(params ReadOnlySpan values) => - AIJsonUtilities.HashDataToString(values, _jsonSerializerOptions); + protected override string GetCacheKey(params ReadOnlySpan values) + { + const int FixedValuesCount = 1; + + object[] clientValues = _cacheKeyAdditionalValues ?? Array.Empty(); + int length = FixedValuesCount + clientValues.Length + values.Length; + + object?[] arr = ArrayPool.Shared.Rent(length); + try + { + arr[0] = _cacheVersion; + values.CopyTo(arr.AsSpan(FixedValuesCount)); + clientValues.CopyTo(arr, FixedValuesCount + values.Length); + + return AIJsonUtilities.HashDataToString(arr.AsSpan(0, length), _jsonSerializerOptions); + } + finally + { + Array.Clear(arr, 0, length); + ArrayPool.Shared.Return(arr); + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json index f7f246eb35c..4f4317c9978 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json @@ -310,6 +310,10 @@ } ], "Properties": [ + { + "Member": "System.Collections.Generic.IReadOnlyList? Microsoft.Extensions.AI.DistributedCachingChatClient.CacheKeyAdditionalValues { get; set; }", + "Stage": "Stable" + }, { "Member": "System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.DistributedCachingChatClient.JsonSerializerOptions { get; set; }", "Stage": "Stable" @@ -351,6 +355,10 @@ } ], "Properties": [ + { + "Member": "System.Collections.Generic.IReadOnlyList? Microsoft.Extensions.AI.DistributedCachingEmbeddingGenerator.CacheKeyAdditionalValues { get; set; }", + "Stage": "Stable" + }, { "Member": "System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.DistributedCachingEmbeddingGenerator.JsonSerializerOptions { get; set; }", "Stage": "Stable" diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index 2c755da7be9..4b6f9bc87e6 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -595,6 +595,52 @@ public async Task CacheKeyVariesByChatOptionsAsync() Assert.Equal("value 2", result4.Text); } + [Fact] + public async Task CacheKeyVariesByAdditionalKeyValuesAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + GetResponseAsyncCallback = async (_, options, _) => + { + innerCallCount++; + await Task.Yield(); + return new(new ChatMessage(ChatRole.Assistant, innerCallCount.ToString())); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + var result1 = await outer.GetResponseAsync([]); + var result2 = await outer.GetResponseAsync([]); + + Assert.Equal(1, innerCallCount); + Assert.Equal("1", result1.Text); + Assert.Equal("1", result2.Text); + + // Change key + outer.CacheKeyAdditionalValues = ["extraKey"]; + + var result3 = await outer.GetResponseAsync([]); + var result4 = await outer.GetResponseAsync([]); + + Assert.Equal(2, innerCallCount); + Assert.Equal("2", result3.Text); + Assert.Equal("2", result4.Text); + + // Remove key + outer.CacheKeyAdditionalValues = []; + + var result5 = await outer.GetResponseAsync([]); + + Assert.Equal(2, innerCallCount); + Assert.Equal("1", result5.Text); + } + [Fact] public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs index 6153ec8ab45..b14d3de83a9 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs @@ -21,6 +21,24 @@ public class DistributedCachingEmbeddingGeneratorTest AdditionalProperties = new() { ["a"] = "b" }, }; + [Fact] + public void Properties_Roundtrip() + { + using var innerGenerator = new TestEmbeddingGenerator(); + using DistributedCachingEmbeddingGenerator> generator = new(innerGenerator, _storage); + + Assert.Same(AIJsonUtilities.DefaultOptions, generator.JsonSerializerOptions); + var jso = new JsonSerializerOptions(); + generator.JsonSerializerOptions = jso; + Assert.Same(jso, generator.JsonSerializerOptions); + + Assert.Null(generator.CacheKeyAdditionalValues); + var additionalValues = new[] { "value1", "value2" }; + generator.CacheKeyAdditionalValues = additionalValues; + Assert.NotSame(additionalValues, generator.CacheKeyAdditionalValues); + Assert.Equal(additionalValues, generator.CacheKeyAdditionalValues); + } + [Fact] public async Task CachesSuccessResultsAsync() { @@ -271,6 +289,49 @@ public async Task CacheKeyVariesByEmbeddingOptionsAsync() AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4); } + [Fact] + public async Task CacheKeyVariesByAdditionalKeyValuesAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + await Task.Yield(); + return new(new Embedding[] { new Embedding(new float[] { innerCallCount }) }); + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + var result1 = await outer.GenerateAsync("abc"); + var result2 = await outer.GenerateAsync("abc"); + AssertEmbeddingsEqual(result1, result2); + + var result3 = await outer.GenerateAsync("abc"); + AssertEmbeddingsEqual(result1, result3); + + // Change key + outer.CacheKeyAdditionalValues = ["extraKey"]; + + var result4 = await outer.GenerateAsync("abc"); + Assert.NotEqual(result1.Vector.ToArray(), result4.Vector.ToArray()); + + var result5 = await outer.GenerateAsync("abc"); + AssertEmbeddingsEqual(result4, result5); + + // Remove key + outer.CacheKeyAdditionalValues = []; + + var result6 = await outer.GenerateAsync("abc"); + AssertEmbeddingsEqual(result1, result6); + } + [Fact] public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync() {