diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs new file mode 100644 index 00000000000..05ac80dd682 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdateExtensions.cs @@ -0,0 +1,212 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +#if NET +using System.Runtime.InteropServices; +#endif +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable S109 // Magic numbers should not be used +#pragma warning disable S127 // "for" loop stop conditions should be invariant + +namespace Microsoft.Extensions.AI; + +/// +/// Provides extension methods for working with instances. +/// +public static class StreamingChatCompletionUpdateExtensions +{ + /// Combines instances into a single . + /// The updates to be combined. + /// + /// to attempt to coalesce contiguous items, where applicable, + /// into a single , in order to reduce the number of individual content items that are included in + /// the manufactured instances. When , the original content items are used. + /// The default is . + /// + /// The combined . + public static ChatCompletion ToChatCompletion( + this IEnumerable updates, bool coalesceContent = true) + { + _ = Throw.IfNull(updates); + + ChatCompletion completion = new([]); + Dictionary messages = []; + + foreach (var update in updates) + { + ProcessUpdate(update, messages, completion); + } + + AddMessagesToCompletion(messages, completion, coalesceContent); + + return completion; + } + + /// Combines instances into a single . + /// The updates to be combined. + /// + /// to attempt to coalesce contiguous items, where applicable, + /// into a single , in order to reduce the number of individual content items that are included in + /// the manufactured instances. When , the original content items are used. + /// The default is . + /// + /// The to monitor for cancellation requests. The default is . + /// The combined . + public static Task ToChatCompletionAsync( + this IAsyncEnumerable updates, bool coalesceContent = true, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(updates); + + return ToChatCompletionAsync(updates, coalesceContent, cancellationToken); + + static async Task ToChatCompletionAsync( + IAsyncEnumerable updates, bool coalesceContent, CancellationToken cancellationToken) + { + ChatCompletion completion = new([]); + Dictionary messages = []; + + await foreach (var update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + ProcessUpdate(update, messages, completion); + } + + AddMessagesToCompletion(messages, completion, coalesceContent); + + return completion; + } + } + + /// Processes the , incorporating its contents into and . + /// The update to process. + /// The dictionary mapping to the being built for that choice. + /// The object whose properties should be updated based on . + private static void ProcessUpdate(StreamingChatCompletionUpdate update, Dictionary messages, ChatCompletion completion) + { + completion.CompletionId ??= update.CompletionId; + completion.CreatedAt ??= update.CreatedAt; + completion.FinishReason ??= update.FinishReason; + completion.ModelId ??= update.ModelId; + +#if NET + ChatMessage message = CollectionsMarshal.GetValueRefOrAddDefault(messages, update.ChoiceIndex, out _) ??= + new(default, new List()); +#else + if (!messages.TryGetValue(update.ChoiceIndex, out ChatMessage? message)) + { + messages[update.ChoiceIndex] = message = new(default, new List()); + } +#endif + + ((List)message.Contents).AddRange(update.Contents); + + message.AuthorName ??= update.AuthorName; + if (update.Role is ChatRole role && message.Role == default) + { + message.Role = role; + } + + if (update.AdditionalProperties is not null) + { + if (message.AdditionalProperties is null) + { + message.AdditionalProperties = new(update.AdditionalProperties); + } + else + { + foreach (var entry in update.AdditionalProperties) + { + // Use first-wins behavior to match the behavior of the other properties. + _ = message.AdditionalProperties.TryAdd(entry.Key, entry.Value); + } + } + } + } + + /// Finalizes the object by transferring the into it. + /// The messages to process further and transfer into . + /// The result being built. + /// The corresponding option value provided to or . + private static void AddMessagesToCompletion(Dictionary messages, ChatCompletion completion, bool coalesceContent) + { + foreach (var entry in messages) + { + if (entry.Value.Role == default) + { + entry.Value.Role = ChatRole.Assistant; + } + + if (coalesceContent) + { + CoalesceTextContent((List)entry.Value.Contents); + } + + completion.Choices.Add(entry.Value); + + if (completion.Usage is null) + { + foreach (var content in entry.Value.Contents) + { + if (content is UsageContent c) + { + completion.Usage = c.Details; + break; + } + } + } + } + } + + /// Coalesces sequential content elements. + private static void CoalesceTextContent(List contents) + { + StringBuilder? coalescedText = null; + + // Iterate through all of the items in the list looking for contiguous items that can be coalesced. + int start = 0; + while (start < contents.Count - 1) + { + // We need at least two TextContents in a row to be able to coalesce. + if (contents[start] is not TextContent firstText) + { + start++; + continue; + } + + if (contents[start + 1] is not TextContent secondText) + { + start += 2; + continue; + } + + // Append the text from those nodes and continue appending subsequent TextContents until we run out. + // We null out nodes as their text is appended so that we can later remove them all in one O(N) operation. + coalescedText ??= new(); + _ = coalescedText.Clear().Append(firstText.Text).Append(secondText.Text); + contents[start + 1] = null!; + int i = start + 2; + for (; i < contents.Count && contents[i] is TextContent next; i++) + { + _ = coalescedText.Append(next.Text); + contents[i] = null!; + } + + // Store the replacement node. + contents[start] = new TextContent(coalescedText.ToString()) + { + // We inherit the properties of the first text node. We don't currently propagate additional + // properties from the subsequent nodes. If we ever need to, we can add that here. + AdditionalProperties = firstText.AdditionalProperties?.Clone(), + }; + + start = i; + } + + // Remove all of the null slots left over from the coalescing process. + _ = contents.RemoveAll(u => u is null); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs new file mode 100644 index 00000000000..bb0f08325d5 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateExtensionsTests.cs @@ -0,0 +1,200 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +#pragma warning disable SA1204 // Static elements should appear before instance elements + +namespace Microsoft.Extensions.AI; + +public class StreamingChatCompletionUpdateExtensionsTests +{ + [Fact] + public void InvalidArgs_Throws() + { + Assert.Throws("updates", () => ((List)null!).ToChatCompletion()); + } + + public static IEnumerable ToChatCompletion_SuccessfullyCreatesCompletion_MemberData() + { + foreach (bool useAsync in new[] { false, true }) + { + foreach (bool? coalesceContent in new bool?[] { null, false, true }) + { + yield return new object?[] { useAsync, coalesceContent }; + } + } + } + + [Theory] + [MemberData(nameof(ToChatCompletion_SuccessfullyCreatesCompletion_MemberData))] + public async Task ToChatCompletion_SuccessfullyCreatesCompletion(bool useAsync, bool? coalesceContent) + { + StreamingChatCompletionUpdate[] updates = + [ + new() { ChoiceIndex = 0, Text = "Hello", CompletionId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model123" }, + new() { ChoiceIndex = 1, Text = "Hey", CompletionId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model124" }, + + new() { ChoiceIndex = 0, Text = ", ", AuthorName = "Someone", Role = ChatRole.User, AdditionalProperties = new() { ["a"] = "b" } }, + new() { ChoiceIndex = 1, Text = ", ", AuthorName = "Else", Role = ChatRole.System, AdditionalProperties = new() { ["g"] = "h" } }, + + new() { ChoiceIndex = 0, Text = "world!", CreatedAt = new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), AdditionalProperties = new() { ["c"] = "d" } }, + new() { ChoiceIndex = 1, Text = "you!", Role = ChatRole.Tool, CreatedAt = new DateTimeOffset(3, 2, 3, 4, 5, 6, TimeSpan.Zero), AdditionalProperties = new() { ["e"] = "f", ["i"] = 42 } }, + + new() { ChoiceIndex = 0, Contents = new[] { new UsageContent(new() { InputTokenCount = 1, OutputTokenCount = 2 }) } }, + new() { ChoiceIndex = 3, Contents = new[] { new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 }) } }, + ]; + + ChatCompletion completion = (coalesceContent is bool, useAsync) switch + { + (false, false) => updates.ToChatCompletion(), + (false, true) => await YieldAsync(updates).ToChatCompletionAsync(), + + (true, false) => updates.ToChatCompletion(coalesceContent.GetValueOrDefault()), + (true, true) => await YieldAsync(updates).ToChatCompletionAsync(coalesceContent.GetValueOrDefault()), + }; + Assert.NotNull(completion); + + Assert.Equal("12345", completion.CompletionId); + Assert.Equal(new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), completion.CreatedAt); + Assert.Equal("model123", completion.ModelId); + Assert.Same(Assert.IsType(updates[6].Contents[0]).Details, completion.Usage); + + Assert.Equal(3, completion.Choices.Count); + + ChatMessage message = completion.Choices[0]; + Assert.Equal(ChatRole.User, message.Role); + Assert.Equal("Someone", message.AuthorName); + Assert.NotNull(message.AdditionalProperties); + Assert.Equal(2, message.AdditionalProperties.Count); + Assert.Equal("b", message.AdditionalProperties["a"]); + Assert.Equal("d", message.AdditionalProperties["c"]); + + message = completion.Choices[1]; + Assert.Equal(ChatRole.System, message.Role); + Assert.Equal("Else", message.AuthorName); + Assert.NotNull(message.AdditionalProperties); + Assert.Equal(3, message.AdditionalProperties.Count); + Assert.Equal("h", message.AdditionalProperties["g"]); + Assert.Equal("f", message.AdditionalProperties["e"]); + Assert.Equal(42, message.AdditionalProperties["i"]); + + message = completion.Choices[2]; + Assert.Equal(ChatRole.Assistant, message.Role); + Assert.Null(message.AuthorName); + Assert.Null(message.AdditionalProperties); + Assert.Same(updates[7].Contents[0], Assert.Single(message.Contents)); + + if (coalesceContent is null or true) + { + Assert.Equal("Hello, world!", completion.Choices[0].Text); + Assert.Equal("Hey, you!", completion.Choices[1].Text); + Assert.Null(completion.Choices[2].Text); + } + else + { + Assert.Equal("Hello", completion.Choices[0].Contents[0].ToString()); + Assert.Equal(", ", completion.Choices[0].Contents[1].ToString()); + Assert.Equal("world!", completion.Choices[0].Contents[2].ToString()); + + Assert.Equal("Hey", completion.Choices[1].Contents[0].ToString()); + Assert.Equal(", ", completion.Choices[1].Contents[1].ToString()); + Assert.Equal("you!", completion.Choices[1].Contents[2].ToString()); + + Assert.Null(completion.Choices[2].Text); + } + } + + public static IEnumerable ToChatCompletion_Coalescing_VariousSequenceAndGapLengths_MemberData() + { + foreach (bool useAsync in new[] { false, true }) + { + for (int numSequences = 1; numSequences <= 3; numSequences++) + { + for (int sequenceLength = 1; sequenceLength <= 3; sequenceLength++) + { + for (int gapLength = 1; gapLength <= 3; gapLength++) + { + foreach (bool gapBeginningEnd in new[] { false, true }) + { + yield return new object[] { useAsync, numSequences, sequenceLength, gapLength, false }; + } + } + } + } + } + } + + [Theory] + [MemberData(nameof(ToChatCompletion_Coalescing_VariousSequenceAndGapLengths_MemberData))] + public async Task ToChatCompletion_Coalescing_VariousSequenceAndGapLengths(bool useAsync, int numSequences, int sequenceLength, int gapLength, bool gapBeginningEnd) + { + List updates = []; + + List expected = []; + + if (gapBeginningEnd) + { + AddGap(); + } + + for (int sequenceNum = 0; sequenceNum < numSequences; sequenceNum++) + { + StringBuilder sb = new(); + for (int i = 0; i < sequenceLength; i++) + { + string text = $"{(char)('A' + sequenceNum)}{i}"; + updates.Add(new() { Text = text }); + sb.Append(text); + } + + expected.Add(sb.ToString()); + + if (sequenceNum < numSequences - 1) + { + AddGap(); + } + } + + if (gapBeginningEnd) + { + AddGap(); + } + + void AddGap() + { + for (int i = 0; i < gapLength; i++) + { + updates.Add(new() { Contents = [new ImageContent("https://uri")] }); + } + } + + ChatCompletion completion = useAsync ? await YieldAsync(updates).ToChatCompletionAsync() : updates.ToChatCompletion(); + Assert.Single(completion.Choices); + + ChatMessage message = completion.Message; + Assert.Equal(expected.Count + (gapLength * ((numSequences - 1) + (gapBeginningEnd ? 2 : 0))), message.Contents.Count); + + TextContent[] contents = message.Contents.OfType().ToArray(); + Assert.Equal(expected.Count, contents.Length); + for (int i = 0; i < expected.Count; i++) + { + Assert.Equal(expected[i], contents[i].Text); + } + } + + private static async IAsyncEnumerable YieldAsync(IEnumerable updates) + { + foreach (StreamingChatCompletionUpdate update in updates) + { + await Task.Yield(); + yield return update; + } + } +}