From dcc3c76d56de9ada83fb831f3997c2784fca8afe Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 8 May 2025 14:47:52 -0400 Subject: [PATCH] Avoid caching in CachingChatClient when ConversationId is set --- .../ChatCompletion/CachingChatClient.cs | 32 +++++++++++++++-- .../DistributedCachingChatClientTest.cs | 36 +++++++++++-------- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 5aa70e4b262..211fc39ec85 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -9,6 +9,7 @@ using Microsoft.Shared.Diagnostics; #pragma warning disable S127 // "for" loop stop conditions should be invariant +#pragma warning disable SA1202 // Elements should be ordered by access namespace Microsoft.Extensions.AI; @@ -45,11 +46,19 @@ protected CachingChatClient(IChatClient innerClient) public bool CoalesceStreamingUpdates { get; set; } = true; /// - public override async Task GetResponseAsync( + public override Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(messages); + return UseCaching(options) ? + GetCachedResponseAsync(messages, options, cancellationToken) : + base.GetResponseAsync(messages, options, cancellationToken); + } + + private async Task GetCachedResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { // We're only storing the final result, not the in-flight task, so that we can avoid caching failures // or having problems when one of the callers cancels but others don't. This has the drawback that // concurrent callers might trigger duplicate requests, but that's acceptable. @@ -65,11 +74,19 @@ public override async Task GetResponseAsync( } /// - public override async IAsyncEnumerable GetStreamingResponseAsync( - IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public override IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(messages); + return UseCaching(options) ? + GetCachedStreamingResponseAsync(messages, options, cancellationToken) : + base.GetStreamingResponseAsync(messages, options, cancellationToken); + } + + private async IAsyncEnumerable GetCachedStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { if (CoalesceStreamingUpdates) { // When coalescing updates, we cache non-streaming results coalesced from streaming ones. That means @@ -178,4 +195,13 @@ public override async IAsyncEnumerable GetStreamingResponseA /// is . /// is . protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken); + + /// Determine whether to use caching with the request. + private static bool UseCaching(ChatOptions? options) + { + // We want to skip caching if options.ConversationId is set. If it's set, that implies there's + // some state that will impact the response and that's not represented in the messages. Since + // that state could change even with the same ID, we have to assume caching isn't valid. + return options?.ConversationId is null; + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index 374e617adba..4f2427d133c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -32,10 +32,13 @@ public void Ctor_ExpectedDefaults() Assert.True(cachingClient.CoalesceStreamingUpdates); } - [Fact] - public async Task CachesSuccessResultsAsync() + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task CachesSuccessResultsAsync(bool conversationIdSet) { // Arrange + ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null }; // Verify that all the expected properties will round-trip through the cache, // even if this involves serialization @@ -82,20 +85,20 @@ public async Task CachesSuccessResultsAsync() }; // Make the initial request and do a quick sanity check - var result1 = await outer.GetResponseAsync("some input"); + var result1 = await outer.GetResponseAsync("some input", options); Assert.Same(expectedResponse, result1); Assert.Equal(1, innerCallCount); // Act - var result2 = await outer.GetResponseAsync("some input"); + var result2 = await outer.GetResponseAsync("some input", options); // Assert - Assert.Equal(1, innerCallCount); + Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount); AssertResponsesEqual(expectedResponse, result2); // Act/Assert 2: Cache misses do not return cached results - await outer.GetResponseAsync("some modified input"); - Assert.Equal(2, innerCallCount); + await outer.GetResponseAsync("some modified input", options); + Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount); } [Fact] @@ -207,10 +210,13 @@ public async Task DoesNotCacheCanceledResultsAsync() Assert.Equal("A good result", result2.Text); } - [Fact] - public async Task StreamingCachesSuccessResultsAsync() + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task StreamingCachesSuccessResultsAsync(bool conversationIdSet) { // Arrange + ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null }; // Verify that all the expected properties will round-trip through the cache, // even if this involves serialization @@ -255,20 +261,20 @@ public async Task StreamingCachesSuccessResultsAsync() }; // Make the initial request and do a quick sanity check - var result1 = outer.GetStreamingResponseAsync("some input"); + var result1 = outer.GetStreamingResponseAsync("some input", options); await AssertResponsesEqualAsync(actualUpdate, result1); Assert.Equal(1, innerCallCount); // Act - var result2 = outer.GetStreamingResponseAsync("some input"); + var result2 = outer.GetStreamingResponseAsync("some input", options); // Assert - Assert.Equal(1, innerCallCount); - await AssertResponsesEqualAsync(expectedCachedResponse, result2); + Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount); + await AssertResponsesEqualAsync(conversationIdSet ? actualUpdate : expectedCachedResponse, result2); // Act/Assert 2: Cache misses do not return cached results - await ToListAsync(outer.GetStreamingResponseAsync("some modified input")); - Assert.Equal(2, innerCallCount); + await ToListAsync(outer.GetStreamingResponseAsync("some modified input", options)); + Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount); } [Theory]