diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs
index 5aa70e4b262..211fc39ec85 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs
@@ -9,6 +9,7 @@
using Microsoft.Shared.Diagnostics;
#pragma warning disable S127 // "for" loop stop conditions should be invariant
+#pragma warning disable SA1202 // Elements should be ordered by access
namespace Microsoft.Extensions.AI;
@@ -45,11 +46,19 @@ protected CachingChatClient(IChatClient innerClient)
public bool CoalesceStreamingUpdates { get; set; } = true;
///
- public override async Task GetResponseAsync(
+ public override Task GetResponseAsync(
IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(messages);
+ return UseCaching(options) ?
+ GetCachedResponseAsync(messages, options, cancellationToken) :
+ base.GetResponseAsync(messages, options, cancellationToken);
+ }
+
+ private async Task GetCachedResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
// We're only storing the final result, not the in-flight task, so that we can avoid caching failures
// or having problems when one of the callers cancels but others don't. This has the drawback that
// concurrent callers might trigger duplicate requests, but that's acceptable.
@@ -65,11 +74,19 @@ public override async Task GetResponseAsync(
}
///
- public override async IAsyncEnumerable GetStreamingResponseAsync(
- IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public override IAsyncEnumerable GetStreamingResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(messages);
+ return UseCaching(options) ?
+ GetCachedStreamingResponseAsync(messages, options, cancellationToken) :
+ base.GetStreamingResponseAsync(messages, options, cancellationToken);
+ }
+
+ private async IAsyncEnumerable GetCachedStreamingResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
if (CoalesceStreamingUpdates)
{
// When coalescing updates, we cache non-streaming results coalesced from streaming ones. That means
@@ -178,4 +195,13 @@ public override async IAsyncEnumerable GetStreamingResponseA
/// is .
/// is .
protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken);
+
+ /// Determine whether to use caching with the request.
+ private static bool UseCaching(ChatOptions? options)
+ {
+ // We want to skip caching if options.ConversationId is set. If it's set, that implies there's
+ // some state that will impact the response and that's not represented in the messages. Since
+ // that state could change even with the same ID, we have to assume caching isn't valid.
+ return options?.ConversationId is null;
+ }
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs
index 374e617adba..4f2427d133c 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs
@@ -32,10 +32,13 @@ public void Ctor_ExpectedDefaults()
Assert.True(cachingClient.CoalesceStreamingUpdates);
}
- [Fact]
- public async Task CachesSuccessResultsAsync()
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public async Task CachesSuccessResultsAsync(bool conversationIdSet)
{
// Arrange
+ ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null };
// Verify that all the expected properties will round-trip through the cache,
// even if this involves serialization
@@ -82,20 +85,20 @@ public async Task CachesSuccessResultsAsync()
};
// Make the initial request and do a quick sanity check
- var result1 = await outer.GetResponseAsync("some input");
+ var result1 = await outer.GetResponseAsync("some input", options);
Assert.Same(expectedResponse, result1);
Assert.Equal(1, innerCallCount);
// Act
- var result2 = await outer.GetResponseAsync("some input");
+ var result2 = await outer.GetResponseAsync("some input", options);
// Assert
- Assert.Equal(1, innerCallCount);
+ Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
AssertResponsesEqual(expectedResponse, result2);
// Act/Assert 2: Cache misses do not return cached results
- await outer.GetResponseAsync("some modified input");
- Assert.Equal(2, innerCallCount);
+ await outer.GetResponseAsync("some modified input", options);
+ Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount);
}
[Fact]
@@ -207,10 +210,13 @@ public async Task DoesNotCacheCanceledResultsAsync()
Assert.Equal("A good result", result2.Text);
}
- [Fact]
- public async Task StreamingCachesSuccessResultsAsync()
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public async Task StreamingCachesSuccessResultsAsync(bool conversationIdSet)
{
// Arrange
+ ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null };
// Verify that all the expected properties will round-trip through the cache,
// even if this involves serialization
@@ -255,20 +261,20 @@ public async Task StreamingCachesSuccessResultsAsync()
};
// Make the initial request and do a quick sanity check
- var result1 = outer.GetStreamingResponseAsync("some input");
+ var result1 = outer.GetStreamingResponseAsync("some input", options);
await AssertResponsesEqualAsync(actualUpdate, result1);
Assert.Equal(1, innerCallCount);
// Act
- var result2 = outer.GetStreamingResponseAsync("some input");
+ var result2 = outer.GetStreamingResponseAsync("some input", options);
// Assert
- Assert.Equal(1, innerCallCount);
- await AssertResponsesEqualAsync(expectedCachedResponse, result2);
+ Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
+ await AssertResponsesEqualAsync(conversationIdSet ? actualUpdate : expectedCachedResponse, result2);
// Act/Assert 2: Cache misses do not return cached results
- await ToListAsync(outer.GetStreamingResponseAsync("some modified input"));
- Assert.Equal(2, innerCallCount);
+ await ToListAsync(outer.GetStreamingResponseAsync("some modified input", options));
+ Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount);
}
[Theory]