From 77fd7be4b5d7632b641d007be94dac697e568af5 Mon Sep 17 00:00:00 2001 From: carlodek <56030624+carlodek@users.noreply.github.com> Date: Sun, 1 Dec 2024 12:45:23 +0100 Subject: [PATCH] Implement Response Streaming (#726) ## Motivation and Context (Why the change? What's the scenario?) Add option to stream Ask result tokens without waiting for the full answer to be ready. ## High level description (Approach, Design) - New `stream` boolean option for the `Ask` API, false by default. When true, answer tokens are streamed as soon as they are generated by LLMs. - New `MemoryAnswer.StreamState` enum property: `Error`, `Reset`, `Append`, `Last`. - If moderation is enabled, the content is validated at the end. In case of moderation failure, the service returns an answer with `StreamState` = `Reset` and the new content to show to the end user. - Streaming uses SSE message format. - By default, SSE streams end with a `[DONE]` token. This can be disabled via KM settings. - SSE payload is optimized, returning `RelevantSources` only in the first SSE message. --------- Co-authored-by: Carlo Co-authored-by: Devis Lucato Co-authored-by: Devis Lucato --- .github/_typos.toml | 1 + KernelMemory.sln.DotSettings | 1 + .../SemanticKernelPlugin/MemoryPlugin.cs | 1 + clients/dotnet/WebClient/MemoryWebClient.cs | 28 +++- examples/001-dotnet-WebClient/Program.cs | 46 ++++-- examples/002-dotnet-Serverless/Program.cs | 46 ++++-- .../AzureAISearch.TestApplication/Program.cs | 3 +- .../AzureAISearch/AzureAISearchMemory.cs | 11 +- service/Abstractions/Abstractions.csproj | 3 +- service/Abstractions/HTTP/SSE.cs | 68 ++++++++ service/Abstractions/IKernelMemory.cs | 9 +- .../Abstractions/KernelMemoryExtensions.cs | 43 +++++ service/Abstractions/Models/MemoryAnswer.cs | 37 +++-- service/Abstractions/Models/MemoryQuery.cs | 4 + service/Abstractions/Models/SearchResult.cs | 6 +- service/Abstractions/Models/StreamStates.cs | 58 +++++++ service/Abstractions/Search/ISearchClient.cs | 18 +++ service/Abstractions/Search/SearchOptions.cs | 29 ++++ service/Core/Configuration/ServiceConfig.cs | 5 + service/Core/MemoryServerless.cs | 30 +++- service/Core/MemoryService.cs | 30 +++- service/Core/Search/AnswerGenerator.cs | 83 +++++----- service/Core/Search/SearchClient.cs | 129 ++++++++++++--- service/Core/Search/SearchClientResult.cs | 66 +++++++- service/Service.AspNetCore/WebAPIEndpoints.cs | 88 +++++++++-- service/Service/appsettings.json | 2 + .../Abstractions.UnitTests/Http/SSETest.cs | 149 ++++++++++++++++++ 27 files changed, 841 insertions(+), 153 deletions(-) create mode 100644 service/Abstractions/HTTP/SSE.cs create mode 100644 service/Abstractions/Models/StreamStates.cs create mode 100644 service/Abstractions/Search/SearchOptions.cs create mode 100644 service/tests/Abstractions.UnitTests/Http/SSETest.cs diff --git a/.github/_typos.toml b/.github/_typos.toml index 5fa7634f9..000d278e9 100644 --- a/.github/_typos.toml +++ b/.github/_typos.toml @@ -15,6 +15,7 @@ extend-exclude = [ "encoder.json", "appsettings.development.json", "appsettings.Development.json", + "appsettings.*.json.*", "AzureAISearchFilteringTest.cs", "KernelMemory.sln.DotSettings" ] diff --git a/KernelMemory.sln.DotSettings b/KernelMemory.sln.DotSettings index 4d9e859e9..56b15a2c0 100644 --- a/KernelMemory.sln.DotSettings +++ b/KernelMemory.sln.DotSettings @@ -120,6 +120,7 @@ SHA SK SKHTTP + SSE SSL TTL UI diff --git a/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs b/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs index 4446459ab..a4db5b6a0 100644 --- a/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs +++ b/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs @@ -404,6 +404,7 @@ public async Task AskAsync( MemoryAnswer answer = await this._memory.AskAsync( question: question, index: index ?? this._defaultIndex, + options: new SearchOptions { Stream = false }, filter: TagsToMemoryFilter(tags ?? this._defaultRetrievalTags), minRelevance: minRelevance, cancellationToken: cancellationToken).ConfigureAwait(false); diff --git a/clients/dotnet/WebClient/MemoryWebClient.cs b/clients/dotnet/WebClient/MemoryWebClient.cs index a89b25a55..f767881fd 100644 --- a/clients/dotnet/WebClient/MemoryWebClient.cs +++ b/clients/dotnet/WebClient/MemoryWebClient.cs @@ -7,11 +7,13 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.KernelMemory.Context; +using Microsoft.KernelMemory.HTTP; using Microsoft.KernelMemory.Internals; namespace Microsoft.KernelMemory; @@ -337,28 +339,30 @@ public async Task SearchAsync( } /// - public async Task AskAsync( + public async IAsyncEnumerable AskStreamingAsync( string question, string? index = null, MemoryFilter? filter = null, ICollection? filters = null, double minRelevance = 0, + SearchOptions? options = null, IContext? context = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (filter != null) { - if (filters == null) { filters = []; } - + filters ??= []; filters.Add(filter); } + var useStreaming = options?.Stream ?? false; MemoryQuery request = new() { Index = index, Question = question, Filters = (filters is { Count: > 0 }) ? filters.ToList() : [], MinRelevance = minRelevance, + Stream = useStreaming, ContextArguments = (context?.Arguments ?? new Dictionary()).ToDictionary(), }; using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json"); @@ -367,8 +371,20 @@ public async Task AskAsync( HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false); response.EnsureSuccessStatusCode(); - var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer(); + if (useStreaming) + { + Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + IAsyncEnumerable answers = SSE.ParseStreamAsync(stream, cancellationToken); + await foreach (MemoryAnswer answer in answers.ConfigureAwait(false)) + { + yield return answer; + } + } + else + { + var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + yield return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer(); + } } #region private diff --git a/examples/001-dotnet-WebClient/Program.cs b/examples/001-dotnet-WebClient/Program.cs index 0eec905c4..f1e6d4db4 100644 --- a/examples/001-dotnet-WebClient/Program.cs +++ b/examples/001-dotnet-WebClient/Program.cs @@ -15,7 +15,7 @@ * without extracting memories. */ internal static class Program { - private static MemoryWebClient? s_memory; + private static MemoryWebClient s_memory = null!; private static readonly List s_toDelete = []; // Change this to True and configure Azure Document Intelligence to test OCR and support for images @@ -55,8 +55,8 @@ public static async Task Main() // === RETRIEVAL ========= // ======================= - await AskSimpleQuestion(); - await AskSimpleQuestionAndShowSources(); + await AskSimpleQuestionStreamingTheAnswer(); + await AskSimpleQuestionStreamingAndShowSources(); await AskQuestionAboutImageContent(); await AskQuestionUsingFilter(); await AskQuestionsFilteringByUser(); @@ -249,16 +249,25 @@ private static async Task StoreJson() // ======================= // Question without filters - private static async Task AskSimpleQuestion() + private static async Task AskSimpleQuestionStreamingTheAnswer() { var question = "What's E = m*c^2?"; Console.WriteLine($"Question: {question}"); Console.WriteLine($"Expected result: formula explanation using the information loaded"); - var answer = await s_memory.AskAsync(question, minRelevance: 0.6); - Console.WriteLine($"\nAnswer: {answer.Result}"); + Console.Write("\nAnswer: "); + var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.6, + options: new SearchOptions { Stream = true }); - Console.WriteLine("\n====================================\n"); + await foreach (var answer in answerStream) + { + // Print token received by LLM + Console.Write(answer.Result); + // Slow down the stream for demo purpose + await Task.Delay(25); + } + + Console.WriteLine("\n\n====================================\n"); /* OUTPUT @@ -275,17 +284,32 @@ due to the speed of light being a very large number when squared. This concept i } // Another question without filters and show sources - private static async Task AskSimpleQuestionAndShowSources() + private static async Task AskSimpleQuestionStreamingAndShowSources() { var question = "What's Kernel Memory?"; Console.WriteLine($"Question: {question}"); Console.WriteLine($"Expected result: it should explain what KM project is (not generic kernel memory)"); - var answer = await s_memory.AskAsync(question, minRelevance: 0.5); - Console.WriteLine($"\nAnswer: {answer.Result}\n\n Sources:\n"); + Console.Write("\nAnswer: "); + var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.5, + options: new SearchOptions { Stream = true }); + + List sources = []; + await foreach (var answer in answerStream) + { + // Print token received by LLM + Console.Write(answer.Result); + + // Collect sources + sources.AddRange(answer.RelevantSources); + + // Slow down the stream for demo purpose + await Task.Delay(5); + } // Show sources / citations - foreach (var x in answer.RelevantSources) + Console.WriteLine("\n\nSources:\n"); + foreach (var x in sources) { Console.WriteLine(x.SourceUrl != null ? $" - {x.SourceUrl} [{x.Partitions.First().LastUpdate:D}]" diff --git a/examples/002-dotnet-Serverless/Program.cs b/examples/002-dotnet-Serverless/Program.cs index ee5f680d6..26a9fb5c8 100644 --- a/examples/002-dotnet-Serverless/Program.cs +++ b/examples/002-dotnet-Serverless/Program.cs @@ -13,7 +13,7 @@ #pragma warning disable CS8602 // by design public static class Program { - private static MemoryServerless? s_memory; + private static MemoryServerless s_memory = null!; private static readonly List s_toDelete = []; // Remember to configure Azure Document Intelligence to test OCR and support for images @@ -107,8 +107,8 @@ public static async Task Main() // === RETRIEVAL ========= // ======================= - await AskSimpleQuestion(); - await AskSimpleQuestionAndShowSources(); + await AskSimpleQuestionStreamingTheAnswer(); + await AskSimpleQuestionStreamingAndShowSources(); await AskQuestionAboutImageContent(); await AskQuestionUsingFilter(); await AskQuestionsFilteringByUser(); @@ -303,16 +303,25 @@ private static async Task StoreJson() // ======================= // Question without filters - private static async Task AskSimpleQuestion() + private static async Task AskSimpleQuestionStreamingTheAnswer() { var question = "What's E = m*c^2?"; Console.WriteLine($"Question: {question}"); Console.WriteLine($"Expected result: formula explanation using the information loaded"); - var answer = await s_memory.AskAsync(question, minRelevance: 0.6); - Console.WriteLine($"\nAnswer: {answer.Result}"); + Console.Write("\nAnswer: "); + var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.6, + options: new SearchOptions { Stream = true }); - Console.WriteLine("\n====================================\n"); + await foreach (var answer in answerStream) + { + // Print token received by LLM + Console.Write(answer.Result); + // Slow down the stream for demo purpose + await Task.Delay(25); + } + + Console.WriteLine("\n\n====================================\n"); /* OUTPUT @@ -329,17 +338,32 @@ due to the speed of light being a very large number when squared. This concept i } // Another question without filters and show sources - private static async Task AskSimpleQuestionAndShowSources() + private static async Task AskSimpleQuestionStreamingAndShowSources() { var question = "What's Kernel Memory?"; Console.WriteLine($"Question: {question}"); Console.WriteLine($"Expected result: it should explain what KM project is (not generic kernel memory)"); - var answer = await s_memory.AskAsync(question, minRelevance: 0.5); - Console.WriteLine($"\nAnswer: {answer.Result}\n\n Sources:\n"); + Console.Write("\nAnswer: "); + var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.5, + options: new SearchOptions { Stream = true }); + + List sources = []; + await foreach (var answer in answerStream) + { + // Print token received by LLM + Console.Write(answer.Result); + + // Collect sources + sources.AddRange(answer.RelevantSources); + + // Slow down the stream for demo purpose + await Task.Delay(5); + } // Show sources / citations - foreach (var x in answer.RelevantSources) + Console.WriteLine("\n\nSources:\n"); + foreach (var x in sources) { Console.WriteLine(x.SourceUrl != null ? $" - {x.SourceUrl} [{x.Partitions.First().LastUpdate:D}]" diff --git a/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs b/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs index 93618ad48..3aee0565b 100644 --- a/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs +++ b/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs @@ -8,6 +8,7 @@ using Microsoft.KernelMemory; using Microsoft.KernelMemory.MemoryDb.AzureAISearch; using Microsoft.KernelMemory.MemoryStorage; +using AISearchOptions = Azure.Search.Documents.SearchOptions; namespace Microsoft.AzureAISearch.TestApplication; @@ -246,7 +247,7 @@ private static async Task> SearchByFieldValueAsync( fieldValue1 = fieldValue1.Replace("'", "''", StringComparison.Ordinal); fieldValue2 = fieldValue2.Replace("'", "''", StringComparison.Ordinal); - SearchOptions options = new() + AISearchOptions options = new() { Filter = fieldIsCollection ? $"{fieldName}/any(s: s eq '{fieldValue1}') and {fieldName}/any(s: s eq '{fieldValue2}')" diff --git a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs index fb838e862..b8ad546c6 100644 --- a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs +++ b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs @@ -19,6 +19,7 @@ using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.DocumentStorage; using Microsoft.KernelMemory.MemoryStorage; +using AISearchOptions = Azure.Search.Documents.SearchOptions; namespace Microsoft.KernelMemory.MemoryDb.AzureAISearch; @@ -184,7 +185,7 @@ await client.IndexDocumentsAsync( Exhaustive = false }; - SearchOptions options = new() + AISearchOptions options = new() { VectorSearch = new() { @@ -246,7 +247,7 @@ public async IAsyncEnumerable GetListAsync( { var client = this.GetSearchClient(index); - SearchOptions options = this.PrepareSearchOptions(null, withEmbeddings, filters, limit); + AISearchOptions options = this.PrepareSearchOptions(null, withEmbeddings, filters, limit); Response>? searchResult = null; try @@ -596,13 +597,13 @@ at Azure.Search.Documents.SearchClient.SearchInternal[T](SearchOptions options, return indexSchema; } - private SearchOptions PrepareSearchOptions( - SearchOptions? options, + private AISearchOptions PrepareSearchOptions( + AISearchOptions? options, bool withEmbeddings, ICollection? filters = null, int limit = 1) { - options ??= new SearchOptions(); + options ??= new AISearchOptions(); // Define which fields to fetch options.Select.Add(AzureAISearchMemoryRecord.IdField); diff --git a/service/Abstractions/Abstractions.csproj b/service/Abstractions/Abstractions.csproj index 6fab4f39c..caac4445d 100644 --- a/service/Abstractions/Abstractions.csproj +++ b/service/Abstractions/Abstractions.csproj @@ -4,7 +4,7 @@ net8.0 Microsoft.KernelMemory.Abstractions Microsoft.KernelMemory - $(NoWarn);KMEXP00;CA1711;CA1724;CS1574;SKEXP0001; + $(NoWarn);KMEXP00;SKEXP0001;CA1711;CA1724;CS1574;CA1812; @@ -13,6 +13,7 @@ + diff --git a/service/Abstractions/HTTP/SSE.cs b/service/Abstractions/HTTP/SSE.cs new file mode 100644 index 000000000..e6d9254fd --- /dev/null +++ b/service/Abstractions/HTTP/SSE.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; + +namespace Microsoft.KernelMemory.HTTP; + +// See https://developer.mozilla.org/docs/Web/API/Server-sent_events/Using_server-sent_events +public static class SSE +{ + public const string DataPrefix = "data: "; + public const string LastToken = "[DONE]"; + public const string DoneMessage = $"{DataPrefix}{LastToken}"; + + public async static IAsyncEnumerable ParseStreamAsync( + Stream stream, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using var reader = new StreamReader(stream, Encoding.UTF8); + StringBuilder buffer = new(); + + while (await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is { } line) + { + if (string.IsNullOrWhiteSpace(line)) // \n\n detected => Message delimiter + { + if (buffer.Length == 0) { continue; } + + string message = buffer.ToString(); + buffer.Clear(); + if (message.Trim() == DoneMessage) { yield break; } + + var memoryAnswer = ParseMessage(message); + if (memoryAnswer != null) { yield return memoryAnswer; } + } + else + { + buffer.AppendLine(line); + } + } + + // Process any remaining text as the last message + if (buffer.Length > 0) + { + string message = buffer.ToString(); + if (message.Trim() == DoneMessage) { yield break; } + + var memoryAnswer = ParseMessage(message); + if (memoryAnswer != null) { yield return memoryAnswer; } + } + } + + public static T? ParseMessage(string? message) + { + if (string.IsNullOrWhiteSpace(message)) { return default; } + + string json = string.Join("", + message.Split('\n', StringSplitOptions.RemoveEmptyEntries) + .Where(line => line.StartsWith(DataPrefix, StringComparison.OrdinalIgnoreCase)) + .Select(line => line[DataPrefix.Length..])); + + return JsonSerializer.Deserialize(json); + } +} diff --git a/service/Abstractions/IKernelMemory.cs b/service/Abstractions/IKernelMemory.cs index 89dc57009..f1152ff6a 100644 --- a/service/Abstractions/IKernelMemory.cs +++ b/service/Abstractions/IKernelMemory.cs @@ -211,21 +211,28 @@ public Task SearchAsync( /// /// Search the given index for an answer to the given query. + /// + /// Use this method to work with IAsyncEnumerable and optionally stream the output. + /// - Note: you must set options.Stream = true to enable token streaming. + /// + /// Use the .AskAsync() extension method to receive the complete answer without streaming. /// /// Question to answer /// Optional index name /// Filter to match /// Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list. /// Minimum Cosine Similarity required + /// Options for the request, such as whether to stream results /// Unstructured data supporting custom business logic in the current request. /// Async task cancellation token /// Answer to the query, if possible - public Task AskAsync( + public IAsyncEnumerable AskStreamingAsync( string question, string? index = null, MemoryFilter? filter = null, ICollection? filters = null, double minRelevance = 0, + SearchOptions? options = null, IContext? context = null, CancellationToken cancellationToken = default); } diff --git a/service/Abstractions/KernelMemoryExtensions.cs b/service/Abstractions/KernelMemoryExtensions.cs index 62443a890..0e5dfb53a 100644 --- a/service/Abstractions/KernelMemoryExtensions.cs +++ b/service/Abstractions/KernelMemoryExtensions.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.KernelMemory.Context; namespace Microsoft.KernelMemory; @@ -11,6 +13,47 @@ namespace Microsoft.KernelMemory; /// public static class KernelMemoryExtensions { + /// + /// Search the given index for an answer to the given query + /// and return it without streaming the content. + /// + /// Memory instance + /// Question to answer + /// Optional index name + /// Filter to match + /// Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list. + /// Minimum Cosine Similarity required + /// Options for the request, such as whether to stream results + /// Unstructured data supporting custom business logic in the current request. + /// Async task cancellation token + /// Answer to the query, if possible + public static async Task AskAsync( + this IKernelMemory memory, + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + SearchOptions? options = null, + IContext? context = null, + CancellationToken cancellationToken = default) + { + var optionsOverride = options.Clone() ?? new SearchOptions(); + optionsOverride.Stream = false; + + return await memory.AskStreamingAsync( + question: question, + index: index, + filter: filter, + filters: filters, + minRelevance: minRelevance, + options: optionsOverride, + context: context, + cancellationToken) + .FirstAsync(cancellationToken: cancellationToken) + .ConfigureAwait(false); + } + /// /// Return a list of synthetic memories of the specified type /// diff --git a/service/Abstractions/Models/MemoryAnswer.cs b/service/Abstractions/Models/MemoryAnswer.cs index eeb9d4d03..c78e695c6 100644 --- a/service/Abstractions/Models/MemoryAnswer.cs +++ b/service/Abstractions/Models/MemoryAnswer.cs @@ -11,15 +11,20 @@ namespace Microsoft.KernelMemory; public class MemoryAnswer { - private static readonly JsonSerializerOptions s_indentedJsonOptions = new() { WriteIndented = true }; - private static readonly JsonSerializerOptions s_notIndentedJsonOptions = new() { WriteIndented = false }; - private static readonly JsonSerializerOptions s_caseInsensitiveJsonOptions = new() { PropertyNameCaseInsensitive = true }; + /// + /// Used only when streaming. How to handle the current record. + /// + [JsonPropertyName("streamState")] + [JsonPropertyOrder(0)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public StreamStates? StreamState { get; set; } = null; /// /// Client question. /// [JsonPropertyName("question")] [JsonPropertyOrder(1)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public string Question { get; set; } = string.Empty; [JsonPropertyName("noResult")] @@ -48,23 +53,31 @@ public class MemoryAnswer /// [JsonPropertyName("relevantSources")] [JsonPropertyOrder(20)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public List RelevantSources { get; set; } = []; /// /// Serialize using .NET JSON serializer, e.g. to avoid ambiguity /// with other serializers and other options /// - /// Whether to keep the JSON readable, e.g. for debugging and views + /// Whether to reduce the payload size for SSE /// JSON serialization - public string ToJson(bool indented = false) + public string ToJson(bool optimizeForStream) { - return JsonSerializer.Serialize(this, indented ? s_indentedJsonOptions : s_notIndentedJsonOptions); - } + if (!optimizeForStream || this.StreamState != StreamStates.Append) + { + return JsonSerializer.Serialize(this); + } - public MemoryAnswer FromJson(string json) - { - return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) - ?? new MemoryAnswer(); + MemoryAnswer clone = JsonSerializer.Deserialize(JsonSerializer.Serialize(this))!; + +#pragma warning disable CA1820 + if (clone.Question == string.Empty) { clone.Question = null!; } +#pragma warning restore CA1820 + + if (clone.RelevantSources.Count == 0) { clone.RelevantSources = null!; } + + return JsonSerializer.Serialize(clone); } public override string ToString() @@ -72,7 +85,7 @@ public override string ToString() var result = new StringBuilder(); result.AppendLine(this.Result); - if (!this.NoResult) + if (!this.NoResult && this.RelevantSources is { Count: > 0 }) { var sources = new Dictionary(); foreach (var x in this.RelevantSources) diff --git a/service/Abstractions/Models/MemoryQuery.cs b/service/Abstractions/Models/MemoryQuery.cs index f6e8dfae4..3676b945a 100644 --- a/service/Abstractions/Models/MemoryQuery.cs +++ b/service/Abstractions/Models/MemoryQuery.cs @@ -23,6 +23,10 @@ public class MemoryQuery [JsonPropertyOrder(2)] public double MinRelevance { get; set; } = 0; + [JsonPropertyName("stream")] + [JsonPropertyOrder(3)] + public bool Stream { get; set; } = false; + [JsonPropertyName("args")] [JsonPropertyOrder(100)] public Dictionary ContextArguments { get; set; } = []; diff --git a/service/Abstractions/Models/SearchResult.cs b/service/Abstractions/Models/SearchResult.cs index 019720c22..4e2e35128 100644 --- a/service/Abstractions/Models/SearchResult.cs +++ b/service/Abstractions/Models/SearchResult.cs @@ -26,10 +26,7 @@ public class SearchResult [JsonPropertyOrder(2)] public bool NoResult { - get - { - return this.Results.Count == 0; - } + get => this.Results == null || this.Results.Count == 0; private set { } } @@ -40,6 +37,7 @@ private set { } /// [JsonPropertyName("results")] [JsonPropertyOrder(3)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public List Results { get; set; } = []; /// diff --git a/service/Abstractions/Models/StreamStates.cs b/service/Abstractions/Models/StreamStates.cs new file mode 100644 index 000000000..041187dd7 --- /dev/null +++ b/service/Abstractions/Models/StreamStates.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.KernelMemory; + +[JsonConverter(typeof(StreamStatesConverter))] +public enum StreamStates +{ + // Inform the client the stream ended to an error. + Error = 0, + + // When streaming, inform the client to discard any previous data + // and start collecting again using this record as the first one. + Reset = 1, + + // When streaming, append the current result to the data + // already received so far. + Append = 2, + + // Inform the client the end of the stream has been reached + // and that this is the last record to append. + Last = 3, +} + +#pragma warning disable CA1308 +internal sealed class StreamStatesConverter : JsonConverter +{ + public override StreamStates Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + string value = reader.GetString()!; + return value.ToLowerInvariant() switch + { + "error" => StreamStates.Error, + "reset" => StreamStates.Reset, + "append" => StreamStates.Append, + "last" => StreamStates.Last, + _ => throw new JsonException($"Unknown {nameof(StreamStates)} value: {value}") + }; + } + + public override void Write(Utf8JsonWriter writer, StreamStates value, JsonSerializerOptions options) + { + string serializedValue = value switch + { + StreamStates.Error => "error", + StreamStates.Reset => "reset", + StreamStates.Append => "append", + StreamStates.Last => "last", + _ => throw new JsonException($"Cannot serialize {nameof(StreamStates)} value: {value}") + }; + + writer.WriteStringValue(serializedValue); + } +} +#pragma warning restore CA1308 diff --git a/service/Abstractions/Search/ISearchClient.cs b/service/Abstractions/Search/ISearchClient.cs index a8da8f1cc..4329538e3 100644 --- a/service/Abstractions/Search/ISearchClient.cs +++ b/service/Abstractions/Search/ISearchClient.cs @@ -50,6 +50,24 @@ Task AskAsync( IContext? context = null, CancellationToken cancellationToken = default); + /// + /// Answer the given question, if possible, grounding the response with relevant memories matching the given criteria. + /// + /// Index (aka collection) to search for grounding information + /// Question to answer + /// Filtering criteria to select memories to consider + /// Minimum relevance of the memories considered + /// Optional context carrying optional information used by internal logic + /// Async task cancellation token + /// Answer to the given question + IAsyncEnumerable AskStreamingAsync( + string index, + string question, + ICollection? filters = null, + double minRelevance = 0, + IContext? context = null, + CancellationToken cancellationToken = default); + /// /// List the available memory indexes (aka collections). /// diff --git a/service/Abstractions/Search/SearchOptions.cs b/service/Abstractions/Search/SearchOptions.cs new file mode 100644 index 000000000..4f88b1f0c --- /dev/null +++ b/service/Abstractions/Search/SearchOptions.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. + +#pragma warning disable IDE0130 // reduce number of "using" statements +// ReSharper disable once CheckNamespace - reduce number of "using" statements +namespace Microsoft.KernelMemory; + +// TODO: move minRelevance to this class +// TODO: move filter to this class +// TODO: move filters to this class +public sealed class SearchOptions +{ + /// + /// Whether to stream results back to the client + /// + public bool Stream { get; set; } = false; +} + +public static class SearchOptionsExtensions +{ + public static SearchOptions? Clone(this SearchOptions? options) + { + if (options == null) { return null; } + + return new SearchOptions + { + Stream = options.Stream + }; + } +} diff --git a/service/Core/Configuration/ServiceConfig.cs b/service/Core/Configuration/ServiceConfig.cs index 4b89ff129..2a9ab78ec 100644 --- a/service/Core/Configuration/ServiceConfig.cs +++ b/service/Core/Configuration/ServiceConfig.cs @@ -24,6 +24,11 @@ public class ServiceConfig /// public bool OpenApiEnabled { get; set; } = false; + /// + /// Whether to send a [DONE] message at the end of SSE streams. + /// + public bool SendSSEDoneMessage { get; set; } = true; + /// /// List of handlers to enable /// diff --git a/service/Core/MemoryServerless.cs b/service/Core/MemoryServerless.cs index 2aba4cf0d..240ddd52a 100644 --- a/service/Core/MemoryServerless.cs +++ b/service/Core/MemoryServerless.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -262,30 +263,47 @@ public Task SearchAsync( } /// - public Task AskAsync( + public async IAsyncEnumerable AskStreamingAsync( string question, string? index = null, MemoryFilter? filter = null, ICollection? filters = null, double minRelevance = 0, + SearchOptions? options = null, IContext? context = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { this._contextProvider.InitContext(context); if (filter != null) { - if (filters == null) { filters = []; } - + filters ??= []; filters.Add(filter); } index = IndexName.CleanName(index, this._defaultIndexName); - return this._searchClient.AskAsync( + + if (options is { Stream: true }) + { + await foreach (var answer in this._searchClient.AskStreamingAsync( + index: index, + question: question, + filters: filters, + minRelevance: minRelevance, + context: context, + cancellationToken).ConfigureAwait(false)) + { + yield return answer; + } + + yield break; + } + + yield return await this._searchClient.AskAsync( index: index, question: question, filters: filters, minRelevance: minRelevance, context: context, - cancellationToken: cancellationToken); + cancellationToken).ConfigureAwait(false); } } diff --git a/service/Core/MemoryService.cs b/service/Core/MemoryService.cs index 2ee34b163..5a88813da 100644 --- a/service/Core/MemoryService.cs +++ b/service/Core/MemoryService.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -228,29 +229,46 @@ public Task SearchAsync( } /// - public Task AskAsync( + public async IAsyncEnumerable AskStreamingAsync( string question, string? index = null, MemoryFilter? filter = null, ICollection? filters = null, double minRelevance = 0, + SearchOptions? options = null, IContext? context = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (filter != null) { - if (filters == null) { filters = []; } - + filters ??= []; filters.Add(filter); } index = IndexName.CleanName(index, this._defaultIndexName); - return this._searchClient.AskAsync( + + if (options is { Stream: true }) + { + await foreach (var answer in this._searchClient.AskStreamingAsync( + index: index, + question: question, + filters: filters, + minRelevance: minRelevance, + context: context, + cancellationToken).ConfigureAwait(false)) + { + yield return answer; + } + + yield break; + } + + yield return await this._searchClient.AskAsync( index: index, question: question, filters: filters, minRelevance: minRelevance, context: context, - cancellationToken: cancellationToken); + cancellationToken).ConfigureAwait(false); } } diff --git a/service/Core/Search/AnswerGenerator.cs b/service/Core/Search/AnswerGenerator.cs index 14abb981b..352160851 100644 --- a/service/Core/Search/AnswerGenerator.cs +++ b/service/Core/Search/AnswerGenerator.cs @@ -2,8 +2,8 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -44,86 +44,75 @@ public AnswerGenerator( { throw new KernelMemoryException("Text generator not configured"); } + + if (this._contentModeration == null || !this._config.UseContentModeration) + { + this._log.LogInformation("Content moderation is not enabled."); + } } - internal async Task GenerateAnswerAsync( - string question, SearchClientResult result, IContext? context, CancellationToken cancellationToken) + internal async IAsyncEnumerable GenerateAnswerAsync( + string question, SearchClientResult result, + IContext? context, [EnumeratorCancellation] CancellationToken cancellationToken) { if (result.FactsAvailableCount > 0 && result.FactsUsedCount == 0) { this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); - result.AskResult.NoResultReason = "Unable to use memories"; - return result.AskResult; + yield return result.InsufficientTokensResult; + yield break; } if (result.FactsUsedCount == 0) { this._log.LogWarning("No memories available"); - result.AskResult.NoResultReason = "No memories available"; - return result.AskResult; + yield return result.NoFactsResult; + yield break; } - // Collect the LLM output - var text = new StringBuilder(); - var charsGenerated = 0; - var watch = new Stopwatch(); - watch.Restart(); - await foreach (var x in this.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(false)) + var completeAnswer = new StringBuilder(); + await foreach (var answerToken in this.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(false)) { - text.Append(x); - - if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) - { - charsGenerated = text.Length; - this._log.LogTrace("{0} chars generated", charsGenerated); - } + completeAnswer.Append(answerToken); + result.AskResult.Result = answerToken; + yield return result.AskResult; } - watch.Stop(); - // Finalize the answer, checking if it's empty - result.AskResult.Result = text.ToString(); - this._log.LogSensitive("Answer: {0}", result.AskResult.Result); - result.AskResult.NoResult = ValueIsEquivalentTo(result.AskResult.Result, this._config.EmptyAnswer); - if (result.AskResult.NoResult) - { - result.AskResult.NoResultReason = "No relevant memories found"; - this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds); - } - else + result.AskResult.Result = completeAnswer.ToString(); + if (string.IsNullOrWhiteSpace(result.AskResult.Result) + || ValueIsEquivalentTo(result.AskResult.Result, this._config.EmptyAnswer)) { - this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds); + this._log.LogInformation("No relevant memories found, returning empty answer."); + yield return result.NoFactsResult; + yield break; } - // Validate the LLM output - if (this._contentModeration != null && this._config.UseContentModeration) + this._log.LogSensitive("Answer: {0}", result.AskResult.Result); + + if (this._config.UseContentModeration + && this._contentModeration != null + && !await this._contentModeration.IsSafeAsync(result.AskResult.Result, cancellationToken).ConfigureAwait(false)) { - var isSafe = await this._contentModeration.IsSafeAsync(result.AskResult.Result, cancellationToken).ConfigureAwait(false); - if (!isSafe) - { - this._log.LogWarning("Unsafe answer detected. Returning error message instead."); - this._log.LogSensitive("Unsafe answer: {0}", result.AskResult.Result); - result.AskResult.NoResultReason = "Content moderation failure"; - result.AskResult.Result = this._config.ModeratedAnswer; - } + this._log.LogWarning("Unsafe answer detected. Returning error message instead."); + yield return result.UnsafeAnswerResult; } - - return result.AskResult; } private IAsyncEnumerable GenerateAnswerTokensAsync(string question, string facts, IContext? context, CancellationToken cancellationToken) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); + string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens); double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature); double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP); - prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); - question = question.Trim(); question = question.EndsWith('?') ? question : $"{question}?"; + + prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); - prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase); + prompt = prompt.Replace("{{$notFound}}", emptyAnswer, StringComparison.OrdinalIgnoreCase); + this._log.LogInformation("New prompt: {0}", prompt); var options = new TextGenerationOptions { diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index cf410ad54..f6b3ddae7 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -5,6 +5,8 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -120,6 +122,81 @@ public async Task AskAsync( double minRelevance = 0, IContext? context = null, CancellationToken cancellationToken = default) + { + var result = new MemoryAnswer(); + + var stream = this.AskStreamingAsync( + index: index, question: question, filters, minRelevance, context, cancellationToken) + .ConfigureAwait(false); + + var done = false; + StringBuilder text = new(result.Result); + await foreach (var part in stream.ConfigureAwait(false)) + { + if (done) { break; } + + switch (part.StreamState) + { + case StreamStates.Error: + text.Clear(); + result = part; + + done = true; + break; + + case StreamStates.Reset: + text.Clear(); + text.Append(part.Result); + result = part; + break; + + case StreamStates.Append: + result.NoResult = part.NoResult; + result.NoResultReason = part.NoResultReason; + + text.Append(part.Result); + result.Result = text.ToString(); + + if (result.RelevantSources != null && part.RelevantSources != null) + { + result.RelevantSources = result.RelevantSources.Union(part.RelevantSources).ToList(); + } + + break; + + case StreamStates.Last: + result.NoResult = part.NoResult; + result.NoResultReason = part.NoResultReason; + + text.Append(part.Result); + result.Result = text.ToString(); + + if (result.RelevantSources != null && part.RelevantSources != null) + { + result.RelevantSources = result.RelevantSources.Union(part.RelevantSources).ToList(); + } + + done = true; + break; + + default: + throw new ArgumentOutOfRangeException(nameof(part.StreamState)); + } + } + + result.Question = question; + result.StreamState = null; + return result; + } + + /// + public async IAsyncEnumerable AskStreamingAsync( + string index, + string question, + ICollection? filters = null, + double minRelevance = 0, + IContext? context = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); string answerPrompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); @@ -129,9 +206,11 @@ public async Task AskAsync( ? this._config.MaxAskPromptSize : this._textGenerator.MaxTokenTotal; + // Prepare results (empty, error, etc.) SearchClientResult result = SearchClientResult.AskResultInstance( question: question, emptyAnswer: emptyAnswer, + moderatedAnswer: this._config.ModeratedAnswer, maxGroundingFacts: limit, tokensAvailable: maxTokens - this._textGenerator.CountTokens(answerPrompt) @@ -142,7 +221,8 @@ public async Task AskAsync( if (string.IsNullOrEmpty(question)) { this._log.LogWarning("No question provided"); - return result.AskResult; + yield return result.NoQuestionResult; + yield break; } this._log.LogTrace("Fetching relevant memories"); @@ -171,7 +251,22 @@ public async Task AskAsync( this._log.LogTrace("{Count} records processed", result.RecordCount); - return await this._answerGenerator.GenerateAnswerAsync(question, result, context, cancellationToken).ConfigureAwait(false); + var first = true; + await foreach (var answer in this._answerGenerator.GenerateAnswerAsync(question, result, context, cancellationToken).ConfigureAwait(false)) + { + yield return answer; + + if (first) + { + // Remove redundant data, sent only once in the first record, to reduce payload + first = false; + + // Note: we keep the sources in the other collections (e.g. AskResult.ErrorResult.RelevantSources), + // so in case of a stream reset the sources are sent again. + result.AskResult.RelevantSources.Clear(); + result.AskResult.Question = null!; + } + } } /// @@ -248,29 +343,17 @@ private SearchClientResult ProcessMemoryRecord( this._log.LogTrace("Adding content #{0} with relevance {1}", result.FactsUsedCount, recordRelevance); } - Citation? citation; - if (result.Mode == SearchMode.SearchMode) + var citation = result.Mode switch { - citation = result.SearchResult.Results.FirstOrDefault(x => x.Link == linkToFile); - if (citation == null) - { - citation = new Citation(); - result.SearchResult.Results.Add(citation); - } - } - else if (result.Mode == SearchMode.AskMode) - { - // If the file is already in the list of citations, only add the partition - citation = result.AskResult.RelevantSources.FirstOrDefault(x => x.Link == linkToFile); - if (citation == null) - { - citation = new Citation(); - result.AskResult.RelevantSources.Add(citation); - } - } - else + SearchMode.SearchMode => result.SearchResult.Results.FirstOrDefault(x => x.Link == linkToFile), + SearchMode.AskMode => result.AskResult.RelevantSources.FirstOrDefault(x => x.Link == linkToFile), + _ => throw new ArgumentOutOfRangeException(nameof(result.Mode)) + }; + + if (citation == null) { - throw new ArgumentOutOfRangeException(nameof(result.Mode)); + citation = new Citation(); + result.AddSource(citation); } citation.Index = index; diff --git a/service/Core/Search/SearchClientResult.cs b/service/Core/Search/SearchClientResult.cs index c509809a7..66ff58f15 100644 --- a/service/Core/Search/SearchClientResult.cs +++ b/service/Core/Search/SearchClientResult.cs @@ -23,10 +23,16 @@ internal class SearchClientResult public SearchState State { get; set; } public int RecordCount { get; set; } - // Use by in Search and Ask mode - public MemoryAnswer AskResult { get; private init; } = new(); + // Use by Search and Ask mode public int MaxRecordCount { get; private init; } + public MemoryAnswer AskResult { get; private init; } = new(); + public MemoryAnswer NoFactsResult { get; private init; } = new(); + public MemoryAnswer NoQuestionResult { get; private init; } = new(); + public MemoryAnswer UnsafeAnswerResult { get; private init; } = new(); + public MemoryAnswer InsufficientTokensResult { get; private init; } = new(); + public MemoryAnswer ErrorResult { get; private init; } = new(); + // Use by Ask mode public SearchResult SearchResult { get; private init; } = new(); public StringBuilder Facts { get; } = new(); @@ -37,7 +43,9 @@ internal class SearchClientResult /// /// Create new instance in Ask mode /// - public static SearchClientResult AskResultInstance(string question, string emptyAnswer, int maxGroundingFacts, int tokensAvailable) + public static SearchClientResult AskResultInstance( + string question, string emptyAnswer, string moderatedAnswer, + int maxGroundingFacts, int tokensAvailable) { return new SearchClientResult { @@ -46,14 +54,64 @@ public static SearchClientResult AskResultInstance(string question, string empty MaxRecordCount = maxGroundingFacts, AskResult = new MemoryAnswer { + StreamState = StreamStates.Append, + Question = question, + NoResult = false + }, + NoFactsResult = new MemoryAnswer + { + StreamState = StreamStates.Reset, + Question = question, + NoResult = true, + NoResultReason = "No relevant memories available", + Result = emptyAnswer + }, + NoQuestionResult = new MemoryAnswer + { + StreamState = StreamStates.Reset, Question = question, NoResult = true, NoResultReason = "No question provided", - Result = emptyAnswer, + Result = emptyAnswer + }, + InsufficientTokensResult = new MemoryAnswer + { + StreamState = StreamStates.Reset, + Question = question, + NoResult = true, + NoResultReason = "Unable to use memory, max tokens reached", + Result = emptyAnswer + }, + UnsafeAnswerResult = new MemoryAnswer + { + StreamState = StreamStates.Reset, + Question = question, + NoResult = true, + NoResultReason = "Content moderation", + Result = moderatedAnswer + }, + ErrorResult = new MemoryAnswer + { + StreamState = StreamStates.Error, + Question = question, + NoResult = true, + NoResultReason = "An error occurred" } }; } + /// + /// Add source to all the collections + /// + public void AddSource(Citation citation) + { + this.SearchResult.Results?.Add(citation); + this.AskResult.RelevantSources?.Add(citation); + this.InsufficientTokensResult.RelevantSources?.Add(citation); + this.UnsafeAnswerResult.RelevantSources?.Add(citation); + this.ErrorResult.RelevantSources?.Add(citation); + } + /// /// Create new instance in Search mode /// diff --git a/service/Service.AspNetCore/WebAPIEndpoints.cs b/service/Service.AspNetCore/WebAPIEndpoints.cs index cef2388b8..8e63c1aa3 100644 --- a/service/Service.AspNetCore/WebAPIEndpoints.cs +++ b/service/Service.AspNetCore/WebAPIEndpoints.cs @@ -3,6 +3,8 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; @@ -15,6 +17,7 @@ using Microsoft.KernelMemory.Configuration; using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.DocumentStorage; +using Microsoft.KernelMemory.HTTP; using Microsoft.KernelMemory.Service.AspNetCore.Models; namespace Microsoft.KernelMemory.Service.AspNetCore; @@ -31,11 +34,10 @@ public static IEndpointRouteBuilder AddKernelMemoryEndpoints( builder.AddGetIndexesEndpoint(apiPrefix).AddFilters(filters); builder.AddDeleteIndexesEndpoint(apiPrefix).AddFilters(filters); builder.AddDeleteDocumentsEndpoint(apiPrefix).AddFilters(filters); - builder.AddAskEndpoint(apiPrefix).AddFilters(filters); + builder.AddAskEndpoint(apiPrefix, kmConfig?.Service.SendSSEDoneMessage ?? true).AddFilters(filters); builder.AddSearchEndpoint(apiPrefix).AddFilters(filters); builder.AddUploadStatusEndpoint(apiPrefix).AddFilters(filters); builder.AddGetDownloadEndpoint(apiPrefix).AddFilters(filters); - return builder; } @@ -212,13 +214,14 @@ await service.DeleteDocumentAsync(documentId: documentId, index: index, cancella } public static RouteHandlerBuilder AddAskEndpoint( - this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null) + this IEndpointRouteBuilder builder, string apiPrefix = "/", bool sseSendDoneMessage = true, IEndpointFilter[]? filters = null) { RouteGroupBuilder group = builder.MapGroup(apiPrefix); // Ask endpoint var route = group.MapPost(Constants.HttpAskEndpoint, - async Task ( + async Task ( + HttpContext httpContext, MemoryQuery query, IKernelMemory service, ILogger log, @@ -228,20 +231,75 @@ async Task ( // Allow internal classes to access custom arguments via IContextProvider contextProvider.InitContextArgs(query.ContextArguments); - log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance); - MemoryAnswer answer = await service.AskAsync( - question: query.Question, - index: query.Index, - filters: query.Filters, - minRelevance: query.MinRelevance, - context: contextProvider.GetContext(), - cancellationToken: cancellationToken) - .ConfigureAwait(false); - return Results.Ok(answer); + log.LogTrace("New ask request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance); + + IAsyncEnumerable answerStream = service.AskStreamingAsync( + question: query.Question, + index: query.Index, + filters: query.Filters, + minRelevance: query.MinRelevance, + options: new SearchOptions { Stream = query.Stream }, + context: contextProvider.GetContext(), + cancellationToken: cancellationToken); + + httpContext.Response.StatusCode = StatusCodes.Status200OK; + + try + { + if (query.Stream) + { + httpContext.Response.ContentType = "text/event-stream; charset=utf-8"; + await foreach (var answer in answerStream.ConfigureAwait(false)) + { + string json = answer.ToJson(true); + await httpContext.Response.WriteAsync($"{SSE.DataPrefix}{json}\n\n", cancellationToken).ConfigureAwait(false); + await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); + } + } + else + { + httpContext.Response.ContentType = "application/json; charset=utf-8"; + MemoryAnswer answer = await answerStream.FirstAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + string json = answer.ToJson(false); + await httpContext.Response.WriteAsync(json, cancellationToken).ConfigureAwait(false); + } + } + catch (Exception e) + { + log.LogError(e, "An error occurred while preparing the response"); + + // Attempt to set the status code, in case the output hasn't started yet + httpContext.Response.StatusCode = StatusCodes.Status503ServiceUnavailable; + + var json = query.Stream + ? JsonSerializer.Serialize(new MemoryAnswer + { + StreamState = StreamStates.Error, + Question = query.Question, + NoResult = true, + NoResultReason = $"Error: {e.Message} [{e.GetType().FullName}]" + }) + : JsonSerializer.Serialize(new ProblemDetails + { + Status = StatusCodes.Status503ServiceUnavailable, + Title = "Service Unavailable", + Detail = $"{e.Message} [{e.GetType().FullName}]" + }); + + await httpContext.Response.WriteAsync(query.Stream ? $"{SSE.DataPrefix}{json}\n\n" : json, cancellationToken).ConfigureAwait(false); + } + + if (query.Stream && sseSendDoneMessage) + { + await httpContext.Response.WriteAsync($"{SSE.DoneMessage}\n\n", cancellationToken: cancellationToken).ConfigureAwait(false); + } + + await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); }) .Produces(StatusCodes.Status200OK) .Produces(StatusCodes.Status401Unauthorized) - .Produces(StatusCodes.Status403Forbidden); + .Produces(StatusCodes.Status403Forbidden) + .Produces(StatusCodes.Status503ServiceUnavailable); return route; } diff --git a/service/Service/appsettings.json b/service/Service/appsettings.json index 506f52c2d..db88ad35c 100644 --- a/service/Service/appsettings.json +++ b/service/Service/appsettings.json @@ -48,6 +48,8 @@ // If not set the solution defaults to 30,000,000 bytes (~28.6 MB) // Note: this applies only to KM HTTP service. "MaxUploadSizeMb": null, + // Whether to send a [DONE] message at the end of SSE streams. + "SendSSEDoneMessage": true, // Whether to run the asynchronous pipeline handlers // Use these booleans to deploy the web service and the handlers on same/different VMs "RunHandlers": true, diff --git a/service/tests/Abstractions.UnitTests/Http/SSETest.cs b/service/tests/Abstractions.UnitTests/Http/SSETest.cs new file mode 100644 index 000000000..ffa463737 --- /dev/null +++ b/service/tests/Abstractions.UnitTests/Http/SSETest.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.KernelMemory; +using Microsoft.KernelMemory.HTTP; + +namespace Microsoft.KM.Abstractions.UnitTests.Http; + +public class SSETest +{ + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" ")] + [InlineData(" \n")] + public void ItParsesEmptyStrings(string input) + { + Assert.Null(SSE.ParseMessage(input)); + } + + [Fact] + public void ItParsesSingleLineMessage() + { + // Arrange + var message = """ + data: { "question": "q" } + """; + + // Act + var x = SSE.ParseMessage(message); + + // Assert + Assert.NotNull(x); + Assert.Equal("q", x.Question); + } + + [Fact] + public void ItParsesSingleLineMessageWithSeparator() + { + // Arrange + var message = """ + data: { "question": "q" } + + + """; + + // Act + var x = SSE.ParseMessage(message); + + // Assert + Assert.NotNull(x); + Assert.Equal("q", x.Question); + } + + [Fact] + public void ItParsesMultiLineMessage() + { + // Arrange + var message = """ + data: { "question": "q" + data: , "noResultReason": "abc" + data: } + """; + + // Act + var x = SSE.ParseMessage(message); + + // Assert + Assert.NotNull(x); + Assert.Equal("q", x.Question); + Assert.Equal("abc", x.NoResultReason); + } + + [Theory] + [InlineData("data: [DONE]")] + [InlineData("data: [DONE]\n")] + [InlineData("data: [DONE]\n\n")] + public async Task ItParsesEmptyStreams(string input) + { + // Arrange + using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input)); + + // Act + var result = SSE.ParseStreamAsync(stream); + + // Assert + var messages = new List(); + await foreach (var message in result) + { + messages.Add(message); + } + + Assert.Equal(0, messages.Count); + } + + [Theory] + [InlineData("data: { \"question\": \"qq\" }")] + [InlineData("data: { \"question\": \"qq\" }\n")] + [InlineData("data: { \"question\": \"qq\" }\n\n")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]\n")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]\n\n")] + public async Task ItParsesStreamsWithASingleMessage(string input) + { + // Arrange + using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input)); + + // Act + var result = SSE.ParseStreamAsync(stream); + + // Assert + var messages = new List(); + await foreach (var message in result) + { + messages.Add(message); + } + + Assert.Equal(1, messages.Count); + Assert.NotNull(messages[0]); + Assert.Equal("qq", messages[0].Question); + } + + [Theory] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\n")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]\n")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]\n\n")] + public async Task ItParsesStreamsWithMultipleMessage(string input) + { + // Arrange + using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input)); + + // Act + var result = SSE.ParseStreamAsync(stream); + + // Assert + var messages = new List(); + await foreach (var message in result) + { + messages.Add(message); + } + + Assert.Equal(2, messages.Count); + Assert.NotNull(messages[0]); + Assert.Equal("qq", messages[0].Question); + Assert.NotNull(messages[1]); + Assert.Equal("kk", messages[1].Question); + } +}