diff --git a/docs/decisions/0023-kernel-streaming.md b/docs/decisions/0023-kernel-streaming.md new file mode 100644 index 000000000000..c4116cb5eaa9 --- /dev/null +++ b/docs/decisions/0023-kernel-streaming.md @@ -0,0 +1,350 @@ +--- +# These are optional elements. Feel free to remove any of them. +status: proposed +date: 2023-11-13 +deciders: rogerbarreto,markwallace-microsoft,SergeyMenshykh,dmytrostruk +consulted: +informed: +--- + +# Streaming Capability for Kernel and Functions usage - Phase 1 + +## Context and Problem Statement + +It is quite common in co-pilot implementations to have a streamlined output of messages from the LLM (large language models)M and currently that is not possible while using ISKFunctions.InvokeAsync or Kernel.RunAsync methods, which enforces users to work around the Kernel and Functions to use `ITextCompletion` and `IChatCompletion` services directly as the only interfaces that currently support streaming. + +Currently streaming is a capability that not all providers do support and this as part of our design we try to ensure the services will have the proper abstractions to support streaming not only of text but be open to other types of data like images, audio, video, etc. + +Needs to be clear for the sk developer when he is attempting to get streaming data. + +## Decision Drivers + +1. The sk developer should be able to get streaming data from the Kernel and Functions using Kernel.RunAsync or ISKFunctions.InvokeAsync methods + +2. The sk developer should be able to get the data in a generic way, so the Kernel and Functions can be able to stream data of any type, not limited to text. + +3. The sk developer when using streaming from a model that does not support streaming should still be able to use it with only one streaming update representing the whole data. + +## Out of Scope + +- Streaming with plans will not be supported in this phase. Attempting to do so will throw an exception. +- Kernel streaming will not support multiple functions (pipeline). +- Input streaming will not be supported in this phase. +- Post Hook Skipping, Repeat and Cancelling of streaming functions are not supported. + +## Considered Options + +### Option 1 - Dedicated Streaming Interfaces + +Using dedicated streaming interfaces that allow the sk developer to get the streaming data in a generic way, including string, byte array directly from the connector as well as allowing the Kernel and Functions implementations to be able to stream data of any type, not limited to text. + +This approach also exposes dedicated interfaces in the kernel and functions to use streaming making it clear to the sk developer what is the type of data being returned in IAsyncEnumerable format. + +`ITextCompletion` and `IChatCompletion` will have new APIs to get `byte[]` and `string` streaming data directly as well as the specialized `StreamingContent` return. + +The sk developer will be able to specify a generic type to the `Kernel.RunStreamingAsync()` and `ISKFunction.InvokeStreamingAsync` to get the streaming data. If the type is not specified, the Kernel and Functions will return the data as StreamingContent. + +If the type is not specified or if the string representation cannot be cast, an exception will be thrown. + +If the type specified is `StreamingContent` or another any type supported by the connector no error will be thrown. + +## User Experience Goal + +```csharp +//(providing the type at as generic parameter) + +// Getting a Raw Streaming data from Kernel +await foreach(string update in kernel.RunStreamingAsync(function, variables)) + +// Getting a String as Streaming data from Kernel +await foreach(string update in kernel.RunStreamingAsync(function, variables)) + +// Getting a StreamingContent as Streaming data from Kernel +await foreach(StreamingContent update in kernel.RunStreamingAsync(variables, function)) +// OR +await foreach(StreamingContent update in kernel.RunStreamingAsync(function, variables)) // defaults to Generic above) +{ + Console.WriteLine(update); +} +``` + +Abstraction class for any stream content, connectors will be responsible to provide the specialized type of `StreamingContent` which will contain the data as well as any metadata related to the streaming result. + +```csharp + +public abstract class StreamingContent +{ + public abstract int ChoiceIndex { get; } + + /// Returns a string representation of the chunk content + public abstract override string ToString(); + + /// Abstract byte[] representation of the chunk content in a way it could be composed/appended with previous chunk contents. + /// Depending on the nature of the underlying type, this method may be more efficient than . + public abstract byte[] ToByteArray(); + + /// Internal chunk content object reference. (Breaking glass). + /// Each connector will have its own internal object representing the content chunk content. + /// The usage of this property is considered "unsafe". Use it only if strictly necessary. + public object? InnerContent { get; } + + /// The metadata associated with the content. + public Dictionary? Metadata { get; set; } + + /// The current context associated the function call. + internal SKContext? Context { get; set; } + + /// Inner content object reference + protected StreamingContent(object? innerContent) + { + this.InnerContent = innerContent; + } +} +``` + +Specialization example of a StreamingChatContent + +```csharp +// +public class StreamingChatContent : StreamingContent +{ + public override int ChoiceIndex { get; } + public FunctionCall? FunctionCall { get; } + public string? Content { get; } + public AuthorRole? Role { get; } + public string? Name { get; } + + public StreamingChatContent(AzureOpenAIChatMessage chatMessage, int resultIndex) : base(chatMessage) + { + this.ChoiceIndex = resultIndex; + this.FunctionCall = chatMessage.InnerChatMessage?.FunctionCall; + this.Content = chatMessage.Content; + this.Role = new AuthorRole(chatMessage.Role.ToString()); + this.Name = chatMessage.InnerChatMessage?.Name; + } + + public override byte[] ToByteArray() => Encoding.UTF8.GetBytes(this.ToString()); + public override string ToString() => this.Content ?? string.Empty; +} +``` + +`IChatCompletion` and `ITextCompletion` interfaces will have new APIs to get a generic streaming content data. + +```csharp +interface ITextCompletion + IChatCompletion +{ + IAsyncEnumerable GetStreamingContentAsync(...); + + // Throw exception if T is not supported +} + +interface IKernel +{ + // Get streaming function content of T + IAsyncEnumerable RunStreamingAsync(ContextVariables variables, ISKFunction function); +} + +interface ISKFunction +{ + // Get streaming function content of T + IAsyncEnumerable InvokeStreamingAsync(SKContext context); +} +``` + +## Prompt/Semantic Functions Behavior + +When Prompt Functions are invoked using the Streaming API, they will attempt to use the Connectors streaming implementation. +The connector will be responsible to provide the specialized type of `StreamingContent` and even if the underlying backend API don't support streaming the output will be one streamingcontent with the whole data. + +## Method/Native Functions Behavior + +Method Functions will support `StreamingContent` automatically with as a `StreamingMethodContent` wrapping the object returned in the iterator. + +```csharp +public sealed class StreamingMethodContent : StreamingContent +{ + public override int ChoiceIndex => 0; + + /// Method object value that represents the content chunk + public object Value { get; } + + /// Default implementation + public override byte[] ToByteArray() + { + if (this.Value is byte[]) + { + // If the method value is byte[] we return it directly + return (byte[])this.Value; + } + + // By default if a native value is not byte[] we output the UTF8 string representation of the value + return Encoding.UTF8.GetBytes(this.Value?.ToString()); + } + + /// + public override string ToString() + { + return this.Value.ToString(); + } + + /// + /// Initializes a new instance of the class. + /// + /// Underlying object that represents the chunk + public StreamingMethodContent(object innerContent) : base(innerContent) + { + this.Value = innerContent; + } +} +``` + +If a MethodFunction is returning an `IAsyncEnumerable` each enumerable result will be automatically wrapped in the `StreamingMethodContent` keeping the streaming behavior and the overall abstraction consistent. + +When a MethodFunction is not an `IAsyncEnumerable`, the complete result will be wrapped in a `StreamingMethodContent` and will be returned as a single item. + +## Pros + +1. All the User Experience Goal section options will be possible. +2. Kernel and Functions implementations will be able to stream data of any type, not limited to text +3. The sk developer will be able to provide the streaming content type it expects from the `GetStreamingContentAsync` method. +4. Sk developer will be able to get streaming from the Kernel, Functions and Connectors with the same result type. + +## Cons + +1. If the sk developer wants to use the specialized type of `StreamingContent` he will need to know what the connector is being used to use the correct **StreamingContent extension method** or to provide directly type in ``. +2. Connectors will have greater responsibility to support the correct special types of `StreamingContent`. + +### Option 2 - Dedicated Streaming Interfaces (Returning a Class) + +All changes from option 1 with the small difference below: + +- The Kernel and SKFunction streaming APIs interfaces will return `StreamingFunctionResult` which also implements `IAsyncEnumerable` +- Connectors streaming APIs interfaces will return `StreamingConnectorContent` which also implements `IAsyncEnumerable` + +The `StreamingConnectorContent` class is needed for connectors as one way to pass any information relative to the request and not the chunk that can be used by the functions to fill `StreamingFunctionResult` metadata. + +## User Experience Goal + +Option 2 Biggest benefit: + +```csharp +// When the caller needs to know more about the streaming he can get the result reference before starting the streaming. +var streamingResult = await kernel.RunStreamingAsync(function); +// Do something with streamingResult properties + +// Consuming the streamingResult requires an extra await: +await foreach(StreamingContent chunk content in await streamingResult) +``` + +Using the other operations will be quite similar (only needing an extra `await` to get the iterator) + +```csharp +// Getting a Raw Streaming data from Kernel +await foreach(string update in await kernel.RunStreamingAsync(function, variables)) + +// Getting a String as Streaming data from Kernel +await foreach(string update in await kernel.RunStreamingAsync(function, variables)) + +// Getting a StreamingContent as Streaming data from Kernel +await foreach(StreamingContent update in await kernel.RunStreamingAsync(variables, function)) +// OR +await foreach(StreamingContent update in await kernel.RunStreamingAsync(function, variables)) // defaults to Generic above) +{ + Console.WriteLine(update); +} + +``` + +StreamingConnectorResult is a class that can store information regarding the result before the stream is consumed as well as any underlying object (breaking glass) that the stream consumes at the connector level. + +```csharp + +public sealed class StreamingConnectorResult : IAsyncEnumerable +{ + private readonly IAsyncEnumerable _StreamingContentource; + + public object? InnerResult { get; private set; } = null; + + public StreamingConnectorResult(Func> streamingReference, object? innerConnectorResult) + { + this._StreamingContentource = streamingReference.Invoke(); + this.InnerResult = innerConnectorResult; + } +} + +interface ITextCompletion + IChatCompletion +{ + Task> GetStreamingContentAsync(); + // Throw exception if T is not supported + // Initially connectors +} +``` + +StreamingFunctionResult is a class that can store information regarding the result before the stream is consumed as well as any underlying object (breaking glass) that the stream consumes from Kernel and SKFunctions. + +```csharp +public sealed class StreamingFunctionResult : IAsyncEnumerable +{ + internal Dictionary? _metadata; + private readonly IAsyncEnumerable _streamingResult; + + public string FunctionName { get; internal set; } + public Dictionary Metadata { get; internal set; } + + /// + /// Internal object reference. (Breaking glass). + /// Each connector will have its own internal object representing the result. + /// + public object? InnerResult { get; private set; } = null; + + /// + /// Instance of used by the function. + /// + internal SKContext Context { get; private set; } + + public StreamingFunctionResult(string functionName, SKContext context, Func> streamingResult, object? innerFunctionResult) + { + this.FunctionName = functionName; + this.Context = context; + this._streamingResult = streamingResult.Invoke(); + this.InnerResult = innerFunctionResult; + } +} + +interface ISKFunction +{ + // Extension generic method to get from type + Task> InvokeStreamingAsync(...); +} + +static class KernelExtensions +{ + public static async Task> RunStreamingAsync(this Kernel kernel, ISKFunction skFunction, ContextVariables? variables, CancellationToken cancellationToken) + { + ... + } +} +``` + +## Pros + +1. All benefits from Option 1 + +2. Having StreamingFunctionResults allow sk developer to know more details about the result before consuming the stream, like: + - Any metadata provided by the underlying API, + - SKContext + - Function Name and Details +3. Experience using the Streaming is quite similar (need an extra await to get the result) to option 1 +4. APIs behave similarly to the non-streaming API (returning a result representation to get the value) + +## Cons + +1. All cons from Option 1 + +2. Added complexity as the IAsyncEnumerable cannot be passed directly in the method result demanding a delegate approach to be adapted inside of the Results that implements the IAsyncEnumerator. +3. Added complexity where IDisposable is needed to be implemented in the Results to dispose the response object and the caller would need to handle the disposal of the result. +4. As soon the caller gets a `StreamingFunctionResult` a network connection will be kept open until the caller implementation consume it (Enumerate over the `IAsyncEnumerable`). + +## Decision Outcome + +Option 1 was chosen as the best option as small benefit of the Option 2 don't justify the complexity involved described in the Cons. + +Was also decided that the Metadata related to a connector backend response can be added to the `StreamingContent.Metadata` property. This will allow the sk developer to get the metadata even without a `StreamingConnectorResult` or `StreamingFunctionResult`. diff --git a/dotnet/samples/KernelSyntaxExamples/Example16_CustomLLM.cs b/dotnet/samples/KernelSyntaxExamples/Example16_CustomLLM.cs index 494ca48cf6f8..6166eed5cd94 100644 --- a/dotnet/samples/KernelSyntaxExamples/Example16_CustomLLM.cs +++ b/dotnet/samples/KernelSyntaxExamples/Example16_CustomLLM.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Runtime.CompilerServices; +using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -47,6 +48,37 @@ public async IAsyncEnumerable GetStreamingCompletionsAsync { yield return new MyTextCompletionStreamingResult(); } + + public async IAsyncEnumerable GetStreamingContentAsync(string prompt, AIRequestSettings? requestSettings = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (typeof(T) == typeof(MyStreamingContent)) + { + yield return (T)(object)new MyStreamingContent("llm content update 1"); + yield return (T)(object)new MyStreamingContent("llm content update 2"); + } + } +} + +public class MyStreamingContent : StreamingContent +{ + public override int ChoiceIndex => 0; + + public string Content { get; } + + public MyStreamingContent(string content) : base(content) + { + this.Content = content; + } + + public override byte[] ToByteArray() + { + return Encoding.UTF8.GetBytes(this.Content); + } + + public override string ToString() + { + return this.Content; + } } public class MyTextCompletionStreamingResult : ITextStreamingResult, ITextResult diff --git a/dotnet/samples/KernelSyntaxExamples/Example72_KernelStreaming.cs b/dotnet/samples/KernelSyntaxExamples/Example72_KernelStreaming.cs new file mode 100644 index 000000000000..cfeba26793ee --- /dev/null +++ b/dotnet/samples/KernelSyntaxExamples/Example72_KernelStreaming.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AI.OpenAI; +using Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk; +using RepoUtils; + +#pragma warning disable RCS1110 // Declare type inside namespace. +#pragma warning disable CA1819 // Properties should not return arrays + +/** + * This example shows how to use multiple prompt template formats. + */ +// ReSharper disable once InconsistentNaming +public static class Example72_KernelStreaming +{ + /// + /// Show how to combine multiple prompt template factories. + /// + public static async Task RunAsync() + { + string apiKey = TestConfiguration.AzureOpenAI.ApiKey; + string chatDeploymentName = TestConfiguration.AzureOpenAI.ChatDeploymentName; + string endpoint = TestConfiguration.AzureOpenAI.Endpoint; + + if (apiKey == null || chatDeploymentName == null || endpoint == null) + { + Console.WriteLine("Azure endpoint, apiKey, or deploymentName not found. Skipping example."); + return; + } + + Kernel kernel = new KernelBuilder() + .WithLoggerFactory(ConsoleLogger.LoggerFactory) + .WithAzureOpenAIChatCompletionService( + deploymentName: chatDeploymentName, + endpoint: endpoint, + serviceId: "AzureOpenAIChat", + apiKey: apiKey) + .Build(); + + var funyParagraphFunction = kernel.CreateFunctionFromPrompt("Write a funny paragraph about streaming", new OpenAIRequestSettings() { MaxTokens = 100, Temperature = 0.4, TopP = 1 }); + + var roleDisplayed = false; + + Console.WriteLine("\n=== Semantic Function - Streaming ===\n"); + + // Streaming can be of any type depending on the underlying service the function is using. + await foreach (var update in kernel.RunStreamingAsync(funyParagraphFunction)) + { + // You will be always able to know the type of the update by checking the Type property. + if (!roleDisplayed && update.Role.HasValue) + { + Console.WriteLine($"Role: {update.Role}"); + roleDisplayed = true; + } + + if (update.Content is { Length: > 0 }) + { + Console.Write(update.Content); + } + }; + } +} diff --git a/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/HuggingFaceTextCompletion.cs b/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/HuggingFaceTextCompletion.cs index 1475058c534a..6ad24932f427 100644 --- a/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/HuggingFaceTextCompletion.cs +++ b/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/HuggingFaceTextCompletion.cs @@ -93,6 +93,36 @@ public async Task> GetCompletionsAsync( return await this.ExecuteGetCompletionsAsync(text, cancellationToken).ConfigureAwait(false); } + /// + public async IAsyncEnumerable GetStreamingContentAsync( + string prompt, + AIRequestSettings? requestSettings = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var result in await this.ExecuteGetCompletionsAsync(prompt, cancellationToken).ConfigureAwait(false)) + { + cancellationToken.ThrowIfCancellationRequested(); + // Gets the non streaming content and returns as one complete result + var content = await result.GetCompletionAsync(cancellationToken).ConfigureAwait(false); + + // If the provided T is a string, return the completion as is + if (typeof(T) == typeof(string)) + { + yield return (T)(object)content; + continue; + } + + // If the provided T is an specialized class of StreamingContent interface + if (typeof(T) == typeof(StreamingTextContent) || + typeof(T) == typeof(StreamingContent)) + { + yield return (T)(object)new StreamingTextContent(content, 1, result); + } + + throw new NotSupportedException($"Type {typeof(T)} is not supported"); + } + } + #region private ================================================================================ private async Task> ExecuteGetCompletionsAsync(string text, CancellationToken cancellationToken = default) diff --git a/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/StreamingTextContent.cs b/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/StreamingTextContent.cs new file mode 100644 index 000000000000..3f1a1c516dcc --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/StreamingTextContent.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text; +using Microsoft.SemanticKernel.AI; + +namespace Microsoft.SemanticKernel.Connectors.AI.HuggingFace.TextCompletion; + +/// +/// StreamResponse class in +/// +public class StreamingTextContent : StreamingContent +{ + /// + public override int ChoiceIndex { get; } + + /// + /// Text associated to the update + /// + public string Content { get; } + + /// + /// Create a new instance of the class. + /// + /// Text update + /// Index of the choice + /// Inner chunk object + /// Metadata information + public StreamingTextContent(string text, int resultIndex, object? innerContentObject = null, Dictionary? metadata = null) : base(innerContentObject, metadata) + { + this.ChoiceIndex = resultIndex; + this.Content = text; + } + + /// + public override byte[] ToByteArray() + { + return Encoding.UTF8.GetBytes(this.ToString()); + } + + /// + public override string ToString() + { + return this.Content; + } +} diff --git a/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/TextCompletionRequest.cs b/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/TextCompletionRequest.cs index 3f2df2952969..b9d309302e59 100644 --- a/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/TextCompletionRequest.cs +++ b/dotnet/src/Connectors/Connectors.AI.HuggingFace/TextCompletion/TextCompletionRequest.cs @@ -14,4 +14,10 @@ public sealed class TextCompletionRequest /// [JsonPropertyName("inputs")] public string Input { get; set; } = string.Empty; + + /// + /// Enable streaming + /// + [JsonPropertyName("stream")] + public bool Stream { get; set; } = false; } diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/ClientBase.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/ClientBase.cs index 049e2bacf93c..adc2a7c81577 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/ClientBase.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/ClientBase.cs @@ -16,6 +16,7 @@ using Microsoft.SemanticKernel.AI.TextCompletion; using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion; using Microsoft.SemanticKernel.Prompt; +using ChatMessage = Azure.AI.OpenAI.ChatMessage; namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk; @@ -153,6 +154,69 @@ private protected async IAsyncEnumerable InternalGetTextStr } } + private protected async IAsyncEnumerable InternalGetTextStreamingUpdatesAsync( + string prompt, + AIRequestSettings? requestSettings, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + OpenAIRequestSettings textRequestSettings = OpenAIRequestSettings.FromRequestSettings(requestSettings, OpenAIRequestSettings.DefaultTextMaxTokens); + + ValidateMaxTokens(textRequestSettings.MaxTokens); + + var options = CreateCompletionsOptions(prompt, textRequestSettings); + + Response? response = await RunRequestAsync>( + () => this.Client.GetCompletionsStreamingAsync(this.DeploymentOrModelName, options, cancellationToken)).ConfigureAwait(false); + + using StreamingCompletions streamingChatCompletions = response.Value; + var responseMetadata = GetResponseMetadata(streamingChatCompletions); + + int choiceIndex = 0; + await foreach (StreamingChoice choice in streamingChatCompletions.GetChoicesStreaming(cancellationToken).ConfigureAwait(false)) + { + await foreach (string update in choice.GetTextStreaming(cancellationToken).ConfigureAwait(false)) + { + // If the provided T is a string, return the completion as is + if (typeof(T) == typeof(string)) + { + yield return (T)(object)update; + continue; + } + + // If the provided T is an specialized class of StreamingContent interface + if (typeof(T) == typeof(StreamingTextContent) || + typeof(T) == typeof(StreamingContent)) + { + yield return (T)(object)new StreamingTextContent(update, choiceIndex, update, responseMetadata); + continue; + } + + throw new NotSupportedException($"Type {typeof(T)} is not supported"); + } + choiceIndex++; + } + } + + private static Dictionary GetResponseMetadata(StreamingCompletions streamingChatCompletions) + { + return new Dictionary() + { + { $"{nameof(StreamingCompletions)}.{nameof(streamingChatCompletions.Id)}", streamingChatCompletions.Id }, + { $"{nameof(StreamingCompletions)}.{nameof(streamingChatCompletions.Created)}", streamingChatCompletions.Created }, + { $"{nameof(StreamingCompletions)}.{nameof(streamingChatCompletions.PromptFilterResults)}", streamingChatCompletions.PromptFilterResults }, + }; + } + + private static Dictionary GetResponseMetadata(StreamingChatCompletions streamingChatCompletions) + { + return new Dictionary() + { + { $"{nameof(StreamingChatCompletions)}.{nameof(streamingChatCompletions.Id)}", streamingChatCompletions.Id }, + { $"{nameof(StreamingChatCompletions)}.{nameof(streamingChatCompletions.Created)}", streamingChatCompletions.Created }, + { $"{nameof(StreamingChatCompletions)}.{nameof(streamingChatCompletions.PromptFilterResults)}", streamingChatCompletions.PromptFilterResults }, + }; + } + /// /// Generates an embedding from the given . /// @@ -262,6 +326,55 @@ private protected async IAsyncEnumerable InternalGetChatSt } } + private protected async IAsyncEnumerable InternalGetChatStreamingUpdatesAsync( + IEnumerable chat, + AIRequestSettings? requestSettings, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(chat); + + OpenAIRequestSettings chatRequestSettings = OpenAIRequestSettings.FromRequestSettings(requestSettings); + + ValidateMaxTokens(chatRequestSettings.MaxTokens); + + var options = CreateChatCompletionsOptions(chatRequestSettings, chat); + + Response? response = await RunRequestAsync>( + () => this.Client.GetChatCompletionsStreamingAsync(this.DeploymentOrModelName, options, cancellationToken)).ConfigureAwait(false); + + if (response is null) + { + throw new SKException("Chat completions null response"); + } + + using StreamingChatCompletions streamingChatCompletions = response.Value; + var responseMetadata = GetResponseMetadata(streamingChatCompletions); + + int choiceIndex = 0; + await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming(cancellationToken).ConfigureAwait(false)) + { + await foreach (ChatMessage chatMessage in choice.GetMessageStreaming(cancellationToken).ConfigureAwait(false)) + { + if (typeof(T) == typeof(string)) + { + yield return (T)(object)chatMessage.Content; + continue; + } + + // If the provided T is an specialized class of StreamingResultChunk interface + if (typeof(T) == typeof(StreamingChatContent) || + typeof(T) == typeof(StreamingContent)) + { + yield return (T)(object)new StreamingChatContent(chatMessage, choiceIndex, responseMetadata); + continue; + } + + throw new NotSupportedException($"Type {typeof(T)} is not supported"); + } + choiceIndex++; + } + } + /// /// Create a new empty chat instance /// diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingChatContent.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingChatContent.cs new file mode 100644 index 000000000000..017243f6fd55 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingChatContent.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text; +using Azure.AI.OpenAI; +using Microsoft.SemanticKernel.AI; +using Microsoft.SemanticKernel.AI.ChatCompletion; + +namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk; + +/// +/// Streaming chat result update. +/// +public class StreamingChatContent : StreamingContent +{ + /// + public override int ChoiceIndex { get; } + + /// + /// Function call associated to the message payload + /// + public FunctionCall? FunctionCall { get; } + + /// + /// Text associated to the message payload + /// + public string? Content { get; } + + /// + /// Role of the author of the message + /// + public AuthorRole? Role { get; } + + /// + /// Name of the author of the message. Name is required if the role is 'function'. + /// + public string? Name { get; } + + /// + /// Create a new instance of the class. + /// + /// Internal Azure SDK Message update representation + /// Index of the choice + /// Additional metadata + public StreamingChatContent(Azure.AI.OpenAI.ChatMessage chatMessage, int resultIndex, Dictionary metadata) : base(chatMessage, metadata) + { + this.ChoiceIndex = resultIndex; + this.FunctionCall = chatMessage.FunctionCall; + this.Content = chatMessage.Content; + this.Role = new AuthorRole(chatMessage.Role.ToString()); + this.Name = chatMessage.FunctionCall?.Name; + } + + /// + public override byte[] ToByteArray() => Encoding.UTF8.GetBytes(this.ToString()); + + /// + public override string ToString() => this.Content ?? string.Empty; +} diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingChatWithDataContent.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingChatWithDataContent.cs new file mode 100644 index 000000000000..26c3e55c4569 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingChatWithDataContent.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json; +using Microsoft.SemanticKernel.AI; +using Microsoft.SemanticKernel.AI.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletionWithData; + +namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk; + +/// +/// Streaming chat result update. +/// +public sealed class StreamingChatWithDataContent : StreamingContent +{ + /// + public override int ChoiceIndex { get; } + + /// + /// Chat message abstraction + /// + public ChatMessage ChatMessage { get; } + + /// + /// Create a new instance of the class. + /// + /// Azure message update representation from WithData apis + /// Index of the choice + /// Additional metadata + internal StreamingChatWithDataContent(ChatWithDataStreamingChoice choice, int resultIndex, Dictionary metadata) : base(choice, metadata) + { + this.ChoiceIndex = resultIndex; + var message = choice.Messages.FirstOrDefault(this.IsValidMessage); + + this.ChatMessage = new AzureOpenAIChatMessage(AuthorRole.Assistant.Label, message?.Delta?.Content ?? string.Empty); + } + + /// + public override byte[] ToByteArray() + { + return Encoding.UTF8.GetBytes(this.ToString()); + } + + /// + public override string ToString() + { + return JsonSerializer.Serialize(this); + } + + private bool IsValidMessage(ChatWithDataStreamingMessage message) + { + return !message.EndTurn && + (message.Delta.Role is null || !message.Delta.Role.Equals(AuthorRole.Tool.Label, StringComparison.Ordinal)); + } +} diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingTextContent.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingTextContent.cs new file mode 100644 index 000000000000..a7497d131b7f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/StreamingTextContent.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text; +using Microsoft.SemanticKernel.AI; + +namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk; + +/// +/// Streaming text result update. +/// +public class StreamingTextContent : StreamingContent +{ + /// + public override int ChoiceIndex { get; } + + /// + /// Text associated to the update + /// + public string Content { get; } + + /// + /// Create a new instance of the class. + /// + /// Text update + /// Index of the choice + /// Inner chunk object + /// Metadata information + public StreamingTextContent(string text, int resultIndex, object? innerContentObject = null, Dictionary? metadata = null) : base(innerContentObject, metadata) + { + this.ChoiceIndex = resultIndex; + this.Content = text; + } + + /// + public override byte[] ToByteArray() + { + return Encoding.UTF8.GetBytes(this.ToString()); + } + + /// + public override string ToString() + { + return this.Content; + } +} diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletion/AzureOpenAIChatCompletion.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletion/AzureOpenAIChatCompletion.cs index c40a039c965c..81d81af4d577 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletion/AzureOpenAIChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletion/AzureOpenAIChatCompletion.cs @@ -125,4 +125,11 @@ public Task> GetCompletionsAsync( this.LogActionDetails(); return this.InternalGetChatResultsAsTextAsync(text, requestSettings, cancellationToken); } + + /// + public IAsyncEnumerable GetStreamingContentAsync(string prompt, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) + { + var chatHistory = this.CreateNewChat(prompt); + return this.InternalGetChatStreamingUpdatesAsync(chatHistory, requestSettings, cancellationToken); + } } diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletion/OpenAIChatCompletion.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletion/OpenAIChatCompletion.cs index 75dd523ddf95..4bc81e8dfe9e 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletion/OpenAIChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletion/OpenAIChatCompletion.cs @@ -101,4 +101,11 @@ public Task> GetCompletionsAsync( this.LogActionDetails(); return this.InternalGetChatResultsAsTextAsync(text, requestSettings, cancellationToken); } + + /// + public IAsyncEnumerable GetStreamingContentAsync(string prompt, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) + { + var chatHistory = this.CreateNewChat(prompt); + return this.InternalGetChatStreamingUpdatesAsync(chatHistory, requestSettings, cancellationToken); + } } diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletionWithData/AzureOpenAIChatCompletionWithData.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletionWithData/AzureOpenAIChatCompletionWithData.cs index 97ca9b3c0b6f..30a5a17bfb13 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletionWithData/AzureOpenAIChatCompletionWithData.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletionWithData/AzureOpenAIChatCompletionWithData.cs @@ -15,6 +15,7 @@ using Microsoft.SemanticKernel.AI; using Microsoft.SemanticKernel.AI.ChatCompletion; using Microsoft.SemanticKernel.AI.TextCompletion; +using Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk; using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion; using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Services; @@ -119,6 +120,25 @@ public async IAsyncEnumerable GetStreamingCompletionsAsync } } + /// + public async IAsyncEnumerable GetStreamingContentAsync( + string prompt, + AIRequestSettings? requestSettings = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + OpenAIRequestSettings chatRequestSettings = OpenAIRequestSettings.FromRequestSettings(requestSettings); + + var chat = this.PrepareChatHistory(prompt, chatRequestSettings); + + using var request = this.GetRequest(chat, chatRequestSettings, isStreamEnabled: true); + using var response = await this.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); + + await foreach (var result in this.GetChatStreamingUpdatesAsync(response)) + { + yield return result; + } + } + #region private ================================================================================ private const string DefaultApiVersion = "2023-06-01-preview"; @@ -227,6 +247,66 @@ private async IAsyncEnumerable GetStreamingResultsAsync(Ht } } + private async IAsyncEnumerable GetChatStreamingUpdatesAsync(HttpResponseMessage response) + { + const string ServerEventPayloadPrefix = "data:"; + + using var stream = await response.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); + using var reader = new StreamReader(stream); + + while (!reader.EndOfStream) + { + var body = await reader.ReadLineAsync().ConfigureAwait(false); + + if (string.IsNullOrWhiteSpace(body)) + { + continue; + } + + if (body.StartsWith(ServerEventPayloadPrefix, StringComparison.Ordinal)) + { + body = body.Substring(ServerEventPayloadPrefix.Length); + } + + var chatWithDataResponse = this.DeserializeResponse(body); + var responseMetadata = this.GetResponseMetadata(response); + foreach (var choice in chatWithDataResponse.Choices) + { + // If the provided T is an specialized class of StreamingContent interface + if (typeof(T) == typeof(StreamingChatContent) || + typeof(T) == typeof(StreamingContent)) + { + yield return (T)(object)new StreamingChatWithDataContent(choice, choice.Index, responseMetadata); + continue; + } + + var result = new ChatWithDataStreamingResult(chatWithDataResponse, choice); + if (typeof(T) == typeof(string)) + { + await foreach (SemanticKernel.AI.ChatCompletion.ChatMessage message in result.GetStreamingChatMessageAsync().ConfigureAwait(false)) + { + yield return (T)(object)message.Content; + } + } + + if (typeof(T) == typeof(ChatWithDataStreamingResult)) + { + yield return (T)(object)result; + } + + throw new NotSupportedException($"Type {typeof(T)} is not supported"); + } + } + } + + private Dictionary GetResponseMetadata(HttpResponseMessage response) + { + return new Dictionary() + { + { nameof(HttpResponseMessage), response }, + }; + } + private T DeserializeResponse(string body) { var response = JsonSerializer.Deserialize(body, JsonOptionsCache.ReadPermissive); diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletionWithData/ChatWithDataStreamingChoice.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletionWithData/ChatWithDataStreamingChoice.cs index af31ce8ba610..ff4ee1f6364b 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletionWithData/ChatWithDataStreamingChoice.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/ChatCompletionWithData/ChatWithDataStreamingChoice.cs @@ -12,4 +12,7 @@ internal sealed class ChatWithDataStreamingChoice { [JsonPropertyName("messages")] public IList Messages { get; set; } = Array.Empty(); + + [JsonPropertyName("index")] + public int Index { get; set; } = 0; } diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/TextCompletion/AzureTextCompletion.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/TextCompletion/AzureTextCompletion.cs index 99e1faded63d..ad67d76f31fe 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/TextCompletion/AzureTextCompletion.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/TextCompletion/AzureTextCompletion.cs @@ -20,6 +20,9 @@ namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextCompletion; /// public sealed class AzureTextCompletion : AzureOpenAIClientBase, ITextCompletion { + /// + public IReadOnlyDictionary Attributes => this.InternalAttributes; + /// /// Creates a new AzureTextCompletion client instance using API Key auth /// @@ -76,9 +79,6 @@ public AzureTextCompletion( this.AddAttribute(IAIServiceExtensions.ModelIdKey, modelId); } - /// - public IReadOnlyDictionary Attributes => this.InternalAttributes; - /// public IAsyncEnumerable GetStreamingCompletionsAsync( string text, @@ -98,4 +98,10 @@ public Task> GetCompletionsAsync( this.LogActionDetails(); return this.InternalGetTextResultsAsync(text, requestSettings, cancellationToken); } + + /// + public IAsyncEnumerable GetStreamingContentAsync(string prompt, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) + { + return this.InternalGetTextStreamingUpdatesAsync(prompt, requestSettings, cancellationToken); + } } diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/TextCompletion/OpenAITextCompletion.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/TextCompletion/OpenAITextCompletion.cs index 53e5760ed38c..0b44cb52d708 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/TextCompletion/OpenAITextCompletion.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/TextCompletion/OpenAITextCompletion.cs @@ -60,4 +60,10 @@ public Task> GetCompletionsAsync( this.LogActionDetails(); return this.InternalGetTextResultsAsync(text, requestSettings, cancellationToken); } + + /// + public IAsyncEnumerable GetStreamingContentAsync(string prompt, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) + { + return this.InternalGetTextStreamingUpdatesAsync(prompt, requestSettings, cancellationToken); + } } diff --git a/dotnet/src/IntegrationTests/Extensions/KernelSemanticFunctionExtensionsTests.cs b/dotnet/src/IntegrationTests/Extensions/KernelSemanticFunctionExtensionsTests.cs index 2e2a8003bca3..6ebccaf76888 100644 --- a/dotnet/src/IntegrationTests/Extensions/KernelSemanticFunctionExtensionsTests.cs +++ b/dotnet/src/IntegrationTests/Extensions/KernelSemanticFunctionExtensionsTests.cs @@ -82,6 +82,11 @@ IAsyncEnumerable ITextCompletion.GetStreamingCompletionsAs { throw new NotImplementedException(); // TODO } + + public IAsyncEnumerable GetStreamingContentAsync(string prompt, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } } internal sealed class RedirectTextCompletionResult : ITextResult diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionExtensions.cs index 94dcdd9c39d9..1b7d3b6cda02 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionExtensions.cs @@ -59,4 +59,19 @@ public static async Task GenerateMessageAsync( var firstChatMessage = await chatResults[0].GetChatMessageAsync(cancellationToken).ConfigureAwait(false); return firstChatMessage.Content; } + + /// + /// Get asynchronous stream of . + /// + /// Chat completion target + /// The input string. (May be a JSON for complex objects, Byte64 for binary, will depend on the connector spec). + /// Request settings for the completion API + /// The to monitor for cancellation requests. The default is . + /// Streaming list of different completion streaming result updates generated by the remote model + public static IAsyncEnumerable GetStreamingContentAsync( + this IChatCompletion chatCompletion, + string input, + AIRequestSettings? requestSettings = null, + CancellationToken cancellationToken = default) + => chatCompletion.GetStreamingContentAsync(input, requestSettings, cancellationToken); } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/IChatCompletion.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/IChatCompletion.cs index 6d713703c93d..817d04dd8a25 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/IChatCompletion.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/IChatCompletion.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -42,4 +43,21 @@ IAsyncEnumerable GetStreamingChatCompletionsAsync( ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default); + + /// + /// Get streaming results for the prompt using the specified request settings. + /// Each modality may support for different types of streaming result. + /// + /// + /// Usage of this method may be more efficient if the connector has a dedicated API to return this result without extra allocations for StreamingResultChunk abstraction. + /// + /// Throws if the specified type is not the same or fail to cast + /// The prompt to complete. + /// Request settings for the completion API + /// The to monitor for cancellation requests. The default is . + /// Streaming list of different completion streaming string updates generated by the remote model + IAsyncEnumerable GetStreamingContentAsync( + string prompt, + AIRequestSettings? requestSettings = null, + CancellationToken cancellationToken = default); } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/StreamingContent.cs b/dotnet/src/SemanticKernel.Abstractions/AI/StreamingContent.cs new file mode 100644 index 000000000000..7c3b1d2e3733 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/StreamingContent.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.AI; + +/// +/// Represents a single update to a streaming content. +/// +public abstract class StreamingContent +{ + /// + /// In a scenario of multiple choices per request, this represents zero-based index of the choice in the streaming sequence + /// + public abstract int ChoiceIndex { get; } + + /// + /// Internal chunk object reference. (Breaking glass). + /// Each connector will have its own internal object representing the content chunk. + /// + /// + /// The usage of this property is considered "unsafe". Use it only if strictly necessary. + /// + public object? InnerContent { get; } + + /// + /// The metadata associated with the content. + /// + public Dictionary? Metadata { get; } + + /// + /// Abstract string representation of the chunk in a way it could compose/append with previous chunks. + /// + /// + /// Depending on the nature of the underlying type, this method may be more efficient than . + /// + /// String representation of the chunk + public abstract override string ToString(); + + /// + /// Abstract byte[] representation of the chunk in a way it could be composed/appended with previous chunks. + /// + /// + /// Depending on the nature of the underlying type, this method may be more efficient than . + /// + /// Byte array representation of the chunk + public abstract byte[] ToByteArray(); + + /// + /// Initializes a new instance of the class. + /// + /// Inner content object reference + /// + protected StreamingContent(object? innerContent, Dictionary? metadata = null) + { + this.InnerContent = innerContent; + this.Metadata = metadata ?? new(); + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/ITextCompletion.cs b/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/ITextCompletion.cs index 42d8f295ef65..40ee7bdd519d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/ITextCompletion.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/ITextCompletion.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -30,9 +31,26 @@ Task> GetCompletionsAsync( /// The prompt to complete. /// Request settings for the completion API /// The to monitor for cancellation requests. The default is . - /// List of different completion streaming results generated by the remote model + /// Streaming list of different completion streaming results generated by the remote model IAsyncEnumerable GetStreamingCompletionsAsync( string text, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default); + + /// + /// Get streaming results for the prompt using the specified request settings. + /// Each modality may support for different types of streaming contents. + /// + /// + /// Usage of this method with value types may be more efficient if the connector supports it. + /// + /// Throws if the specified type is not the same or fail to cast + /// The prompt to complete. + /// Request settings for the completion API + /// The to monitor for cancellation requests. The default is . + /// Streaming list of different completion streaming string updates generated by the remote model + IAsyncEnumerable GetStreamingContentAsync( + string prompt, + AIRequestSettings? requestSettings = null, + CancellationToken cancellationToken = default); } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/TextCompletionExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/TextCompletionExtensions.cs index 4350f709d287..54e713614506 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/TextCompletionExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/TextCompletion/TextCompletionExtensions.cs @@ -82,4 +82,19 @@ public static async IAsyncEnumerable CompleteStreamsAsync(this ITextComp } } } + + /// + /// Get streaming completion results for the prompt and settings. + /// + /// Target text completion + /// The prompt to complete. + /// Request settings for the completion API + /// The to monitor for cancellation requests. The default is . + /// Streaming list of different completion streaming result updates generated by the remote model + public static IAsyncEnumerable GetStreamingContentAsync( + this ITextCompletion textCompletion, + string input, + AIRequestSettings? requestSettings = null, + CancellationToken cancellationToken = default) + => textCompletion.GetStreamingContentAsync(input, requestSettings, cancellationToken); } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs index 7e16379376ef..b172370844a1 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Diagnostics.Metrics; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -97,9 +98,9 @@ public async Task InvokeAsync( { var result = await this.InvokeCoreAsync(kernel, context, requestSettings, cancellationToken).ConfigureAwait(false); - if (logger.IsEnabled(LogLevel.Trace)) + if (logger.IsEnabled(LogLevel.Information)) { - logger.LogTrace("Function succeeded. Result: {Result}", result.GetValue()); // Sensitive data, logging as trace, disabled by default + logger.LogTrace("Function succeeded."); } return result; @@ -124,6 +125,48 @@ public async Task InvokeAsync( } } + /// + /// Invoke the in streaming mode. + /// + /// The kernel + /// SK context + /// LLM completion settings (for semantic functions only) + /// The to monitor for cancellation requests. The default is . + /// A asynchronous list of streaming content chunks + public async IAsyncEnumerable InvokeStreamingAsync( + Kernel kernel, + SKContext context, + AIRequestSettings? requestSettings = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using var activity = s_activitySource.StartActivity(this.Name); + ILogger logger = kernel.LoggerFactory.CreateLogger(this.Name); + + logger.LogInformation("Function streaming invoking."); + + cancellationToken.ThrowIfCancellationRequested(); + + await foreach (var genericChunk in this.InvokeCoreStreamingAsync(kernel, context, requestSettings, cancellationToken)) + { + yield return genericChunk; + } + + // Completion logging is not supported for streaming functions + } + + /// + /// Invoke as streaming the . + /// + /// The kernel. + /// SK context + /// LLM completion settings (for semantic functions only) + /// The updated context, potentially a new one if context switching is implemented. + /// The to monitor for cancellation requests. The default is . + protected abstract IAsyncEnumerable InvokeCoreStreamingAsync(Kernel kernel, + SKContext context, + AIRequestSettings? requestSettings = null, + CancellationToken cancellationToken = default); + /// /// Invoke the . /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionExtensions.cs new file mode 100644 index 000000000000..3a6512d85d89 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionExtensions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using Microsoft.SemanticKernel.AI; +using Microsoft.SemanticKernel.Orchestration; + +namespace Microsoft.SemanticKernel.Functions; + +/// +/// Kernel function extensions class. +/// +public static class KernelFunctionExtensions +{ + /// + /// Invoke the in streaming mode. + /// + /// Target function + /// The kernel + /// SK context + /// LLM completion settings (for semantic functions only) + /// The to monitor for cancellation requests. The default is . + /// A asynchronous list of streaming result chunks + public static IAsyncEnumerable InvokeStreamingAsync( + this KernelFunction function, + Kernel kernel, + SKContext? context = null, + AIRequestSettings? requestSettings = null, + CancellationToken cancellationToken = default) + { + return function.InvokeStreamingAsync(kernel, context ?? kernel.CreateNewContext(), requestSettings, cancellationToken); + } +} diff --git a/dotnet/src/SemanticKernel.Core/Functions/SKFunctionFromMethod.cs b/dotnet/src/SemanticKernel.Core/Functions/SKFunctionFromMethod.cs index 99fcf123cbdd..44278ec07d2a 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/SKFunctionFromMethod.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/SKFunctionFromMethod.cs @@ -9,6 +9,7 @@ using System.Globalization; using System.Linq; using System.Reflection; +using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; using System.Text.Json; using System.Text.RegularExpressions; @@ -147,6 +148,37 @@ protected override async Task InvokeCoreAsync( } } + /// + protected override async IAsyncEnumerable InvokeCoreStreamingAsync( + Kernel kernel, + SKContext context, + AIRequestSettings? requestSettings = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // We don't invoke the hook here as the InvokeCoreAsync will do that for us + var functionResult = await this.InvokeCoreAsync(kernel, context, requestSettings, cancellationToken).ConfigureAwait(false); + if (functionResult.Value is T) + { + yield return (T)functionResult.Value; + yield break; + } + + // Supports the following provided T types for Method streaming + if (typeof(T) == typeof(StreamingContent) || + typeof(T) == typeof(StreamingMethodContent)) + { + if (functionResult.Value is not null) + { + yield return (T)(object)new StreamingMethodContent(functionResult.Value); + } + yield break; + } + + throw new NotSupportedException($"Streaming function {this.Name} does not support type {typeof(T)}"); + + // We don't invoke the hook here as the InvokeCoreAsync will do that for us + } + private FunctionInvokingEventArgs CallFunctionInvoking(Kernel kernel, SKContext context) { var eventArgs = new FunctionInvokingEventArgs(this.GetMetadata(), context); @@ -154,8 +186,9 @@ private FunctionInvokingEventArgs CallFunctionInvoking(Kernel kernel, SKContext return eventArgs; } - private (FunctionInvokedEventArgs, FunctionResult) CallFunctionInvoked(Kernel kernel, SKContext context, FunctionResult result) + private (FunctionInvokedEventArgs, FunctionResult) CallFunctionInvoked(Kernel kernel, SKContext context, FunctionResult? result = null) { + result ??= new FunctionResult(this.Name, context); var eventArgs = new FunctionInvokedEventArgs(this.GetMetadata(), result); if (kernel.OnFunctionInvoked(eventArgs)) { @@ -307,7 +340,7 @@ private static bool IsAsyncMethod(MethodInfo method) if (t.IsGenericType) { t = t.GetGenericTypeDefinition(); - if (t == typeof(Task<>) || t == typeof(ValueTask<>)) + if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) { return true; } diff --git a/dotnet/src/SemanticKernel.Core/Functions/SKFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/SKFunctionFromPrompt.cs index 4726c0e52d8e..405477937282 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/SKFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/SKFunctionFromPrompt.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -138,19 +139,16 @@ protected override async Task InvokeCoreAsync( AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) { - this.AddDefaultValues(context.Variables); - try { - string renderedPrompt = await this._promptTemplate.RenderAsync(kernel, context, cancellationToken).ConfigureAwait(false); - - var serviceSelector = kernel.ServiceSelector; - (var textCompletion, var defaultRequestSettings) = serviceSelector.SelectAIService(kernel, context, this); - Verify.NotNull(textCompletion); - - var invokingEventArgs = this.CallFunctionInvoking(kernel, context, renderedPrompt); + var (invokingEventArgs, renderedPrompt, textCompletion, defaultRequestSettings) = await this.PrepareInvokeAsync(kernel, context, requestSettings, cancellationToken).ConfigureAwait(false); if (invokingEventArgs.IsSkipRequested || invokingEventArgs.CancelToken.IsCancellationRequested) { + if (this._logger.IsEnabled(LogLevel.Trace)) + { + this._logger.LogTrace("Function {Name} canceled or skipped prior to invocation.", this.Name); + } + return new FunctionResult(this.Name, context) { IsCancellationRequested = invokingEventArgs.CancelToken.IsCancellationRequested, @@ -158,8 +156,6 @@ protected override async Task InvokeCoreAsync( }; } - renderedPrompt = this.GetPromptFromEventArgsMetadataOrDefault(invokingEventArgs, renderedPrompt); - IReadOnlyList completionResults = await textCompletion.GetCompletionsAsync(renderedPrompt, requestSettings ?? defaultRequestSettings, cancellationToken).ConfigureAwait(false); string completion = await GetCompletionsResultContentAsync(completionResults, cancellationToken).ConfigureAwait(false); @@ -171,7 +167,6 @@ protected override async Task InvokeCoreAsync( var result = new FunctionResult(this.Name, context, completion); result.Metadata.Add(AIFunctionResultExtensions.ModelResultsMetadataKey, modelResults); - result.Metadata.Add(SKEventArgsExtensions.RenderedPromptMetadataKey, renderedPrompt); (var invokedEventArgs, result) = this.CallFunctionInvoked(kernel, context, result, renderedPrompt); result.IsCancellationRequested = invokedEventArgs.CancelToken.IsCancellationRequested; @@ -181,11 +176,33 @@ protected override async Task InvokeCoreAsync( } catch (Exception ex) when (!ex.IsCriticalException()) { - this._logger?.LogError(ex, "Semantic function {Name} execution failed with error {Error}", this.Name, ex.Message); + this._logger?.LogError(ex, "Prompt function {Name} execution failed with error {Error}", this.Name, ex.Message); throw; } } + protected override async IAsyncEnumerable InvokeCoreStreamingAsync( + Kernel kernel, + SKContext context, + AIRequestSettings? requestSettings = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var (invokingEventArgs, renderedPrompt, textCompletion, defaultRequestSettings) = await this.PrepareInvokeAsync(kernel, context, requestSettings, cancellationToken).ConfigureAwait(false); + if (invokingEventArgs.IsSkipRequested || invokingEventArgs.CancelToken.IsCancellationRequested) + { + yield break; + } + + await foreach (T genericChunk in textCompletion.GetStreamingContentAsync(renderedPrompt, requestSettings ?? defaultRequestSettings, cancellationToken)) + { + cancellationToken.ThrowIfCancellationRequested(); + yield return genericChunk; + } + + // Invoked is not supported for streaming + // There is no post cancellation check to override the result as the stream data was already sent. + } + /// /// JSON serialized string representation of the function. /// @@ -257,8 +274,9 @@ private FunctionInvokingEventArgs CallFunctionInvoking(Kernel kernel, SKContext /// Execution context /// Current function result /// Prompt used by the function - private (FunctionInvokedEventArgs, FunctionResult) CallFunctionInvoked(Kernel kernel, SKContext context, FunctionResult result, string prompt) + private (FunctionInvokedEventArgs, FunctionResult) CallFunctionInvoked(Kernel kernel, SKContext context, FunctionResult? result, string prompt) { + result ??= new FunctionResult(this.Name, context); result.Metadata[SKEventArgsExtensions.RenderedPromptMetadataKey] = prompt; var eventArgs = new FunctionInvokedEventArgs(this.GetMetadata(), result); @@ -276,24 +294,26 @@ private FunctionInvokingEventArgs CallFunctionInvoking(Kernel kernel, SKContext return (eventArgs, result); } - /// - /// Try to get the prompt from the event args metadata. - /// - /// Function invoking event args - /// Default prompt if none is found in metadata - /// - private string GetPromptFromEventArgsMetadataOrDefault(FunctionInvokingEventArgs eventArgs, string defaultPrompt) + /// Create a random, valid function name. + private static string RandomFunctionName() => $"func{Guid.NewGuid():N}"; + + private async Task<(FunctionInvokingEventArgs InvokingEventArgs, string RenderedPrompt, ITextCompletion TextCompletion, AIRequestSettings? DefaultRequestSettings)> PrepareInvokeAsync( + Kernel kernel, + SKContext context, + AIRequestSettings? requestSettings, + CancellationToken cancellationToken) { - if (!eventArgs.Metadata.TryGetValue(SKEventArgsExtensions.RenderedPromptMetadataKey, out var renderedPromptFromMetadata)) - { - return defaultPrompt; - } + this.AddDefaultValues(context.Variables); + string renderedPrompt = await this._promptTemplate.RenderAsync(kernel, context, cancellationToken).ConfigureAwait(false); + + var serviceSelector = kernel.ServiceSelector; + (var textCompletion, var defaultRequestSettings) = serviceSelector.SelectAIService(kernel, context, this); + Verify.NotNull(textCompletion); - // If prompt key exists and was modified to null default to an empty string - return renderedPromptFromMetadata?.ToString() ?? string.Empty; + var invokingEventArgs = this.CallFunctionInvoking(kernel, context, renderedPrompt); + + return (invokingEventArgs, renderedPrompt, textCompletion, defaultRequestSettings); } - /// Create a random, valid function name. - private static string RandomFunctionName() => $"func{Guid.NewGuid():N}"; #endregion } diff --git a/dotnet/src/SemanticKernel.Core/Functions/StreamingMethodContent.cs b/dotnet/src/SemanticKernel.Core/Functions/StreamingMethodContent.cs new file mode 100644 index 000000000000..e9425fe3fb7a --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Functions/StreamingMethodContent.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.SemanticKernel.AI; + +#pragma warning disable IDE0130 +// ReSharper disable once CheckNamespace - Using the main namespace +namespace Microsoft.SemanticKernel; +#pragma warning restore IDE0130 + +/// +/// This class represents the content of a streaming chunk generated by KernelFunctionsFromMethod. +/// +public sealed class StreamingMethodContent : StreamingContent +{ + /// + public override int ChoiceIndex => 0; + + /// + /// Content of a streaming chunk result of a KernelFunctionFromMethod. + /// + public object Content { get; } + + /// + public override byte[] ToByteArray() + { + if (this.Content is byte[]) + { + return (byte[])this.Content; + } + + // By default if a native value is not Byte[] we output the UTF8 string representation of the value + return Encoding.UTF8.GetBytes(this.Content?.ToString()); + } + + /// + public override string ToString() + { + return this.Content.ToString(); + } + + /// + /// Initializes a new instance of the class. + /// + /// Underlying object that represents the chunk content + public StreamingMethodContent(object innerContent) : base(innerContent) + { + this.Content = innerContent; + } +} diff --git a/dotnet/src/SemanticKernel.Core/KernelExtensions.cs b/dotnet/src/SemanticKernel.Core/KernelExtensions.cs index 97fbea9703f1..af6c9ff87cd2 100644 --- a/dotnet/src/SemanticKernel.Core/KernelExtensions.cs +++ b/dotnet/src/SemanticKernel.Core/KernelExtensions.cs @@ -447,4 +447,50 @@ public static Task RunAsync( return kernel.RunAsync(function, variables ?? new(), cancellationToken); } #endregion + + #region RunStreamingAsync + /// + /// Run a function in streaming mode. + /// + /// The target kernel + /// Target function to run + /// Input to process + /// The to monitor for cancellation requests. + /// Streaming result of the function + public static IAsyncEnumerable RunStreamingAsync(this Kernel kernel, KernelFunction function, ContextVariables? variables = null, CancellationToken cancellationToken = default) + => function.InvokeStreamingAsync(kernel, kernel.CreateNewContext(variables), null, cancellationToken); + + /// + /// Run a function in streaming mode. + /// + /// Target kernel + /// Target function to run + /// Input to process + /// Cancellation token + /// Streaming result of the function + public static IAsyncEnumerable RunStreamingAsync(this Kernel kernel, KernelFunction function, ContextVariables? variables = null, CancellationToken cancellationToken = default) + => kernel.RunStreamingAsync(function, variables ?? new ContextVariables(), CancellationToken.None); + + /// + /// Run a function in streaming mode. + /// + /// The target kernel + /// Target function to run + /// Input to process + /// The to monitor for cancellation requests. + /// Streaming result of the function + public static IAsyncEnumerable RunStreamingAsync(this Kernel kernel, KernelFunction function, string input, CancellationToken cancellationToken = default) + => function.InvokeStreamingAsync(kernel, kernel.CreateNewContext(new ContextVariables(input)), null, cancellationToken); + + /// + /// Run a function in streaming mode. + /// + /// Target kernel + /// Target function to run + /// Input to process + /// Cancellation token + /// Streaming result of the function + public static IAsyncEnumerable RunStreamingAsync(this Kernel kernel, KernelFunction function, string input, CancellationToken cancellationToken = default) + => kernel.RunStreamingAsync(function, input, CancellationToken.None); + #endregion } diff --git a/dotnet/src/SemanticKernel.Core/Planning/Plan.cs b/dotnet/src/SemanticKernel.Core/Planning/Plan.cs index c15ce7a1feb4..676085d3e7c9 100644 --- a/dotnet/src/SemanticKernel.Core/Planning/Plan.cs +++ b/dotnet/src/SemanticKernel.Core/Planning/Plan.cs @@ -330,6 +330,17 @@ protected override async Task InvokeCoreAsync( return result; } + /// + protected override IAsyncEnumerable InvokeCoreStreamingAsync( + Kernel kernel, + SKContext context, + AIRequestSettings? requestSettings = null, + CancellationToken cancellationToken = default) + { + // Implementation will be added in future streaming feature iteration + throw new NotSupportedException("Streaming currently is not supported for plans"); + } + #endregion ISKFunction implementation /// @@ -425,8 +436,9 @@ private FunctionInvokingEventArgs CallFunctionInvoking(Kernel kernel, SKContext return eventArgs; } - private FunctionInvokedEventArgs CallFunctionInvoked(Kernel kernel, SKContext context, FunctionResult result) + private FunctionInvokedEventArgs CallFunctionInvoked(Kernel kernel, SKContext context, FunctionResult? result = null) { + result ??= new FunctionResult(this.Name, context); var eventArgs = new FunctionInvokedEventArgs(this.GetMetadata(), result); if (kernel.OnFunctionInvoked(eventArgs)) { diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedIAIServiceConfigurationProviderTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedIAIServiceConfigurationProviderTests.cs index e4105f1c5c79..cb59b0827eff 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedIAIServiceConfigurationProviderTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedIAIServiceConfigurationProviderTests.cs @@ -212,6 +212,11 @@ public Task> GetCompletionsAsync(string text, AIReque throw new NotImplementedException(); } + public IAsyncEnumerable GetStreamingContentAsync(string prompt, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + public IAsyncEnumerable GetStreamingCompletionsAsync(string text, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) { throw new NotImplementedException(); diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/SemanticFunctionTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/SemanticFunctionTests.cs index 06d03db00f2e..d40426f21ed5 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/SemanticFunctionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/SemanticFunctionTests.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -10,6 +9,7 @@ using Microsoft.SemanticKernel.AI; using Microsoft.SemanticKernel.AI.TextCompletion; using Microsoft.SemanticKernel.Connectors.AI.OpenAI; +using Microsoft.SemanticKernel.Orchestration; using Microsoft.SemanticKernel.TemplateEngine; using Moq; using Xunit; @@ -346,6 +346,26 @@ public async Task RunAsyncChangeVariableInvokedHandlerAsync() Assert.Equal(newInput, originalInput); } + [Fact] + public async Task InvokeStreamingAsyncCallsConnectorStreamingApiAsync() + { + // Arrange + var mockTextCompletion = this.SetupStreamingMocks(new TestStreamingContent()); + var kernel = new KernelBuilder().WithAIService(null, mockTextCompletion.Object).Build(); + var prompt = "Write a simple phrase about UnitTests {{$input}}"; + var sut = SKFunctionFactory.CreateFromPrompt(prompt); + var variables = new ContextVariables("importance"); + var context = kernel.CreateNewContext(variables); + + // Act + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, context)) + { + } + + // Assert + mockTextCompletion.Verify(m => m.GetStreamingContentAsync(It.IsIn("Write a simple phrase about UnitTests importance"), It.IsAny(), It.IsAny()), Times.Exactly(1)); + } + private (Mock textResultMock, Mock textCompletionMock) SetupMocks(string? completionResult = null) { var mockTextResult = new Mock(); @@ -353,12 +373,45 @@ public async Task RunAsyncChangeVariableInvokedHandlerAsync() var mockTextCompletion = new Mock(); mockTextCompletion.Setup(m => m.GetCompletionsAsync(It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(new List { mockTextResult.Object }); - return (mockTextResult, mockTextCompletion); } - private static MethodInfo Method(Delegate method) + private Mock SetupStreamingMocks(T completionResult) { - return method.Method; + var mockTextCompletion = new Mock(); + mockTextCompletion.Setup(m => m.GetStreamingContentAsync(It.IsAny(), It.IsAny(), It.IsAny())).Returns(this.ToAsyncEnumerable(new List { completionResult })); + + return mockTextCompletion; + } + + private sealed class TestStreamingContent : StreamingContent + { + public TestStreamingContent() : base(null) + { + } + + public override int ChoiceIndex => 0; + + public override byte[] ToByteArray() + { + return Array.Empty(); + } + + public override string ToString() + { + return string.Empty; + } + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously +#pragma warning disable IDE1006 // Naming Styles + private async IAsyncEnumerable ToAsyncEnumerable(IEnumerable enumeration) +#pragma warning restore IDE1006 // Naming Styles +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + { + foreach (var enumerationItem in enumeration) + { + yield return enumerationItem; + } } } diff --git a/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs b/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs index 887da4fbfff3..415ab3016117 100644 --- a/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs @@ -1,16 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.ComponentModel; using System.Globalization; using System.Linq; -using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.AI; using Microsoft.SemanticKernel.AI.TextCompletion; using Microsoft.SemanticKernel.Events; using Microsoft.SemanticKernel.Http; @@ -122,12 +119,34 @@ public async Task RunAsyncHandlesPreInvocationAsync() } [Fact] - public async Task RunAsyncHandlesPreInvocationWasCancelledAsync() + public async Task RunStreamingAsyncHandlesPreInvocationAsync() + { + // Arrange + var sut = new KernelBuilder().Build(); + int functionInvocations = 0; + var function = SKFunctionFactory.CreateFromMethod(() => functionInvocations++); + + var handlerInvocations = 0; + sut.FunctionInvoking += (object? sender, FunctionInvokingEventArgs e) => + { + handlerInvocations++; + }; + + // Act + await foreach (var chunk in sut.RunStreamingAsync(function)) { } + + // Assert + Assert.Equal(1, functionInvocations); + Assert.Equal(1, handlerInvocations); + } + + [Fact] + public async Task RunStreamingAsyncHandlesPreInvocationWasCancelledAsync() { // Arrange var sut = new KernelBuilder().Build(); int functionInvocations = 0; - KernelFunction function = SKFunctionFactory.CreateFromMethod(() => functionInvocations++); + var function = SKFunctionFactory.CreateFromMethod(() => functionInvocations++); var handlerInvocations = 0; sut.FunctionInvoking += (object? sender, FunctionInvokingEventArgs e) => @@ -137,16 +156,85 @@ public async Task RunAsyncHandlesPreInvocationWasCancelledAsync() }; // Act - var result = await sut.RunAsync(function); + int chunksCount = 0; + await foreach (var chunk in sut.RunStreamingAsync(function)) + { + chunksCount++; + } // Assert Assert.Equal(1, handlerInvocations); Assert.Equal(0, functionInvocations); - Assert.NotNull(result); + Assert.Equal(0, chunksCount); + } + + [Fact] + public async Task RunStreamingAsyncPreInvocationCancelationDontTriggerInvokedHandlerAsync() + { + // Arrange + var sut = new KernelBuilder().Build(); + var functions = sut.ImportPluginFromObject(); + + var invoked = 0; + sut.FunctionInvoking += (object? sender, FunctionInvokingEventArgs e) => + { + e.Cancel(); + }; + + sut.FunctionInvoked += (object? sender, FunctionInvokedEventArgs e) => + { + invoked++; + }; + + // Act + await foreach (var chunk in sut.RunStreamingAsync(functions["GetAnyValue"])) + { + } + + // Assert + Assert.Equal(0, invoked); } [Fact] - public async Task RunAsyncHandlesPreInvocationCancelationDontRunSubsequentFunctionsInThePipelineAsync() + public async Task RunStreamingAsyncPreInvocationSkipDontTriggerInvokedHandlerAsync() + { + // Arrange + var sut = new KernelBuilder().Build(); + int funcInvocations = 0; + var function = SKFunctionFactory.CreateFromMethod(() => funcInvocations++, functionName: "func1"); + + var invoked = 0; + var invoking = 0; + string invokedFunction = string.Empty; + + sut.FunctionInvoking += (object? sender, FunctionInvokingEventArgs e) => + { + invoking++; + if (e.FunctionMetadata.Name == "func1") + { + e.Skip(); + } + }; + + sut.FunctionInvoked += (object? sender, FunctionInvokedEventArgs e) => + { + invokedFunction = e.FunctionMetadata.Name; + invoked++; + }; + + // Act + await foreach (var chunk in sut.RunStreamingAsync(function)) + { + } + + // Assert + Assert.Equal(1, invoking); + Assert.Equal(0, invoked); + Assert.Equal(0, funcInvocations); + } + + [Fact] + public async Task RunStreamingAsyncHandlesPostInvocationAsync() { // Arrange var sut = new KernelBuilder().Build(); @@ -154,6 +242,30 @@ public async Task RunAsyncHandlesPreInvocationCancelationDontRunSubsequentFuncti var function = SKFunctionFactory.CreateFromMethod(() => functionInvocations++); int handlerInvocations = 0; + sut.FunctionInvoked += (object? sender, FunctionInvokedEventArgs e) => + { + handlerInvocations++; + }; + + // Act + await foreach (var chunk in sut.RunStreamingAsync(function)) + { + } + + // Assert + Assert.Equal(1, functionInvocations); + Assert.Equal(1, handlerInvocations); + } + + [Fact] + public async Task RunAsyncHandlesPreInvocationWasCancelledAsync() + { + // Arrange + var sut = new KernelBuilder().Build(); + int functionInvocations = 0; + var function = SKFunctionFactory.CreateFromMethod(() => functionInvocations++); + + var handlerInvocations = 0; sut.FunctionInvoking += (object? sender, FunctionInvokingEventArgs e) => { handlerInvocations++; @@ -166,6 +278,7 @@ public async Task RunAsyncHandlesPreInvocationCancelationDontRunSubsequentFuncti // Assert Assert.Equal(1, handlerInvocations); Assert.Equal(0, functionInvocations); + Assert.NotNull(result); } [Fact] @@ -199,7 +312,7 @@ public async Task RunAsyncPreInvocationSkipDontTriggerInvokedHandlerAsync() // Arrange var sut = new KernelBuilder().Build(); int funcInvocations = 0; - KernelFunction function = SKFunctionFactory.CreateFromMethod(() => funcInvocations++, functionName: "func1"); + var function = SKFunctionFactory.CreateFromMethod(() => funcInvocations++, functionName: "func1"); var invoked = 0; var invoking = 0; @@ -235,7 +348,7 @@ public async Task RunAsyncHandlesPostInvocationAsync() // Arrange var sut = new KernelBuilder().Build(); int functionInvocations = 0; - KernelFunction function = SKFunctionFactory.CreateFromMethod(() => functionInvocations++); + var function = SKFunctionFactory.CreateFromMethod(() => functionInvocations++); int handlerInvocations = 0; sut.FunctionInvoked += (object? sender, FunctionInvokedEventArgs e) => @@ -255,7 +368,7 @@ public async Task RunAsyncHandlesPostInvocationAsync() public async Task RunAsyncChangeVariableInvokingHandlerAsync() { var sut = new KernelBuilder().Build(); - KernelFunction function = SKFunctionFactory.CreateFromMethod(() => { }); + var function = SKFunctionFactory.CreateFromMethod(() => { }); var originalInput = "Importance"; var newInput = "Problems"; @@ -276,7 +389,7 @@ public async Task RunAsyncChangeVariableInvokingHandlerAsync() public async Task RunAsyncChangeVariableInvokedHandlerAsync() { var sut = new KernelBuilder().Build(); - KernelFunction function = SKFunctionFactory.CreateFromMethod(() => { }); + var function = SKFunctionFactory.CreateFromMethod(() => { }); var originalInput = "Importance"; var newInput = "Problems"; @@ -477,20 +590,4 @@ public async Task ReadFunctionCollectionAsync(SKContext context, Kern return context; } } - - private (Mock textResultMock, Mock textCompletionMock) SetupMocks(string? completionResult = null) - { - var mockTextResult = new Mock(); - mockTextResult.Setup(m => m.GetCompletionAsync(It.IsAny())).ReturnsAsync(completionResult ?? "LLM Result about UnitTests"); - - var mockTextCompletion = new Mock(); - mockTextCompletion.Setup(m => m.GetCompletionsAsync(It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(new List { mockTextResult.Object }); - - return (mockTextResult, mockTextCompletion); - } - - private static MethodInfo Method(Delegate method) - { - return method.Method; - } }