Skip to content

Commit

Permalink
Implement Response Streaming (#726)
Browse files Browse the repository at this point in the history
## Motivation and Context (Why the change? What's the scenario?)

Add option to stream Ask result tokens without waiting for the full
answer to be ready.

## High level description (Approach, Design)

- New `stream` boolean option for the `Ask` API, false by default. When
true, answer tokens are streamed as soon as they are generated by LLMs.
- New `MemoryAnswer.StreamState` enum property: `Error`, `Reset`,
`Append`, `Last`.
- If moderation is enabled, the content is validated at the end. In case
of moderation failure, the service returns an answer with `StreamState`
= `Reset` and the new content to show to the end user.
- Streaming uses SSE message format. 
- By default, SSE streams end with a `[DONE]` token. This can be
disabled via KM settings.
- SSE payload is optimized, returning `RelevantSources` only in the
first SSE message.

---------

Co-authored-by: Carlo <carlo.dechellis@mobilesoft.it>
Co-authored-by: Devis Lucato <dluc@users.noreply.github.com>
Co-authored-by: Devis Lucato <devis@microsoft.com>
  • Loading branch information
4 people authored Dec 1, 2024
1 parent 53db61a commit 77fd7be
Show file tree
Hide file tree
Showing 27 changed files with 841 additions and 153 deletions.
1 change: 1 addition & 0 deletions .github/_typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ extend-exclude = [
"encoder.json",
"appsettings.development.json",
"appsettings.Development.json",
"appsettings.*.json.*",
"AzureAISearchFilteringTest.cs",
"KernelMemory.sln.DotSettings"
]
Expand Down
1 change: 1 addition & 0 deletions KernelMemory.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=SHA/@EntryIndexedValue">SHA</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=SK/@EntryIndexedValue">SK</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=SKHTTP/@EntryIndexedValue">SKHTTP</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=SSE/@EntryIndexedValue">SSE</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=SSL/@EntryIndexedValue">SSL</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=TTL/@EntryIndexedValue">TTL</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=UI/@EntryIndexedValue">UI</s:String>
Expand Down
1 change: 1 addition & 0 deletions clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ public async Task<string> AskAsync(
MemoryAnswer answer = await this._memory.AskAsync(
question: question,
index: index ?? this._defaultIndex,
options: new SearchOptions { Stream = false },
filter: TagsToMemoryFilter(tags ?? this._defaultRetrievalTags),
minRelevance: minRelevance,
cancellationToken: cancellationToken).ConfigureAwait(false);
Expand Down
28 changes: 22 additions & 6 deletions clients/dotnet/WebClient/MemoryWebClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.KernelMemory.Context;
using Microsoft.KernelMemory.HTTP;
using Microsoft.KernelMemory.Internals;

namespace Microsoft.KernelMemory;
Expand Down Expand Up @@ -337,28 +339,30 @@ public async Task<SearchResult> SearchAsync(
}

/// <inheritdoc />
public async Task<MemoryAnswer> AskAsync(
public async IAsyncEnumerable<MemoryAnswer> AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
SearchOptions? options = null,
IContext? context = null,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = []; }

filters ??= [];
filters.Add(filter);
}

var useStreaming = options?.Stream ?? false;
MemoryQuery request = new()
{
Index = index,
Question = question,
Filters = (filters is { Count: > 0 }) ? filters.ToList() : [],
MinRelevance = minRelevance,
Stream = useStreaming,
ContextArguments = (context?.Arguments ?? new Dictionary<string, object?>()).ToDictionary(),
};
using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json");
Expand All @@ -367,8 +371,20 @@ public async Task<MemoryAnswer> AskAsync(
HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();

var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
return JsonSerializer.Deserialize<MemoryAnswer>(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer();
if (useStreaming)
{
Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
IAsyncEnumerable<MemoryAnswer> answers = SSE.ParseStreamAsync<MemoryAnswer>(stream, cancellationToken);
await foreach (MemoryAnswer answer in answers.ConfigureAwait(false))
{
yield return answer;
}
}
else
{
var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
yield return JsonSerializer.Deserialize<MemoryAnswer>(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer();
}
}

#region private
Expand Down
46 changes: 35 additions & 11 deletions examples/001-dotnet-WebClient/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* without extracting memories. */
internal static class Program
{
private static MemoryWebClient? s_memory;
private static MemoryWebClient s_memory = null!;
private static readonly List<string> s_toDelete = [];

// Change this to True and configure Azure Document Intelligence to test OCR and support for images
Expand Down Expand Up @@ -55,8 +55,8 @@ public static async Task Main()
// === RETRIEVAL =========
// =======================

await AskSimpleQuestion();
await AskSimpleQuestionAndShowSources();
await AskSimpleQuestionStreamingTheAnswer();
await AskSimpleQuestionStreamingAndShowSources();
await AskQuestionAboutImageContent();
await AskQuestionUsingFilter();
await AskQuestionsFilteringByUser();
Expand Down Expand Up @@ -249,16 +249,25 @@ private static async Task StoreJson()
// =======================

// Question without filters
private static async Task AskSimpleQuestion()
private static async Task AskSimpleQuestionStreamingTheAnswer()
{
var question = "What's E = m*c^2?";
Console.WriteLine($"Question: {question}");
Console.WriteLine($"Expected result: formula explanation using the information loaded");

var answer = await s_memory.AskAsync(question, minRelevance: 0.6);
Console.WriteLine($"\nAnswer: {answer.Result}");
Console.Write("\nAnswer: ");
var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.6,
options: new SearchOptions { Stream = true });

Console.WriteLine("\n====================================\n");
await foreach (var answer in answerStream)
{
// Print token received by LLM
Console.Write(answer.Result);
// Slow down the stream for demo purpose
await Task.Delay(25);
}

Console.WriteLine("\n\n====================================\n");

/* OUTPUT
Expand All @@ -275,17 +284,32 @@ due to the speed of light being a very large number when squared. This concept i
}

// Another question without filters and show sources
private static async Task AskSimpleQuestionAndShowSources()
private static async Task AskSimpleQuestionStreamingAndShowSources()
{
var question = "What's Kernel Memory?";
Console.WriteLine($"Question: {question}");
Console.WriteLine($"Expected result: it should explain what KM project is (not generic kernel memory)");

var answer = await s_memory.AskAsync(question, minRelevance: 0.5);
Console.WriteLine($"\nAnswer: {answer.Result}\n\n Sources:\n");
Console.Write("\nAnswer: ");
var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.5,
options: new SearchOptions { Stream = true });

List<Citation> sources = [];
await foreach (var answer in answerStream)
{
// Print token received by LLM
Console.Write(answer.Result);

// Collect sources
sources.AddRange(answer.RelevantSources);

// Slow down the stream for demo purpose
await Task.Delay(5);
}

// Show sources / citations
foreach (var x in answer.RelevantSources)
Console.WriteLine("\n\nSources:\n");
foreach (var x in sources)
{
Console.WriteLine(x.SourceUrl != null
? $" - {x.SourceUrl} [{x.Partitions.First().LastUpdate:D}]"
Expand Down
46 changes: 35 additions & 11 deletions examples/002-dotnet-Serverless/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#pragma warning disable CS8602 // by design
public static class Program
{
private static MemoryServerless? s_memory;
private static MemoryServerless s_memory = null!;
private static readonly List<string> s_toDelete = [];

// Remember to configure Azure Document Intelligence to test OCR and support for images
Expand Down Expand Up @@ -107,8 +107,8 @@ public static async Task Main()
// === RETRIEVAL =========
// =======================

await AskSimpleQuestion();
await AskSimpleQuestionAndShowSources();
await AskSimpleQuestionStreamingTheAnswer();
await AskSimpleQuestionStreamingAndShowSources();
await AskQuestionAboutImageContent();
await AskQuestionUsingFilter();
await AskQuestionsFilteringByUser();
Expand Down Expand Up @@ -303,16 +303,25 @@ private static async Task StoreJson()
// =======================

// Question without filters
private static async Task AskSimpleQuestion()
private static async Task AskSimpleQuestionStreamingTheAnswer()
{
var question = "What's E = m*c^2?";
Console.WriteLine($"Question: {question}");
Console.WriteLine($"Expected result: formula explanation using the information loaded");

var answer = await s_memory.AskAsync(question, minRelevance: 0.6);
Console.WriteLine($"\nAnswer: {answer.Result}");
Console.Write("\nAnswer: ");
var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.6,
options: new SearchOptions { Stream = true });

Console.WriteLine("\n====================================\n");
await foreach (var answer in answerStream)
{
// Print token received by LLM
Console.Write(answer.Result);
// Slow down the stream for demo purpose
await Task.Delay(25);
}

Console.WriteLine("\n\n====================================\n");

/* OUTPUT
Expand All @@ -329,17 +338,32 @@ due to the speed of light being a very large number when squared. This concept i
}

// Another question without filters and show sources
private static async Task AskSimpleQuestionAndShowSources()
private static async Task AskSimpleQuestionStreamingAndShowSources()
{
var question = "What's Kernel Memory?";
Console.WriteLine($"Question: {question}");
Console.WriteLine($"Expected result: it should explain what KM project is (not generic kernel memory)");

var answer = await s_memory.AskAsync(question, minRelevance: 0.5);
Console.WriteLine($"\nAnswer: {answer.Result}\n\n Sources:\n");
Console.Write("\nAnswer: ");
var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.5,
options: new SearchOptions { Stream = true });

List<Citation> sources = [];
await foreach (var answer in answerStream)
{
// Print token received by LLM
Console.Write(answer.Result);

// Collect sources
sources.AddRange(answer.RelevantSources);

// Slow down the stream for demo purpose
await Task.Delay(5);
}

// Show sources / citations
foreach (var x in answer.RelevantSources)
Console.WriteLine("\n\nSources:\n");
foreach (var x in sources)
{
Console.WriteLine(x.SourceUrl != null
? $" - {x.SourceUrl} [{x.Partitions.First().LastUpdate:D}]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.MemoryDb.AzureAISearch;
using Microsoft.KernelMemory.MemoryStorage;
using AISearchOptions = Azure.Search.Documents.SearchOptions;

namespace Microsoft.AzureAISearch.TestApplication;

Expand Down Expand Up @@ -246,7 +247,7 @@ private static async Task<IList<MemoryRecord>> SearchByFieldValueAsync(

fieldValue1 = fieldValue1.Replace("'", "''", StringComparison.Ordinal);
fieldValue2 = fieldValue2.Replace("'", "''", StringComparison.Ordinal);
SearchOptions options = new()
AISearchOptions options = new()
{
Filter = fieldIsCollection
? $"{fieldName}/any(s: s eq '{fieldValue1}') and {fieldName}/any(s: s eq '{fieldValue2}')"
Expand Down
11 changes: 6 additions & 5 deletions extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.DocumentStorage;
using Microsoft.KernelMemory.MemoryStorage;
using AISearchOptions = Azure.Search.Documents.SearchOptions;

namespace Microsoft.KernelMemory.MemoryDb.AzureAISearch;

Expand Down Expand Up @@ -184,7 +185,7 @@ await client.IndexDocumentsAsync(
Exhaustive = false
};

SearchOptions options = new()
AISearchOptions options = new()
{
VectorSearch = new()
{
Expand Down Expand Up @@ -246,7 +247,7 @@ public async IAsyncEnumerable<MemoryRecord> GetListAsync(
{
var client = this.GetSearchClient(index);

SearchOptions options = this.PrepareSearchOptions(null, withEmbeddings, filters, limit);
AISearchOptions options = this.PrepareSearchOptions(null, withEmbeddings, filters, limit);

Response<SearchResults<AzureAISearchMemoryRecord>>? searchResult = null;
try
Expand Down Expand Up @@ -596,13 +597,13 @@ at Azure.Search.Documents.SearchClient.SearchInternal[T](SearchOptions options,
return indexSchema;
}

private SearchOptions PrepareSearchOptions(
SearchOptions? options,
private AISearchOptions PrepareSearchOptions(
AISearchOptions? options,
bool withEmbeddings,
ICollection<MemoryFilter>? filters = null,
int limit = 1)
{
options ??= new SearchOptions();
options ??= new AISearchOptions();

// Define which fields to fetch
options.Select.Add(AzureAISearchMemoryRecord.IdField);
Expand Down
3 changes: 2 additions & 1 deletion service/Abstractions/Abstractions.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<TargetFramework>net8.0</TargetFramework>
<AssemblyName>Microsoft.KernelMemory.Abstractions</AssemblyName>
<RootNamespace>Microsoft.KernelMemory</RootNamespace>
<NoWarn>$(NoWarn);KMEXP00;CA1711;CA1724;CS1574;SKEXP0001;</NoWarn>
<NoWarn>$(NoWarn);KMEXP00;SKEXP0001;CA1711;CA1724;CS1574;CA1812;</NoWarn>
</PropertyGroup>

<ItemGroup>
Expand All @@ -13,6 +13,7 @@
<PackageReference Include="Microsoft.Extensions.Hosting" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" />
<PackageReference Include="System.Linq.Async" />
<PackageReference Include="System.Memory.Data" />
<PackageReference Include="System.Numerics.Tensors" />
</ItemGroup>
Expand Down
Loading

0 comments on commit 77fd7be

Please sign in to comment.