Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.Shared.Diagnostics;

#pragma warning disable S127 // "for" loop stop conditions should be invariant
#pragma warning disable SA1202 // Elements should be ordered by access

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -45,11 +46,19 @@ protected CachingChatClient(IChatClient innerClient)
public bool CoalesceStreamingUpdates { get; set; } = true;

/// <inheritdoc />
public override async Task<ChatResponse> GetResponseAsync(
public override Task<ChatResponse> GetResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(messages);

return UseCaching(options) ?
GetCachedResponseAsync(messages, options, cancellationToken) :
base.GetResponseAsync(messages, options, cancellationToken);
}

private async Task<ChatResponse> GetCachedResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
// We're only storing the final result, not the in-flight task, so that we can avoid caching failures
// or having problems when one of the callers cancels but others don't. This has the drawback that
// concurrent callers might trigger duplicate requests, but that's acceptable.
Expand All @@ -65,11 +74,19 @@ public override async Task<ChatResponse> GetResponseAsync(
}

/// <inheritdoc />
public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public override IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(messages);

return UseCaching(options) ?
GetCachedStreamingResponseAsync(messages, options, cancellationToken) :
base.GetStreamingResponseAsync(messages, options, cancellationToken);
}

private async IAsyncEnumerable<ChatResponseUpdate> GetCachedStreamingResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (CoalesceStreamingUpdates)
{
// When coalescing updates, we cache non-streaming results coalesced from streaming ones. That means
Expand Down Expand Up @@ -178,4 +195,13 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
/// <exception cref="ArgumentNullException"><paramref name="key"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList<ChatResponseUpdate> value, CancellationToken cancellationToken);

/// <summary>Determine whether to use caching with the request.</summary>
private static bool UseCaching(ChatOptions? options)
{
// We want to skip caching if options.ConversationId is set. If it's set, that implies there's
// some state that will impact the response and that's not represented in the messages. Since
// that state could change even with the same ID, we have to assume caching isn't valid.
return options?.ConversationId is null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ public void Ctor_ExpectedDefaults()
Assert.True(cachingClient.CoalesceStreamingUpdates);
}

[Fact]
public async Task CachesSuccessResultsAsync()
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task CachesSuccessResultsAsync(bool conversationIdSet)
{
// Arrange
ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null };

// Verify that all the expected properties will round-trip through the cache,
// even if this involves serialization
Expand Down Expand Up @@ -82,20 +85,20 @@ public async Task CachesSuccessResultsAsync()
};

// Make the initial request and do a quick sanity check
var result1 = await outer.GetResponseAsync("some input");
var result1 = await outer.GetResponseAsync("some input", options);
Assert.Same(expectedResponse, result1);
Assert.Equal(1, innerCallCount);

// Act
var result2 = await outer.GetResponseAsync("some input");
var result2 = await outer.GetResponseAsync("some input", options);

// Assert
Assert.Equal(1, innerCallCount);
Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
AssertResponsesEqual(expectedResponse, result2);

// Act/Assert 2: Cache misses do not return cached results
await outer.GetResponseAsync("some modified input");
Assert.Equal(2, innerCallCount);
await outer.GetResponseAsync("some modified input", options);
Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount);
}

[Fact]
Expand Down Expand Up @@ -207,10 +210,13 @@ public async Task DoesNotCacheCanceledResultsAsync()
Assert.Equal("A good result", result2.Text);
}

[Fact]
public async Task StreamingCachesSuccessResultsAsync()
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task StreamingCachesSuccessResultsAsync(bool conversationIdSet)
{
// Arrange
ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null };

// Verify that all the expected properties will round-trip through the cache,
// even if this involves serialization
Expand Down Expand Up @@ -255,20 +261,20 @@ public async Task StreamingCachesSuccessResultsAsync()
};

// Make the initial request and do a quick sanity check
var result1 = outer.GetStreamingResponseAsync("some input");
var result1 = outer.GetStreamingResponseAsync("some input", options);
await AssertResponsesEqualAsync(actualUpdate, result1);
Assert.Equal(1, innerCallCount);

// Act
var result2 = outer.GetStreamingResponseAsync("some input");
var result2 = outer.GetStreamingResponseAsync("some input", options);

// Assert
Assert.Equal(1, innerCallCount);
await AssertResponsesEqualAsync(expectedCachedResponse, result2);
Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
await AssertResponsesEqualAsync(conversationIdSet ? actualUpdate : expectedCachedResponse, result2);

// Act/Assert 2: Cache misses do not return cached results
await ToListAsync(outer.GetStreamingResponseAsync("some modified input"));
Assert.Equal(2, innerCallCount);
await ToListAsync(outer.GetStreamingResponseAsync("some modified input", options));
Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount);
}

[Theory]
Expand Down
Loading