From fc6883998a9719f4a0a4ff06c4ade589200a7560 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 28 Mar 2025 11:29:45 +0000 Subject: [PATCH 01/46] Add memory projects with common abstractions and mem zero implementation --- dotnet/SK-dotnet.sln | 21 ++ .../Memory.Abstractions.csproj | 25 +++ .../Memory.Abstractions/TextMemoryDocument.cs | 39 ++++ .../Memory.Abstractions/TextMemoryStore.cs | 46 ++++ .../Memory.Abstractions/ThreadExtension.cs | 75 +++++++ .../ThreadExtensionsManager.cs | 88 ++++++++ .../Memory/Memory/MemZeroMemoryComponent.cs | 212 ++++++++++++++++++ dotnet/src/Memory/Memory/Memory.csproj | 26 +++ 8 files changed, 532 insertions(+) create mode 100644 dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj create mode 100644 dotnet/src/Memory/Memory.Abstractions/TextMemoryDocument.cs create mode 100644 dotnet/src/Memory/Memory.Abstractions/TextMemoryStore.cs create mode 100644 dotnet/src/Memory/Memory.Abstractions/ThreadExtension.cs create mode 100644 dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs create mode 100644 dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs create mode 100644 dotnet/src/Memory/Memory/Memory.csproj diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 16fa0c43cd0d..16912b79f55e 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -520,6 +520,12 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MCPClient", "samples\Demos\ {12C7E0C7-A7DF-3BC3-0D4B-1A706BCE6981} = {12C7E0C7-A7DF-3BC3-0D4B-1A706BCE6981} EndProjectSection EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "memory", "memory", "{4B850B93-46D6-4F25-9DB1-90D1E6E4AB70}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Memory.Abstractions", "src\Memory\Memory.Abstractions\Memory.Abstractions.csproj", "{F5124057-1DA1-4799-9357-D9A635047678}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Memory", "src\Memory\Memory\Memory.csproj", "{A0538079-AB6F-4C7D-9138-A15258583F80}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -1427,6 +1433,18 @@ Global {B06770D5-2F3E-4271-9F6B-3AA9E716176F}.Publish|Any CPU.Build.0 = Release|Any CPU {B06770D5-2F3E-4271-9F6B-3AA9E716176F}.Release|Any CPU.ActiveCfg = Release|Any CPU {B06770D5-2F3E-4271-9F6B-3AA9E716176F}.Release|Any CPU.Build.0 = Release|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Publish|Any CPU.Build.0 = Release|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Release|Any CPU.Build.0 = Release|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Publish|Any CPU.Build.0 = Release|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1621,6 +1639,9 @@ Global {879545ED-D429-49B1-96F1-2EC55FFED31D} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {12C7E0C7-A7DF-3BC3-0D4B-1A706BCE6981} = {879545ED-D429-49B1-96F1-2EC55FFED31D} {B06770D5-2F3E-4271-9F6B-3AA9E716176F} = {879545ED-D429-49B1-96F1-2EC55FFED31D} + {4B850B93-46D6-4F25-9DB1-90D1E6E4AB70} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} + {F5124057-1DA1-4799-9357-D9A635047678} = {4B850B93-46D6-4F25-9DB1-90D1E6E4AB70} + {A0538079-AB6F-4C7D-9138-A15258583F80} = {4B850B93-46D6-4F25-9DB1-90D1E6E4AB70} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj b/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj new file mode 100644 index 000000000000..cefd99323ad0 --- /dev/null +++ b/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj @@ -0,0 +1,25 @@ + + + + Microsoft.SemanticKernel.Memory.Abstractions + Microsoft.SemanticKernel.Memory + net8.0;netstandard2.0 + false + + + + + + + Semantic Kernel - Memory Abstractions + Semantic Kernel interfaces and abstractions for capturing, storing and retrieving memories. + + + + rc + + + + + + diff --git a/dotnet/src/Memory/Memory.Abstractions/TextMemoryDocument.cs b/dotnet/src/Memory/Memory.Abstractions/TextMemoryDocument.cs new file mode 100644 index 000000000000..5fa9408e47df --- /dev/null +++ b/dotnet/src/Memory/Memory.Abstractions/TextMemoryDocument.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Represents a storage record for a single text based memory. +/// +public sealed class TextMemoryDocument +{ + /// + /// Gets or sets a unique identifier for the memory document. + /// + public Guid Key { get; set; } + + /// + /// Gets or sets the namespace for the memory document. + /// + /// + /// A namespace is a logical grouping of memory documents, e.g. may include a user id to scope the memory to a specific user. + /// + public string Namespace { get; set; } = string.Empty; + + /// + /// Gets or sets an optional name for the memory document. + /// + public string Name { get; set; } = string.Empty; + + /// + /// Gets or sets an optional category for the memory document. + /// + public string Category { get; set; } = string.Empty; + + /// + /// Gets or sets the actual memory content as text. + /// + public string MemoryText { get; set; } = string.Empty; +} diff --git a/dotnet/src/Memory/Memory.Abstractions/TextMemoryStore.cs b/dotnet/src/Memory/Memory.Abstractions/TextMemoryStore.cs new file mode 100644 index 000000000000..ded74e1c0398 --- /dev/null +++ b/dotnet/src/Memory/Memory.Abstractions/TextMemoryStore.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Abstract base class for storing and retrieving text based memories. +/// +public abstract class TextMemoryStore +{ + /// + /// Retrieves a memory asynchronously by its document name. + /// + /// The name of the document to retrieve the memory for. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the memory text if found, otherwise null. + public abstract Task GetMemoryAsync(string documentName, CancellationToken cancellationToken = default); + + /// + /// Searches for memories that are similar to the given text asynchronously. + /// + /// The text to search for similar memories. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous enumerable of similar memory texts. + public abstract IAsyncEnumerable SimilaritySearch(string query, CancellationToken cancellationToken = default); + + /// + /// Saves a memory asynchronously with the specified document name. + /// + /// The name of the document to save the memory to. + /// The memory text to save. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous save operation. + public abstract Task SaveMemoryAsync(string documentName, string memoryText, CancellationToken cancellationToken = default); + + /// + /// Saves a memory asynchronously with no document name. + /// + /// The memory text to save. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous save operation. + public abstract Task SaveMemoryAsync(string memoryText, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Memory/Memory.Abstractions/ThreadExtension.cs b/dotnet/src/Memory/Memory.Abstractions/ThreadExtension.cs new file mode 100644 index 000000000000..b4584cbafac2 --- /dev/null +++ b/dotnet/src/Memory/Memory.Abstractions/ThreadExtension.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Base class for all thread extensions. +/// +public abstract class ThreadExtension +{ + /// + /// Called just after a new thread is created. + /// + /// + /// Implementers can use this method to do any operations required at the creation of a new thread. + /// For exmple, checking long term storage for any data that is relevant to the current session based on the input text. + /// + /// The ID of the new thread. + /// The input text, typically a user ask. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been loaded. + public virtual Task OnThreadCreateAsync(string threadId, string? inputText = default, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + /// + /// This method is called when a new message has been contributed to the chat by any participant. + /// + /// + /// Inheritors can use this method to update their context based on the new message. + /// + /// The new message. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been updated. + public virtual Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + /// + /// Called just before a thread is deleted. + /// + /// + /// Implementers can use this method to do any operations required before a thread is deleted. + /// For exmple, storing the context to long term storage. + /// + /// The id of the thread that will be deleted. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been saved. + public virtual Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + /// + /// Called just before the AI is invoked + /// Implementers can load any additional context required at this time, + /// but they should also return any context that should be passed to the AI. + /// + /// The most recent message that the AI is being invoked with. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been rendered and returned. + public abstract Task OnAIInvocationAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default); + + /// + /// Register plugins required by this extension component on the provided . + /// + /// The kernel to register the plugins on. + public virtual void RegisterPlugins(Kernel kernel) + { + } +} diff --git a/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs b/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs new file mode 100644 index 000000000000..74e99a9615e5 --- /dev/null +++ b/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// A container class for thread extension components that manages their lifecycle and interactions. +/// +public class ThreadExtensionsManager +{ + private readonly List _threadExtensions = new(); + + /// + /// Gets the list of registered thread extensions. + /// + public virtual IReadOnlyList ThreadExtensions => this._threadExtensions; + + /// + /// Registers a new thread extensions. + /// + /// The thread extensions to register. + public virtual void RegisterThreadExtension(ThreadExtension threadExtension) + { + this._threadExtensions.Add(threadExtension); + } + + /// + /// Called when a new thread is created. + /// + /// The ID of the new thread. + /// The input text, typically a user ask. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public virtual async Task OnThreadCreateAsync(string threadId, string inputText, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnThreadCreateAsync(threadId, inputText, cancellationToken)).ToList()).ConfigureAwait(false); + } + + /// + /// Called just before a thread is deleted. + /// + /// The id of the thread that will be deleted. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public virtual async Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnThreadDeleteAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + } + + /// + /// This method is called when a new message has been contributed to the chat by any participant. + /// + /// The new message. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnNewMessageAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); + } + + /// + /// Called just before the AI is invoked + /// + /// The most recent message that the AI is being invoked with. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation, containing the combined context from all thread extensions. + public virtual async Task OnAIInvocationAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + var subContexts = await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnAIInvocationAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); + return string.Join("\n", subContexts); + } + + /// + /// Registers plugins required by all thread extensions contained by this manager on the provided . + /// + /// The kernel to register the plugins on. + public virtual void RegisterPlugins(Kernel kernel) + { + foreach (var threadExtension in this.ThreadExtensions) + { + threadExtension.RegisterPlugins(kernel); + } + } +} diff --git a/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs b/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs new file mode 100644 index 000000000000..4b6723fb1a82 --- /dev/null +++ b/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ComponentModel; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// A component that listenes to messages added to the conversation thread, and automatically captures +/// information about the user. It is also able to retrieve this information and add it to the AI invocation context. +/// +public class MemZeroMemoryComponent : ThreadExtension +{ + private static readonly Uri s_searchUri = new("/search", UriKind.Relative); + private static readonly Uri s_createMemoryUri = new("/memories", UriKind.Relative); + + private readonly string? _agentId; + private string? _threadId; + private readonly string? _userId; + private readonly bool _scopeToThread; + private readonly HttpClient _httpClient; + + private bool _contextLoaded = false; + private string _userInformation = string.Empty; + + /// + /// Initializes a new instance of the class. + /// + /// The HTTP client used for making requests. + /// The ID of the agent. + /// The ID of the thread. + /// The ID of the user. + /// Indicates whether the scope is limited to the thread. + public MemZeroMemoryComponent(HttpClient httpClient, string? agentId = default, string? threadId = default, string? userId = default, bool scopeToThread = false) + { + this._agentId = agentId; + this._threadId = threadId; + this._userId = userId; + this._scopeToThread = scopeToThread; + this._httpClient = httpClient; + } + + /// + public override async Task OnThreadCreateAsync(string threadId, string? inputText = default, CancellationToken cancellationToken = default) + { + if (!this._contextLoaded) + { + this._threadId ??= threadId; + + var searchRequest = new SearchRequest + { + AgentId = this._agentId, + RunId = this._scopeToThread ? this._threadId : null, + UserId = this._userId, + Query = inputText ?? string.Empty + }; + var responseItems = await this.SearchAsync(searchRequest).ConfigureAwait(false); + this._userInformation = string.Join("\n", responseItems); + this._contextLoaded = true; + } + } + + /// + public override async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + if (newMessage.Role == AuthorRole.User) + { + await this.CreateMemoryAsync( + new CreateMemoryRequest() + { + AgentId = this._agentId, + RunId = this._scopeToThread ? this._threadId : null, + UserId = this._userId, + Messages = new[] + { + new CreateMemoryMemory + { + Content = newMessage.Content ?? string.Empty, + Role = newMessage.Role.Label + } + } + }).ConfigureAwait(false); + } + } + + /// + public override Task OnAIInvocationAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + return Task.FromResult("The following list contains facts about the user:\n" + this._userInformation); + } + + /// + public override void RegisterPlugins(Kernel kernel) + { + base.RegisterPlugins(kernel); + kernel.Plugins.AddFromObject(this, "MemZeroMemory"); + } + + /// + /// Plugin method to clear user preferences stored in memory for the current agent/thread/user. + /// + /// A task that completes when the memory is cleared. + [KernelFunction] + [Description("Deletes any user preferences stored about the user.")] + public async Task ClearUserPreferencesAsync() + { + await this.ClearMemoryAsync().ConfigureAwait(false); + } + + private async Task CreateMemoryAsync(CreateMemoryRequest createMemoryRequest) + { + using var content = new StringContent(JsonSerializer.Serialize(createMemoryRequest), Encoding.UTF8, "application/json"); + var responseMessage = await this._httpClient.PostAsync(s_createMemoryUri, content).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + } + + private async Task SearchAsync(SearchRequest searchRequest) + { + using var content = new StringContent(JsonSerializer.Serialize(searchRequest), Encoding.UTF8, "application/json"); + var responseMessage = await this._httpClient.PostAsync(s_searchUri, content).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + var response = await responseMessage.Content.ReadAsStringAsync().ConfigureAwait(false); + var searchResponseItems = JsonSerializer.Deserialize(response); + return searchResponseItems?.Select(item => item.Memory).ToArray() ?? Array.Empty(); + } + + private async Task ClearMemoryAsync() + { + try + { + var querystringParams = new string?[3] { this._userId, this._agentId, this._scopeToThread ? this._threadId : null } + .Where(x => !string.IsNullOrWhiteSpace(x)) + .Select((param, index) => $"param{index}={param}"); + var queryString = string.Join("&", querystringParams); + var clearMemoryUrl = new Uri($"/memories?{queryString}", UriKind.Relative); + + var responseMessage = await this._httpClient.DeleteAsync(clearMemoryUrl).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + } + catch (Exception ex) + { + Console.WriteLine($"- MemZeroMemory - Error clearing memory: {ex.Message}"); + throw; + } + } + + private sealed class CreateMemoryRequest + { + [JsonPropertyName("agent_id")] + public string? AgentId { get; set; } + [JsonPropertyName("run_id")] + public string? RunId { get; set; } + [JsonPropertyName("user_id")] + public string? UserId { get; set; } + [JsonPropertyName("messages")] + public CreateMemoryMemory[] Messages { get; set; } = []; + } + + private sealed class CreateMemoryMemory + { + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; + [JsonPropertyName("role")] + public string Role { get; set; } = string.Empty; + } + + private sealed class SearchRequest + { + [JsonPropertyName("agent_id")] + public string? AgentId { get; set; } = null; + [JsonPropertyName("run_id")] + public string? RunId { get; set; } = null; + [JsonPropertyName("user_id")] + public string? UserId { get; set; } = null; + [JsonPropertyName("query")] + public string Query { get; set; } = string.Empty; + } + +#pragma warning disable CA1812 // Avoid uninstantiated internal classes + private sealed class SearchResponseItem + { + [JsonPropertyName("id")] + public string Id { get; set; } = string.Empty; + [JsonPropertyName("memory")] + public string Memory { get; set; } = string.Empty; + [JsonPropertyName("hash")] + public string Hash { get; set; } = string.Empty; + [JsonPropertyName("metadata")] + public object? Metadata { get; set; } + [JsonPropertyName("score")] + public double Score { get; set; } + [JsonPropertyName("created_at")] + public DateTime CreatedAt { get; set; } + [JsonPropertyName("updated_at")] + public DateTime? UpdatedAt { get; set; } + [JsonPropertyName("user_id")] + public string UserId { get; set; } = string.Empty; + [JsonPropertyName("agent_id")] + public string AgentId { get; set; } = string.Empty; + [JsonPropertyName("run_id")] + public string RunId { get; set; } = string.Empty; + } +#pragma warning restore CA1812 // Avoid uninstantiated internal classes +} diff --git a/dotnet/src/Memory/Memory/Memory.csproj b/dotnet/src/Memory/Memory/Memory.csproj new file mode 100644 index 000000000000..1b1ec608e218 --- /dev/null +++ b/dotnet/src/Memory/Memory/Memory.csproj @@ -0,0 +1,26 @@ + + + + Microsoft.SemanticKernel.Memory.Core + Microsoft.SemanticKernel.Memory + net8.0;netstandard2.0 + false + + + + + + + Semantic Kernel - Memory Core + Semantic Kernel implementations for capturing, storing and retrieving memories. + + + + rc + + + + + + + From f0d6dbf9ee64698d08c9008f5215f12e5653e88a Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 28 Mar 2025 11:48:20 +0000 Subject: [PATCH 02/46] Add user preferences component --- dotnet/src/Memory/Memory/Memory.csproj | 9 + .../Memory/Memory/OptionalTextMemoryStore.cs | 77 +++++++++ .../Memory/UserPreferencesMemoryComponent.cs | 157 ++++++++++++++++++ 3 files changed, 243 insertions(+) create mode 100644 dotnet/src/Memory/Memory/OptionalTextMemoryStore.cs create mode 100644 dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs diff --git a/dotnet/src/Memory/Memory/Memory.csproj b/dotnet/src/Memory/Memory/Memory.csproj index 1b1ec608e218..dfef3edf1395 100644 --- a/dotnet/src/Memory/Memory/Memory.csproj +++ b/dotnet/src/Memory/Memory/Memory.csproj @@ -19,6 +19,15 @@ rc + + + + + + + + + diff --git a/dotnet/src/Memory/Memory/OptionalTextMemoryStore.cs b/dotnet/src/Memory/Memory/OptionalTextMemoryStore.cs new file mode 100644 index 000000000000..af7797776681 --- /dev/null +++ b/dotnet/src/Memory/Memory/OptionalTextMemoryStore.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Memory; + +namespace Microsoft.SemanticKernel.Agents.Memory; + +/// +/// Helper class to get a from the DI container if a name is provided +/// and if the memory store is registered. +/// Class implements no-op methods if no name is provided or the store is not registered. +/// +internal sealed class OptionalTextMemoryStore : TextMemoryStore +{ + private readonly TextMemoryStore? _textMemoryStore; + + /// + /// Initializes a new instance of the class. + /// + /// The kernel to try and get the named store from. + /// The name of the store to get from the DI container. + public OptionalTextMemoryStore(Kernel kernel, string? storeName) + { + if (storeName is not null) + { + this._textMemoryStore = kernel.Services.GetKeyedService(storeName); + } + } + + /// + public override Task GetMemoryAsync(string documentName, CancellationToken cancellationToken = default) + { + if (this._textMemoryStore is not null) + { + return this._textMemoryStore.GetMemoryAsync(documentName, cancellationToken); + } + + return Task.FromResult(null); + } + + /// + public override Task SaveMemoryAsync(string documentName, string memoryText, CancellationToken cancellationToken = default) + { + if (this._textMemoryStore is not null) + { + return this._textMemoryStore.SaveMemoryAsync(documentName, memoryText, cancellationToken); + } + + return Task.CompletedTask; + } + + /// + public override Task SaveMemoryAsync(string memoryText, CancellationToken cancellationToken = default) + { + if (this._textMemoryStore is not null) + { + return this._textMemoryStore.SaveMemoryAsync(memoryText, cancellationToken); + } + + return Task.CompletedTask; + } + + /// + public override IAsyncEnumerable SimilaritySearch(string query, CancellationToken cancellationToken = default) + { + if (this._textMemoryStore is not null) + { + return this._textMemoryStore.SimilaritySearch(query, cancellationToken); + } + + return AsyncEnumerable.Empty(); + } +} diff --git a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs b/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs new file mode 100644 index 000000000000..c3e4543abf79 --- /dev/null +++ b/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.ComponentModel; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Memory; + +namespace Microsoft.SemanticKernel.Agents.Memory; + +/// +/// A memory component that can retrieve, maintain and store user preferences that +/// are learned from the user's interactions with the agent. +/// +public class UserPreferencesMemoryComponent : ThreadExtension +{ + private readonly Kernel _kernel; + private readonly TextMemoryStore _textMemoryStore; + private string _userPreferences = string.Empty; + private bool _contextLoaded = false; + + /// + /// Initializes a new instance of the class. + /// + /// A kernel to use for making chat completion calls. + /// The memory store to retrieve and save memories from and to. + public UserPreferencesMemoryComponent(Kernel kernel, TextMemoryStore textMemoryStore) + { + this._kernel = kernel; + this._textMemoryStore = textMemoryStore; + } + + /// + /// Initializes a new instance of the class. + /// + /// A kernel to use for making chat completion calls. + /// The service key that the for user preferences is registered under in DI. + public UserPreferencesMemoryComponent(Kernel kernel, string? userPreferencesStoreName = "UserPreferencesStore") + { + this._kernel = kernel; + this._textMemoryStore = new OptionalTextMemoryStore(kernel, userPreferencesStoreName); + } + + /// + /// Gets or sets the name of the document to use for storing user preferences. + /// + public string UserPreferencesDocumentName { get; init; } = "UserPreferences"; + + /// + /// Gets or sets the prompt template to use for extracting user preferences and merging them with existing preferences. + /// + public string MaintainencePromptTemplate { get; init; } = + """ + You are an expert in extracting facts about a user from text and combining these facts with existing facts to output a new list of facts. + Facts are short statements that each contain a single piece of information. + Facts should always be about the user and should always be in the present tense. + Facts should focus on the user's long term preferences and characteristics, not on their short term actions. + + Here are 5 few shot examples: + + EXAMPLES START + + Input text: My name is John. I love dogs and cats, but unfortunately I am allergic to cats. I'm not alergic to dogs though. I have a dog called Fido. + Input facts: User name is John. User is alergic to cats. + Output: User name is John. User loves dogs. User loves cats. User is alergic to cats. User is not alergic to dogs. User has a dog. User dog's name is Fido. + + Input text: My name is Mary. I like active holidays. I enjoy cycling and hiking. + Input facts: User name is Mary. User dislikes cycling. + Output: User name is Mary. User likes cycling. User likes hiking. User likes active holidays. + + Input text: How do I calculate the area of a circle? + Input facts: + Output: + + Input text: What is today's date? + Input facts: User name is Peter. + Output: User name is Peter. + + EXAMPLES END + + Return output for the following inputs like shown in the examples above: + + Input text: {{$inputText}} + Input facts: {{$existingPreferences}} + """; + + /// + public override async Task OnThreadCreateAsync(string threadId, string? inputText = default, CancellationToken cancellationToken = default) + { + if (!this._contextLoaded) + { + this._userPreferences = string.Empty; + + var memoryText = await this._textMemoryStore.GetMemoryAsync("UserPreferences", cancellationToken).ConfigureAwait(false); + if (memoryText is not null) + { + this._userPreferences = memoryText; + } + + this._contextLoaded = true; + } + } + + /// + public override async Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) + { + await this._textMemoryStore.SaveMemoryAsync("UserPreferences", this._userPreferences, cancellationToken).ConfigureAwait(false); + } + + /// + public override Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + if (newMessage.Role == AuthorRole.User && !string.IsNullOrWhiteSpace(newMessage.Content)) + { + // Don't wait for task to complete. Just run in the background. + var task = this.ExtractAndSaveMemoriesAsync(newMessage.Content, cancellationToken); + } + + return Task.CompletedTask; + } + + /// + public override Task OnAIInvocationAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + return Task.FromResult("The following list contains facts about the user:\n" + this._userPreferences); + } + + /// + public override void RegisterPlugins(Kernel kernel) + { + base.RegisterPlugins(kernel); + kernel.Plugins.AddFromObject(this, "UserPreferencesMemory"); + } + + /// + /// Plugin method to clear user preferences stored in memory. + /// + [KernelFunction] + [Description("Deletes any user preferences stored about the user.")] + public async Task ClearUserPreferencesAsync(CancellationToken cancellationToken = default) + { + this._userPreferences = string.Empty; + await this._textMemoryStore.SaveMemoryAsync("UserPreferences", this._userPreferences, cancellationToken).ConfigureAwait(false); + } + + private async Task ExtractAndSaveMemoriesAsync(string inputText, CancellationToken cancellationToken = default) + { + var result = await this._kernel.InvokePromptAsync( + this.MaintainencePromptTemplate, + new KernelArguments() { ["inputText"] = inputText, ["existingPreferences"] = this._userPreferences }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + this._userPreferences = result.ToString(); + + await this._textMemoryStore.SaveMemoryAsync("UserPreferences", this._userPreferences, cancellationToken).ConfigureAwait(false); + } +} From 6e926e835f677dbfa19f1ba8394185e9574891fd Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 28 Mar 2025 12:18:51 +0000 Subject: [PATCH 03/46] Update AgentThread and ChatCompletionAgent to support thread extensions --- dotnet/src/Agents/Abstractions/AgentThread.cs | 24 ++++++++ .../Abstractions/Agents.Abstractions.csproj | 1 + dotnet/src/Agents/Core/ChatCompletionAgent.cs | 10 +++- .../Memory.Abstractions/ThreadExtension.cs | 8 +-- .../ThreadExtensionsManager.cs | 27 +++++++-- .../Memory/Memory/MemZeroMemoryComponent.cs | 55 +++++++++++++------ .../Memory/UserPreferencesMemoryComponent.cs | 7 ++- 7 files changed, 100 insertions(+), 32 deletions(-) diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index 74477d556340..cd2c408d7d7c 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Agents; @@ -26,6 +28,11 @@ public abstract class AgentThread /// public virtual bool IsDeleted { get; protected set; } = false; + /// + /// Gets or sets the container for thread extension components that manages their lifecycle and interactions. + /// + public virtual ThreadExtensionsManager ThreadExtensionsManager { get; init; } = new ThreadExtensionsManager(); + /// /// Creates the thread and returns the thread id. /// @@ -45,6 +52,8 @@ protected internal virtual async Task CreateAsync(CancellationToken cancellation } this.Id = await this.CreateInternalAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + + await this.ThreadExtensionsManager.OnThreadCreatedAsync(this.Id!, cancellationToken).ConfigureAwait(false); } /// @@ -65,11 +74,24 @@ public virtual async Task DeleteAsync(CancellationToken cancellationToken = defa throw new InvalidOperationException("This thread cannot be deleted, since it has not been created."); } + await this.ThreadExtensionsManager.OnThreadDeleteAsync(this.Id!, cancellationToken).ConfigureAwait(false); + await this.DeleteInternalAsync(cancellationToken).ConfigureAwait(false); this.IsDeleted = true; } + /// + /// Called just before the AI is invoked + /// + /// The most recent messages that the AI is being invoked with. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation, containing the combined context from all thread extensions. + public virtual async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + { + return await this.ThreadExtensionsManager.OnAIInvocationAsync(newMessages, cancellationToken).ConfigureAwait(false); + } + /// /// This method is called when a new message has been contributed to the chat by any participant. /// @@ -92,6 +114,8 @@ internal virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Can await this.CreateAsync(cancellationToken).ConfigureAwait(false); } + await this.ThreadExtensionsManager.OnNewMessageAsync(newMessage, cancellationToken).ConfigureAwait(false); + await this.OnNewMessageInternalAsync(newMessage, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Agents/Abstractions/Agents.Abstractions.csproj b/dotnet/src/Agents/Abstractions/Agents.Abstractions.csproj index 2cc0d9799bc1..4e8c0c3884f8 100644 --- a/dotnet/src/Agents/Abstractions/Agents.Abstractions.csproj +++ b/dotnet/src/Agents/Abstractions/Agents.Abstractions.csproj @@ -30,6 +30,7 @@ + diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 62d334520647..baa04929c94c 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -73,6 +73,9 @@ public override async IAsyncEnumerable> In () => new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); + // Get the thread extensions context contributions + var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + // Invoke Chat Completion with the updated chat history. var chatHistory = new ChatHistory(); await foreach (var existingMessage in chatHistoryAgentThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false)) @@ -92,7 +95,7 @@ public override async IAsyncEnumerable> In }, options?.KernelArguments, options?.Kernel, - options?.AdditionalInstructions, + options?.AdditionalInstructions == null ? extensionsContext : options.AdditionalInstructions + Environment.NewLine + Environment.NewLine + extensionsContext, cancellationToken); // Notify the thread of new messages and return them to the caller. @@ -156,6 +159,9 @@ public override async IAsyncEnumerable new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); + // Get the thread extensions context contributions + var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + // Invoke Chat Completion with the updated chat history. var chatHistory = new ChatHistory(); await foreach (var existingMessage in chatHistoryAgentThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false)) @@ -176,7 +182,7 @@ public override async IAsyncEnumerable /// The ID of the new thread. - /// The input text, typically a user ask. /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been loaded. - public virtual Task OnThreadCreateAsync(string threadId, string? inputText = default, CancellationToken cancellationToken = default) + public virtual Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { return Task.CompletedTask; } @@ -60,10 +60,10 @@ public virtual Task OnThreadDeleteAsync(string threadId, CancellationToken cance /// Implementers can load any additional context required at this time, /// but they should also return any context that should be passed to the AI. /// - /// The most recent message that the AI is being invoked with. + /// The most recent messages that the AI is being invoked with. /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been rendered and returned. - public abstract Task OnAIInvocationAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default); + public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); /// /// Register plugins required by this extension component on the provided . diff --git a/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs b/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs index 74e99a9615e5..66f0636e4f2e 100644 --- a/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs +++ b/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs @@ -19,6 +19,22 @@ public class ThreadExtensionsManager /// public virtual IReadOnlyList ThreadExtensions => this._threadExtensions; + /// + /// Initializes a new instance of the class. + /// + public ThreadExtensionsManager() + { + } + + /// + /// Initializes a new instance of the class with the specified thread extensions. + /// + /// The thread extensions to add to the manager. + public ThreadExtensionsManager(IEnumerable threadExtensions) + { + this._threadExtensions.AddRange(threadExtensions); + } + /// /// Registers a new thread extensions. /// @@ -32,12 +48,11 @@ public virtual void RegisterThreadExtension(ThreadExtension threadExtension) /// Called when a new thread is created. /// /// The ID of the new thread. - /// The input text, typically a user ask. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. - public virtual async Task OnThreadCreateAsync(string threadId, string inputText, CancellationToken cancellationToken = default) + public virtual async Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnThreadCreateAsync(threadId, inputText, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnThreadCreatedAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -65,12 +80,12 @@ public virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Cance /// /// Called just before the AI is invoked /// - /// The most recent message that the AI is being invoked with. + /// The most recent messages that the AI is being invoked with. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation, containing the combined context from all thread extensions. - public virtual async Task OnAIInvocationAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public virtual async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { - var subContexts = await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnAIInvocationAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); + var subContexts = await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnAIInvocationAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); return string.Join("\n", subContexts); } diff --git a/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs b/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs index 4b6723fb1a82..d2dbb87e894c 100644 --- a/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.ComponentModel; using System.Linq; using System.Net.Http; @@ -41,6 +42,8 @@ public class MemZeroMemoryComponent : ThreadExtension /// Indicates whether the scope is limited to the thread. public MemZeroMemoryComponent(HttpClient httpClient, string? agentId = default, string? threadId = default, string? userId = default, bool scopeToThread = false) { + Verify.NotNull(httpClient); + this._agentId = agentId; this._threadId = threadId; this._userId = userId; @@ -49,28 +52,17 @@ public MemZeroMemoryComponent(HttpClient httpClient, string? agentId = default, } /// - public override async Task OnThreadCreateAsync(string threadId, string? inputText = default, CancellationToken cancellationToken = default) + public override Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { - if (!this._contextLoaded) - { - this._threadId ??= threadId; - - var searchRequest = new SearchRequest - { - AgentId = this._agentId, - RunId = this._scopeToThread ? this._threadId : null, - UserId = this._userId, - Query = inputText ?? string.Empty - }; - var responseItems = await this.SearchAsync(searchRequest).ConfigureAwait(false); - this._userInformation = string.Join("\n", responseItems); - this._contextLoaded = true; - } + this._threadId ??= threadId; + return Task.CompletedTask; } /// public override async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) { + Verify.NotNull(newMessage); + if (newMessage.Role == AuthorRole.User) { await this.CreateMemoryAsync( @@ -92,14 +84,22 @@ await this.CreateMemoryAsync( } /// - public override Task OnAIInvocationAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { - return Task.FromResult("The following list contains facts about the user:\n" + this._userInformation); + Verify.NotNull(newMessages); + + string input = string.Join("\n", newMessages.Where(m => m is not null).Select(m => m.Content)); + + await this.LoadContextAsync(this._threadId, input).ConfigureAwait(false); + + return "The following list contains facts about the user:\n" + this._userInformation; } /// public override void RegisterPlugins(Kernel kernel) { + Verify.NotNull(kernel); + base.RegisterPlugins(kernel); kernel.Plugins.AddFromObject(this, "MemZeroMemory"); } @@ -115,6 +115,25 @@ public async Task ClearUserPreferencesAsync() await this.ClearMemoryAsync().ConfigureAwait(false); } + private async Task LoadContextAsync(string? threadId, string? inputText) + { + if (!this._contextLoaded) + { + this._threadId ??= threadId; + + var searchRequest = new SearchRequest + { + AgentId = this._agentId, + RunId = this._scopeToThread ? this._threadId : null, + UserId = this._userId, + Query = inputText ?? string.Empty + }; + var responseItems = await this.SearchAsync(searchRequest).ConfigureAwait(false); + this._userInformation = string.Join("\n", responseItems); + this._contextLoaded = true; + } + } + private async Task CreateMemoryAsync(CreateMemoryRequest createMemoryRequest) { using var content = new StringContent(JsonSerializer.Serialize(createMemoryRequest), Encoding.UTF8, "application/json"); diff --git a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs b/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs index c3e4543abf79..93a14c0b29a6 100644 --- a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; using System.ComponentModel; using System.Threading; using System.Threading.Tasks; @@ -85,7 +86,7 @@ EXAMPLES END """; /// - public override async Task OnThreadCreateAsync(string threadId, string? inputText = default, CancellationToken cancellationToken = default) + public override async Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { if (!this._contextLoaded) { @@ -120,7 +121,7 @@ public override Task OnNewMessageAsync(ChatMessageContent newMessage, Cancellati } /// - public override Task OnAIInvocationAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public override Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { return Task.FromResult("The following list contains facts about the user:\n" + this._userPreferences); } @@ -128,6 +129,8 @@ public override Task OnAIInvocationAsync(ChatMessageContent newMessage, /// public override void RegisterPlugins(Kernel kernel) { + Verify.NotNull(kernel); + base.RegisterPlugins(kernel); kernel.Plugins.AddFromObject(this, "UserPreferencesMemory"); } From c13a3c18362ccaccaae2182661586ac303a5c8d4 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 28 Mar 2025 14:50:14 +0000 Subject: [PATCH 04/46] Add memory support to open ai assistant agent --- .../Internal/AssistantRunOptionsFactory.cs | 6 ++- .../OpenAI/Internal/AssistantThreadActions.cs | 8 +++- .../src/Agents/OpenAI/OpenAIAssistantAgent.cs | 40 +++++++++++++----- .../Agents/OpenAI/OpenAIAssistantChannel.cs | 4 +- .../AssistantRunOptionsFactoryTests.cs | 12 +++--- .../AgentWithMemoryTests.cs | 28 +++++++++++++ .../ChatCompletionAgentWithMemoryTests.cs | 42 +++++++++++++++++++ .../OpenAIAssistantAgentWithMemoryTests.cs.cs | 42 +++++++++++++++++++ .../ChatCompletionAgentFixture.cs | 4 ++ .../OpenAIAssistantAgentFixture.cs | 11 +++++ .../IntegrationTests/IntegrationTests.csproj | 1 + .../Memory/UserPreferencesMemoryComponent.cs | 6 +-- 12 files changed, 177 insertions(+), 27 deletions(-) create mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AgentWithMemoryTests.cs create mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs create mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs diff --git a/dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs b/dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs index d0245dbb9bdf..243778db83ac 100644 --- a/dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs +++ b/dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs @@ -10,12 +10,14 @@ namespace Microsoft.SemanticKernel.Agents.OpenAI.Internal; /// internal static class AssistantRunOptionsFactory { - public static RunCreationOptions GenerateOptions(RunCreationOptions? defaultOptions, string? agentInstructions, RunCreationOptions? invocationOptions) + public static RunCreationOptions GenerateOptions(RunCreationOptions? defaultOptions, string? agentInstructions, RunCreationOptions? invocationOptions, string? threadExtensionsContext) { + var additionalInstructions = (invocationOptions?.AdditionalInstructions ?? defaultOptions?.AdditionalInstructions) + threadExtensionsContext; + RunCreationOptions runOptions = new() { - AdditionalInstructions = invocationOptions?.AdditionalInstructions ?? defaultOptions?.AdditionalInstructions, + AdditionalInstructions = additionalInstructions, InstructionsOverride = invocationOptions?.InstructionsOverride ?? agentInstructions, MaxOutputTokenCount = invocationOptions?.MaxOutputTokenCount ?? defaultOptions?.MaxOutputTokenCount, MaxInputTokenCount = invocationOptions?.MaxInputTokenCount ?? defaultOptions?.MaxInputTokenCount, diff --git a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs index e1cb991b643e..c4a603b38922 100644 --- a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs +++ b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs @@ -105,6 +105,7 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist /// The assistant client /// The thread identifier /// Options to utilize for the invocation + /// Additional context from thead extensions to pass to the invoke method. /// The logger to utilize (might be agent or channel scoped) /// The plugins and other state. /// Optional arguments to pass to the agents's invocation, including any . @@ -115,6 +116,7 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist AssistantClient client, string threadId, RunCreationOptions? invocationOptions, + string? threadExtensionsContext, ILogger logger, Kernel kernel, KernelArguments? arguments, @@ -133,7 +135,7 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist string? instructions = await agent.GetInstructionsAsync(kernel, arguments, cancellationToken).ConfigureAwait(false); - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(agent.RunOptions, instructions, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(agent.RunOptions, instructions, invocationOptions, threadExtensionsContext); options.ToolsOverride.AddRange(tools); @@ -335,6 +337,7 @@ async Task PollRunStatusAsync() /// The thread identifier /// The receiver for the completed messages generated /// Options to utilize for the invocation + /// Additional context from thead extensions to pass to the invoke method. /// The logger to utilize (might be agent or channel scoped) /// The plugins and other state. /// Optional arguments to pass to the agents's invocation, including any . @@ -350,6 +353,7 @@ public static async IAsyncEnumerable InvokeStreamin string threadId, IList? messages, RunCreationOptions? invocationOptions, + string? threadExtensionsContext, ILogger logger, Kernel kernel, KernelArguments? arguments, @@ -361,7 +365,7 @@ public static async IAsyncEnumerable InvokeStreamin string? instructions = await agent.GetInstructionsAsync(kernel, arguments, cancellationToken).ConfigureAwait(false); - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(agent.RunOptions, instructions, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(agent.RunOptions, instructions, invocationOptions, threadExtensionsContext); options.ToolsOverride.AddRange(tools); diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs index d5642c496665..1e8f2858a3f8 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs @@ -411,6 +411,9 @@ public async IAsyncEnumerable> InvokeAsync AdditionalInstructions = options?.AdditionalInstructions, }); + // Get the thread extensions context contributions + var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + var invokeResults = ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), () => InternalInvokeAsync(), @@ -423,6 +426,7 @@ async IAsyncEnumerable InternalInvokeAsync() this.Client, openAIAssistantAgentThread.Id!, internalOptions, + extensionsContext, this.Logger, options?.Kernel ?? this.Kernel, this.MergeArguments(options?.KernelArguments), @@ -498,7 +502,7 @@ async IAsyncEnumerable InternalInvokeAsync() kernel ??= this.Kernel; arguments = this.MergeArguments(arguments); - await foreach ((bool isVisible, ChatMessageContent message) in AssistantThreadActions.InvokeAsync(this, this.Client, threadId, options, this.Logger, kernel, arguments, cancellationToken).ConfigureAwait(false)) + await foreach ((bool isVisible, ChatMessageContent message) in AssistantThreadActions.InvokeAsync(this, this.Client, threadId, options, null, this.Logger, kernel, arguments, cancellationToken).ConfigureAwait(false)) { if (isVisible) { @@ -549,6 +553,9 @@ public async IAsyncEnumerable> In () => new OpenAIAssistantAgentThread(this.Client), cancellationToken).ConfigureAwait(false); + // Get the thread extensions context contributions + var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + // Create options that use the RunCreationOptions from the options param if provided or // falls back to creating a new RunCreationOptions if additional instructions is provided // separately. @@ -557,17 +564,28 @@ public async IAsyncEnumerable> In AdditionalInstructions = options?.AdditionalInstructions, }); -#pragma warning disable CS0618 // Type or member is obsolete - // Invoke the Agent with the thread that we already added our message to. +#pragma warning disable SKEXP0001 // ModelDiagnostics is marked experimental. var newMessagesReceiver = new ChatHistory(); - var invokeResults = this.InvokeStreamingAsync( - openAIAssistantAgentThread.Id!, - internalOptions, - this.MergeArguments(options?.KernelArguments), - options?.Kernel ?? this.Kernel, - newMessagesReceiver, + var invokeResults = ActivityExtensions.RunWithActivityAsync( + () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), + () => InternalInvokeStreamingAsync(), cancellationToken); -#pragma warning restore CS0618 // Type or member is obsolete +#pragma warning restore SKEXP0001 // ModelDiagnostics is marked experimental. + + IAsyncEnumerable InternalInvokeStreamingAsync() + { + return AssistantThreadActions.InvokeStreamingAsync( + this, + this.Client, + openAIAssistantAgentThread.Id!, + newMessagesReceiver, + internalOptions, + extensionsContext, + this.Logger, + options?.Kernel ?? this.Kernel, + this.MergeArguments(options?.KernelArguments), + cancellationToken); + } // Return the chunks to the caller. await foreach (var result in invokeResults.ConfigureAwait(false)) @@ -642,7 +660,7 @@ IAsyncEnumerable InternalInvokeStreamingAsync() kernel ??= this.Kernel; arguments = this.MergeArguments(arguments); - return AssistantThreadActions.InvokeStreamingAsync(this, this.Client, threadId, messages, options, this.Logger, kernel, arguments, cancellationToken); + return AssistantThreadActions.InvokeStreamingAsync(this, this.Client, threadId, messages, options, null, this.Logger, kernel, arguments, cancellationToken); } } diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs index 39534df768da..ad2331a2fb6e 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs @@ -36,7 +36,7 @@ protected override async Task ReceiveAsync(IEnumerable histo { return ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(agent.Id, agent.GetDisplayName(), agent.Description), - () => AssistantThreadActions.InvokeAsync(agent, this._client, this._threadId, invocationOptions: null, this.Logger, agent.Kernel, agent.Arguments, cancellationToken), + () => AssistantThreadActions.InvokeAsync(agent, this._client, this._threadId, invocationOptions: null, threadExtensionsContext: null, this.Logger, agent.Kernel, agent.Arguments, cancellationToken), cancellationToken); } @@ -45,7 +45,7 @@ protected override IAsyncEnumerable InvokeStreaming { return ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(agent.Id, agent.GetDisplayName(), agent.Description), - () => AssistantThreadActions.InvokeStreamingAsync(agent, this._client, this._threadId, messages, invocationOptions: null, this.Logger, agent.Kernel, agent.Arguments, cancellationToken), + () => AssistantThreadActions.InvokeStreamingAsync(agent, this._client, this._threadId, messages, invocationOptions: null, threadExtensionsContext: null, this.Logger, agent.Kernel, agent.Arguments, cancellationToken), cancellationToken); } diff --git a/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs b/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs index dfca85afc0f2..8777a29b6cfc 100644 --- a/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs +++ b/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs @@ -29,7 +29,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsNullTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null, threadExtensionsContext: null); // Assert Assert.NotNull(options); @@ -62,7 +62,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsEquivalentTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, "test", invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, "test", invocationOptions, threadExtensionsContext: null); // Assert Assert.NotNull(options); @@ -97,7 +97,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsOverrideTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions, threadExtensionsContext: null); // Assert Assert.NotNull(options); @@ -134,7 +134,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsMetadataTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions, threadExtensionsContext: null); // Assert Assert.Equal(2, options.Metadata.Count); @@ -163,7 +163,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsMessagesTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions, threadExtensionsContext: null); // Assert Assert.Single(options.AdditionalMessages); @@ -186,7 +186,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsMaxTokensTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null, threadExtensionsContext: null); // Assert Assert.Equal(1024, options.MaxInputTokenCount); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AgentWithMemoryTests.cs new file mode 100644 index 000000000000..6327a64ee307 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AgentWithMemoryTests.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading.Tasks; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; + +public abstract class AgentWithMemoryTests(Func createAgentFixture) : IAsyncLifetime + where TFixture : AgentFixture +{ +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + private TFixture _agentFixture; +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + + protected TFixture Fixture => this._agentFixture; + + public Task InitializeAsync() + { + this._agentFixture = createAgentFixture(); + return this._agentFixture.InitializeAsync(); + } + + public Task DisposeAsync() + { + return this._agentFixture.DisposeAsync(); + } +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs new file mode 100644 index 000000000000..57525accb5b9 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.Memory; +using Microsoft.SemanticKernel.ChatCompletion; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; + +public class ChatCompletionAgentWithMemoryTests() : AgentWithMemoryTests(() => new ChatCompletionAgentFixture()) +{ + [Fact] + public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() + { + // Arrange + var agent = this.Fixture.Agent; + var memoryComponent = new UserPreferencesMemoryComponent(this.Fixture.Agent.Kernel); + + var agentThread1 = new ChatHistoryAgentThread(); + agentThread1.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); + + var agentThread2 = new ChatHistoryAgentThread(); + agentThread2.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); + + // Act + var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); + var results1 = await asyncResults1.ToListAsync(); + + var asyncResults2 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "What is my name?."), agentThread2); + var results2 = await asyncResults2.ToListAsync(); + + // Assert + Assert.Contains("Caoimhe", results2.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread1); + await this.Fixture.DeleteThread(agentThread2); + } +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs new file mode 100644 index 000000000000..6e03c749a049 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.Memory; +using Microsoft.SemanticKernel.Agents.OpenAI; +using Microsoft.SemanticKernel.ChatCompletion; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; + +public class OpenAIAssistantAgentWithMemoryTests() : AgentWithMemoryTests(() => new OpenAIAssistantAgentFixture()) +{ + [Fact] + public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() + { + // Arrange + var agent = this.Fixture.Agent; + var memoryComponent = new UserPreferencesMemoryComponent(this.Fixture.Agent.Kernel); + + var agentThread1 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); + agentThread1.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); + + var agentThread2 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); + agentThread2.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); + + // Act + var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); + var results1 = await asyncResults1.ToListAsync(); + + var asyncResults2 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "What is my name?."), agentThread2); + var results2 = await asyncResults2.ToListAsync(); + + // Assert + Assert.Contains("Caoimhe", results2.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread1); + await this.Fixture.DeleteThread(agentThread2); + } +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs index 999a831c0e07..e8ea7839ff27 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs @@ -65,6 +65,10 @@ public async override Task InitializeAsync() deploymentName: configuration.ChatDeploymentName!, endpoint: configuration.Endpoint, credentials: new AzureCliCredential()); + kernelBuilder.AddAzureOpenAITextEmbeddingGeneration( + deploymentName: configuration.EmbeddingModelId!, + endpoint: configuration.Endpoint, + credential: new AzureCliCredential()); Kernel kernel = kernelBuilder.Build(); this._agent = new ChatCompletionAgent() diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs index 3a2a5ded8df9..9a5e15f4e298 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs @@ -34,6 +34,8 @@ public class OpenAIAssistantAgentFixture : AgentFixture private OpenAIAssistantAgentThread? _serviceFailingAgentThread; private OpenAIAssistantAgentThread? _createdServiceFailingAgentThread; + public AssistantClient AssistantClient => this._assistantClient!; + public override KernelAgent Agent => this._agent!; public override AgentThread AgentThread => this._thread!; @@ -85,6 +87,7 @@ public override Task DeleteThread(AgentThread thread) public override async Task InitializeAsync() { + AzureOpenAIConfiguration openAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get()!; AzureAIConfiguration configuration = this._configuration.GetSection("AzureAI").Get()!; var client = OpenAIAssistantAgent.CreateAzureOpenAIClient(new AzureCliCredential(), new Uri(configuration.Endpoint)); this._assistantClient = client.GetAssistantClient(); @@ -96,6 +99,14 @@ await this._assistantClient.CreateAssistantAsync( instructions: "You are a helpful assistant."); var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddAzureOpenAIChatCompletion( + deploymentName: openAIConfiguration.ChatDeploymentName!, + endpoint: openAIConfiguration.Endpoint, + credentials: new AzureCliCredential()); + kernelBuilder.AddAzureOpenAITextEmbeddingGeneration( + deploymentName: openAIConfiguration.EmbeddingModelId!, + endpoint: openAIConfiguration.Endpoint, + credential: new AzureCliCredential()); Kernel kernel = kernelBuilder.Build(); this._agent = new OpenAIAssistantAgent(this._assistant, this._assistantClient) { Kernel = kernel }; diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index cb21d20b7f4a..0b5a37fa05ed 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -96,6 +96,7 @@ + diff --git a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs b/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs index 93a14c0b29a6..4190f1ac7b13 100644 --- a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs @@ -109,15 +109,13 @@ public override async Task OnThreadDeleteAsync(string threadId, CancellationToke } /// - public override Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public override async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) { if (newMessage.Role == AuthorRole.User && !string.IsNullOrWhiteSpace(newMessage.Content)) { // Don't wait for task to complete. Just run in the background. - var task = this.ExtractAndSaveMemoriesAsync(newMessage.Content, cancellationToken); + await this.ExtractAndSaveMemoriesAsync(newMessage.Content, cancellationToken).ConfigureAwait(false); } - - return Task.CompletedTask; } /// From c28c0aeddc44596cec5b00a546c7d0933b59ac83 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 7 Apr 2025 19:20:13 +0100 Subject: [PATCH 05/46] Add vector data memory store. --- .../Memory.Abstractions/TextMemoryDocument.cs | 39 --- dotnet/src/Memory/Memory/Memory.csproj | 1 + .../Memory/VectorDataTextMemoryStore.cs | 275 ++++++++++++++++++ 3 files changed, 276 insertions(+), 39 deletions(-) delete mode 100644 dotnet/src/Memory/Memory.Abstractions/TextMemoryDocument.cs create mode 100644 dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs diff --git a/dotnet/src/Memory/Memory.Abstractions/TextMemoryDocument.cs b/dotnet/src/Memory/Memory.Abstractions/TextMemoryDocument.cs deleted file mode 100644 index 5fa9408e47df..000000000000 --- a/dotnet/src/Memory/Memory.Abstractions/TextMemoryDocument.cs +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; - -namespace Microsoft.SemanticKernel.Memory; - -/// -/// Represents a storage record for a single text based memory. -/// -public sealed class TextMemoryDocument -{ - /// - /// Gets or sets a unique identifier for the memory document. - /// - public Guid Key { get; set; } - - /// - /// Gets or sets the namespace for the memory document. - /// - /// - /// A namespace is a logical grouping of memory documents, e.g. may include a user id to scope the memory to a specific user. - /// - public string Namespace { get; set; } = string.Empty; - - /// - /// Gets or sets an optional name for the memory document. - /// - public string Name { get; set; } = string.Empty; - - /// - /// Gets or sets an optional category for the memory document. - /// - public string Category { get; set; } = string.Empty; - - /// - /// Gets or sets the actual memory content as text. - /// - public string MemoryText { get; set; } = string.Empty; -} diff --git a/dotnet/src/Memory/Memory/Memory.csproj b/dotnet/src/Memory/Memory/Memory.csproj index dfef3edf1395..dfe3e844db9f 100644 --- a/dotnet/src/Memory/Memory/Memory.csproj +++ b/dotnet/src/Memory/Memory/Memory.csproj @@ -29,6 +29,7 @@ + diff --git a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs b/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs new file mode 100644 index 000000000000..0315ba58d8d4 --- /dev/null +++ b/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Embeddings; + +namespace Microsoft.SemanticKernel.Memory; + +#pragma warning disable SKEXP0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +/// +/// Class to store and retrieve text-based memories to and from a vector store. +/// +/// The key type to use with the vector store. +public class VectorDataTextMemoryStore : TextMemoryStore, IDisposable + where TKey : notnull + +{ + private readonly IVectorStore _vectorStore; + private readonly ITextEmbeddingGenerationService _textEmbeddingGenerationService; + private readonly string _storageNamespace; + private readonly int _vectorDimensions; + private readonly Lazy>> _vectorStoreRecordCollection; + private readonly SemaphoreSlim _collectionInitializationLock = new(1, 1); + private bool _collectionInitialized = false; + private bool _disposedValue; + + private readonly VectorStoreRecordDefinition _memoryDocumentDefinition = new() + { + Properties = new List() + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("Namespace", typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Name", typeof(string)), + new VectorStoreRecordDataProperty("Category", typeof(string)), + new VectorStoreRecordDataProperty("MemoryText", typeof(string)), + new VectorStoreRecordVectorProperty("MemoryTextEmbedding", typeof(ReadOnlyMemory)), + } + }; + + /// + /// Initializes a new instance of the class. + /// + /// The vector store to store and read the memories from. + /// The service to use for generating embeddings for the memories. + /// The name of the collection in the vector store to store and read the memories from. + /// The namespace to scope memories to within the collection. + /// The number of dimentions to use for the memory embeddings. + /// Thrown if the key type provided is not supported. + public VectorDataTextMemoryStore(IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService, string collectionName, string storageNamespace, int vectorDimensions) + { + Verify.NotNull(vectorStore); + Verify.NotNull(textEmbeddingGenerationService); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNullOrWhiteSpace(storageNamespace); + Verify.True(vectorDimensions > 0, "Vector dimensions must be greater than 0"); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(Guid)) + { + throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}'"); + } + + VectorStoreRecordDefinition _memoryDocumentDefinition = new() + { + Properties = new List() + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("Namespace", typeof(string)), + new VectorStoreRecordDataProperty("Name", typeof(string)), + new VectorStoreRecordDataProperty("Category", typeof(string)), + new VectorStoreRecordDataProperty("MemoryText", typeof(string)), + new VectorStoreRecordVectorProperty("MemoryTextEmbedding", typeof(ReadOnlyMemory)) { Dimensions = vectorDimensions }, + } + }; + + this._vectorStore = vectorStore; + this._textEmbeddingGenerationService = textEmbeddingGenerationService; + this._storageNamespace = storageNamespace; + this._vectorDimensions = vectorDimensions; + this._vectorStoreRecordCollection = new Lazy>>(() => + this._vectorStore.GetCollection>(collectionName, this._memoryDocumentDefinition)); + } + + /// + public override async Task GetMemoryAsync(string documentName, CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + + ReadOnlyMemory vector = new(new float[this._vectorDimensions]); + var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( + vector, + options: new() + { + Top = 1, + Filter = x => x.Name == documentName && x.Namespace == this._storageNamespace, + }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + var results = await searchResult.Results.ToListAsync(cancellationToken).ConfigureAwait(false); + + if (results.Count == 0) + { + return null; + } + + return results[0].Record.MemoryText; + } + + /// + public override async IAsyncEnumerable SimilaritySearch(string query, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + + var vector = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(query, cancellationToken: cancellationToken).ConfigureAwait(false); + var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( + vector, + options: new() + { + Top = 3, + Filter = x => x.Namespace == this._storageNamespace, + }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + await foreach (var result in searchResult.Results.ConfigureAwait(false)) + { + yield return result.Record.MemoryText; + } + } + + /// + public override async Task SaveMemoryAsync(string documentName, string memoryText, CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + + var vector = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync( + string.IsNullOrWhiteSpace(memoryText) ? "Empty" : memoryText, + cancellationToken: cancellationToken).ConfigureAwait(false); + + var memoryDocument = new MemoryDocument + { + Key = GenerateUniqueKey(), + Namespace = this._storageNamespace, + Name = documentName, + MemoryText = memoryText, + MemoryTextEmbedding = vector, + }; + + await vectorStoreRecordCollection.UpsertAsync(memoryDocument, cancellationToken: cancellationToken).ConfigureAwait(false); + } + + /// + public override Task SaveMemoryAsync(string memoryText, CancellationToken cancellationToken = default) + { + return this.SaveMemoryAsync(null!, memoryText, cancellationToken); + } + + /// + /// Thread safe method to get the collection and ensure that it is created at least once. + /// + /// The to monitor for cancellation requests. The default is . + /// The created collection. + private async Task>> EnsureCollectionCreatedAsync(CancellationToken cancellationToken) + { + var vectorStoreRecordCollection = this._vectorStoreRecordCollection.Value; + + // Return immediately if the collection is already created, no need to do any locking in this case. + if (this._collectionInitialized) + { + return vectorStoreRecordCollection; + } + + // Wait on a lock to ensure that only one thread can create the collection. + await this._collectionInitializationLock.WaitAsync(cancellationToken).ConfigureAwait(false); + + // If multiple threads waited on the lock, and the first already created the collection, + // we can return immediately without doing any work in subsequent threads. + if (this._collectionInitialized) + { + this._collectionInitializationLock.Release(); + return vectorStoreRecordCollection; + } + + // Only the winning thread should reach this point and create the collection. + try + { + await vectorStoreRecordCollection.CreateCollectionIfNotExistsAsync(cancellationToken).ConfigureAwait(false); + this._collectionInitialized = true; + } + finally + { + this._collectionInitializationLock.Release(); + } + + return vectorStoreRecordCollection; + } + + /// + /// Generates a unique key for the memory document. + /// + /// The type of the key to use, since different databases require/support different keys. + /// A new unique key. + /// Thrown if the requested key type is not supported. + private static TDocumentKey GenerateUniqueKey() + => typeof(TDocumentKey) switch + { + _ when typeof(TDocumentKey) == typeof(string) => (TDocumentKey)(object)Guid.NewGuid().ToString(), + _ when typeof(TDocumentKey) == typeof(Guid) => (TDocumentKey)(object)Guid.NewGuid(), + + _ => throw new NotSupportedException($"Unsupported key of type '{typeof(TDocumentKey).Name}'") + }; + + /// + /// The data model to use for storing memory documents in the vector store. + /// + /// The type of the key to use, since different databases require/support different keys. + private sealed class MemoryDocument + { + /// + /// Gets or sets a unique identifier for the memory document. + /// + public TDocumentKey Key { get; set; } = default!; + + /// + /// Gets or sets the namespace for the memory document. + /// + /// + /// A namespace is a logical grouping of memory documents, e.g. may include a user id to scope the memory to a specific user. + /// + public string Namespace { get; set; } = string.Empty; + + /// + /// Gets or sets an optional name for the memory document. + /// + public string Name { get; set; } = string.Empty; + + /// + /// Gets or sets an optional category for the memory document. + /// + public string Category { get; set; } = string.Empty; + + /// + /// Gets or sets the actual memory content as text. + /// + public string MemoryText { get; set; } = string.Empty; + + public ReadOnlyMemory MemoryTextEmbedding { get; set; } + } + + /// + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._collectionInitializationLock.Dispose(); + } + + this._disposedValue = true; + } + } + + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } +} From 27cb090b4f94af0ac761d97ef8a797f93f2d1132 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 9 Apr 2025 15:35:26 +0100 Subject: [PATCH 06/46] Fix bug in VectorDataTextMemoryStore and add integration test with vector storage --- .../ChatCompletionAgentWithMemoryTests.cs | 47 +++++++++++++++++++ .../Memory/VectorDataTextMemoryStore.cs | 17 ++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index 57525accb5b9..e5c8d1c9522a 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -2,16 +2,29 @@ using System.Linq; using System.Threading.Tasks; +using Azure.Identity; +using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.Agents.Memory; using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.Connectors.InMemory; +using Microsoft.SemanticKernel.Memory; +using SemanticKernel.IntegrationTests.TestSettings; using Xunit; namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; public class ChatCompletionAgentWithMemoryTests() : AgentWithMemoryTests(() => new ChatCompletionAgentFixture()) { + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + [Fact] public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() { @@ -39,4 +52,38 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() await this.Fixture.DeleteThread(agentThread1); await this.Fixture.DeleteThread(agentThread2); } + + [Fact] + public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserInputAsync() + { + // Arrange + var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); + + var vectorStore = new InMemoryVectorStore(); + var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); + using var textMemoryStore = new VectorDataTextMemoryStore(vectorStore, textEmbeddingService, "Memories", "user/12345", 1536); + + var agent = this.Fixture.Agent; + + // Act - First invocation with first thread. + var agentThread1 = new ChatHistoryAgentThread(); + agentThread1.ThreadExtensionsManager.RegisterThreadExtension(new UserPreferencesMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + + var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); + var results1 = await asyncResults1.ToListAsync(); + + // Act - Second invocation with second thread. + var agentThread2 = new ChatHistoryAgentThread(); + agentThread2.ThreadExtensionsManager.RegisterThreadExtension(new UserPreferencesMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + + var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); + var results2 = await asyncResults2.ToListAsync(); + + // Assert + Assert.Contains("Caoimhe", results2.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread1); + await this.Fixture.DeleteThread(agentThread2); + } } diff --git a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs b/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs index 0315ba58d8d4..6027b5baf923 100644 --- a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs +++ b/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs @@ -91,6 +91,16 @@ public VectorDataTextMemoryStore(IVectorStore vectorStore, ITextEmbeddingGenerat { var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + // If the database supports string keys, get using the namespace + document name. + if (typeof(TKey) == typeof(string)) + { + var namespaceKey = $"{this._storageNamespace}:{documentName}"; + + var record = await vectorStoreRecordCollection.GetAsync((TKey)(object)namespaceKey, cancellationToken: cancellationToken).ConfigureAwait(false); + return record?.MemoryText; + } + + // Otherwise do a search with a filter on the document name and namespace. ReadOnlyMemory vector = new(new float[this._vectorDimensions]); var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( vector, @@ -143,7 +153,7 @@ public override async Task SaveMemoryAsync(string documentName, string memoryTex var memoryDocument = new MemoryDocument { - Key = GenerateUniqueKey(), + Key = GenerateUniqueKey(this._storageNamespace, documentName), Namespace = this._storageNamespace, Name = documentName, MemoryText = memoryText, @@ -202,12 +212,15 @@ private async Task>> Ens /// /// Generates a unique key for the memory document. /// + /// Storage namespace to use for string keys. + /// An optional document name to use for the key if the database supports string keys. /// The type of the key to use, since different databases require/support different keys. /// A new unique key. /// Thrown if the requested key type is not supported. - private static TDocumentKey GenerateUniqueKey() + private static TDocumentKey GenerateUniqueKey(string storageNamespace, string? documentName) => typeof(TDocumentKey) switch { + _ when typeof(TDocumentKey) == typeof(string) && documentName is not null => (TDocumentKey)(object)$"{storageNamespace}:{documentName}", _ when typeof(TDocumentKey) == typeof(string) => (TDocumentKey)(object)Guid.NewGuid().ToString(), _ when typeof(TDocumentKey) == typeof(Guid) => (TDocumentKey)(object)Guid.NewGuid(), From 0169f6e0a54e5d6d804622ef78f99c6daf940c04 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 9 Apr 2025 19:41:48 +0100 Subject: [PATCH 07/46] Rename extension, add experimental attributes and mark memory packages as alpha --- dotnet/docs/EXPERIMENTS.md | 1 + dotnet/src/Agents/Abstractions/AgentThread.cs | 15 ++++-- dotnet/src/Agents/Core/ChatCompletionAgent.cs | 12 +++-- .../src/Agents/OpenAI/OpenAIAssistantAgent.cs | 12 +++-- .../IntegrationTests/IntegrationTests.csproj | 2 +- ...nsion.cs => ConversationStateExtension.cs} | 17 +++++-- ... => ConversationStateExtensionsManager.cs} | 46 ++++++++++--------- .../Memory.Abstractions.csproj | 6 +++ .../Properties/AssemblyInfo.cs | 6 +++ .../Memory/Memory/MemZeroMemoryComponent.cs | 2 +- dotnet/src/Memory/Memory/Memory.csproj | 1 + .../Memory/Memory/Properties/AssemblyInfo.cs | 6 +++ .../Memory/UserPreferencesMemoryComponent.cs | 4 +- 13 files changed, 88 insertions(+), 42 deletions(-) rename dotnet/src/Memory/Memory.Abstractions/{ThreadExtension.cs => ConversationStateExtension.cs} (81%) rename dotnet/src/Memory/Memory.Abstractions/{ThreadExtensionsManager.cs => ConversationStateExtensionsManager.cs} (55%) create mode 100644 dotnet/src/Memory/Memory.Abstractions/Properties/AssemblyInfo.cs create mode 100644 dotnet/src/Memory/Memory/Properties/AssemblyInfo.cs diff --git a/dotnet/docs/EXPERIMENTS.md b/dotnet/docs/EXPERIMENTS.md index 99fd9b56afb4..114211fe01f8 100644 --- a/dotnet/docs/EXPERIMENTS.md +++ b/dotnet/docs/EXPERIMENTS.md @@ -25,6 +25,7 @@ You can use the following diagnostic IDs to ignore warnings or errors for a part | SKEXP0100 | Advanced Semantic Kernel features | | SKEXP0110 | Semantic Kernel Agents | | SKEXP0120 | Native-AOT | +| SKEXP0130 | Memory | ## Experimental Features Tracking diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index cd2c408d7d7c..e11aa7cdef19 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel.Memory; @@ -29,9 +30,10 @@ public abstract class AgentThread public virtual bool IsDeleted { get; protected set; } = false; /// - /// Gets or sets the container for thread extension components that manages their lifecycle and interactions. + /// Gets or sets the container for conversation state extension components that manages their lifecycle and interactions. /// - public virtual ThreadExtensionsManager ThreadExtensionsManager { get; init; } = new ThreadExtensionsManager(); + [Experimental("SKEXP0130")] + public virtual ConversationStateExtensionsManager ThreadExtensionsManager { get; init; } = new ConversationStateExtensionsManager(); /// /// Creates the thread and returns the thread id. @@ -53,7 +55,9 @@ protected internal virtual async Task CreateAsync(CancellationToken cancellation this.Id = await this.CreateInternalAsync(cancellationToken: cancellationToken).ConfigureAwait(false); +#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.ThreadExtensionsManager.OnThreadCreatedAsync(this.Id!, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. } /// @@ -74,7 +78,9 @@ public virtual async Task DeleteAsync(CancellationToken cancellationToken = defa throw new InvalidOperationException("This thread cannot be deleted, since it has not been created."); } +#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.ThreadExtensionsManager.OnThreadDeleteAsync(this.Id!, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.DeleteInternalAsync(cancellationToken).ConfigureAwait(false); @@ -86,7 +92,8 @@ public virtual async Task DeleteAsync(CancellationToken cancellationToken = defa /// /// The most recent messages that the AI is being invoked with. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation, containing the combined context from all thread extensions. + /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. + [Experimental("SKEXP0130")] public virtual async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { return await this.ThreadExtensionsManager.OnAIInvocationAsync(newMessages, cancellationToken).ConfigureAwait(false); @@ -114,7 +121,9 @@ internal virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Can await this.CreateAsync(cancellationToken).ConfigureAwait(false); } +#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.ThreadExtensionsManager.OnNewMessageAsync(newMessage, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.OnNewMessageInternalAsync(newMessage, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 7fde212ff67b..4ec17c5d7a7e 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -74,8 +74,10 @@ public override async IAsyncEnumerable> In () => new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); - // Get the thread extensions context contributions - var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + // Get the conversation state extensions context contributions +#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await chatHistoryAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Invoke Chat Completion with the updated chat history. var chatHistory = new ChatHistory(); @@ -160,8 +162,10 @@ public override async IAsyncEnumerable new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); - // Get the thread extensions context contributions - var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + // Get the conversation state extensions context contributions +#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await chatHistoryAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Invoke Chat Completion with the updated chat history. var chatHistory = new ChatHistory(); diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs index fd6ac3617c86..a900040064b0 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs @@ -411,8 +411,10 @@ public async IAsyncEnumerable> InvokeAsync AdditionalInstructions = options?.AdditionalInstructions, }); - // Get the thread extensions context contributions - var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + // Get the conversation state extensions context contributions +#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await openAIAssistantAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var invokeResults = ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), @@ -551,8 +553,10 @@ public async IAsyncEnumerable> In () => new OpenAIAssistantAgentThread(this.Client), cancellationToken).ConfigureAwait(false); - // Get the thread extensions context contributions - var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + // Get the conversation state extensions context contributions +#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await openAIAssistantAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Create options that use the RunCreationOptions from the options param if provided or // falls back to creating a new RunCreationOptions if additional instructions is provided diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 0b5a37fa05ed..e1bff6844998 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -5,7 +5,7 @@ net8.0 true false - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110,OPENAI001 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110,SKEXP0130,OPENAI001 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 diff --git a/dotnet/src/Memory/Memory.Abstractions/ThreadExtension.cs b/dotnet/src/Memory/Memory.Abstractions/ConversationStateExtension.cs similarity index 81% rename from dotnet/src/Memory/Memory.Abstractions/ThreadExtension.cs rename to dotnet/src/Memory/Memory.Abstractions/ConversationStateExtension.cs index c4e4e222991c..0ac251d58508 100644 --- a/dotnet/src/Memory/Memory.Abstractions/ThreadExtension.cs +++ b/dotnet/src/Memory/Memory.Abstractions/ConversationStateExtension.cs @@ -1,15 +1,22 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; namespace Microsoft.SemanticKernel.Memory; /// -/// Base class for all thread extensions. +/// Base class for all conversation state extensions. /// -public abstract class ThreadExtension +/// +/// A conversation state extension is a component that can be used to store additional state related +/// to a conversation, listen to changes in the conversation state, and provide additional context to +/// the AI model in use just before invocation. +/// +[Experimental("SKEXP0130")] +public abstract class ConversationStateExtension { /// /// Called just after a new thread is created. @@ -18,7 +25,7 @@ public abstract class ThreadExtension /// Implementers can use this method to do any operations required at the creation of a new thread. /// For exmple, checking long term storage for any data that is relevant to the current session based on the input text. /// - /// The ID of the new thread. + /// The ID of the new thread, if the thread has an ID. /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been loaded. public virtual Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) @@ -47,10 +54,10 @@ public virtual Task OnNewMessageAsync(ChatMessageContent newMessage, Cancellatio /// Implementers can use this method to do any operations required before a thread is deleted. /// For exmple, storing the context to long term storage. /// - /// The id of the thread that will be deleted. + /// The ID of the thread that will be deleted, if the thread has an ID. /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been saved. - public virtual Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) + public virtual Task OnThreadDeleteAsync(string? threadId, CancellationToken cancellationToken = default) { return Task.CompletedTask; } diff --git a/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs b/dotnet/src/Memory/Memory.Abstractions/ConversationStateExtensionsManager.cs similarity index 55% rename from dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs rename to dotnet/src/Memory/Memory.Abstractions/ConversationStateExtensionsManager.cs index 66f0636e4f2e..e07b411e0f04 100644 --- a/dotnet/src/Memory/Memory.Abstractions/ThreadExtensionsManager.cs +++ b/dotnet/src/Memory/Memory.Abstractions/ConversationStateExtensionsManager.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -8,40 +9,41 @@ namespace Microsoft.SemanticKernel.Memory; /// -/// A container class for thread extension components that manages their lifecycle and interactions. +/// A container class for objects that manages their lifecycle and interactions. /// -public class ThreadExtensionsManager +[Experimental("SKEXP0130")] +public class ConversationStateExtensionsManager { - private readonly List _threadExtensions = new(); + private readonly List _conversationStateExtensions = new(); /// - /// Gets the list of registered thread extensions. + /// Gets the list of registered conversation state extensions. /// - public virtual IReadOnlyList ThreadExtensions => this._threadExtensions; + public virtual IReadOnlyList ConversationStateExtensions => this._conversationStateExtensions; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// - public ThreadExtensionsManager() + public ConversationStateExtensionsManager() { } /// - /// Initializes a new instance of the class with the specified thread extensions. + /// Initializes a new instance of the class with the specified conversation state extensions. /// - /// The thread extensions to add to the manager. - public ThreadExtensionsManager(IEnumerable threadExtensions) + /// The conversation state extensions to add to the manager. + public ConversationStateExtensionsManager(IEnumerable conversationtStateExtensions) { - this._threadExtensions.AddRange(threadExtensions); + this._conversationStateExtensions.AddRange(conversationtStateExtensions); } /// - /// Registers a new thread extensions. + /// Registers a new conversation state extensions. /// - /// The thread extensions to register. - public virtual void RegisterThreadExtension(ThreadExtension threadExtension) + /// The conversation state extensions to register. + public virtual void RegisterThreadExtension(ConversationStateExtension conversationtStateExtension) { - this._threadExtensions.Add(threadExtension); + this._conversationStateExtensions.Add(conversationtStateExtension); } /// @@ -52,7 +54,7 @@ public virtual void RegisterThreadExtension(ThreadExtension threadExtension) /// A task that represents the asynchronous operation. public virtual async Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnThreadCreatedAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnThreadCreatedAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -63,7 +65,7 @@ public virtual async Task OnThreadCreatedAsync(string? threadId, CancellationTok /// A task that represents the asynchronous operation. public virtual async Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnThreadDeleteAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnThreadDeleteAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -74,7 +76,7 @@ public virtual async Task OnThreadDeleteAsync(string threadId, CancellationToken /// A task that represents the asynchronous operation. public virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnNewMessageAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnNewMessageAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -82,20 +84,20 @@ public virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Cance /// /// The most recent messages that the AI is being invoked with. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation, containing the combined context from all thread extensions. + /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. public virtual async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { - var subContexts = await Task.WhenAll(this.ThreadExtensions.Select(x => x.OnAIInvocationAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); + var subContexts = await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnAIInvocationAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); return string.Join("\n", subContexts); } /// - /// Registers plugins required by all thread extensions contained by this manager on the provided . + /// Registers plugins required by all conversation state extensions contained by this manager on the provided . /// /// The kernel to register the plugins on. public virtual void RegisterPlugins(Kernel kernel) { - foreach (var threadExtension in this.ThreadExtensions) + foreach (var threadExtension in this.ConversationStateExtensions) { threadExtension.RegisterPlugins(kernel); } diff --git a/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj b/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj index cefd99323ad0..fb48dbd9236b 100644 --- a/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj +++ b/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj @@ -5,6 +5,7 @@ Microsoft.SemanticKernel.Memory net8.0;netstandard2.0 false + alpha @@ -19,6 +20,11 @@ rc + + + + + diff --git a/dotnet/src/Memory/Memory.Abstractions/Properties/AssemblyInfo.cs b/dotnet/src/Memory/Memory.Abstractions/Properties/AssemblyInfo.cs new file mode 100644 index 000000000000..418ffa6d4b58 --- /dev/null +++ b/dotnet/src/Memory/Memory.Abstractions/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +// This assembly is currently experimental. +[assembly: Experimental("SKEXP0130")] diff --git a/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs b/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs index d2dbb87e894c..af43cf14bdff 100644 --- a/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs @@ -18,7 +18,7 @@ namespace Microsoft.SemanticKernel.Memory; /// A component that listenes to messages added to the conversation thread, and automatically captures /// information about the user. It is also able to retrieve this information and add it to the AI invocation context. /// -public class MemZeroMemoryComponent : ThreadExtension +public class MemZeroMemoryComponent : ConversationStateExtension { private static readonly Uri s_searchUri = new("/search", UriKind.Relative); private static readonly Uri s_createMemoryUri = new("/memories", UriKind.Relative); diff --git a/dotnet/src/Memory/Memory/Memory.csproj b/dotnet/src/Memory/Memory/Memory.csproj index dfe3e844db9f..196cf06c06a7 100644 --- a/dotnet/src/Memory/Memory/Memory.csproj +++ b/dotnet/src/Memory/Memory/Memory.csproj @@ -5,6 +5,7 @@ Microsoft.SemanticKernel.Memory net8.0;netstandard2.0 false + alpha diff --git a/dotnet/src/Memory/Memory/Properties/AssemblyInfo.cs b/dotnet/src/Memory/Memory/Properties/AssemblyInfo.cs new file mode 100644 index 000000000000..418ffa6d4b58 --- /dev/null +++ b/dotnet/src/Memory/Memory/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +// This assembly is currently experimental. +[assembly: Experimental("SKEXP0130")] diff --git a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs b/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs index 4190f1ac7b13..3ecd23133f92 100644 --- a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs @@ -13,7 +13,7 @@ namespace Microsoft.SemanticKernel.Agents.Memory; /// A memory component that can retrieve, maintain and store user preferences that /// are learned from the user's interactions with the agent. /// -public class UserPreferencesMemoryComponent : ThreadExtension +public class UserPreferencesMemoryComponent : ConversationStateExtension { private readonly Kernel _kernel; private readonly TextMemoryStore _textMemoryStore; @@ -103,7 +103,7 @@ public override async Task OnThreadCreatedAsync(string? threadId, CancellationTo } /// - public override async Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) + public override async Task OnThreadDeleteAsync(string? threadId, CancellationToken cancellationToken = default) { await this._textMemoryStore.SaveMemoryAsync("UserPreferences", this._userPreferences, cancellationToken).ConfigureAwait(false); } From a219464b0c4aeec7d2a2d5ed33f312e44dc0861e Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 9 Apr 2025 19:45:39 +0100 Subject: [PATCH 08/46] Move extension classes into SK abstractions --- .../Memory}/ConversationStateExtension.cs | 0 .../Memory}/ConversationStateExtensionsManager.cs | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename dotnet/src/{Memory/Memory.Abstractions => SemanticKernel.Abstractions/Memory}/ConversationStateExtension.cs (100%) rename dotnet/src/{Memory/Memory.Abstractions => SemanticKernel.Abstractions/Memory}/ConversationStateExtensionsManager.cs (100%) diff --git a/dotnet/src/Memory/Memory.Abstractions/ConversationStateExtension.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs similarity index 100% rename from dotnet/src/Memory/Memory.Abstractions/ConversationStateExtension.cs rename to dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs diff --git a/dotnet/src/Memory/Memory.Abstractions/ConversationStateExtensionsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs similarity index 100% rename from dotnet/src/Memory/Memory.Abstractions/ConversationStateExtensionsManager.cs rename to dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs From 5f292832d15381b71f7e34142bc646473633d043 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 11 Apr 2025 13:55:29 +0100 Subject: [PATCH 09/46] Make some tweaks to mem0 and text memory store --- ...oryComponent.cs => Mem0MemoryComponent.cs} | 6 +++--- .../Memory/VectorDataTextMemoryStore.cs | 20 +++---------------- 2 files changed, 6 insertions(+), 20 deletions(-) rename dotnet/src/Memory/Memory/{MemZeroMemoryComponent.cs => Mem0MemoryComponent.cs} (96%) diff --git a/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs b/dotnet/src/Memory/Memory/Mem0MemoryComponent.cs similarity index 96% rename from dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs rename to dotnet/src/Memory/Memory/Mem0MemoryComponent.cs index af43cf14bdff..8bce4f09d0fb 100644 --- a/dotnet/src/Memory/Memory/MemZeroMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/Mem0MemoryComponent.cs @@ -18,7 +18,7 @@ namespace Microsoft.SemanticKernel.Memory; /// A component that listenes to messages added to the conversation thread, and automatically captures /// information about the user. It is also able to retrieve this information and add it to the AI invocation context. /// -public class MemZeroMemoryComponent : ConversationStateExtension +public class Mem0MemoryComponent : ConversationStateExtension { private static readonly Uri s_searchUri = new("/search", UriKind.Relative); private static readonly Uri s_createMemoryUri = new("/memories", UriKind.Relative); @@ -33,14 +33,14 @@ public class MemZeroMemoryComponent : ConversationStateExtension private string _userInformation = string.Empty; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The HTTP client used for making requests. /// The ID of the agent. /// The ID of the thread. /// The ID of the user. /// Indicates whether the scope is limited to the thread. - public MemZeroMemoryComponent(HttpClient httpClient, string? agentId = default, string? threadId = default, string? userId = default, bool scopeToThread = false) + public Mem0MemoryComponent(HttpClient httpClient, string? agentId = default, string? threadId = default, string? userId = default, bool scopeToThread = false) { Verify.NotNull(httpClient); diff --git a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs b/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs index 6027b5baf923..594fde5e8605 100644 --- a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs +++ b/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs @@ -19,7 +19,6 @@ namespace Microsoft.SemanticKernel.Memory; /// The key type to use with the vector store. public class VectorDataTextMemoryStore : TextMemoryStore, IDisposable where TKey : notnull - { private readonly IVectorStore _vectorStore; private readonly ITextEmbeddingGenerationService _textEmbeddingGenerationService; @@ -30,19 +29,6 @@ public class VectorDataTextMemoryStore : TextMemoryStore, IDisposable private bool _collectionInitialized = false; private bool _disposedValue; - private readonly VectorStoreRecordDefinition _memoryDocumentDefinition = new() - { - Properties = new List() - { - new VectorStoreRecordKeyProperty("Key", typeof(TKey)), - new VectorStoreRecordDataProperty("Namespace", typeof(string)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("Name", typeof(string)), - new VectorStoreRecordDataProperty("Category", typeof(string)), - new VectorStoreRecordDataProperty("MemoryText", typeof(string)), - new VectorStoreRecordVectorProperty("MemoryTextEmbedding", typeof(ReadOnlyMemory)), - } - }; - /// /// Initializes a new instance of the class. /// @@ -65,12 +51,12 @@ public VectorDataTextMemoryStore(IVectorStore vectorStore, ITextEmbeddingGenerat throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}'"); } - VectorStoreRecordDefinition _memoryDocumentDefinition = new() + VectorStoreRecordDefinition memoryDocumentDefinition = new() { Properties = new List() { new VectorStoreRecordKeyProperty("Key", typeof(TKey)), - new VectorStoreRecordDataProperty("Namespace", typeof(string)), + new VectorStoreRecordDataProperty("Namespace", typeof(string)) { IsFilterable = true }, new VectorStoreRecordDataProperty("Name", typeof(string)), new VectorStoreRecordDataProperty("Category", typeof(string)), new VectorStoreRecordDataProperty("MemoryText", typeof(string)), @@ -83,7 +69,7 @@ public VectorDataTextMemoryStore(IVectorStore vectorStore, ITextEmbeddingGenerat this._storageNamespace = storageNamespace; this._vectorDimensions = vectorDimensions; this._vectorStoreRecordCollection = new Lazy>>(() => - this._vectorStore.GetCollection>(collectionName, this._memoryDocumentDefinition)); + this._vectorStore.GetCollection>(collectionName, memoryDocumentDefinition)); } /// From 33fcb06292bf3c2032e96680b72158b4f7760cc1 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 11 Apr 2025 14:55:39 +0100 Subject: [PATCH 10/46] Add TextRagComponent, TextRagStore and test to use it. --- .../ChatCompletionAgentWithMemoryTests.cs | 102 ++++++ .../Memory/Memory/TextRag/TextRagComponent.cs | 67 ++++ .../Memory/TextRag/TextRagComponentOptions.cs | 32 ++ .../Memory/Memory/TextRag/TextRagDocument.cs | 64 ++++ .../src/Memory/Memory/TextRag/TextRagStore.cs | 301 ++++++++++++++++++ 5 files changed, 566 insertions(+) create mode 100644 dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs create mode 100644 dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs create mode 100644 dotnet/src/Memory/Memory/TextRag/TextRagDocument.cs create mode 100644 dotnet/src/Memory/Memory/TextRag/TextRagStore.cs diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index e5c8d1c9522a..d8f5f904245f 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Azure.Identity; @@ -11,6 +12,7 @@ using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Connectors.InMemory; using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticKernel.Memory.TextRag; using SemanticKernel.IntegrationTests.TestSettings; using Xunit; @@ -86,4 +88,104 @@ public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserIn await this.Fixture.DeleteThread(agentThread1); await this.Fixture.DeleteThread(agentThread2); } + + [Fact] + public virtual async Task RagComponentWithoutMatchesAsync() + { + // Arrange + var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); + + var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); + + var vectorStore = new InMemoryVectorStore(); + using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "Memories", 1536, "group/g1"); + var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); + + await ragStore.UpsertDocumentsAsync(GetSampleDocuments()); + + var agent = this.Fixture.Agent; + + // Act + var agentThread = new ChatHistoryAgentThread(); + agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + + var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); + var results1 = await asyncResults1.ToListAsync(); + + // Assert + Assert.DoesNotContain("174", results1.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + + [Fact] + public virtual async Task RagComponentWithMatchesAsync() + { + // Arrange + var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); + + var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); + + var vectorStore = new InMemoryVectorStore(); + using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "Memories", 1536, "group/g2"); + var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); + + await ragStore.UpsertDocumentsAsync(GetSampleDocuments()); + + var agent = this.Fixture.Agent; + + // Act + var agentThread = new ChatHistoryAgentThread(); + agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + + var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); + var results1 = await asyncResults1.ToListAsync(); + + // Assert + Assert.Contains("174", results1.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + + private static IEnumerable GetSampleDocuments() + { + yield return new TextRagDocument("The financial results of Contoso Corp for 2024 is as follows:\nIncome EUR 154 000 000\nExpenses EUR 142 000 000") + { + SourceName = "Contoso 2024 Financial Report", + SourceReference = "https://www.consoso.com/reports/2024.pdf", + Namespaces = ["group/g1"] + }; + yield return new TextRagDocument("The financial results of Contoso Corp for 2023 is as follows:\nIncome EUR 174 000 000\nExpenses EUR 152 000 000") + { + SourceName = "Contoso 2023 Financial Report", + SourceReference = "https://www.consoso.com/reports/2023.pdf", + Namespaces = ["group/g2"] + }; + yield return new TextRagDocument("The financial results of Contoso Corp for 2022 is as follows:\nIncome EUR 184 000 000\nExpenses EUR 162 000 000") + { + SourceName = "Contoso 2022 Financial Report", + SourceReference = "https://www.consoso.com/reports/2022.pdf", + Namespaces = ["group/g2"] + }; + yield return new TextRagDocument("The Contoso Corporation is a multinational business with its headquarters in Paris. The company is a manufacturing, sales, and support organization with more than 100,000 products.") + { + SourceName = "About Contoso", + SourceReference = "https://www.consoso.com/about-us", + Namespaces = ["group/g2"] + }; + yield return new TextRagDocument("The financial results of AdventureWorks for 2021 is as follows:\nIncome USD 223 000 000\nExpenses USD 210 000 000") + { + SourceName = "AdventureWorks 2021 Financial Report", + SourceReference = "https://www.adventure-works.com/reports/2021.pdf", + Namespaces = ["group/g1", "group/g2"] + }; + yield return new TextRagDocument("AdventureWorks is a large American business that specializaes in adventure parks and family entertainment.") + { + SourceName = "About AdventureWorks", + SourceReference = "https://www.adventure-works.com/about-us", + Namespaces = ["group/g1", "group/g2"] + }; + } } diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs new file mode 100644 index 000000000000..d28adc58ce0d --- /dev/null +++ b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// A component that does a search based on any messages that the AI is invoked with and injects the results into the AI invocation context. +/// +public class TextRagComponent : ConversationStateExtension +{ + private readonly ITextSearch _textSearch; + + /// + /// Initializes a new instance of the class. + /// + /// The text search component to retrieve results from. + /// + /// + public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions options) + { + Verify.NotNull(textSearch); + + this._textSearch = textSearch; + this.Options = options ?? throw new ArgumentNullException(nameof(options)); + } + + /// + /// Gets the options that have been configured for this component. + /// + public TextRagComponentOptions Options { get; } + + /// + public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + { + Verify.NotNull(newMessages); + + string input = string.Join("\n", newMessages.Where(m => m is not null).Select(m => m.Content)); + + var searchResults = await this._textSearch.GetTextSearchResultsAsync( + input, + new() { Top = this.Options.Top }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + // Format the results showing the content with source link and name for each result. + var sb = new StringBuilder(); + sb.AppendLine("Please consider the following source information when responding to the user:"); + await foreach (var result in searchResults.Results.ConfigureAwait(false)) + { + sb.AppendLine($" Source Document Name: {result.Name}"); + sb.AppendLine($" Source Document Link: {result.Link}"); + sb.AppendLine($" Source Document Contents: {result.Value}"); + sb.AppendLine(" -----------------"); + } + + sb.AppendLine("Include citations to the relevant information where it is referenced in the response."); + sb.AppendLine("-------------------"); + + return sb.ToString(); + } +} diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs b/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs new file mode 100644 index 000000000000..146042c2a0b2 --- /dev/null +++ b/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Contains options for the . +/// +public class TextRagComponentOptions +{ + private int _top = 3; + + /// + /// Maximum number of results to return from the similarity search. + /// + /// The value must be greater than 0. + /// The default value is 3 if not set. + public int Top + { + get => this._top; + init + { + if (value < 1) + { + throw new ArgumentOutOfRangeException(nameof(value), "Top must be greater than 0."); + } + + this._top = value; + } + } +} diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagDocument.cs b/dotnet/src/Memory/Memory/TextRag/TextRagDocument.cs new file mode 100644 index 000000000000..7e6146b769a8 --- /dev/null +++ b/dotnet/src/Memory/Memory/TextRag/TextRagDocument.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Memory.TextRag; + +/// +/// Represents a document that can be used for Retrieval Augmented Generation (RAG). +/// +public class TextRagDocument +{ + /// + /// Initializes a new instance of the class. + /// + /// The text content. + public TextRagDocument(string text) + { + Verify.NotNullOrWhiteSpace(text); + + this.Text = text; + } + + /// + /// Gets or sets an optional list of namespaces that the document should belong to. + /// + /// + /// A namespace is a logical grouping of documents, e.g. may include a group id to scope the document to a specific group of users. + /// + public List Namespaces { get; set; } = []; + + /// + /// Gets or sets an optional source ID for the document. + /// + /// + /// This ID should be unique within the collection that the document is stored in, and can + /// be used to map back to the source artifact for this document. + /// If updates need to be made later or the source document was deleted and this document + /// also needs to be deleted, this id can be used to find the document again. + /// + public string? SourceId { get; set; } + + /// + /// Gets or sets the content as text. + /// + public string Text { get; set; } + + /// + /// Gets or sets an optional name for the source document. + /// + /// + /// This can be used to provide display names for citation links when the document is referenced as + /// part of a response to a query. + /// + public string? SourceName { get; set; } + + /// + /// Gets or sets an optional reference back to the source of the document. + /// + /// + /// This can be used to provide citation links when the document is referenced as + /// part of a response to a query. + /// + public string? SourceReference { get; set; } +} diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs b/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs new file mode 100644 index 000000000000..1931c004c6a6 --- /dev/null +++ b/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs @@ -0,0 +1,301 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Data; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Memory.TextRag; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// A class that allows for easy storage and retrieval of documents in a Vector Store for Retrieval Augmented Generation (RAG). +/// +/// The key type to use with the vector store. +public class TextRagStore : ITextSearch, IDisposable + where TKey : notnull +{ + private readonly IVectorStore _vectorStore; + private readonly ITextEmbeddingGenerationService _textEmbeddingGenerationService; + private readonly int _vectorDimensions; + private readonly string? _searchNamespace; + + private readonly Lazy>> _vectorStoreRecordCollection; + private readonly SemaphoreSlim _collectionInitializationLock = new(1, 1); + private bool _collectionInitialized = false; + private bool _disposedValue; + + /// + /// Initializes a new instance of the class. + /// + /// The vector store to store and read the memories from. + /// The service to use for generating embeddings for the memories. + /// The name of the collection in the vector store to store and read the memories from. + /// The number of dimentions to use for the memory embeddings. + /// An optional namespace to filter search results to. + /// Thrown if the key type provided is not supported. + public TextRagStore(IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService, string collectionName, int vectorDimensions, string? searchNamespace) + { + Verify.NotNull(vectorStore); + Verify.NotNull(textEmbeddingGenerationService); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.True(vectorDimensions > 0, "Vector dimensions must be greater than 0"); + + this._vectorStore = vectorStore; + this._textEmbeddingGenerationService = textEmbeddingGenerationService; + this._vectorDimensions = vectorDimensions; + this._searchNamespace = searchNamespace; + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(Guid)) + { + throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}'"); + } + + VectorStoreRecordDefinition ragDocumentDefinition = new() + { + Properties = new List() + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("Namespaces", typeof(List)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("SourceId", typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Text", typeof(string)), + new VectorStoreRecordDataProperty("SourceName", typeof(string)), + new VectorStoreRecordDataProperty("SourceReference", typeof(string)), + new VectorStoreRecordVectorProperty("TextEmbedding", typeof(ReadOnlyMemory)) { Dimensions = vectorDimensions }, + } + }; + + this._vectorStoreRecordCollection = new Lazy>>(() => + this._vectorStore.GetCollection>(collectionName, ragDocumentDefinition)); + } + + /// + /// Upserts a batch of documents into the vector store. + /// + /// The documents to upload. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the documents have been upserted. + public async Task UpsertDocumentsAsync(IEnumerable documents, CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + + var storageDocumentsTasks = documents.Select(async document => + { + var key = GenerateUniqueKey(document.SourceId); + var textEmbedding = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(document.Text).ConfigureAwait(false); + + return new TextRagStorageDocument + { + Key = key, + Namespaces = document.Namespaces, + SourceId = document.SourceId, + Text = document.Text, + SourceName = document.SourceName, + SourceReference = document.SourceReference, + TextEmbedding = textEmbedding + }; + }); + + var storageDocuments = await Task.WhenAll(storageDocumentsTasks).ConfigureAwait(false); + await vectorStoreRecordCollection.UpsertBatchAsync(storageDocuments, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + } + + /// + public async Task> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + + return new(searchResult.Results.Select(x => x.Record.Text ?? string.Empty)); + } + + /// + public async Task> GetTextSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + + var results = searchResult.Results.Select(x => new TextSearchResult(x.Record.Text ?? string.Empty) { Name = x.Record.SourceName, Link = x.Record.SourceReference }); + return new(searchResult.Results.Select(x => + new TextSearchResult(x.Record.Text ?? string.Empty) + { + Name = x.Record.SourceName, + Link = x.Record.SourceReference + })); + } + + /// + public async Task> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + return new(searchResult.Results.Cast()); + } + + /// + /// Internal search implementation. + /// + /// The text query to find similar documents to. + /// Search options. + /// The to monitor for cancellation requests. The default is . + /// The search results. + private async Task>> SearchInternalAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + + // Optional filter to limit the search to a specific namespace. + Expression, bool>>? filter = string.IsNullOrWhiteSpace(this._searchNamespace) ? null : x => x.Namespaces.Contains(this._searchNamespace); + + var vector = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(query, cancellationToken: cancellationToken).ConfigureAwait(false); + var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( + vector, + options: new() + { + Top = searchOptions?.Top ?? 3, + Filter = filter, + }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + return searchResult; + } + + /// + /// Thread safe method to get the collection and ensure that it is created at least once. + /// + /// The to monitor for cancellation requests. The default is . + /// The created collection. + private async Task>> EnsureCollectionCreatedAsync(CancellationToken cancellationToken) + { + var vectorStoreRecordCollection = this._vectorStoreRecordCollection.Value; + + // Return immediately if the collection is already created, no need to do any locking in this case. + if (this._collectionInitialized) + { + return vectorStoreRecordCollection; + } + + // Wait on a lock to ensure that only one thread can create the collection. + await this._collectionInitializationLock.WaitAsync(cancellationToken).ConfigureAwait(false); + + // If multiple threads waited on the lock, and the first already created the collection, + // we can return immediately without doing any work in subsequent threads. + if (this._collectionInitialized) + { + this._collectionInitializationLock.Release(); + return vectorStoreRecordCollection; + } + + // Only the winning thread should reach this point and create the collection. + try + { + await vectorStoreRecordCollection.CreateCollectionIfNotExistsAsync(cancellationToken).ConfigureAwait(false); + this._collectionInitialized = true; + } + finally + { + this._collectionInitializationLock.Release(); + } + + return vectorStoreRecordCollection; + } + + /// + /// Generates a unique key for the RAG document. + /// + /// Source id of the source document for this RAG document. + /// The type of the key to use, since different databases require/support different keys. + /// A new unique key. + /// Thrown if the requested key type is not supported. + private static TDocumentKey GenerateUniqueKey(string? sourceId) + => typeof(TDocumentKey) switch + { + _ when typeof(TDocumentKey) == typeof(string) && !string.IsNullOrWhiteSpace(sourceId) => (TDocumentKey)(object)sourceId!, + _ when typeof(TDocumentKey) == typeof(string) => (TDocumentKey)(object)Guid.NewGuid().ToString(), + _ when typeof(TDocumentKey) == typeof(Guid) => (TDocumentKey)(object)Guid.NewGuid(), + + _ => throw new NotSupportedException($"Unsupported key of type '{typeof(TDocumentKey).Name}'") + }; + + /// + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._collectionInitializationLock.Dispose(); + } + + this._disposedValue = true; + } + } + + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + /// + /// The data model to use for storing RAG documents in the vector store. + /// + /// The type of the key to use, since different databases require/support different keys. + private sealed class TextRagStorageDocument + { + /// + /// Gets or sets a unique identifier for the memory document. + /// + public TDocumentKey Key { get; set; } = default!; + + /// + /// Gets or sets an optional list of namespaces that the document should belong to. + /// + /// + /// A namespace is a logical grouping of documents, e.g. may include a group id to scope the document to a specific group of users. + /// + public List Namespaces { get; set; } = []; + + /// + /// Gets or sets an optional source ID for the document. + /// + /// + /// This ID should be unique within the collection that the document is stored in, and can + /// be used to map back to the source artifact for this document. + /// If updates need to be made later or the source document was deleted and this document + /// also needs to be deleted, this id can be used to find the document again. + /// + public string? SourceId { get; set; } + + /// + /// Gets or sets the content as text. + /// + public string? Text { get; set; } + + /// + /// Gets or sets an optional name for the source document. + /// + /// + /// This can be used to provide display names for citation links when the document is referenced as + /// part of a response to a query. + /// + public string? SourceName { get; set; } + + /// + /// Gets or sets an optional reference back to the source of the document. + /// + /// + /// This can be used to provide citation links when the document is referenced as + /// part of a response to a query. + /// + public string? SourceReference { get; set; } + + /// + /// Gets or sets the embedding for the text content. + /// + public ReadOnlyMemory TextEmbedding { get; set; } + } +} From 059607cee3e5e5c38e60d8ddf7615e61457e2e4b Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:55:41 +0100 Subject: [PATCH 11/46] Add comments to tests. --- .../ChatCompletionAgentWithMemoryTests.cs | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index d8f5f904245f..3d809c7f6d2b 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -92,27 +92,30 @@ public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserIn [Fact] public virtual async Task RagComponentWithoutMatchesAsync() { - // Arrange + // Arrange - Create Embedding Service var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); - var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); + // Arrange - Create Vector Store and Rag Store/Component var vectorStore = new InMemoryVectorStore(); using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "Memories", 1536, "group/g1"); var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); + // Arrange - Upsert documents into the Rag Store await ragStore.UpsertDocumentsAsync(GetSampleDocuments()); var agent = this.Fixture.Agent; - // Act + // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); var results1 = await asyncResults1.ToListAsync(); - // Assert + // Assert - Check if the response does not contain the expected value from the database because + // we filtered by group/g1 which doesn't include the required document. Assert.DoesNotContain("174", results1.First().Message.Content); // Cleanup @@ -122,27 +125,29 @@ public virtual async Task RagComponentWithoutMatchesAsync() [Fact] public virtual async Task RagComponentWithMatchesAsync() { - // Arrange + // Arrange - Create Embedding Service var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); - var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); + // Arrange - Create Vector Store and Rag Store/Component var vectorStore = new InMemoryVectorStore(); using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "Memories", 1536, "group/g2"); var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); + // Arrange - Upsert documents into the Rag Store await ragStore.UpsertDocumentsAsync(GetSampleDocuments()); var agent = this.Fixture.Agent; - // Act + // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); var results1 = await asyncResults1.ToListAsync(); - // Assert + // Assert - Check if the response contains the expected value from the database. Assert.Contains("174", results1.First().Message.Content); // Cleanup From 9e1c57ea2fe3fae62738a75e75a7eb864e4743c2 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 11 Apr 2025 17:17:12 +0100 Subject: [PATCH 12/46] Rename user preferences to user facts and fix a bug --- .../ChatCompletionAgentWithMemoryTests.cs | 6 +-- .../OpenAIAssistantAgentWithMemoryTests.cs.cs | 2 +- ...mponent.cs => UserFactsMemoryComponent.cs} | 54 +++++++++---------- 3 files changed, 31 insertions(+), 31 deletions(-) rename dotnet/src/Memory/Memory/{UserPreferencesMemoryComponent.cs => UserFactsMemoryComponent.cs} (70%) diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index 3d809c7f6d2b..eb84ef92a1a4 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -32,7 +32,7 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() { // Arrange var agent = this.Fixture.Agent; - var memoryComponent = new UserPreferencesMemoryComponent(this.Fixture.Agent.Kernel); + var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); var agentThread1 = new ChatHistoryAgentThread(); agentThread1.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); @@ -69,14 +69,14 @@ public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserIn // Act - First invocation with first thread. var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.ThreadExtensionsManager.RegisterThreadExtension(new UserPreferencesMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + agentThread1.ThreadExtensionsManager.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); var results1 = await asyncResults1.ToListAsync(); // Act - Second invocation with second thread. var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.ThreadExtensionsManager.RegisterThreadExtension(new UserPreferencesMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + agentThread2.ThreadExtensionsManager.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); var results2 = await asyncResults2.ToListAsync(); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs index 6e03c749a049..0cb0b238c6ee 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs @@ -17,7 +17,7 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() { // Arrange var agent = this.Fixture.Agent; - var memoryComponent = new UserPreferencesMemoryComponent(this.Fixture.Agent.Kernel); + var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); var agentThread1 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); agentThread1.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); diff --git a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs similarity index 70% rename from dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs rename to dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs index 3ecd23133f92..cbd734a4d565 100644 --- a/dotnet/src/Memory/Memory/UserPreferencesMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -10,45 +10,45 @@ namespace Microsoft.SemanticKernel.Agents.Memory; /// -/// A memory component that can retrieve, maintain and store user preferences that +/// A memory component that can retrieve, maintain and store user facts that /// are learned from the user's interactions with the agent. /// -public class UserPreferencesMemoryComponent : ConversationStateExtension +public class UserFactsMemoryComponent : ConversationStateExtension { private readonly Kernel _kernel; private readonly TextMemoryStore _textMemoryStore; - private string _userPreferences = string.Empty; + private string _userFacts = string.Empty; private bool _contextLoaded = false; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// A kernel to use for making chat completion calls. /// The memory store to retrieve and save memories from and to. - public UserPreferencesMemoryComponent(Kernel kernel, TextMemoryStore textMemoryStore) + public UserFactsMemoryComponent(Kernel kernel, TextMemoryStore textMemoryStore) { this._kernel = kernel; this._textMemoryStore = textMemoryStore; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// A kernel to use for making chat completion calls. - /// The service key that the for user preferences is registered under in DI. - public UserPreferencesMemoryComponent(Kernel kernel, string? userPreferencesStoreName = "UserPreferencesStore") + /// The service key that the for user facts is registered under in DI. + public UserFactsMemoryComponent(Kernel kernel, string? userFactsStoreName = "UserFactsStore") { this._kernel = kernel; - this._textMemoryStore = new OptionalTextMemoryStore(kernel, userPreferencesStoreName); + this._textMemoryStore = new OptionalTextMemoryStore(kernel, userFactsStoreName); } /// - /// Gets or sets the name of the document to use for storing user preferences. + /// Gets or sets the name of the document to use for storing user preferfactsences. /// - public string UserPreferencesDocumentName { get; init; } = "UserPreferences"; + public string UserFactsDocumentName { get; init; } = "UserFacts"; /// - /// Gets or sets the prompt template to use for extracting user preferences and merging them with existing preferences. + /// Gets or sets the prompt template to use for extracting user facts and merging them with existing facts. /// public string MaintainencePromptTemplate { get; init; } = """ @@ -82,7 +82,7 @@ EXAMPLES END Return output for the following inputs like shown in the examples above: Input text: {{$inputText}} - Input facts: {{$existingPreferences}} + Input facts: {{existingFacts}} """; /// @@ -90,12 +90,12 @@ public override async Task OnThreadCreatedAsync(string? threadId, CancellationTo { if (!this._contextLoaded) { - this._userPreferences = string.Empty; + this._userFacts = string.Empty; - var memoryText = await this._textMemoryStore.GetMemoryAsync("UserPreferences", cancellationToken).ConfigureAwait(false); + var memoryText = await this._textMemoryStore.GetMemoryAsync(this.UserFactsDocumentName, cancellationToken).ConfigureAwait(false); if (memoryText is not null) { - this._userPreferences = memoryText; + this._userFacts = memoryText; } this._contextLoaded = true; @@ -105,7 +105,7 @@ public override async Task OnThreadCreatedAsync(string? threadId, CancellationTo /// public override async Task OnThreadDeleteAsync(string? threadId, CancellationToken cancellationToken = default) { - await this._textMemoryStore.SaveMemoryAsync("UserPreferences", this._userPreferences, cancellationToken).ConfigureAwait(false); + await this._textMemoryStore.SaveMemoryAsync(this.UserFactsDocumentName, this._userFacts, cancellationToken).ConfigureAwait(false); } /// @@ -121,7 +121,7 @@ public override async Task OnNewMessageAsync(ChatMessageContent newMessage, Canc /// public override Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { - return Task.FromResult("The following list contains facts about the user:\n" + this._userPreferences); + return Task.FromResult("The following list contains facts about the user:\n" + this._userFacts); } /// @@ -130,29 +130,29 @@ public override void RegisterPlugins(Kernel kernel) Verify.NotNull(kernel); base.RegisterPlugins(kernel); - kernel.Plugins.AddFromObject(this, "UserPreferencesMemory"); + kernel.Plugins.AddFromObject(this, "UserFactsMemory"); } /// - /// Plugin method to clear user preferences stored in memory. + /// Plugin method to clear user facts stored in memory. /// [KernelFunction] - [Description("Deletes any user preferences stored about the user.")] - public async Task ClearUserPreferencesAsync(CancellationToken cancellationToken = default) + [Description("Deletes any user facts stored about the user.")] + public async Task ClearUserFactsAsync(CancellationToken cancellationToken = default) { - this._userPreferences = string.Empty; - await this._textMemoryStore.SaveMemoryAsync("UserPreferences", this._userPreferences, cancellationToken).ConfigureAwait(false); + this._userFacts = string.Empty; + await this._textMemoryStore.SaveMemoryAsync(this.UserFactsDocumentName, this._userFacts, cancellationToken).ConfigureAwait(false); } private async Task ExtractAndSaveMemoriesAsync(string inputText, CancellationToken cancellationToken = default) { var result = await this._kernel.InvokePromptAsync( this.MaintainencePromptTemplate, - new KernelArguments() { ["inputText"] = inputText, ["existingPreferences"] = this._userPreferences }, + new KernelArguments() { ["inputText"] = inputText, ["existingFacts"] = this._userFacts }, cancellationToken: cancellationToken).ConfigureAwait(false); - this._userPreferences = result.ToString(); + this._userFacts = result.ToString(); - await this._textMemoryStore.SaveMemoryAsync("UserPreferences", this._userPreferences, cancellationToken).ConfigureAwait(false); + await this._textMemoryStore.SaveMemoryAsync(this.UserFactsDocumentName, this._userFacts, cancellationToken).ConfigureAwait(false); } } From 7150cacddb9b8ba2a0de19740591d58b201e3861 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 11 Apr 2025 17:58:13 +0100 Subject: [PATCH 13/46] Add agents with memory ADR --- docs/decisions/00NN-agents-with-memory.md | 170 ++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 docs/decisions/00NN-agents-with-memory.md diff --git a/docs/decisions/00NN-agents-with-memory.md b/docs/decisions/00NN-agents-with-memory.md new file mode 100644 index 000000000000..5fe2f0399d72 --- /dev/null +++ b/docs/decisions/00NN-agents-with-memory.md @@ -0,0 +1,170 @@ +--- +# These are optional elements. Feel free to remove any of them. +status: {proposed | rejected | accepted | deprecated | … | superseded by [ADR-0001](0001-madr-architecture-decisions.md)} +contact: {person proposing the ADR} +date: {YYYY-MM-DD when the decision was last updated} +deciders: {list everyone involved in the decision} +consulted: {list everyone whose opinions are sought (typically subject-matter experts); and with whom there is a two-way communication} +informed: {list everyone who is kept up-to-date on progress; and with whom there is a one-way communication} +--- + +# Agents with Memory + +## What do we mean by Memory? + +By memory we mean the capability to remember information and skills that are learned during +a conversation and re-use those later in the same conversation or later in a subsequent conversation. + +## Context and Problem Statement + +Today we support multiple agent types with different characteristcs: + +1. In process vs remote. +1. Remote agents that store and maintain conversation state in the service vs those that require the caller to provide conversation state on each invocation. + +We need to support advanced memory capabilities across this range of agent types. + +### Memory Scope + +Another aspect of memory that is important to consider is the scope of different memory types. +Most agent implementations have instructions and skills but the agent is not tied to a single conversation. +On each invocation of the agent, the agent is told which conversation to participate in, during that invocation. + +Memories about a user or about a conversation with a user is therefore extracted from one of these conversation and recalled +during the same or another conversation with the same user. +These memories will typically contain information that the user would not like to share with other users of the system. + +Other types of memories also exist which are not tied to a specific user or conversation. +E.g. an Agent may learn how to do something and be able to do that in many conversations with different users. +With these type of memories there is of cousrse risk in leaking personal information between different users which is important to guard against. + +### Packaging memory capabilities + +All of the above memory types can be supported for any agent by attaching software components to conversation threads. +This is achieved via a simple mechanism of: + +1. Inspecting and using messages as they are passed to and from the agent. +1. Passing additional context to the agent per invocation. + +With our current `AgentThread` implementation, when an agent is invoked, all input and output messages are already passed to the `AgentThread` +and can be made availble to any components attached to the `AgentThread`. +Where agents are remote/external and manage conversation state in the service, passing the messages to the `AgentThread` may not have any +affect on the thread in the service. This is OK, since the service will have already updated the thread during the remote invocation. +It does however, still allow us to subscribe to messages in any attached components. + +For the second requirement of getting addtional context per invocation, the agent may ask the thread passed to it, to in turn ask +each of the components attached to it, to provide context ot pass to the Agent. +This enables the component to provide memories that it contains to the Agent as needed. + +Different memory capabilities can be built using separate components. Each component would have the following characteristics: + +1. May store some context that can be provided to the agent per invocation. +1. May inspect messages from the conversation to learn from the conversation and build its context. +1. May register plugins to allow the agent to directly store, retrieve, update or clear memories. + +### Suspend / Resume + +Building a service to host an agent comes with challenges. +It's hard to build a stateful service, but service consumers expect an experience that looks stateful from the outside. +E.g. on each invocation, the user expects that the service can continue a conversation they are having. + +This means that where the the service is exposing a local agent with local conversation state management (e.g. via `ChatHistory`) +that conversation state needs to be loaded and persisted for each invocation of the service. + +It also means that any memory components that may have some in-memory state will need to be loaded and persisted too. + +## Proposed interface for Memory Components + +The types of events that Memory Components require are not unique to memory, and can be used to package up other capabilities too. +The suggestion is therefore to create a more generally named type that can be used for other scenarios as well and can even +be used for non-agent scenarios too. + +This type should live in the `Microsoft.SemanticKernel.Abstractions` nuget, since these components can be used by systems other than just agents. + +```csharp +namespace Microsoft.SemanticKernel.Memory; + +public abstract class ConversationStateExtension +{ + public virtual Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default); + public virtual Task OnThreadCheckpointAsync(string threadId, CancellationToken cancellationToken = default); + public virtual Task OnThreadDeleteAsync(string? threadId, CancellationToken cancellationToken = default); + + public virtual Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default); + public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); + + public virtual void RegisterPlugins(Kernel kernel); +} +``` + +> TODO: Decide about the correct namespace for `ConversationStateExtension` + +## Managing multiple components + +To manage multiple components I propose that we have a `ConversationStateExtensionsManager`. +This class allows registering components and delegating new message notifications, ai invocation calls, etc. to the contained components. + +## Integrating with agents + +I propose to add a `ConversationStateExtensionsManager` to the `AgentThread` class, allowing us to attach components to any `AgentThread`. + +When an `Agent` is invoked, we will call `OnAIInvocationAsync` on each component via the `ConversationStateExtensionsManager` to get +a combined set of context to pass to the agent for this invocation. This will be internal to the `Agent` class and transparent to the user. + +```csharp +var additionalInstructions = await currentAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); +``` + +## Usage examples + +### Multiple threads using the same memory component + +```csharp +// Create a vector store for storing memories. +var vectorStore = new InMemoryVectorStore(); +// Create a memory store that is tired to a "Memories" collection in the vector store and stores memories under the "user/12345" namespace. +using var textMemoryStore = new VectorDataTextMemoryStore(vectorStore, textEmbeddingService, "Memories", "user/12345", 1536); + +// Create a memory component to will pull user facts from the conversation, store them in the vector store +// and pass them to the agent as additional instructions. +var userFacts = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore); + +// Create a thread and attach a Memory Component. +var agentThread1 = new ChatHistoryAgentThread(); +agentThread1.ThreadExtensionsManager.RegisterThreadExtension(userFacts); +var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); + +// Create a second thread and attach a Memory Component. +var agentThread2 = new ChatHistoryAgentThread(); +agentThread2.ThreadExtensionsManager.RegisterThreadExtension(userFacts); +var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); +// Expected response contains Caoimhe. +``` + +### Using a RAG component + +```csharp +// Create Vector Store and Rag Store/Component +var vectorStore = new InMemoryVectorStore(); +using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "Memories", 1536, "group/g2"); +var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); + +// Upsert docs into vector store. +await ragStore.UpsertDocumentsAsync( +[ + new TextRagDocument("The financial results of Contoso Corp for 2023 is as follows:\nIncome EUR 174 000 000\nExpenses EUR 152 000 000") + { + SourceName = "Contoso 2023 Financial Report", + SourceReference = "https://www.consoso.com/reports/2023.pdf", + Namespaces = ["group/g2"] + } +]); + +// Create a new agent thread and register the Rag component +var agentThread = new ChatHistoryAgentThread(); +agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + +// Inovke the agent. +var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); +// Expected response contains the 174M income from the document. +``` From 09e3e76f4747f4e547f08cb70c18841472252a96 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 14 Apr 2025 13:49:38 +0100 Subject: [PATCH 14/46] Split mem0 into client and component, test against service and add integration test. --- .../ChatCompletionAgentWithMemoryTests.cs | 36 +++ dotnet/src/Memory/Memory/Mem0/Mem0Client.cs | 171 +++++++++++++ .../Memory/Memory/Mem0/Mem0MemoryComponent.cs | 114 +++++++++ .../Memory/Mem0/Mem0MemoryComponentOptions.cs | 49 ++++ .../src/Memory/Memory/Mem0MemoryComponent.cs | 231 ------------------ 5 files changed, 370 insertions(+), 231 deletions(-) create mode 100644 dotnet/src/Memory/Memory/Mem0/Mem0Client.cs create mode 100644 dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs create mode 100644 dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponentOptions.cs delete mode 100644 dotnet/src/Memory/Memory/Mem0MemoryComponent.cs diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index eb84ef92a1a4..1c5282dee624 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -1,7 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; +using System.Net.Http; +using System.Net.Http.Headers; using System.Threading.Tasks; using Azure.Identity; using Microsoft.Extensions.Configuration; @@ -27,6 +30,39 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Agen .AddUserSecrets() .Build(); + [Fact(Skip = "For manual verification")] + public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() + { + // Arrange + var agent = this.Fixture.Agent; + + using var httpClient = new HttpClient(); + httpClient.BaseAddress = new Uri("https://api.mem0.ai"); + httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", ""); + + var mem0Component = new Mem0MemoryComponent(httpClient, new() { UserId = "U1" }); + + var agentThread1 = new ChatHistoryAgentThread(); + agentThread1.ThreadExtensionsManager.RegisterThreadExtension(mem0Component); + + var agentThread2 = new ChatHistoryAgentThread(); + agentThread2.ThreadExtensionsManager.RegisterThreadExtension(mem0Component); + + // Act + var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); + var results1 = await asyncResults1.ToListAsync(); + + var asyncResults2 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "What is my name?."), agentThread2); + var results2 = await asyncResults2.ToListAsync(); + + // Assert + Assert.Contains("Caoimhe", results2.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread1); + await this.Fixture.DeleteThread(agentThread2); + } + [Fact] public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() { diff --git a/dotnet/src/Memory/Memory/Mem0/Mem0Client.cs b/dotnet/src/Memory/Memory/Mem0/Mem0Client.cs new file mode 100644 index 000000000000..1fc9990bbb77 --- /dev/null +++ b/dotnet/src/Memory/Memory/Mem0/Mem0Client.cs @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Client for the Mem0 memory service. +/// +internal sealed class Mem0Client +{ + private static readonly Uri s_searchUri = new("/v1/memories/search/", UriKind.Relative); + private static readonly Uri s_createMemoryUri = new("/v1/memories/", UriKind.Relative); + + private readonly HttpClient _httpClient; + + public Mem0Client(HttpClient httpClient) + { + Verify.NotNull(httpClient); + + this._httpClient = httpClient; + } + + public async Task> SearchAsync(string? applicationId, string? agentId, string? threadId, string? userId, string? inputText) + { + if (string.IsNullOrWhiteSpace(applicationId) + && string.IsNullOrWhiteSpace(agentId) + && string.IsNullOrWhiteSpace(threadId) + && string.IsNullOrWhiteSpace(userId)) + { + throw new InvalidOperationException("At least one of applicationId, agentId, threadId, or userId must be provided."); + } + + var searchRequest = new SearchRequest + { + AppId = applicationId, + AgentId = agentId, + RunId = threadId, + UserId = userId, + Query = inputText ?? string.Empty + }; + + // Search. + using var content = new StringContent(JsonSerializer.Serialize(searchRequest), Encoding.UTF8, "application/json"); + var responseMessage = await this._httpClient.PostAsync(s_searchUri, content).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + + // Process response. + var response = await responseMessage.Content.ReadAsStringAsync().ConfigureAwait(false); + var searchResponseItems = JsonSerializer.Deserialize(response); + return searchResponseItems?.Select(item => item.Memory) ?? []; + } + + public async Task CreateMemoryAsync(string? applicationId, string? agentId, string? threadId, string? userId, string messageContent, string messageRole) + { + if (string.IsNullOrWhiteSpace(applicationId) + && string.IsNullOrWhiteSpace(agentId) + && string.IsNullOrWhiteSpace(threadId) + && string.IsNullOrWhiteSpace(userId)) + { + throw new InvalidOperationException("At least one of applicationId, agentId, threadId, or userId must be provided."); + } + + var createMemoryRequest = new CreateMemoryRequest() + { + AppId = applicationId, + AgentId = agentId, + RunId = threadId, + UserId = userId, + Messages = new[] + { + new CreateMemoryMemory + { + Content = messageContent, + Role = messageRole + } + } + }; + + using var content = new StringContent(JsonSerializer.Serialize(createMemoryRequest), Encoding.UTF8, "application/json"); + var responseMessage = await this._httpClient.PostAsync(s_createMemoryUri, content).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + } + + public async Task ClearMemoryAsync(string? applicationId, string? agentId, string? threadId, string? userId) + { + string[] paramNames = ["app_id", "agent_id", "run_id", "user_id"]; + + // Build query string. + var querystringParams = new string?[4] { applicationId, userId, agentId, threadId } + .Select((param, index) => string.IsNullOrWhiteSpace(param) ? null : $"{paramNames[index]}={param}") + .Where(x => x != null); + var queryString = string.Join("&", querystringParams); + var clearMemoryUrl = new Uri($"/v1/memories?{queryString}", UriKind.Relative); + + // Delete. + var responseMessage = await this._httpClient.DeleteAsync(clearMemoryUrl).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + } + + private sealed class CreateMemoryRequest + { + [JsonPropertyName("app_id")] + public string? AppId { get; set; } + [JsonPropertyName("agent_id")] + public string? AgentId { get; set; } + [JsonPropertyName("run_id")] + public string? RunId { get; set; } + [JsonPropertyName("user_id")] + public string? UserId { get; set; } + [JsonPropertyName("messages")] + public CreateMemoryMemory[] Messages { get; set; } = []; + } + + private sealed class CreateMemoryMemory + { + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; + [JsonPropertyName("role")] + public string Role { get; set; } = string.Empty; + } + + private sealed class SearchRequest + { + [JsonPropertyName("app_id")] + public string? AppId { get; set; } + [JsonPropertyName("agent_id")] + public string? AgentId { get; set; } = null; + [JsonPropertyName("run_id")] + public string? RunId { get; set; } = null; + [JsonPropertyName("user_id")] + public string? UserId { get; set; } = null; + [JsonPropertyName("query")] + public string Query { get; set; } = string.Empty; + } + +#pragma warning disable CA1812 // Avoid uninstantiated internal classes + private sealed class SearchResponseItem + { + [JsonPropertyName("id")] + public string Id { get; set; } = string.Empty; + [JsonPropertyName("memory")] + public string Memory { get; set; } = string.Empty; + [JsonPropertyName("hash")] + public string Hash { get; set; } = string.Empty; + [JsonPropertyName("metadata")] + public object? Metadata { get; set; } + [JsonPropertyName("score")] + public double Score { get; set; } + [JsonPropertyName("created_at")] + public DateTime CreatedAt { get; set; } + [JsonPropertyName("updated_at")] + public DateTime? UpdatedAt { get; set; } + [JsonPropertyName("user_id")] + public string UserId { get; set; } = string.Empty; + [JsonPropertyName("app_id")] + public string? AppId { get; set; } + [JsonPropertyName("agent_id")] + public string AgentId { get; set; } = string.Empty; + [JsonPropertyName("run_id")] + public string RunId { get; set; } = string.Empty; + } +#pragma warning restore CA1812 // Avoid uninstantiated internal classes +} diff --git a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs new file mode 100644 index 000000000000..96bdd5373a5b --- /dev/null +++ b/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// A component that listenes to messages added to the conversation thread, and automatically captures +/// information about the user. It is also able to retrieve this information and add it to the AI invocation context. +/// +public class Mem0MemoryComponent : ConversationStateExtension +{ + private readonly string? _applicationId; + private readonly string? _agentId; + private string? _threadId; + private readonly string? _userId; + private readonly bool _scopeToThread; + + private readonly Mem0Client _mem0Client; + + /// + /// Initializes a new instance of the class. + /// + /// The HTTP client used for making requests. + /// Options for configuring the component. + public Mem0MemoryComponent(HttpClient httpClient, Mem0MemoryComponentOptions? options = default) + { + Verify.NotNull(httpClient); + + this._applicationId = options?.ApplicationId; + this._agentId = options?.AgentId; + this._threadId = options?.ThreadId; + this._userId = options?.UserId; + this._scopeToThread = options?.ScopeToThread ?? false; + + this._mem0Client = new(httpClient); + } + + /// + public override Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) + { + this._threadId ??= threadId; + return Task.CompletedTask; + } + + /// + public override async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + Verify.NotNull(newMessage); + + if (newMessage.Role == AuthorRole.User && !string.IsNullOrWhiteSpace(newMessage.Content)) + { + await this._mem0Client.CreateMemoryAsync( + this._applicationId, + this._agentId, + this._scopeToThread ? this._threadId : null, + this._userId, + newMessage.Content, + newMessage.Role.Label).ConfigureAwait(false); + } + } + + /// + public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + { + Verify.NotNull(newMessages); + + string inputText = string.Join( + "\n", + newMessages. + Where(m => m is not null && !string.IsNullOrWhiteSpace(m.Content)). + Select(m => m.Content)); + + var memories = await this._mem0Client.SearchAsync( + this._applicationId, + this._agentId, + this._scopeToThread ? this._threadId : null, + this._userId, + inputText).ConfigureAwait(false); + + var userInformation = string.Join("\n", memories); + return "The following list contains facts about the user:\n" + userInformation; + } + + /// + public override void RegisterPlugins(Kernel kernel) + { + Verify.NotNull(kernel); + + base.RegisterPlugins(kernel); + kernel.Plugins.AddFromObject(this, "MemZeroMemory"); + } + + /// + /// Plugin method to clear user preferences stored in memory for the current agent/thread/user. + /// + /// A task that completes when the memory is cleared. + [KernelFunction] + [Description("Deletes any user preferences stored about the user.")] + public async Task ClearUserPreferencesAsync() + { + await this._mem0Client.ClearMemoryAsync( + this._applicationId, + this._userId, + this._agentId, + this._scopeToThread ? this._threadId : null).ConfigureAwait(false); + } +} diff --git a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponentOptions.cs b/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponentOptions.cs new file mode 100644 index 000000000000..6a88d711322a --- /dev/null +++ b/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponentOptions.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Options for the . +/// +public class Mem0MemoryComponentOptions +{ + /// + /// Gets or sets an optional ID for the application to scope memories to. + /// + /// + /// If not set, the scope of the memories will span all applications. + /// + public string? ApplicationId { get; init; } + + /// + /// Gets or sets an optional ID for the agent to scope memories to. + /// + /// + /// If not set, the scope of the memories will span all agents. + /// + public string? AgentId { get; init; } + + /// + /// Gets or sets an optional ID for the thread to scope memories to. + /// + /// + /// This value will be overridden by any thread id provided to the methods of the . + /// + public string? ThreadId { get; init; } + + /// + /// Gets or sets an optional ID for the user to scope memories to. + /// + /// + /// If not set, the scope of the memories will span all users. + /// + public string? UserId { get; init; } + + /// + /// Gets or sets a value indicating whether the scope of the memories is limited to the current thread. + /// + /// + /// If false, will be ignored, and any thread ids passed into the methods of the will also be ignored. + /// + public bool ScopeToThread { get; init; } = false; +} diff --git a/dotnet/src/Memory/Memory/Mem0MemoryComponent.cs b/dotnet/src/Memory/Memory/Mem0MemoryComponent.cs deleted file mode 100644 index 8bce4f09d0fb..000000000000 --- a/dotnet/src/Memory/Memory/Mem0MemoryComponent.cs +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Net.Http; -using System.Text; -using System.Text.Json; -using System.Text.Json.Serialization; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.SemanticKernel.ChatCompletion; - -namespace Microsoft.SemanticKernel.Memory; - -/// -/// A component that listenes to messages added to the conversation thread, and automatically captures -/// information about the user. It is also able to retrieve this information and add it to the AI invocation context. -/// -public class Mem0MemoryComponent : ConversationStateExtension -{ - private static readonly Uri s_searchUri = new("/search", UriKind.Relative); - private static readonly Uri s_createMemoryUri = new("/memories", UriKind.Relative); - - private readonly string? _agentId; - private string? _threadId; - private readonly string? _userId; - private readonly bool _scopeToThread; - private readonly HttpClient _httpClient; - - private bool _contextLoaded = false; - private string _userInformation = string.Empty; - - /// - /// Initializes a new instance of the class. - /// - /// The HTTP client used for making requests. - /// The ID of the agent. - /// The ID of the thread. - /// The ID of the user. - /// Indicates whether the scope is limited to the thread. - public Mem0MemoryComponent(HttpClient httpClient, string? agentId = default, string? threadId = default, string? userId = default, bool scopeToThread = false) - { - Verify.NotNull(httpClient); - - this._agentId = agentId; - this._threadId = threadId; - this._userId = userId; - this._scopeToThread = scopeToThread; - this._httpClient = httpClient; - } - - /// - public override Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) - { - this._threadId ??= threadId; - return Task.CompletedTask; - } - - /// - public override async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) - { - Verify.NotNull(newMessage); - - if (newMessage.Role == AuthorRole.User) - { - await this.CreateMemoryAsync( - new CreateMemoryRequest() - { - AgentId = this._agentId, - RunId = this._scopeToThread ? this._threadId : null, - UserId = this._userId, - Messages = new[] - { - new CreateMemoryMemory - { - Content = newMessage.Content ?? string.Empty, - Role = newMessage.Role.Label - } - } - }).ConfigureAwait(false); - } - } - - /// - public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) - { - Verify.NotNull(newMessages); - - string input = string.Join("\n", newMessages.Where(m => m is not null).Select(m => m.Content)); - - await this.LoadContextAsync(this._threadId, input).ConfigureAwait(false); - - return "The following list contains facts about the user:\n" + this._userInformation; - } - - /// - public override void RegisterPlugins(Kernel kernel) - { - Verify.NotNull(kernel); - - base.RegisterPlugins(kernel); - kernel.Plugins.AddFromObject(this, "MemZeroMemory"); - } - - /// - /// Plugin method to clear user preferences stored in memory for the current agent/thread/user. - /// - /// A task that completes when the memory is cleared. - [KernelFunction] - [Description("Deletes any user preferences stored about the user.")] - public async Task ClearUserPreferencesAsync() - { - await this.ClearMemoryAsync().ConfigureAwait(false); - } - - private async Task LoadContextAsync(string? threadId, string? inputText) - { - if (!this._contextLoaded) - { - this._threadId ??= threadId; - - var searchRequest = new SearchRequest - { - AgentId = this._agentId, - RunId = this._scopeToThread ? this._threadId : null, - UserId = this._userId, - Query = inputText ?? string.Empty - }; - var responseItems = await this.SearchAsync(searchRequest).ConfigureAwait(false); - this._userInformation = string.Join("\n", responseItems); - this._contextLoaded = true; - } - } - - private async Task CreateMemoryAsync(CreateMemoryRequest createMemoryRequest) - { - using var content = new StringContent(JsonSerializer.Serialize(createMemoryRequest), Encoding.UTF8, "application/json"); - var responseMessage = await this._httpClient.PostAsync(s_createMemoryUri, content).ConfigureAwait(false); - responseMessage.EnsureSuccessStatusCode(); - } - - private async Task SearchAsync(SearchRequest searchRequest) - { - using var content = new StringContent(JsonSerializer.Serialize(searchRequest), Encoding.UTF8, "application/json"); - var responseMessage = await this._httpClient.PostAsync(s_searchUri, content).ConfigureAwait(false); - responseMessage.EnsureSuccessStatusCode(); - var response = await responseMessage.Content.ReadAsStringAsync().ConfigureAwait(false); - var searchResponseItems = JsonSerializer.Deserialize(response); - return searchResponseItems?.Select(item => item.Memory).ToArray() ?? Array.Empty(); - } - - private async Task ClearMemoryAsync() - { - try - { - var querystringParams = new string?[3] { this._userId, this._agentId, this._scopeToThread ? this._threadId : null } - .Where(x => !string.IsNullOrWhiteSpace(x)) - .Select((param, index) => $"param{index}={param}"); - var queryString = string.Join("&", querystringParams); - var clearMemoryUrl = new Uri($"/memories?{queryString}", UriKind.Relative); - - var responseMessage = await this._httpClient.DeleteAsync(clearMemoryUrl).ConfigureAwait(false); - responseMessage.EnsureSuccessStatusCode(); - } - catch (Exception ex) - { - Console.WriteLine($"- MemZeroMemory - Error clearing memory: {ex.Message}"); - throw; - } - } - - private sealed class CreateMemoryRequest - { - [JsonPropertyName("agent_id")] - public string? AgentId { get; set; } - [JsonPropertyName("run_id")] - public string? RunId { get; set; } - [JsonPropertyName("user_id")] - public string? UserId { get; set; } - [JsonPropertyName("messages")] - public CreateMemoryMemory[] Messages { get; set; } = []; - } - - private sealed class CreateMemoryMemory - { - [JsonPropertyName("content")] - public string Content { get; set; } = string.Empty; - [JsonPropertyName("role")] - public string Role { get; set; } = string.Empty; - } - - private sealed class SearchRequest - { - [JsonPropertyName("agent_id")] - public string? AgentId { get; set; } = null; - [JsonPropertyName("run_id")] - public string? RunId { get; set; } = null; - [JsonPropertyName("user_id")] - public string? UserId { get; set; } = null; - [JsonPropertyName("query")] - public string Query { get; set; } = string.Empty; - } - -#pragma warning disable CA1812 // Avoid uninstantiated internal classes - private sealed class SearchResponseItem - { - [JsonPropertyName("id")] - public string Id { get; set; } = string.Empty; - [JsonPropertyName("memory")] - public string Memory { get; set; } = string.Empty; - [JsonPropertyName("hash")] - public string Hash { get; set; } = string.Empty; - [JsonPropertyName("metadata")] - public object? Metadata { get; set; } - [JsonPropertyName("score")] - public double Score { get; set; } - [JsonPropertyName("created_at")] - public DateTime CreatedAt { get; set; } - [JsonPropertyName("updated_at")] - public DateTime? UpdatedAt { get; set; } - [JsonPropertyName("user_id")] - public string UserId { get; set; } = string.Empty; - [JsonPropertyName("agent_id")] - public string AgentId { get; set; } = string.Empty; - [JsonPropertyName("run_id")] - public string RunId { get; set; } = string.Empty; - } -#pragma warning restore CA1812 // Avoid uninstantiated internal classes -} From d89e0583b6b09c671e64a6f1442aa2b9036f98b2 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 14 Apr 2025 14:15:11 +0100 Subject: [PATCH 15/46] Add support for suspend and resume --- .../Memory/Memory/UserFactsMemoryComponent.cs | 12 ++++++++ .../Memory/ConversationStateExtension.cs | 30 +++++++++++++++++++ .../ConversationStateExtensionsManager.cs | 30 +++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs index cbd734a4d565..ba4ed5319cf1 100644 --- a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -133,6 +133,18 @@ public override void RegisterPlugins(Kernel kernel) kernel.Plugins.AddFromObject(this, "UserFactsMemory"); } + /// + public override Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) + { + return this.OnThreadCreatedAsync(threadId, cancellationToken); + } + + /// + public override Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) + { + return this.OnThreadDeleteAsync(threadId, cancellationToken); + } + /// /// Plugin method to clear user facts stored in memory. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs index 0ac251d58508..fe4f9e8fc7b3 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs @@ -79,4 +79,34 @@ public virtual Task OnThreadDeleteAsync(string? threadId, CancellationToken canc public virtual void RegisterPlugins(Kernel kernel) { } + + /// + /// Called when the current conversion is temporarily suspended and any state should be saved. + /// + /// The ID of the current thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the end of each service call. + /// In a client application, this might be when the user closes the chat window or the application. + /// + public virtual Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + /// + /// Called when the current conversion is resumed and any state should be restored. + /// + /// The ID of the current thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the start of each service call where a previous conversation is being continued. + /// In a client application, this might be when the user re-opens the chat window to resume a conversation after having previously closed it. + /// + public virtual Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs index e07b411e0f04..ea510fa40b68 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs @@ -102,4 +102,34 @@ public virtual void RegisterPlugins(Kernel kernel) threadExtension.RegisterPlugins(kernel); } } + + /// + /// Called when the current conversion is temporarily suspended and any state should be saved. + /// + /// The ID of the current thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the end of each service call. + /// In a client application, this might be when the user closes the chat window or the application. + /// + public virtual async Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnSuspendAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + } + + /// + /// Called when the current conversion is resumed and any state should be restored. + /// + /// The ID of the current thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the start of each service call where a previous conversation is being continued. + /// In a client application, this might be when the user re-opens the chat window to resume a conversation after having previously closed it. + /// + public virtual async Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnResumeAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + } } From 4dccaa7b4f360e67af11f73d20fc54ac21571d71 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:39:28 +0100 Subject: [PATCH 16/46] Add support for loading extensions from DI. --- dotnet/src/Agents/Abstractions/AgentThread.cs | 13 ---- dotnet/src/Agents/Core/ChatCompletionAgent.cs | 4 +- .../src/Agents/OpenAI/OpenAIAssistantAgent.cs | 4 +- .../ChatCompletionAgentWithMemoryTests.cs | 68 ++++++++++++++++++- .../IntegrationTests/IntegrationTests.csproj | 1 + .../Memory/Memory/UserFactsMemoryComponent.cs | 2 +- .../ConversationStateExtensionsManager.cs | 18 ++++- 7 files changed, 88 insertions(+), 22 deletions(-) diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index e11aa7cdef19..ad12eecdcaae 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; @@ -87,18 +86,6 @@ public virtual async Task DeleteAsync(CancellationToken cancellationToken = defa this.IsDeleted = true; } - /// - /// Called just before the AI is invoked - /// - /// The most recent messages that the AI is being invoked with. - /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. - [Experimental("SKEXP0130")] - public virtual async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) - { - return await this.ThreadExtensionsManager.OnAIInvocationAsync(newMessages, cancellationToken).ConfigureAwait(false); - } - /// /// This method is called when a new message has been contributed to the chat by any participant. /// diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 4ec17c5d7a7e..d513b326b751 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -76,7 +76,7 @@ public override async IAsyncEnumerable> In // Get the conversation state extensions context contributions #pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await chatHistoryAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Invoke Chat Completion with the updated chat history. @@ -164,7 +164,7 @@ public override async IAsyncEnumerable> InvokeAsync // Get the conversation state extensions context contributions #pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await openAIAssistantAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var invokeResults = ActivityExtensions.RunWithActivityAsync( @@ -555,7 +555,7 @@ public async IAsyncEnumerable> In // Get the conversation state extensions context contributions #pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await openAIAssistantAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Create options that use the RunCreationOptions from the options param if provided or diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index 1c5282dee624..d267508a6ac5 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -8,12 +8,16 @@ using System.Threading.Tasks; using Azure.Identity; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.Agents.Memory; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Connectors.InMemory; +using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Memory; using Microsoft.SemanticKernel.Memory.TextRag; using SemanticKernel.IntegrationTests.TestSettings; @@ -91,6 +95,66 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() await this.Fixture.DeleteThread(agentThread2); } + [Fact] + public virtual async Task RegisterComponentsFromDIAsync() + { + var chatConfig = this._configuration.GetSection("AzureOpenAI").Get()!; + var embeddingConfig = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); + + // Arrange - Setup DI container. + var builder = Host.CreateEmptyApplicationBuilder(settings: null); + builder.Services.AddKernel(); + builder.Services.AddInMemoryVectorStore(); + builder.Services.AddAzureOpenAIChatCompletion( + deploymentName: chatConfig.ChatDeploymentName!, + endpoint: chatConfig.Endpoint, + credentials: new AzureCliCredential()); + builder.Services.AddAzureOpenAITextEmbeddingGeneration( + embeddingConfig!.EmbeddingModelId, + embeddingConfig.Endpoint, + new AzureCliCredential()); + builder.Services.AddKeyedTransient>("UserFactsStore", (sp, _) => new VectorDataTextMemoryStore( + sp.GetRequiredService(), + sp.GetRequiredService(), + "Memories", "user/12345", 1536)); + builder.Services.AddTransient(); + builder.Services.AddTransient((sp) => + { + var thread = new ChatHistoryAgentThread(); + thread.ThreadExtensionsManager.RegisterThreadExtensionsFromContainer(sp); + return thread; + }); + var host = builder.Build(); + + // Arrange - Create agent. + var agent = new ChatCompletionAgent() + { + Kernel = host.Services.GetRequiredService(), + Instructions = "You are a helpful assistant.", + }; + + // Act - First invocation + var agentThread1 = host.Services.GetRequiredService(); + + var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); + var results1 = await asyncResults1.ToListAsync(); + + await agentThread1.ThreadExtensionsManager.OnSuspendAsync(null, default); + + // Act - Second invocation + var agentThread2 = host.Services.GetRequiredService(); + + var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); + var results2 = await asyncResults2.ToListAsync(); + + // Assert + Assert.Contains("Caoimhe", results2.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread1); + await this.Fixture.DeleteThread(agentThread2); + } + [Fact] public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserInputAsync() { @@ -134,7 +198,7 @@ public virtual async Task RagComponentWithoutMatchesAsync() // Arrange - Create Vector Store and Rag Store/Component var vectorStore = new InMemoryVectorStore(); - using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "Memories", 1536, "group/g1"); + using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "FinancialData", 1536, "group/g1"); var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); // Arrange - Upsert documents into the Rag Store @@ -167,7 +231,7 @@ public virtual async Task RagComponentWithMatchesAsync() // Arrange - Create Vector Store and Rag Store/Component var vectorStore = new InMemoryVectorStore(); - using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "Memories", 1536, "group/g2"); + using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "FinancialData", 1536, "group/g2"); var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); // Arrange - Upsert documents into the Rag Store diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index e1bff6844998..356b31fb3581 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -48,6 +48,7 @@ + diff --git a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs index ba4ed5319cf1..4430caabb8a1 100644 --- a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -82,7 +82,7 @@ EXAMPLES END Return output for the following inputs like shown in the examples above: Input text: {{$inputText}} - Input facts: {{existingFacts}} + Input facts: {{$existingFacts}} """; /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs index ea510fa40b68..74e885297859 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs @@ -1,10 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Microsoft.SemanticKernel.Memory; @@ -38,14 +40,26 @@ public ConversationStateExtensionsManager(IEnumerable - /// Registers a new conversation state extensions. + /// Registers a new conversation state extension. /// - /// The conversation state extensions to register. + /// The conversation state extension to register. public virtual void RegisterThreadExtension(ConversationStateExtension conversationtStateExtension) { this._conversationStateExtensions.Add(conversationtStateExtension); } + /// + /// Registers all conversation state extensions registered on the provided dependency injection service provider. + /// + /// The dependency injection service provider to read conversation state extensions from. + public virtual void RegisterThreadExtensionsFromContainer(IServiceProvider serviceProvider) + { + foreach (var extension in serviceProvider.GetServices()) + { + this.RegisterThreadExtension(extension); + } + } + /// /// Called when a new thread is created. /// From d78f987ea5da5af9acb0789c2caaea4fcf761b9f Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:52:28 +0100 Subject: [PATCH 17/46] Simplify suspend/resume --- dotnet/src/Agents/Abstractions/AgentThread.cs | 40 ++++++++++ .../ChatCompletionAgentWithMemoryTests.cs | 73 ++++++++++--------- 2 files changed, 77 insertions(+), 36 deletions(-) diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index ad12eecdcaae..6f1792aa2289 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -34,6 +34,46 @@ public abstract class AgentThread [Experimental("SKEXP0130")] public virtual ConversationStateExtensionsManager ThreadExtensionsManager { get; init; } = new ConversationStateExtensionsManager(); + /// + /// Called when the current conversion is temporarily suspended and any state should be saved. + /// + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the end of each service call. + /// In a client application, this might be when the user closes the chat window or the application. + /// + [Experimental("SKEXP0130")] + public virtual Task OnSuspendAsync(CancellationToken cancellationToken = default) + { + return this.ThreadExtensionsManager.OnSuspendAsync(this.Id, cancellationToken); + } + + /// + /// Called when the current conversion is resumed and any state should be restored. + /// + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the start of each service call where a previous conversation is being continued. + /// In a client application, this might be when the user re-opens the chat window to resume a conversation after having previously closed it. + /// + [Experimental("SKEXP0130")] + public virtual Task OnResumeAsync(CancellationToken cancellationToken = default) + { + if (this.IsDeleted) + { + throw new InvalidOperationException("This thread has been deleted and cannot be used anymore."); + } + + if (this.Id is not null) + { + throw new InvalidOperationException("This thread cannot be resumed, since it has not been created."); + } + + return this.ThreadExtensionsManager.OnSuspendAsync(this.Id, cancellationToken); + } + /// /// Creates the thread and returns the thread id. /// diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index d267508a6ac5..230a5c5e69a5 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -96,7 +96,41 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() } [Fact] - public virtual async Task RegisterComponentsFromDIAsync() + public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserInputAsync() + { + // Arrange + var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); + + var vectorStore = new InMemoryVectorStore(); + var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); + using var textMemoryStore = new VectorDataTextMemoryStore(vectorStore, textEmbeddingService, "Memories", "user/12345", 1536); + + var agent = this.Fixture.Agent; + + // Act - First invocation with first thread. + var agentThread1 = new ChatHistoryAgentThread(); + agentThread1.ThreadExtensionsManager.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + + var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); + var results1 = await asyncResults1.ToListAsync(); + + // Act - Second invocation with second thread. + var agentThread2 = new ChatHistoryAgentThread(); + agentThread2.ThreadExtensionsManager.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + + var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); + var results2 = await asyncResults2.ToListAsync(); + + // Assert + Assert.Contains("Caoimhe", results2.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread1); + await this.Fixture.DeleteThread(agentThread2); + } + + [Fact] + public virtual async Task CapturesMemoriesWhileUsingDIAsync() { var chatConfig = this._configuration.GetSection("AzureOpenAI").Get()!; var embeddingConfig = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); @@ -139,7 +173,8 @@ public virtual async Task RegisterComponentsFromDIAsync() var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); var results1 = await asyncResults1.ToListAsync(); - await agentThread1.ThreadExtensionsManager.OnSuspendAsync(null, default); + // Act - Call suspend on the thread, so that all memory components attached to it, save their state. + await agentThread1.OnSuspendAsync(default); // Act - Second invocation var agentThread2 = host.Services.GetRequiredService(); @@ -155,40 +190,6 @@ public virtual async Task RegisterComponentsFromDIAsync() await this.Fixture.DeleteThread(agentThread2); } - [Fact] - public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserInputAsync() - { - // Arrange - var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); - - var vectorStore = new InMemoryVectorStore(); - var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); - using var textMemoryStore = new VectorDataTextMemoryStore(vectorStore, textEmbeddingService, "Memories", "user/12345", 1536); - - var agent = this.Fixture.Agent; - - // Act - First invocation with first thread. - var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.ThreadExtensionsManager.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); - - var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); - var results1 = await asyncResults1.ToListAsync(); - - // Act - Second invocation with second thread. - var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.ThreadExtensionsManager.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); - - var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); - var results2 = await asyncResults2.ToListAsync(); - - // Assert - Assert.Contains("Caoimhe", results2.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread1); - await this.Fixture.DeleteThread(agentThread2); - } - [Fact] public virtual async Task RagComponentWithoutMatchesAsync() { From d60aa00f4b00b2483c5d1cb5bbe4c7eded928073 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:59:26 +0100 Subject: [PATCH 18/46] Update ADR with onsuspend and onresume and decisions to make list. --- docs/decisions/00NN-agents-with-memory.md | 27 +++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/decisions/00NN-agents-with-memory.md b/docs/decisions/00NN-agents-with-memory.md index 5fe2f0399d72..6f6905974bac 100644 --- a/docs/decisions/00NN-agents-with-memory.md +++ b/docs/decisions/00NN-agents-with-memory.md @@ -73,6 +73,9 @@ that conversation state needs to be loaded and persisted for each invocation of It also means that any memory components that may have some in-memory state will need to be loaded and persisted too. +For cases like this, the `OnSuspend` and `OnResume` methods allow notification of the components that they need to save or reload their state. +It is up to each of these components to decide how and where to save state to or load state from. + ## Proposed interface for Memory Components The types of events that Memory Components require are not unique to memory, and can be used to package up other capabilities too. @@ -94,6 +97,9 @@ public abstract class ConversationStateExtension public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); public virtual void RegisterPlugins(Kernel kernel); + + public virtual Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default); + public virtual Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default); } ``` @@ -168,3 +174,24 @@ agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); // Expected response contains the 174M income from the document. ``` + +## Decisions to make + +### Extension base class name + +1. ConversationStateExtension + 1. Long +1. MemoryComponent + 1. Too specific + +### Location for abstractions + +1. Microsoft.SemanticKernel. +1. Microsoft.SemanticKernel.Memory. +1. Microsoft.SemanticKernel.Memory. (in separate nuget) + +### Location for memory components + +1. A nuget for each component +1. Microsoft.SemanticKernel.Core nuget +1. Microsoft.SemanticKernel.Memory nuget From 699d7acd690523cb0db34851225b2deaad309bac Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 14 Apr 2025 19:28:33 +0100 Subject: [PATCH 19/46] Allow RAG via plugin for TextRagComponent --- dotnet/src/Agents/Core/ChatCompletionAgent.cs | 14 ++++-- .../src/Agents/OpenAI/OpenAIAssistantAgent.cs | 14 ++++-- .../ChatCompletionAgentWithMemoryTests.cs | 39 ++++++++++++++++ .../Memory/Memory/TextRag/TextRagComponent.cs | 46 ++++++++++++++++++- .../Memory/TextRag/TextRagComponentOptions.cs | 34 ++++++++++++++ 5 files changed, 137 insertions(+), 10 deletions(-) diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index d513b326b751..8fa2461c4bdc 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -74,9 +74,12 @@ public override async IAsyncEnumerable> In () => new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); - // Get the conversation state extensions context contributions + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + chatHistoryAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); #pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Invoke Chat Completion with the updated chat history. @@ -97,7 +100,7 @@ public override async IAsyncEnumerable> In } }, options?.KernelArguments, - options?.Kernel, + kernel, options?.AdditionalInstructions == null ? extensionsContext : options.AdditionalInstructions + Environment.NewLine + Environment.NewLine + extensionsContext, cancellationToken); @@ -162,9 +165,12 @@ public override async IAsyncEnumerable new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); - // Get the conversation state extensions context contributions + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + chatHistoryAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); #pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Invoke Chat Completion with the updated chat history. @@ -186,7 +192,7 @@ public override async IAsyncEnumerable> InvokeAsync AdditionalInstructions = options?.AdditionalInstructions, }); - // Get the conversation state extensions context contributions + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + openAIAssistantAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); #pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var invokeResults = ActivityExtensions.RunWithActivityAsync( @@ -430,7 +433,7 @@ async IAsyncEnumerable InternalInvokeAsync() internalOptions, extensionsContext, this.Logger, - options?.Kernel ?? this.Kernel, + kernel, options?.KernelArguments, cancellationToken).ConfigureAwait(false)) { @@ -553,9 +556,12 @@ public async IAsyncEnumerable> In () => new OpenAIAssistantAgentThread(this.Client), cancellationToken).ConfigureAwait(false); - // Get the conversation state extensions context contributions + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + openAIAssistantAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); #pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Create options that use the RunCreationOptions from the options param if provided or @@ -584,7 +590,7 @@ IAsyncEnumerable InternalInvokeStreamingAsync() internalOptions, extensionsContext, this.Logger, - options?.Kernel ?? this.Kernel, + kernel, options?.KernelArguments, cancellationToken); } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index 230a5c5e69a5..23ee3f5fd2ab 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -255,6 +255,45 @@ public virtual async Task RagComponentWithMatchesAsync() await this.Fixture.DeleteThread(agentThread); } + [Fact] + public virtual async Task RagComponentWithMatchesOnDemandAsync() + { + // Arrange - Create Embedding Service + var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); + var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); + + // Arrange - Create Vector Store and Rag Store/Component + var vectorStore = new InMemoryVectorStore(); + using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "FinancialData", 1536, "group/g2"); + var ragComponent = new TextRagComponent( + ragStore, + new() + { + SearchTime = TextRagComponentOptions.TextRagSearchTime.ViaPlugin, + PluginSearchFunctionName = "SearchCorporateData", + PluginSearchFunctionDescription = "RAG Search over dataset containing financial data and company information about various companies." + }); + + // Arrange - Upsert documents into the Rag Store + await ragStore.UpsertDocumentsAsync(GetSampleDocuments()); + + var agent = this.Fixture.Agent; + + // Act - Create a new agent thread and register the Rag component + var agentThread = new ChatHistoryAgentThread(); + agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + + // Act - Invoke the agent with a question + var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread, new() { KernelArguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })}); + var results1 = await asyncResults1.ToListAsync(); + + // Assert - Check if the response contains the expected value from the database. + Assert.Contains("174", results1.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + private static IEnumerable GetSampleDocuments() { yield return new TextRagDocument("The financial results of Contoso Corp for 2024 is as follows:\nIncome EUR 154 000 000\nExpenses EUR 142 000 000") diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs index d28adc58ce0d..0eecb4bf9e19 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs +++ b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel.Data; @@ -23,12 +24,12 @@ public class TextRagComponent : ConversationStateExtension /// The text search component to retrieve results from. /// /// - public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions options) + public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions? options = default) { Verify.NotNull(textSearch); this._textSearch = textSearch; - this.Options = options ?? throw new ArgumentNullException(nameof(options)); + this.Options = options ?? new(); } /// @@ -39,6 +40,11 @@ public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions options) /// public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { + if (this.Options.SearchTime != TextRagComponentOptions.TextRagSearchTime.BeforeAIInvoke) + { + return string.Empty; + } + Verify.NotNull(newMessages); string input = string.Join("\n", newMessages.Where(m => m is not null).Select(m => m.Content)); @@ -64,4 +70,40 @@ public override async Task OnAIInvocationAsync(ICollection + public override void RegisterPlugins(Kernel kernel) + { + if (this.Options.SearchTime != TextRagComponentOptions.TextRagSearchTime.ViaPlugin) + { + return; + } + + Verify.NotNull(kernel); + + KernelFunctionFactory.CreateFromMethod( + typeof(TextRagComponent).GetMethod(nameof(SearchAsync))!, + target: this, + functionName: this.Options.PluginSearchFunctionName ?? "Search", + description: this.Options.PluginSearchFunctionDescription ?? "Allows searching for additional information to help answer the user question."); + + base.RegisterPlugins(kernel); + kernel.Plugins.AddFromObject(this, "UserFactsMemory"); + } + + /// + /// Plugin method to search the database on demand. + /// + [KernelFunction] + public async Task SearchAsync(string userQuestion, CancellationToken cancellationToken = default) + { + var searchResults = await this._textSearch.GetTextSearchResultsAsync( + userQuestion, + new() { Top = this.Options.Top }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + var results = await searchResults.Results.ToListAsync(cancellationToken).ConfigureAwait(false); + + return JsonSerializer.Serialize(results); + } } diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs b/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs index 146042c2a0b2..7b13bb9b0ff6 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs +++ b/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs @@ -29,4 +29,38 @@ public int Top this._top = value; } } + + /// + /// Gets or sets the time at which the text search is performed. + /// + public TextRagSearchTime SearchTime { get; init; } = TextRagSearchTime.BeforeAIInvoke; + + /// + /// Gets or sets the name of the plugin method that will be made available for searching + /// if the option is set to . + /// + public string? PluginSearchFunctionName { get; init; } + + /// + /// Gets or sets the description of the plugin method that will be made available for searching + /// if the option is set to . + /// + public string? PluginSearchFunctionDescription { get; init; } + + /// + /// The time at which the text search is performed. + /// + public enum TextRagSearchTime + { + /// + /// A seach is performed each time that the AI is invoked just before the AI is invoked + /// and the results are provided to the AI via the invocation context. + /// + BeforeAIInvoke, + + /// + /// A search may be performed by the AI on demand via a plugin. + /// + ViaPlugin + } } From 32562c2b4c8607b4b50f96c12bbbeaa15c6d9ed4 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 17 Apr 2025 14:20:06 +0100 Subject: [PATCH 20/46] Update ConversationStateExtensions to use MEAI types. --- docs/decisions/00NN-agents-with-memory.md | 19 +++++-- dotnet/src/Agents/Abstractions/AgentThread.cs | 1 - .../Memory/Memory/Mem0/Mem0MemoryComponent.cs | 37 +++++++------- dotnet/src/Memory/Memory/Memory.csproj | 1 + .../Memory/Memory/TextRag/TextRagComponent.cs | 49 ++++++++++--------- .../Memory/Memory/UserFactsMemoryComponent.cs | 31 ++++++------ .../Memory/ConversationStateExtension.cs | 22 ++++----- .../ConversationStateExtensionsManager.cs | 40 +++++++++------ ...rsationStateExtensionsManagerExtensions.cs | 41 ++++++++++++++++ ...rsationStateExtensionsManagerExtensions.cs | 23 +++++++++ 10 files changed, 174 insertions(+), 90 deletions(-) create mode 100644 dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs create mode 100644 dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs diff --git a/docs/decisions/00NN-agents-with-memory.md b/docs/decisions/00NN-agents-with-memory.md index 6f6905974bac..bc659d03b8b4 100644 --- a/docs/decisions/00NN-agents-with-memory.md +++ b/docs/decisions/00NN-agents-with-memory.md @@ -85,18 +85,20 @@ be used for non-agent scenarios too. This type should live in the `Microsoft.SemanticKernel.Abstractions` nuget, since these components can be used by systems other than just agents. ```csharp -namespace Microsoft.SemanticKernel.Memory; +namespace Microsoft.SemanticKernel; public abstract class ConversationStateExtension { + public virtual IReadOnlyCollection AIFunctions => Array.Empty(); + public virtual Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default); - public virtual Task OnThreadCheckpointAsync(string threadId, CancellationToken cancellationToken = default); public virtual Task OnThreadDeleteAsync(string? threadId, CancellationToken cancellationToken = default); - public virtual Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default); - public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); + // OnThreadCheckpointAsync not included in initial release, maybe in future. + public virtual Task OnThreadCheckpointAsync(string? threadId, CancellationToken cancellationToken = default); - public virtual void RegisterPlugins(Kernel kernel); + public virtual Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default); + public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); public virtual Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default); public virtual Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default); @@ -184,14 +186,21 @@ var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", 1. MemoryComponent 1. Too specific +Chose ConversationStateExtension. + ### Location for abstractions 1. Microsoft.SemanticKernel. 1. Microsoft.SemanticKernel.Memory. 1. Microsoft.SemanticKernel.Memory. (in separate nuget) +Chose Microsoft.SemanticKernel.. + ### Location for memory components 1. A nuget for each component 1. Microsoft.SemanticKernel.Core nuget 1. Microsoft.SemanticKernel.Memory nuget +1. Microsoft.SemanticKernel.ConversationStateExtensions nuget + +Chose Microsoft.SemanticKernel.Core nuget \ No newline at end of file diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index 6f1792aa2289..46d8ac33c4ea 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -4,7 +4,6 @@ using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; -using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Agents; diff --git a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs index 96bdd5373a5b..33f30d739d33 100644 --- a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs @@ -6,7 +6,7 @@ using System.Net.Http; using System.Threading; using System.Threading.Tasks; -using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.Memory; @@ -22,6 +22,8 @@ public class Mem0MemoryComponent : ConversationStateExtension private readonly string? _userId; private readonly bool _scopeToThread; + private readonly AIFunction[] _aIFunctions; + private readonly Mem0Client _mem0Client; /// @@ -39,9 +41,14 @@ public Mem0MemoryComponent(HttpClient httpClient, Mem0MemoryComponentOptions? op this._userId = options?.UserId; this._scopeToThread = options?.ScopeToThread ?? false; + this._aIFunctions = [AIFunctionFactory.Create(this.ClearStoredUserFactsAsync)]; + this._mem0Client = new(httpClient); } + /// + public override IReadOnlyCollection AIFunctions => this._aIFunctions; + /// public override Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { @@ -50,32 +57,32 @@ public override Task OnThreadCreatedAsync(string? threadId, CancellationToken ca } /// - public override async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public override async Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) { Verify.NotNull(newMessage); - if (newMessage.Role == AuthorRole.User && !string.IsNullOrWhiteSpace(newMessage.Content)) + if (newMessage.Role == ChatRole.User && !string.IsNullOrWhiteSpace(newMessage.Text)) { await this._mem0Client.CreateMemoryAsync( this._applicationId, this._agentId, this._scopeToThread ? this._threadId : null, this._userId, - newMessage.Content, - newMessage.Role.Label).ConfigureAwait(false); + newMessage.Text, + newMessage.Role.Value).ConfigureAwait(false); } } /// - public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { Verify.NotNull(newMessages); string inputText = string.Join( "\n", newMessages. - Where(m => m is not null && !string.IsNullOrWhiteSpace(m.Content)). - Select(m => m.Content)); + Where(m => m is not null && !string.IsNullOrWhiteSpace(m.Text)). + Select(m => m.Text)); var memories = await this._mem0Client.SearchAsync( this._applicationId, @@ -88,22 +95,12 @@ public override async Task OnAIInvocationAsync(ICollection - public override void RegisterPlugins(Kernel kernel) - { - Verify.NotNull(kernel); - - base.RegisterPlugins(kernel); - kernel.Plugins.AddFromObject(this, "MemZeroMemory"); - } - /// /// Plugin method to clear user preferences stored in memory for the current agent/thread/user. /// /// A task that completes when the memory is cleared. - [KernelFunction] - [Description("Deletes any user preferences stored about the user.")] - public async Task ClearUserPreferencesAsync() + [Description("Deletes any user facts that are stored across multiple conversations.")] + public async Task ClearStoredUserFactsAsync() { await this._mem0Client.ClearMemoryAsync( this._applicationId, diff --git a/dotnet/src/Memory/Memory/Memory.csproj b/dotnet/src/Memory/Memory/Memory.csproj index 196cf06c06a7..9654db9ab1d5 100644 --- a/dotnet/src/Memory/Memory/Memory.csproj +++ b/dotnet/src/Memory/Memory/Memory.csproj @@ -26,6 +26,7 @@ + diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs index 0eecb4bf9e19..17d149ef6f3c 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs +++ b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs @@ -7,6 +7,7 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.Data; namespace Microsoft.SemanticKernel.Memory; @@ -18,6 +19,8 @@ public class TextRagComponent : ConversationStateExtension { private readonly ITextSearch _textSearch; + private readonly AIFunction[] _aIFunctions; + /// /// Initializes a new instance of the class. /// @@ -30,6 +33,14 @@ public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions? options this._textSearch = textSearch; this.Options = options ?? new(); + + this._aIFunctions = + [ + AIFunctionFactory.Create( + this.SearchAsync, + name: this.Options.PluginSearchFunctionName ?? "Search", + description: this.Options.PluginSearchFunctionDescription ?? "Allows searching for additional information to help answer the user question.") + ]; } /// @@ -38,7 +49,21 @@ public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions? options public TextRagComponentOptions Options { get; } /// - public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public override IReadOnlyCollection AIFunctions + { + get + { + if (this.Options.SearchTime != TextRagComponentOptions.TextRagSearchTime.ViaPlugin) + { + return Array.Empty(); + } + + return this._aIFunctions; + } + } + + /// + public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { if (this.Options.SearchTime != TextRagComponentOptions.TextRagSearchTime.BeforeAIInvoke) { @@ -47,7 +72,7 @@ public override async Task OnAIInvocationAsync(ICollection m is not null).Select(m => m.Content)); + string input = string.Join("\n", newMessages.Where(m => m is not null).Select(m => m.Text)); var searchResults = await this._textSearch.GetTextSearchResultsAsync( input, @@ -71,26 +96,6 @@ public override async Task OnAIInvocationAsync(ICollection - public override void RegisterPlugins(Kernel kernel) - { - if (this.Options.SearchTime != TextRagComponentOptions.TextRagSearchTime.ViaPlugin) - { - return; - } - - Verify.NotNull(kernel); - - KernelFunctionFactory.CreateFromMethod( - typeof(TextRagComponent).GetMethod(nameof(SearchAsync))!, - target: this, - functionName: this.Options.PluginSearchFunctionName ?? "Search", - description: this.Options.PluginSearchFunctionDescription ?? "Allows searching for additional information to help answer the user question."); - - base.RegisterPlugins(kernel); - kernel.Plugins.AddFromObject(this, "UserFactsMemory"); - } - /// /// Plugin method to search the database on demand. /// diff --git a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs index 4430caabb8a1..ff060d3ad334 100644 --- a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -4,7 +4,7 @@ using System.ComponentModel; using System.Threading; using System.Threading.Tasks; -using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Agents.Memory; @@ -20,6 +20,8 @@ public class UserFactsMemoryComponent : ConversationStateExtension private string _userFacts = string.Empty; private bool _contextLoaded = false; + private readonly AIFunction[] _aIFunctions; + /// /// Initializes a new instance of the class. /// @@ -29,6 +31,8 @@ public UserFactsMemoryComponent(Kernel kernel, TextMemoryStore textMemoryStore) { this._kernel = kernel; this._textMemoryStore = textMemoryStore; + + this._aIFunctions = [AIFunctionFactory.Create(this.ClearUserFactsAsync)]; } /// @@ -40,8 +44,13 @@ public UserFactsMemoryComponent(Kernel kernel, string? userFactsStoreName = "Use { this._kernel = kernel; this._textMemoryStore = new OptionalTextMemoryStore(kernel, userFactsStoreName); + + this._aIFunctions = [AIFunctionFactory.Create(this.ClearUserFactsAsync)]; } + /// + public override IReadOnlyCollection AIFunctions => this._aIFunctions; + /// /// Gets or sets the name of the document to use for storing user preferfactsences. /// @@ -109,30 +118,21 @@ public override async Task OnThreadDeleteAsync(string? threadId, CancellationTok } /// - public override async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public override async Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) { - if (newMessage.Role == AuthorRole.User && !string.IsNullOrWhiteSpace(newMessage.Content)) + if (newMessage.Role == ChatRole.User && !string.IsNullOrWhiteSpace(newMessage.Text)) { // Don't wait for task to complete. Just run in the background. - await this.ExtractAndSaveMemoriesAsync(newMessage.Content, cancellationToken).ConfigureAwait(false); + await this.ExtractAndSaveMemoriesAsync(newMessage.Text, cancellationToken).ConfigureAwait(false); } } /// - public override Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public override Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { return Task.FromResult("The following list contains facts about the user:\n" + this._userFacts); } - /// - public override void RegisterPlugins(Kernel kernel) - { - Verify.NotNull(kernel); - - base.RegisterPlugins(kernel); - kernel.Plugins.AddFromObject(this, "UserFactsMemory"); - } - /// public override Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) { @@ -148,8 +148,7 @@ public override Task OnSuspendAsync(string? threadId, CancellationToken cancella /// /// Plugin method to clear user facts stored in memory. /// - [KernelFunction] - [Description("Deletes any user facts stored about the user.")] + [Description("Deletes any user facts that are stored acros multiple conversations.")] public async Task ClearUserFactsAsync(CancellationToken cancellationToken = default) { this._userFacts = string.Empty; diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs index fe4f9e8fc7b3..1266bf1ebb8b 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs @@ -1,11 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; -namespace Microsoft.SemanticKernel.Memory; +namespace Microsoft.SemanticKernel; /// /// Base class for all conversation state extensions. @@ -18,6 +20,12 @@ namespace Microsoft.SemanticKernel.Memory; [Experimental("SKEXP0130")] public abstract class ConversationStateExtension { + /// + /// Gets the list of AI functions that this extension component exposes + /// and which should be used by the consuming AI when using this component. + /// + public virtual IReadOnlyCollection AIFunctions => Array.Empty(); + /// /// Called just after a new thread is created. /// @@ -42,7 +50,7 @@ public virtual Task OnThreadCreatedAsync(string? threadId, CancellationToken can /// The new message. /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been updated. - public virtual Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public virtual Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) { return Task.CompletedTask; } @@ -70,15 +78,7 @@ public virtual Task OnThreadDeleteAsync(string? threadId, CancellationToken canc /// The most recent messages that the AI is being invoked with. /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been rendered and returned. - public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); - - /// - /// Register plugins required by this extension component on the provided . - /// - /// The kernel to register the plugins on. - public virtual void RegisterPlugins(Kernel kernel) - { - } + public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); /// /// Called when the current conversion is temporarily suspended and any state should be saved. diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs index 74e885297859..efe53bc4e6ad 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs @@ -6,9 +6,10 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -namespace Microsoft.SemanticKernel.Memory; +namespace Microsoft.SemanticKernel; /// /// A container class for objects that manages their lifecycle and interactions. @@ -18,6 +19,8 @@ public class ConversationStateExtensionsManager { private readonly List _conversationStateExtensions = new(); + private List? _currentAIFunctions = null; + /// /// Gets the list of registered conversation state extensions. /// @@ -30,6 +33,23 @@ public ConversationStateExtensionsManager() { } + /// + /// Gets the list of AI functions that all contained extension component expose + /// and which should be used by the consuming AI when using these components. + /// + public virtual IReadOnlyCollection AIFunctions + { + get + { + if (this._currentAIFunctions == null) + { + this._currentAIFunctions = this.ConversationStateExtensions.SelectMany(ConversationStateExtensions => ConversationStateExtensions.AIFunctions).ToList(); + } + + return this._currentAIFunctions; + } + } + /// /// Initializes a new instance of the class with the specified conversation state extensions. /// @@ -46,6 +66,7 @@ public ConversationStateExtensionsManager(IEnumerable @@ -58,6 +79,7 @@ public virtual void RegisterThreadExtensionsFromContainer(IServiceProvider servi { this.RegisterThreadExtension(extension); } + this._currentAIFunctions = null; } /// @@ -88,7 +110,7 @@ public virtual async Task OnThreadDeleteAsync(string threadId, CancellationToken /// The new message. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. - public virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public virtual async Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) { await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnNewMessageAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); } @@ -99,24 +121,12 @@ public virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Cance /// The most recent messages that the AI is being invoked with. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. - public virtual async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public virtual async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { var subContexts = await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnAIInvocationAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); return string.Join("\n", subContexts); } - /// - /// Registers plugins required by all conversation state extensions contained by this manager on the provided . - /// - /// The kernel to register the plugins on. - public virtual void RegisterPlugins(Kernel kernel) - { - foreach (var threadExtension in this.ConversationStateExtensions) - { - threadExtension.RegisterPlugins(kernel); - } - } - /// /// Called when the current conversion is temporarily suspended and any state should be saved. /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs new file mode 100644 index 000000000000..bd282a9bf963 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for . +/// +[Experimental("SKEXP0130")] +public static class ConversationStateExtensionsManagerExtensions +{ + /// + /// This method is called when a new message has been contributed to the chat by any participant. + /// + /// The conversation state manager to pass the new message to. + /// The new message. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public static Task OnNewMessageAsync(this ConversationStateExtensionsManager conversationStateExtensionsManager, ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + return conversationStateExtensionsManager.OnNewMessageAsync(ChatCompletionServiceExtensions.ToChatMessage(newMessage), cancellationToken); + } + + /// + /// Called just before the AI is invoked + /// + /// The conversation state manager to call. + /// The most recent messages that the AI is being invoked with. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. + public static Task OnAIInvocationAsync(this ConversationStateExtensionsManager conversationStateExtensionsManager, ICollection newMessages, CancellationToken cancellationToken = default) + { + return conversationStateExtensionsManager.OnAIInvocationAsync(newMessages.Select(ChatCompletionServiceExtensions.ToChatMessage).ToList(), cancellationToken); + } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs b/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs new file mode 100644 index 000000000000..ceb2e37f8a27 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for . +/// +[Experimental("SKEXP0130")] +public static class ConversationStateExtensionsManagerExtensions +{ + /// + /// Registers plugins required by all conversation state extensions contained by this manager on the provided . + /// + /// The conversation state manager to get plugins from. + /// The kernel to register the plugins on. + public static void RegisterPlugins(this ConversationStateExtensionsManager conversationStateExtensionsManager, Kernel kernel) + { + kernel.Plugins.AddFromFunctions("Tools", conversationStateExtensionsManager.AIFunctions.Select(x => x.AsKernelFunction())); + } +} From 059e4b4494efff1cbacd39158b4225ba70cae54f Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 17 Apr 2025 14:25:55 +0100 Subject: [PATCH 21/46] Fix typos --- dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs | 2 +- dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs | 2 +- dotnet/src/Memory/Memory/TextRag/TextRagStore.cs | 2 +- dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs | 4 ++-- dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs | 2 +- .../Memory/ConversationStateExtension.cs | 4 ++-- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs index 33f30d739d33..da3cc7689b55 100644 --- a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Memory; /// -/// A component that listenes to messages added to the conversation thread, and automatically captures +/// A component that listens to messages added to the conversation thread, and automatically captures /// information about the user. It is also able to retrieve this information and add it to the AI invocation context. /// public class Mem0MemoryComponent : ConversationStateExtension diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs b/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs index 7b13bb9b0ff6..275074cd60b1 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs +++ b/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs @@ -53,7 +53,7 @@ public int Top public enum TextRagSearchTime { /// - /// A seach is performed each time that the AI is invoked just before the AI is invoked + /// A search is performed each time that the AI is invoked just before the AI is invoked /// and the results are provided to the AI via the invocation context. /// BeforeAIInvoke, diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs b/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs index 1931c004c6a6..f713658b492f 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs +++ b/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs @@ -36,7 +36,7 @@ public class TextRagStore : ITextSearch, IDisposable /// The vector store to store and read the memories from. /// The service to use for generating embeddings for the memories. /// The name of the collection in the vector store to store and read the memories from. - /// The number of dimentions to use for the memory embeddings. + /// The number of dimensions to use for the memory embeddings. /// An optional namespace to filter search results to. /// Thrown if the key type provided is not supported. public TextRagStore(IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService, string collectionName, int vectorDimensions, string? searchNamespace) diff --git a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs index ff060d3ad334..1d56b5d2ffb6 100644 --- a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -59,7 +59,7 @@ public UserFactsMemoryComponent(Kernel kernel, string? userFactsStoreName = "Use /// /// Gets or sets the prompt template to use for extracting user facts and merging them with existing facts. /// - public string MaintainencePromptTemplate { get; init; } = + public string MaintenancePromptTemplate { get; init; } = """ You are an expert in extracting facts about a user from text and combining these facts with existing facts to output a new list of facts. Facts are short statements that each contain a single piece of information. @@ -158,7 +158,7 @@ public async Task ClearUserFactsAsync(CancellationToken cancellationToken = defa private async Task ExtractAndSaveMemoriesAsync(string inputText, CancellationToken cancellationToken = default) { var result = await this._kernel.InvokePromptAsync( - this.MaintainencePromptTemplate, + this.MaintenancePromptTemplate, new KernelArguments() { ["inputText"] = inputText, ["existingFacts"] = this._userFacts }, cancellationToken: cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs b/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs index 594fde5e8605..fa98021dbba9 100644 --- a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs +++ b/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs @@ -36,7 +36,7 @@ public class VectorDataTextMemoryStore : TextMemoryStore, IDisposable /// The service to use for generating embeddings for the memories. /// The name of the collection in the vector store to store and read the memories from. /// The namespace to scope memories to within the collection. - /// The number of dimentions to use for the memory embeddings. + /// The number of dimensions to use for the memory embeddings. /// Thrown if the key type provided is not supported. public VectorDataTextMemoryStore(IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService, string collectionName, string storageNamespace, int vectorDimensions) { diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs index 1266bf1ebb8b..77e88dfb9cdd 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs @@ -31,7 +31,7 @@ public abstract class ConversationStateExtension /// /// /// Implementers can use this method to do any operations required at the creation of a new thread. - /// For exmple, checking long term storage for any data that is relevant to the current session based on the input text. + /// For example, checking long term storage for any data that is relevant to the current session based on the input text. /// /// The ID of the new thread, if the thread has an ID. /// The to monitor for cancellation requests. The default is . @@ -60,7 +60,7 @@ public virtual Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken /// /// /// Implementers can use this method to do any operations required before a thread is deleted. - /// For exmple, storing the context to long term storage. + /// For example, storing the context to long term storage. /// /// The ID of the thread that will be deleted, if the thread has an ID. /// The to monitor for cancellation requests. The default is . From 8ad185c3f504abb76c70689d6354eba23d890310 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 17 Apr 2025 14:27:44 +0100 Subject: [PATCH 22/46] Fix typos --- docs/decisions/00NN-agents-with-memory.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/decisions/00NN-agents-with-memory.md b/docs/decisions/00NN-agents-with-memory.md index bc659d03b8b4..dc08e9a338ef 100644 --- a/docs/decisions/00NN-agents-with-memory.md +++ b/docs/decisions/00NN-agents-with-memory.md @@ -17,7 +17,7 @@ a conversation and re-use those later in the same conversation or later in a sub ## Context and Problem Statement -Today we support multiple agent types with different characteristcs: +Today we support multiple agent types with different characteristics: 1. In process vs remote. 1. Remote agents that store and maintain conversation state in the service vs those that require the caller to provide conversation state on each invocation. @@ -47,13 +47,13 @@ This is achieved via a simple mechanism of: 1. Passing additional context to the agent per invocation. With our current `AgentThread` implementation, when an agent is invoked, all input and output messages are already passed to the `AgentThread` -and can be made availble to any components attached to the `AgentThread`. +and can be made available to any components attached to the `AgentThread`. Where agents are remote/external and manage conversation state in the service, passing the messages to the `AgentThread` may not have any affect on the thread in the service. This is OK, since the service will have already updated the thread during the remote invocation. It does however, still allow us to subscribe to messages in any attached components. -For the second requirement of getting addtional context per invocation, the agent may ask the thread passed to it, to in turn ask -each of the components attached to it, to provide context ot pass to the Agent. +For the second requirement of getting additional context per invocation, the agent may ask the thread passed to it, to in turn ask +each of the components attached to it, to provide context to pass to the Agent. This enables the component to provide memories that it contains to the Agent as needed. Different memory capabilities can be built using separate components. Each component would have the following characteristics: From 9e488a98af9e3158eb8ba2fe7753ce7e7441e10d Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 17 Apr 2025 15:22:53 +0100 Subject: [PATCH 23/46] Fix experimental flags Move mem0 component to core. Make mem0 component trimming compatible. --- dotnet/Directory.Packages.props | 2 +- dotnet/src/Agents/Abstractions/AgentThread.cs | 18 ++++++------- dotnet/src/Agents/Core/ChatCompletionAgent.cs | 8 +++--- .../src/Agents/OpenAI/OpenAIAssistantAgent.cs | 8 +++--- .../ChatCompletionAgentWithMemoryTests.cs | 2 +- .../Memory/ConversationStateExtension.cs | 2 +- .../ConversationStateExtensionsManager.cs | 2 +- ...rsationStateExtensionsManagerExtensions.cs | 2 +- ...rsationStateExtensionsManagerExtensions.cs | 2 +- .../Memory/Mem0/Mem0Client.cs | 27 ++++++++++++------- .../Memory/Mem0/Mem0MemoryComponent.cs | 12 +++++++++ .../Memory/Mem0/Mem0MemoryComponentOptions.cs | 0 .../SemanticKernel.Core.csproj | 1 + 13 files changed, 54 insertions(+), 32 deletions(-) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/Mem0/Mem0Client.cs (86%) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/Mem0/Mem0MemoryComponent.cs (87%) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/Mem0/Mem0MemoryComponentOptions.cs (100%) diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 6daf4e67a673..5f00aee75a81 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -109,7 +109,7 @@ - + diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index 46d8ac33c4ea..3ec711d428c2 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -30,7 +30,7 @@ public abstract class AgentThread /// /// Gets or sets the container for conversation state extension components that manages their lifecycle and interactions. /// - [Experimental("SKEXP0130")] + [Experimental("SKEXP0110")] public virtual ConversationStateExtensionsManager ThreadExtensionsManager { get; init; } = new ConversationStateExtensionsManager(); /// @@ -42,7 +42,7 @@ public abstract class AgentThread /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the end of each service call. /// In a client application, this might be when the user closes the chat window or the application. /// - [Experimental("SKEXP0130")] + [Experimental("SKEXP0110")] public virtual Task OnSuspendAsync(CancellationToken cancellationToken = default) { return this.ThreadExtensionsManager.OnSuspendAsync(this.Id, cancellationToken); @@ -57,7 +57,7 @@ public virtual Task OnSuspendAsync(CancellationToken cancellationToken = default /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the start of each service call where a previous conversation is being continued. /// In a client application, this might be when the user re-opens the chat window to resume a conversation after having previously closed it. /// - [Experimental("SKEXP0130")] + [Experimental("SKEXP0110")] public virtual Task OnResumeAsync(CancellationToken cancellationToken = default) { if (this.IsDeleted) @@ -93,9 +93,9 @@ protected internal virtual async Task CreateAsync(CancellationToken cancellation this.Id = await this.CreateInternalAsync(cancellationToken: cancellationToken).ConfigureAwait(false); -#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.ThreadExtensionsManager.OnThreadCreatedAsync(this.Id!, cancellationToken).ConfigureAwait(false); -#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. } /// @@ -116,9 +116,9 @@ public virtual async Task DeleteAsync(CancellationToken cancellationToken = defa throw new InvalidOperationException("This thread cannot be deleted, since it has not been created."); } -#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.ThreadExtensionsManager.OnThreadDeleteAsync(this.Id!, cancellationToken).ConfigureAwait(false); -#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.DeleteInternalAsync(cancellationToken).ConfigureAwait(false); @@ -147,9 +147,9 @@ internal virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Can await this.CreateAsync(cancellationToken).ConfigureAwait(false); } -#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.ThreadExtensionsManager.OnNewMessageAsync(newMessage, cancellationToken).ConfigureAwait(false); -#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.OnNewMessageInternalAsync(newMessage, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 8fa2461c4bdc..97e160299c1a 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -77,10 +77,10 @@ public override async IAsyncEnumerable> In var kernel = (options?.Kernel ?? this.Kernel).Clone(); // Get the conversation state extensions context contributions and register plugins from the extensions. -#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); chatHistoryAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); -#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Invoke Chat Completion with the updated chat history. var chatHistory = new ChatHistory(); @@ -168,10 +168,10 @@ public override async IAsyncEnumerable> InvokeAsync var kernel = (options?.Kernel ?? this.Kernel).Clone(); // Get the conversation state extensions context contributions and register plugins from the extensions. -#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); openAIAssistantAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); -#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var invokeResults = ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), @@ -559,10 +559,10 @@ public async IAsyncEnumerable> In var kernel = (options?.Kernel ?? this.Kernel).Clone(); // Get the conversation state extensions context contributions and register plugins from the extensions. -#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); openAIAssistantAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); -#pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Create options that use the RunCreationOptions from the options param if provided or // falls back to creating a new RunCreationOptions if additional instructions is provided diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index 23ee3f5fd2ab..08cbd30d192e 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -284,7 +284,7 @@ public virtual async Task RagComponentWithMatchesOnDemandAsync() agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); // Act - Invoke the agent with a question - var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread, new() { KernelArguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })}); + var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread, new() { KernelArguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }) }); var results1 = await asyncResults1.ToListAsync(); // Assert - Check if the response contains the expected value from the database. diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs index 77e88dfb9cdd..0568efcc4210 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs @@ -17,7 +17,7 @@ namespace Microsoft.SemanticKernel; /// to a conversation, listen to changes in the conversation state, and provide additional context to /// the AI model in use just before invocation. /// -[Experimental("SKEXP0130")] +[Experimental("SKEXP0001")] public abstract class ConversationStateExtension { /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs index efe53bc4e6ad..0df352fa4ba5 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs @@ -14,7 +14,7 @@ namespace Microsoft.SemanticKernel; /// /// A container class for objects that manages their lifecycle and interactions. /// -[Experimental("SKEXP0130")] +[Experimental("SKEXP0001")] public class ConversationStateExtensionsManager { private readonly List _conversationStateExtensions = new(); diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs index bd282a9bf963..2663b082c82d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs @@ -12,7 +12,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods for . /// -[Experimental("SKEXP0130")] +[Experimental("SKEXP0001")] public static class ConversationStateExtensionsManagerExtensions { /// diff --git a/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs b/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs index ceb2e37f8a27..d6b148d7dcec 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods for . /// -[Experimental("SKEXP0130")] +[Experimental("SKEXP0001")] public static class ConversationStateExtensionsManagerExtensions { /// diff --git a/dotnet/src/Memory/Memory/Mem0/Mem0Client.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs similarity index 86% rename from dotnet/src/Memory/Memory/Mem0/Mem0Client.cs rename to dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs index 1fc9990bbb77..096f293e409d 100644 --- a/dotnet/src/Memory/Memory/Mem0/Mem0Client.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs @@ -48,13 +48,13 @@ public async Task> SearchAsync(string? applicationId, string }; // Search. - using var content = new StringContent(JsonSerializer.Serialize(searchRequest), Encoding.UTF8, "application/json"); + using var content = new StringContent(JsonSerializer.Serialize(searchRequest, Mem0SourceGenerationContext.Default.SearchRequest), Encoding.UTF8, "application/json"); var responseMessage = await this._httpClient.PostAsync(s_searchUri, content).ConfigureAwait(false); responseMessage.EnsureSuccessStatusCode(); // Process response. var response = await responseMessage.Content.ReadAsStringAsync().ConfigureAwait(false); - var searchResponseItems = JsonSerializer.Deserialize(response); + var searchResponseItems = JsonSerializer.Deserialize(response, Mem0SourceGenerationContext.Default.SearchResponseItemArray); return searchResponseItems?.Select(item => item.Memory) ?? []; } @@ -84,7 +84,7 @@ public async Task CreateMemoryAsync(string? applicationId, string? agentId, stri } }; - using var content = new StringContent(JsonSerializer.Serialize(createMemoryRequest), Encoding.UTF8, "application/json"); + using var content = new StringContent(JsonSerializer.Serialize(createMemoryRequest, Mem0SourceGenerationContext.Default.CreateMemoryRequest), Encoding.UTF8, "application/json"); var responseMessage = await this._httpClient.PostAsync(s_createMemoryUri, content).ConfigureAwait(false); responseMessage.EnsureSuccessStatusCode(); } @@ -105,7 +105,7 @@ public async Task ClearMemoryAsync(string? applicationId, string? agentId, strin responseMessage.EnsureSuccessStatusCode(); } - private sealed class CreateMemoryRequest + internal sealed class CreateMemoryRequest { [JsonPropertyName("app_id")] public string? AppId { get; set; } @@ -119,7 +119,7 @@ private sealed class CreateMemoryRequest public CreateMemoryMemory[] Messages { get; set; } = []; } - private sealed class CreateMemoryMemory + internal sealed class CreateMemoryMemory { [JsonPropertyName("content")] public string Content { get; set; } = string.Empty; @@ -127,7 +127,7 @@ private sealed class CreateMemoryMemory public string Role { get; set; } = string.Empty; } - private sealed class SearchRequest + internal sealed class SearchRequest { [JsonPropertyName("app_id")] public string? AppId { get; set; } @@ -141,8 +141,7 @@ private sealed class SearchRequest public string Query { get; set; } = string.Empty; } -#pragma warning disable CA1812 // Avoid uninstantiated internal classes - private sealed class SearchResponseItem + internal sealed class SearchResponseItem { [JsonPropertyName("id")] public string Id { get; set; } = string.Empty; @@ -167,5 +166,15 @@ private sealed class SearchResponseItem [JsonPropertyName("run_id")] public string RunId { get; set; } = string.Empty; } -#pragma warning restore CA1812 // Avoid uninstantiated internal classes +} + +[JsonSourceGenerationOptions(JsonSerializerDefaults.General, + UseStringEnumConverter = false, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = false)] +[JsonSerializable(typeof(Mem0Client.CreateMemoryRequest))] +[JsonSerializable(typeof(Mem0Client.SearchRequest))] +[JsonSerializable(typeof(Mem0Client.SearchResponseItem[]))] +internal partial class Mem0SourceGenerationContext : JsonSerializerContext +{ } diff --git a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs similarity index 87% rename from dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs rename to dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs index da3cc7689b55..a5b804d41c7e 100644 --- a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Threading; @@ -14,6 +15,7 @@ namespace Microsoft.SemanticKernel.Memory; /// A component that listens to messages added to the conversation thread, and automatically captures /// information about the user. It is also able to retrieve this information and add it to the AI invocation context. /// +[Experimental("SKEXP0130")] public class Mem0MemoryComponent : ConversationStateExtension { private readonly string? _applicationId; @@ -31,6 +33,16 @@ public class Mem0MemoryComponent : ConversationStateExtension /// /// The HTTP client used for making requests. /// Options for configuring the component. + /// + /// The base address of the required mem0 service and any authentication headers should be set on the + /// already when provided here. E.g.: + /// + /// using var httpClient = new HttpClient(); + /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); + /// httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "<Your APIKey>"); + /// new Mem0Client(httpClient); + /// + /// public Mem0MemoryComponent(HttpClient httpClient, Mem0MemoryComponentOptions? options = default) { Verify.NotNull(httpClient); diff --git a/dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs similarity index 100% rename from dotnet/src/Memory/Memory/Mem0/Mem0MemoryComponentOptions.cs rename to dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs diff --git a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj index 2a5d5d03d961..652dd7181965 100644 --- a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj +++ b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj @@ -34,6 +34,7 @@ + From 61dba2fad2c96aec917dc2c6c43c5922bde8af82 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 17 Apr 2025 15:40:32 +0100 Subject: [PATCH 24/46] Fix warning about preview dependency. --- dotnet/src/Memory/Memory/Memory.csproj | 1 + dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dotnet/src/Memory/Memory/Memory.csproj b/dotnet/src/Memory/Memory/Memory.csproj index 9654db9ab1d5..4caeec7c0426 100644 --- a/dotnet/src/Memory/Memory/Memory.csproj +++ b/dotnet/src/Memory/Memory/Memory.csproj @@ -6,6 +6,7 @@ net8.0;netstandard2.0 false alpha + $(NoWarn);NU5104 diff --git a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj index 652dd7181965..6539a03d547a 100644 --- a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj +++ b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj @@ -6,7 +6,7 @@ Microsoft.SemanticKernel net8.0;netstandard2.0 true - $(NoWarn);SKEXP0001,SKEXP0120 + $(NoWarn);NU5104;SKEXP0001,SKEXP0120 true true From 018c821851d3bb5513bd4f653d45ad32f50a8541 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 17 Apr 2025 15:58:48 +0100 Subject: [PATCH 25/46] Rename ThreadExtensionsManager to StateExtensions --- dotnet/src/Agents/Abstractions/AgentThread.cs | 12 +++++------ .../Abstractions/Agents.Abstractions.csproj | 1 - dotnet/src/Agents/Core/ChatCompletionAgent.cs | 8 ++++---- .../src/Agents/OpenAI/OpenAIAssistantAgent.cs | 8 ++++---- .../ChatCompletionAgentWithMemoryTests.cs | 20 +++++++++---------- .../OpenAIAssistantAgentWithMemoryTests.cs.cs | 4 ++-- 6 files changed, 26 insertions(+), 27 deletions(-) diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index 3ec711d428c2..e5dfa1275657 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -31,7 +31,7 @@ public abstract class AgentThread /// Gets or sets the container for conversation state extension components that manages their lifecycle and interactions. /// [Experimental("SKEXP0110")] - public virtual ConversationStateExtensionsManager ThreadExtensionsManager { get; init; } = new ConversationStateExtensionsManager(); + public virtual ConversationStateExtensionsManager StateExtensions { get; init; } = new ConversationStateExtensionsManager(); /// /// Called when the current conversion is temporarily suspended and any state should be saved. @@ -45,7 +45,7 @@ public abstract class AgentThread [Experimental("SKEXP0110")] public virtual Task OnSuspendAsync(CancellationToken cancellationToken = default) { - return this.ThreadExtensionsManager.OnSuspendAsync(this.Id, cancellationToken); + return this.StateExtensions.OnSuspendAsync(this.Id, cancellationToken); } /// @@ -70,7 +70,7 @@ public virtual Task OnResumeAsync(CancellationToken cancellationToken = default) throw new InvalidOperationException("This thread cannot be resumed, since it has not been created."); } - return this.ThreadExtensionsManager.OnSuspendAsync(this.Id, cancellationToken); + return this.StateExtensions.OnSuspendAsync(this.Id, cancellationToken); } /// @@ -94,7 +94,7 @@ protected internal virtual async Task CreateAsync(CancellationToken cancellation this.Id = await this.CreateInternalAsync(cancellationToken: cancellationToken).ConfigureAwait(false); #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - await this.ThreadExtensionsManager.OnThreadCreatedAsync(this.Id!, cancellationToken).ConfigureAwait(false); + await this.StateExtensions.OnThreadCreatedAsync(this.Id!, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. } @@ -117,7 +117,7 @@ public virtual async Task DeleteAsync(CancellationToken cancellationToken = defa } #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - await this.ThreadExtensionsManager.OnThreadDeleteAsync(this.Id!, cancellationToken).ConfigureAwait(false); + await this.StateExtensions.OnThreadDeleteAsync(this.Id!, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.DeleteInternalAsync(cancellationToken).ConfigureAwait(false); @@ -148,7 +148,7 @@ internal virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Can } #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - await this.ThreadExtensionsManager.OnNewMessageAsync(newMessage, cancellationToken).ConfigureAwait(false); + await this.StateExtensions.OnNewMessageAsync(newMessage, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.OnNewMessageInternalAsync(newMessage, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/Agents/Abstractions/Agents.Abstractions.csproj b/dotnet/src/Agents/Abstractions/Agents.Abstractions.csproj index 0ed1395cb54f..0c393b166bc2 100644 --- a/dotnet/src/Agents/Abstractions/Agents.Abstractions.csproj +++ b/dotnet/src/Agents/Abstractions/Agents.Abstractions.csproj @@ -32,7 +32,6 @@ - diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 97e160299c1a..f057f43333b6 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -78,8 +78,8 @@ public override async IAsyncEnumerable> In // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await chatHistoryAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); - chatHistoryAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); + var extensionsContext = await chatHistoryAgentThread.StateExtensions.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + chatHistoryAgentThread.StateExtensions.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Invoke Chat Completion with the updated chat history. @@ -169,8 +169,8 @@ public override async IAsyncEnumerable> InvokeAsync // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); - openAIAssistantAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); + var extensionsContext = await openAIAssistantAgentThread.StateExtensions.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + openAIAssistantAgentThread.StateExtensions.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var invokeResults = ActivityExtensions.RunWithActivityAsync( @@ -560,8 +560,8 @@ public async IAsyncEnumerable> In // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await openAIAssistantAgentThread.ThreadExtensionsManager.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); - openAIAssistantAgentThread.ThreadExtensionsManager.RegisterPlugins(kernel); + var extensionsContext = await openAIAssistantAgentThread.StateExtensions.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + openAIAssistantAgentThread.StateExtensions.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Create options that use the RunCreationOptions from the options param if provided or diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index 08cbd30d192e..a25f49c25f30 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -47,10 +47,10 @@ public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() var mem0Component = new Mem0MemoryComponent(httpClient, new() { UserId = "U1" }); var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.ThreadExtensionsManager.RegisterThreadExtension(mem0Component); + agentThread1.StateExtensions.RegisterThreadExtension(mem0Component); var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.ThreadExtensionsManager.RegisterThreadExtension(mem0Component); + agentThread2.StateExtensions.RegisterThreadExtension(mem0Component); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); @@ -75,10 +75,10 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); + agentThread1.StateExtensions.RegisterThreadExtension(memoryComponent); var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); + agentThread2.StateExtensions.RegisterThreadExtension(memoryComponent); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); @@ -109,14 +109,14 @@ public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserIn // Act - First invocation with first thread. var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.ThreadExtensionsManager.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + agentThread1.StateExtensions.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); var results1 = await asyncResults1.ToListAsync(); // Act - Second invocation with second thread. var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.ThreadExtensionsManager.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + agentThread2.StateExtensions.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); var results2 = await asyncResults2.ToListAsync(); @@ -155,7 +155,7 @@ public virtual async Task CapturesMemoriesWhileUsingDIAsync() builder.Services.AddTransient((sp) => { var thread = new ChatHistoryAgentThread(); - thread.ThreadExtensionsManager.RegisterThreadExtensionsFromContainer(sp); + thread.StateExtensions.RegisterThreadExtensionsFromContainer(sp); return thread; }); var host = builder.Build(); @@ -209,7 +209,7 @@ public virtual async Task RagComponentWithoutMatchesAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + agentThread.StateExtensions.RegisterThreadExtension(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); @@ -242,7 +242,7 @@ public virtual async Task RagComponentWithMatchesAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + agentThread.StateExtensions.RegisterThreadExtension(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); @@ -281,7 +281,7 @@ public virtual async Task RagComponentWithMatchesOnDemandAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + agentThread.StateExtensions.RegisterThreadExtension(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread, new() { KernelArguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }) }); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs index 0cb0b238c6ee..2bb46da17cdd 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs @@ -20,10 +20,10 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); var agentThread1 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); - agentThread1.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); + agentThread1.StateExtensions.RegisterThreadExtension(memoryComponent); var agentThread2 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); - agentThread2.ThreadExtensionsManager.RegisterThreadExtension(memoryComponent); + agentThread2.StateExtensions.RegisterThreadExtension(memoryComponent); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); From 081472c5d3b2c41c34b29fbf9970a8215eeb1f04 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 17 Apr 2025 16:06:27 +0100 Subject: [PATCH 26/46] Update ADR header --- docs/decisions/00NN-agents-with-memory.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/decisions/00NN-agents-with-memory.md b/docs/decisions/00NN-agents-with-memory.md index dc08e9a338ef..bbb022e1b2db 100644 --- a/docs/decisions/00NN-agents-with-memory.md +++ b/docs/decisions/00NN-agents-with-memory.md @@ -1,11 +1,11 @@ --- # These are optional elements. Feel free to remove any of them. -status: {proposed | rejected | accepted | deprecated | … | superseded by [ADR-0001](0001-madr-architecture-decisions.md)} -contact: {person proposing the ADR} -date: {YYYY-MM-DD when the decision was last updated} -deciders: {list everyone involved in the decision} -consulted: {list everyone whose opinions are sought (typically subject-matter experts); and with whom there is a two-way communication} -informed: {list everyone who is kept up-to-date on progress; and with whom there is a one-way communication} +status: accepted +contact: westey-m +date: 2025-04-17 +deciders: westey-m, markwallace-microsoft, alliscode, TaoChenOSU, moonbox3, crickman +consulted: westey-m, markwallace-microsoft, alliscode, TaoChenOSU, moonbox3, crickman +informed: westey-m, markwallace-microsoft, alliscode, TaoChenOSU, moonbox3, crickman --- # Agents with Memory From 3a48d3f003d5b821c2a1f525a37f80fb60202dd1 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 17 Apr 2025 17:09:19 +0100 Subject: [PATCH 27/46] Add unit tests, fix bugs, rename long props. Seal manager --- dotnet/src/Agents/Abstractions/AgentThread.cs | 4 +- .../Agents/UnitTests/Core/AgentThreadTests.cs | 132 +++++++++++++ .../ChatCompletionAgentWithMemoryTests.cs | 20 +- .../OpenAIAssistantAgentWithMemoryTests.cs.cs | 4 +- .../ConversationStateExtensionsManager.cs | 64 +++---- .../Memory/ConversationStateExtensionTests.cs | 81 ++++++++ ...onStateExtensionsManagerExtensionsTests.cs | 82 ++++++++ ...ConversationStateExtensionsManagerTests.cs | 176 ++++++++++++++++++ 8 files changed, 517 insertions(+), 46 deletions(-) create mode 100644 dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs create mode 100644 dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs create mode 100644 dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index e5dfa1275657..330d26d5ddb3 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -65,12 +65,12 @@ public virtual Task OnResumeAsync(CancellationToken cancellationToken = default) throw new InvalidOperationException("This thread has been deleted and cannot be used anymore."); } - if (this.Id is not null) + if (this.Id is null) { throw new InvalidOperationException("This thread cannot be resumed, since it has not been created."); } - return this.StateExtensions.OnSuspendAsync(this.Id, cancellationToken); + return this.StateExtensions.OnResumeAsync(this.Id, cancellationToken); } /// diff --git a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs index c8e0c1884a87..4a1fb2db791a 100644 --- a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs @@ -3,8 +3,11 @@ using System; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.ChatCompletion; +using Moq; using Xunit; namespace SemanticKernel.Agents.UnitTests.Core; @@ -141,6 +144,135 @@ public async Task OnNewMessageShouldThrowIfThreadDeletedAsync() Assert.Equal(0, thread.OnNewMessageInternalAsyncCount); } + /// + /// Tests that the method throws an InvalidOperationException if the thread is not yet created. + /// + [Fact] + public async Task OnResumeShouldThrowIfThreadNotCreatedAsync() + { + // Arrange + var thread = new TestAgentThread(); + + // Act & Assert + await Assert.ThrowsAsync(() => thread.OnResumeAsync()); + } + + /// + /// Tests that the method throws an InvalidOperationException if the thread is deleted. + /// + [Fact] + public async Task OnResumeShouldThrowIfThreadDeletedAsync() + { + // Arrange + var thread = new TestAgentThread(); + await thread.CreateAsync(); + await thread.DeleteAsync(); + + // Act & Assert + await Assert.ThrowsAsync(() => thread.OnResumeAsync()); + } + + /// + /// Tests that the method + /// calls each registered extension in turn. + /// + [Fact] + public async Task OnSuspendShouldCallOnSuspendOnRegisteredExtensionsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockExtension = new Mock(); + thread.StateExtensions.Add(mockExtension.Object); + await thread.CreateAsync(); + + // Act. + await thread.OnSuspendAsync(); + + // Assert. + mockExtension.Verify(x => x.OnSuspendAsync("test-thread-id", It.IsAny()), Times.Once); + } + + /// + /// Tests that the method + /// calls each registered extension in turn. + /// + [Fact] + public async Task OnResumeShouldCallOnResumeOnRegisteredExtensionsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockExtension = new Mock(); + thread.StateExtensions.Add(mockExtension.Object); + await thread.CreateAsync(); + + // Act. + await thread.OnResumeAsync(); + + // Assert. + mockExtension.Verify(x => x.OnResumeAsync("test-thread-id", It.IsAny()), Times.Once); + } + + /// + /// Tests that the method + /// calls each registered extension in turn. + /// + [Fact] + public async Task CreateShouldCallOnThreadCreatedOnRegisteredExtensionsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockExtension = new Mock(); + thread.StateExtensions.Add(mockExtension.Object); + + // Act. + await thread.CreateAsync(); + + // Assert. + mockExtension.Verify(x => x.OnThreadCreatedAsync("test-thread-id", It.IsAny()), Times.Once); + } + + /// + /// Tests that the method + /// calls each registered extension in turn. + /// + [Fact] + public async Task DeleteShouldCallOnThreadDeleteOnRegisteredExtensionsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockExtension = new Mock(); + thread.StateExtensions.Add(mockExtension.Object); + await thread.CreateAsync(); + + // Act. + await thread.DeleteAsync(); + + // Assert. + mockExtension.Verify(x => x.OnThreadDeleteAsync("test-thread-id", It.IsAny()), Times.Once); + } + + /// + /// Tests that the method + /// calls each registered extension in turn. + /// + [Fact] + public async Task OnNewMessageShouldCallOnNewMessageOnRegisteredExtensionsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockExtension = new Mock(); + thread.StateExtensions.Add(mockExtension.Object); + var message = new ChatMessageContent(AuthorRole.User, "Test Message."); + + await thread.CreateAsync(); + + // Act. + await thread.OnNewMessageAsync(message); + + // Assert. + mockExtension.Verify(x => x.OnNewMessageAsync(It.Is(x => x.Text == "Test Message." && x.Role == ChatRole.User), It.IsAny()), Times.Once); + } + private sealed class TestAgentThread : AgentThread { public int CreateInternalAsyncCount { get; private set; } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index a25f49c25f30..ae18168aa22e 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -47,10 +47,10 @@ public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() var mem0Component = new Mem0MemoryComponent(httpClient, new() { UserId = "U1" }); var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateExtensions.RegisterThreadExtension(mem0Component); + agentThread1.StateExtensions.Add(mem0Component); var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateExtensions.RegisterThreadExtension(mem0Component); + agentThread2.StateExtensions.Add(mem0Component); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); @@ -75,10 +75,10 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateExtensions.RegisterThreadExtension(memoryComponent); + agentThread1.StateExtensions.Add(memoryComponent); var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateExtensions.RegisterThreadExtension(memoryComponent); + agentThread2.StateExtensions.Add(memoryComponent); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); @@ -109,14 +109,14 @@ public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserIn // Act - First invocation with first thread. var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateExtensions.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + agentThread1.StateExtensions.Add(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); var results1 = await asyncResults1.ToListAsync(); // Act - Second invocation with second thread. var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateExtensions.RegisterThreadExtension(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + agentThread2.StateExtensions.Add(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); var results2 = await asyncResults2.ToListAsync(); @@ -155,7 +155,7 @@ public virtual async Task CapturesMemoriesWhileUsingDIAsync() builder.Services.AddTransient((sp) => { var thread = new ChatHistoryAgentThread(); - thread.StateExtensions.RegisterThreadExtensionsFromContainer(sp); + thread.StateExtensions.AddFromServiceProvider(sp); return thread; }); var host = builder.Build(); @@ -209,7 +209,7 @@ public virtual async Task RagComponentWithoutMatchesAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.StateExtensions.RegisterThreadExtension(ragComponent); + agentThread.StateExtensions.Add(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); @@ -242,7 +242,7 @@ public virtual async Task RagComponentWithMatchesAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.StateExtensions.RegisterThreadExtension(ragComponent); + agentThread.StateExtensions.Add(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); @@ -281,7 +281,7 @@ public virtual async Task RagComponentWithMatchesOnDemandAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.StateExtensions.RegisterThreadExtension(ragComponent); + agentThread.StateExtensions.Add(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread, new() { KernelArguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }) }); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs index 2bb46da17cdd..54ac5c508ee3 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs @@ -20,10 +20,10 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); var agentThread1 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); - agentThread1.StateExtensions.RegisterThreadExtension(memoryComponent); + agentThread1.StateExtensions.Add(memoryComponent); var agentThread2 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); - agentThread2.StateExtensions.RegisterThreadExtension(memoryComponent); + agentThread2.StateExtensions.Add(memoryComponent); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs index 0df352fa4ba5..5d88420f2ac2 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs @@ -15,16 +15,16 @@ namespace Microsoft.SemanticKernel; /// A container class for objects that manages their lifecycle and interactions. /// [Experimental("SKEXP0001")] -public class ConversationStateExtensionsManager +public sealed class ConversationStateExtensionsManager { - private readonly List _conversationStateExtensions = new(); + private readonly List _extensions = new(); private List? _currentAIFunctions = null; /// /// Gets the list of registered conversation state extensions. /// - public virtual IReadOnlyList ConversationStateExtensions => this._conversationStateExtensions; + public IReadOnlyList Extensions => this._extensions; /// /// Initializes a new instance of the class. @@ -33,17 +33,26 @@ public ConversationStateExtensionsManager() { } + /// + /// Initializes a new instance of the class with the specified conversation state extensions. + /// + /// The conversation state extensions to add to the manager. + public ConversationStateExtensionsManager(IEnumerable conversationtStateExtensions) + { + this._extensions.AddRange(conversationtStateExtensions); + } + /// /// Gets the list of AI functions that all contained extension component expose /// and which should be used by the consuming AI when using these components. /// - public virtual IReadOnlyCollection AIFunctions + public IReadOnlyCollection AIFunctions { get { if (this._currentAIFunctions == null) { - this._currentAIFunctions = this.ConversationStateExtensions.SelectMany(ConversationStateExtensions => ConversationStateExtensions.AIFunctions).ToList(); + this._currentAIFunctions = this.Extensions.SelectMany(ConversationStateExtensions => ConversationStateExtensions.AIFunctions).ToList(); } return this._currentAIFunctions; @@ -51,33 +60,24 @@ public virtual IReadOnlyCollection AIFunctions } /// - /// Initializes a new instance of the class with the specified conversation state extensions. - /// - /// The conversation state extensions to add to the manager. - public ConversationStateExtensionsManager(IEnumerable conversationtStateExtensions) - { - this._conversationStateExtensions.AddRange(conversationtStateExtensions); - } - - /// - /// Registers a new conversation state extension. + /// Adds a new conversation state extension. /// /// The conversation state extension to register. - public virtual void RegisterThreadExtension(ConversationStateExtension conversationtStateExtension) + public void Add(ConversationStateExtension conversationtStateExtension) { - this._conversationStateExtensions.Add(conversationtStateExtension); + this._extensions.Add(conversationtStateExtension); this._currentAIFunctions = null; } /// - /// Registers all conversation state extensions registered on the provided dependency injection service provider. + /// Adds all conversation state extensions registered on the provided dependency injection service provider. /// /// The dependency injection service provider to read conversation state extensions from. - public virtual void RegisterThreadExtensionsFromContainer(IServiceProvider serviceProvider) + public void AddFromServiceProvider(IServiceProvider serviceProvider) { foreach (var extension in serviceProvider.GetServices()) { - this.RegisterThreadExtension(extension); + this.Add(extension); } this._currentAIFunctions = null; } @@ -88,9 +88,9 @@ public virtual void RegisterThreadExtensionsFromContainer(IServiceProvider servi /// The ID of the new thread. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. - public virtual async Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) + public async Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnThreadCreatedAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Extensions.Select(x => x.OnThreadCreatedAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -99,9 +99,9 @@ public virtual async Task OnThreadCreatedAsync(string? threadId, CancellationTok /// The id of the thread that will be deleted. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. - public virtual async Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) + public async Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnThreadDeleteAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Extensions.Select(x => x.OnThreadDeleteAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -110,9 +110,9 @@ public virtual async Task OnThreadDeleteAsync(string threadId, CancellationToken /// The new message. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. - public virtual async Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) + public async Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnNewMessageAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Extensions.Select(x => x.OnNewMessageAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -121,9 +121,9 @@ public virtual async Task OnNewMessageAsync(ChatMessage newMessage, Cancellation /// The most recent messages that the AI is being invoked with. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. - public virtual async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) { - var subContexts = await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnAIInvocationAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); + var subContexts = await Task.WhenAll(this.Extensions.Select(x => x.OnAIInvocationAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); return string.Join("\n", subContexts); } @@ -137,9 +137,9 @@ public virtual async Task OnAIInvocationAsync(ICollection n /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the end of each service call. /// In a client application, this might be when the user closes the chat window or the application. /// - public virtual async Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) + public async Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnSuspendAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Extensions.Select(x => x.OnSuspendAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -152,8 +152,8 @@ public virtual async Task OnSuspendAsync(string? threadId, CancellationToken can /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the start of each service call where a previous conversation is being continued. /// In a client application, this might be when the user re-opens the chat window to resume a conversation after having previously closed it. /// - public virtual async Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) + public async Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.ConversationStateExtensions.Select(x => x.OnResumeAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Extensions.Select(x => x.OnResumeAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs new file mode 100644 index 000000000000..152716ef3df2 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for the class. +/// +public class ConversationStateExtensionTests +{ + [Fact] + public void AIFunctionsBaseImplementationIsEmpty() + { + // Arrange. + var mockExtension = new Mock() { CallBase = true }; + + // Act. + var functions = mockExtension.Object.AIFunctions; + + // Assert. + Assert.NotNull(functions); + Assert.Empty(functions); + } + + [Fact] + public async Task OnThreadCreatedBaseImplementationSucceeds() + { + // Arrange. + var mockExtension = new Mock() { CallBase = true }; + + // Act & Assert. + await mockExtension.Object.OnThreadCreatedAsync("threadId", CancellationToken.None); + } + + [Fact] + public async Task OnNewMessageBaseImplementationSucceeds() + { + // Arrange. + var mockExtension = new Mock() { CallBase = true }; + var newMessage = new ChatMessage(ChatRole.User, "Hello"); + + // Act & Assert. + await mockExtension.Object.OnNewMessageAsync(newMessage, CancellationToken.None); + } + + [Fact] + public async Task OnThreadDeleteBaseImplementationSucceeds() + { + // Arrange. + var mockExtension = new Mock() { CallBase = true }; + + // Act & Assert. + await mockExtension.Object.OnThreadDeleteAsync("threadId", CancellationToken.None); + } + + [Fact] + public async Task OnSuspendBaseImplementationSucceeds() + { + // Arrange. + var mockExtension = new Mock() { CallBase = true }; + + // Act & Assert. + await mockExtension.Object.OnSuspendAsync("threadId", CancellationToken.None); + } + + [Fact] + public async Task OnResumeBaseImplementationSucceeds() + { + // Arrange. + var mockExtension = new Mock() { CallBase = true }; + + // Act & Assert. + await mockExtension.Object.OnResumeAsync("threadId", CancellationToken.None); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs new file mode 100644 index 000000000000..81690f7314c0 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Tests for the ConversationStateExtensionsManagerExtensions class. +/// +public class ConversationStateExtensionsManagerExtensionsTests +{ + [Fact] + public async Task OnNewMessageShouldConvertMessageAndInvokeRegisteredExtensionsAsync() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var extensionMock = new Mock(); + manager.Add(extensionMock.Object); + + var newMessage = new ChatMessageContent(AuthorRole.User, "Test Message"); + + // Act + await manager.OnNewMessageAsync(newMessage); + + // Assert + extensionMock.Verify(x => x.OnNewMessageAsync(It.Is(m => m.Text == "Test Message" && m.Role == ChatRole.User), It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnAIInvocationShouldConvertMessagesInvokeRegisteredExtensionsAsync() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var extensionMock = new Mock(); + manager.Add(extensionMock.Object); + + var messages = new List + { + new(AuthorRole.User, "Message 1"), + new(AuthorRole.Assistant, "Message 2") + }; + + extensionMock + .Setup(x => x.OnAIInvocationAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync("Combined Context"); + + // Act + var result = await manager.OnAIInvocationAsync(messages); + + // Assert + Assert.Equal("Combined Context", result); + extensionMock.Verify(x => x.OnAIInvocationAsync(It.Is>(m => m.Count == 2), It.IsAny()), Times.Once); + } + + [Fact] + public void RegisterPluginsShouldConvertAIFunctionsAndRegisterAsPlugins() + { + // Arrange + var kernel = new Kernel(); + var manager = new ConversationStateExtensionsManager(); + var extensionMock = new Mock(); + var aiFunctionMock = AIFunctionFactory.Create(() => "Hello", "TestFunction"); + extensionMock + .Setup(x => x.AIFunctions) + .Returns(new List { aiFunctionMock }); + manager.Add(extensionMock.Object); + + // Act + manager.RegisterPlugins(kernel); + + // Assert + var registeredFunction = kernel.Plugins.GetFunction("Tools", aiFunctionMock.Name); + Assert.NotNull(registeredFunction); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs new file mode 100644 index 000000000000..9036496b4df1 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for the class. +/// +public class ConversationStateExtensionsManagerTests +{ + [Fact] + public void ConstructorShouldInitializeEmptyExtensionsList() + { + // Act + var manager = new ConversationStateExtensionsManager(); + + // Assert + Assert.NotNull(manager.Extensions); + Assert.Empty(manager.Extensions); + } + + [Fact] + public void ConstructorShouldInitializeWithProvidedExtensions() + { + // Arrange + var mockExtension = new Mock(); + + // Act + var manager = new ConversationStateExtensionsManager(new[] { mockExtension.Object }); + + // Assert + Assert.Single(manager.Extensions); + Assert.Contains(mockExtension.Object, manager.Extensions); + } + + [Fact] + public void AddShouldRegisterNewExtension() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var mockExtension = new Mock(); + + // Act + manager.Add(mockExtension.Object); + + // Assert + Assert.Single(manager.Extensions); + Assert.Contains(mockExtension.Object, manager.Extensions); + } + + [Fact] + public void AddFromServiceProviderShouldRegisterExtensionsFromServiceProvider() + { + // Arrange + var serviceCollection = new ServiceCollection(); + var mockExtension = new Mock(); + serviceCollection.AddSingleton(mockExtension.Object); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var manager = new ConversationStateExtensionsManager(); + + // Act + manager.AddFromServiceProvider(serviceProvider); + + // Assert + Assert.Single(manager.Extensions); + Assert.Contains(mockExtension.Object, manager.Extensions); + } + + [Fact] + public async Task OnThreadCreatedAsyncShouldCallOnThreadCreatedOnAllExtensions() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var mockExtension = new Mock(); + manager.Add(mockExtension.Object); + + // Act + await manager.OnThreadCreatedAsync("test-thread-id"); + + // Assert + mockExtension.Verify(x => x.OnThreadCreatedAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnThreadDeleteAsyncShouldCallOnThreadDeleteOnAllExtensions() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var mockExtension = new Mock(); + manager.Add(mockExtension.Object); + + // Act + await manager.OnThreadDeleteAsync("test-thread-id"); + + // Assert + mockExtension.Verify(x => x.OnThreadDeleteAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnNewMessageAsyncShouldCallOnNewMessageOnAllExtensions() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var mockExtension = new Mock(); + var message = new ChatMessage(ChatRole.User, "Hello"); + manager.Add(mockExtension.Object); + + // Act + await manager.OnNewMessageAsync(message); + + // Assert + mockExtension.Verify(x => x.OnNewMessageAsync(message, It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnAIInvocationAsyncShouldAggregateContextsFromAllExtensions() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var mockExtension1 = new Mock(); + var mockExtension2 = new Mock(); + mockExtension1.Setup(x => x.OnAIInvocationAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync("Context1"); + mockExtension2.Setup(x => x.OnAIInvocationAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync("Context2"); + manager.Add(mockExtension1.Object); + manager.Add(mockExtension2.Object); + + var messages = new List(); + + // Act + var result = await manager.OnAIInvocationAsync(messages); + + // Assert + Assert.Equal("Context1\nContext2", result); + } + + [Fact] + public async Task OnSuspendAsyncShouldCallOnSuspendOnAllExtensions() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var mockExtension = new Mock(); + manager.Add(mockExtension.Object); + + // Act + await manager.OnSuspendAsync("test-thread-id"); + + // Assert + mockExtension.Verify(x => x.OnSuspendAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnResumeAsyncShouldCallOnResumeOnAllExtensions() + { + // Arrange + var manager = new ConversationStateExtensionsManager(); + var mockExtension = new Mock(); + manager.Add(mockExtension.Object); + + // Act + await manager.OnResumeAsync("test-thread-id"); + + // Assert + mockExtension.Verify(x => x.OnResumeAsync("test-thread-id", It.IsAny()), Times.Once); + } +} From e53fbec3d28d5509e3252e4b3cffa2a367093804 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Tue, 22 Apr 2025 13:55:01 +0100 Subject: [PATCH 28/46] Rename onaiinvoke, add threadid to onnewmessage, update experimental flag. --- docs/decisions/00NN-agents-with-memory.md | 12 ++++++------ dotnet/docs/EXPERIMENTS.md | 2 +- dotnet/src/Agents/Abstractions/AgentThread.cs | 2 +- dotnet/src/Agents/Core/ChatCompletionAgent.cs | 4 ++-- dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs | 4 ++-- .../src/Agents/UnitTests/Agents.UnitTests.csproj | 2 +- .../src/Agents/UnitTests/Core/AgentThreadTests.cs | 2 +- .../src/Memory/Memory/TextRag/TextRagComponent.cs | 2 +- .../src/Memory/Memory/UserFactsMemoryComponent.cs | 4 ++-- .../Memory/ConversationStateExtension.cs | 13 +++++++------ .../Memory/ConversationStateExtensionsManager.cs | 15 ++++++++------- ...onversationStateExtensionsManagerExtensions.cs | 15 ++++++++------- ...onversationStateExtensionsManagerExtensions.cs | 2 +- .../Memory/Mem0/Mem0MemoryComponent.cs | 4 ++-- .../Memory/Mem0/Mem0MemoryComponentOptions.cs | 3 +++ .../Memory/ConversationStateExtensionTests.cs | 2 +- ...sationStateExtensionsManagerExtensionsTests.cs | 10 +++++----- .../ConversationStateExtensionsManagerTests.cs | 10 +++++----- .../SemanticKernel.UnitTests.csproj | 2 +- 19 files changed, 58 insertions(+), 52 deletions(-) diff --git a/docs/decisions/00NN-agents-with-memory.md b/docs/decisions/00NN-agents-with-memory.md index bbb022e1b2db..25adcc43fcd1 100644 --- a/docs/decisions/00NN-agents-with-memory.md +++ b/docs/decisions/00NN-agents-with-memory.md @@ -97,8 +97,8 @@ public abstract class ConversationStateExtension // OnThreadCheckpointAsync not included in initial release, maybe in future. public virtual Task OnThreadCheckpointAsync(string? threadId, CancellationToken cancellationToken = default); - public virtual Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default); - public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); + public virtual Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default); + public abstract Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default); public virtual Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default); public virtual Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default); @@ -116,11 +116,11 @@ This class allows registering components and delegating new message notification I propose to add a `ConversationStateExtensionsManager` to the `AgentThread` class, allowing us to attach components to any `AgentThread`. -When an `Agent` is invoked, we will call `OnAIInvocationAsync` on each component via the `ConversationStateExtensionsManager` to get +When an `Agent` is invoked, we will call `OnModelInvokeAsync` on each component via the `ConversationStateExtensionsManager` to get a combined set of context to pass to the agent for this invocation. This will be internal to the `Agent` class and transparent to the user. ```csharp -var additionalInstructions = await currentAgentThread.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); +var additionalInstructions = await currentAgentThread.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); ``` ## Usage examples @@ -139,12 +139,12 @@ var userFacts = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemo // Create a thread and attach a Memory Component. var agentThread1 = new ChatHistoryAgentThread(); -agentThread1.ThreadExtensionsManager.RegisterThreadExtension(userFacts); +agentThread1.ThreadExtensionsManager.Add(userFacts); var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); // Create a second thread and attach a Memory Component. var agentThread2 = new ChatHistoryAgentThread(); -agentThread2.ThreadExtensionsManager.RegisterThreadExtension(userFacts); +agentThread2.ThreadExtensionsManager.Add(userFacts); var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); // Expected response contains Caoimhe. ``` diff --git a/dotnet/docs/EXPERIMENTS.md b/dotnet/docs/EXPERIMENTS.md index 114211fe01f8..8dbfa4d746b7 100644 --- a/dotnet/docs/EXPERIMENTS.md +++ b/dotnet/docs/EXPERIMENTS.md @@ -25,7 +25,7 @@ You can use the following diagnostic IDs to ignore warnings or errors for a part | SKEXP0100 | Advanced Semantic Kernel features | | SKEXP0110 | Semantic Kernel Agents | | SKEXP0120 | Native-AOT | -| SKEXP0130 | Memory | +| SKEXP0130 | Conversation State | ## Experimental Features Tracking diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index 330d26d5ddb3..fecc94c4f88f 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -148,7 +148,7 @@ internal virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Can } #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - await this.StateExtensions.OnNewMessageAsync(newMessage, cancellationToken).ConfigureAwait(false); + await this.StateExtensions.OnNewMessageAsync(this.Id, newMessage, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.OnNewMessageInternalAsync(newMessage, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index f057f43333b6..7880a87f43ab 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -78,7 +78,7 @@ public override async IAsyncEnumerable> In // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await chatHistoryAgentThread.StateExtensions.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + var extensionsContext = await chatHistoryAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); chatHistoryAgentThread.StateExtensions.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. @@ -169,7 +169,7 @@ public override async IAsyncEnumerable> InvokeAsync // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await openAIAssistantAgentThread.StateExtensions.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + var extensionsContext = await openAIAssistantAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); openAIAssistantAgentThread.StateExtensions.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. @@ -560,7 +560,7 @@ public async IAsyncEnumerable> In // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await openAIAssistantAgentThread.StateExtensions.OnAIInvocationAsync(messages, cancellationToken).ConfigureAwait(false); + var extensionsContext = await openAIAssistantAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); openAIAssistantAgentThread.StateExtensions.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. diff --git a/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj b/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj index a0222fac89cf..c8550ebc178c 100644 --- a/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj +++ b/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj @@ -8,7 +8,7 @@ true false 12 - $(NoWarn);CA2007,CA1812,CA1861,CA1063,CS0618,VSTHRD111,SKEXP0001,SKEXP0050,SKEXP0110;OPENAI001 + $(NoWarn);CA2007,CA1812,CA1861,CA1063,CS0618,VSTHRD111,SKEXP0001,SKEXP0050,SKEXP0110;SKEXP0130;OPENAI001 diff --git a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs index 4a1fb2db791a..550c2ff3d475 100644 --- a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs @@ -270,7 +270,7 @@ public async Task OnNewMessageShouldCallOnNewMessageOnRegisteredExtensionsAsync( await thread.OnNewMessageAsync(message); // Assert. - mockExtension.Verify(x => x.OnNewMessageAsync(It.Is(x => x.Text == "Test Message." && x.Role == ChatRole.User), It.IsAny()), Times.Once); + mockExtension.Verify(x => x.OnNewMessageAsync("test-thread-id", It.Is(x => x.Text == "Test Message." && x.Role == ChatRole.User), It.IsAny()), Times.Once); } private sealed class TestAgentThread : AgentThread diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs index 17d149ef6f3c..3ec32794e899 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs +++ b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs @@ -63,7 +63,7 @@ public override IReadOnlyCollection AIFunctions } /// - public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public override async Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) { if (this.Options.SearchTime != TextRagComponentOptions.TextRagSearchTime.BeforeAIInvoke) { diff --git a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs index 1d56b5d2ffb6..76b82aafb055 100644 --- a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -118,7 +118,7 @@ public override async Task OnThreadDeleteAsync(string? threadId, CancellationTok } /// - public override async Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) + public override async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) { if (newMessage.Role == ChatRole.User && !string.IsNullOrWhiteSpace(newMessage.Text)) { @@ -128,7 +128,7 @@ public override async Task OnNewMessageAsync(ChatMessage newMessage, Cancellatio } /// - public override Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public override Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) { return Task.FromResult("The following list contains facts about the user:\n" + this._userFacts); } diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs index 0568efcc4210..e2764abd5264 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs @@ -17,7 +17,7 @@ namespace Microsoft.SemanticKernel; /// to a conversation, listen to changes in the conversation state, and provide additional context to /// the AI model in use just before invocation. /// -[Experimental("SKEXP0001")] +[Experimental("SKEXP0130")] public abstract class ConversationStateExtension { /// @@ -47,10 +47,11 @@ public virtual Task OnThreadCreatedAsync(string? threadId, CancellationToken can /// /// Inheritors can use this method to update their context based on the new message. /// + /// The ID of the thread for the new message, if the thread has an ID. /// The new message. /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been updated. - public virtual Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) + public virtual Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) { return Task.CompletedTask; } @@ -71,14 +72,14 @@ public virtual Task OnThreadDeleteAsync(string? threadId, CancellationToken canc } /// - /// Called just before the AI is invoked + /// Called just before the Model/Agent/etc. is invoked /// Implementers can load any additional context required at this time, - /// but they should also return any context that should be passed to the AI. + /// but they should also return any context that should be passed to the Model/Agent/etc. /// - /// The most recent messages that the AI is being invoked with. + /// The most recent messages that the Model/Agent/etc. is being invoked with. /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been rendered and returned. - public abstract Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default); + public abstract Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default); /// /// Called when the current conversion is temporarily suspended and any state should be saved. diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs index 5d88420f2ac2..ebea6d97f094 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs @@ -14,7 +14,7 @@ namespace Microsoft.SemanticKernel; /// /// A container class for objects that manages their lifecycle and interactions. /// -[Experimental("SKEXP0001")] +[Experimental("SKEXP0130")] public sealed class ConversationStateExtensionsManager { private readonly List _extensions = new(); @@ -107,23 +107,24 @@ public async Task OnThreadDeleteAsync(string threadId, CancellationToken cancell /// /// This method is called when a new message has been contributed to the chat by any participant. /// + /// The ID of the thread for the new message, if the thread has an ID. /// The new message. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. - public async Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) + public async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.Extensions.Select(x => x.OnNewMessageAsync(newMessage, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Extensions.Select(x => x.OnNewMessageAsync(threadId, newMessage, cancellationToken)).ToList()).ConfigureAwait(false); } /// - /// Called just before the AI is invoked + /// Called just before the Model/Agent/etc. is invoked /// - /// The most recent messages that the AI is being invoked with. + /// The most recent messages that the Model/Agent/etc. is being invoked with. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. - public async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public async Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) { - var subContexts = await Task.WhenAll(this.Extensions.Select(x => x.OnAIInvocationAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); + var subContexts = await Task.WhenAll(this.Extensions.Select(x => x.OnModelInvokeAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); return string.Join("\n", subContexts); } diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs index 2663b082c82d..1eff75a6317e 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs @@ -12,30 +12,31 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods for . /// -[Experimental("SKEXP0001")] +[Experimental("SKEXP0130")] public static class ConversationStateExtensionsManagerExtensions { /// /// This method is called when a new message has been contributed to the chat by any participant. /// /// The conversation state manager to pass the new message to. + /// The ID of the thread for the new message, if the thread has an ID. /// The new message. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. - public static Task OnNewMessageAsync(this ConversationStateExtensionsManager conversationStateExtensionsManager, ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public static Task OnNewMessageAsync(this ConversationStateExtensionsManager conversationStateExtensionsManager, string? threadId, ChatMessageContent newMessage, CancellationToken cancellationToken = default) { - return conversationStateExtensionsManager.OnNewMessageAsync(ChatCompletionServiceExtensions.ToChatMessage(newMessage), cancellationToken); + return conversationStateExtensionsManager.OnNewMessageAsync(threadId, ChatCompletionServiceExtensions.ToChatMessage(newMessage), cancellationToken); } /// - /// Called just before the AI is invoked + /// Called just before the Model/Agent/etc. is invoked /// /// The conversation state manager to call. - /// The most recent messages that the AI is being invoked with. + /// The most recent messages that the Model/Agent/etc. is being invoked with. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. - public static Task OnAIInvocationAsync(this ConversationStateExtensionsManager conversationStateExtensionsManager, ICollection newMessages, CancellationToken cancellationToken = default) + public static Task OnModelInvokeAsync(this ConversationStateExtensionsManager conversationStateExtensionsManager, ICollection newMessages, CancellationToken cancellationToken = default) { - return conversationStateExtensionsManager.OnAIInvocationAsync(newMessages.Select(ChatCompletionServiceExtensions.ToChatMessage).ToList(), cancellationToken); + return conversationStateExtensionsManager.OnModelInvokeAsync(newMessages.Select(ChatCompletionServiceExtensions.ToChatMessage).ToList(), cancellationToken); } } diff --git a/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs b/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs index d6b148d7dcec..ceb2e37f8a27 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel; /// /// Extension methods for . /// -[Experimental("SKEXP0001")] +[Experimental("SKEXP0130")] public static class ConversationStateExtensionsManagerExtensions { /// diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs index a5b804d41c7e..b71f2bff813e 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -69,7 +69,7 @@ public override Task OnThreadCreatedAsync(string? threadId, CancellationToken ca } /// - public override async Task OnNewMessageAsync(ChatMessage newMessage, CancellationToken cancellationToken = default) + public override async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) { Verify.NotNull(newMessage); @@ -86,7 +86,7 @@ await this._mem0Client.CreateMemoryAsync( } /// - public override async Task OnAIInvocationAsync(ICollection newMessages, CancellationToken cancellationToken = default) + public override async Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) { Verify.NotNull(newMessages); diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs index 6a88d711322a..137857eb15aa 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs @@ -1,10 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics.CodeAnalysis; + namespace Microsoft.SemanticKernel.Memory; /// /// Options for the . /// +[Experimental("SKEXP0130")] public class Mem0MemoryComponentOptions { /// diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs index 152716ef3df2..62c82c1ff8eb 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs @@ -46,7 +46,7 @@ public async Task OnNewMessageBaseImplementationSucceeds() var newMessage = new ChatMessage(ChatRole.User, "Hello"); // Act & Assert. - await mockExtension.Object.OnNewMessageAsync(newMessage, CancellationToken.None); + await mockExtension.Object.OnNewMessageAsync("threadId", newMessage, CancellationToken.None); } [Fact] diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs index 81690f7314c0..e1bf25f94443 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs @@ -27,10 +27,10 @@ public async Task OnNewMessageShouldConvertMessageAndInvokeRegisteredExtensionsA var newMessage = new ChatMessageContent(AuthorRole.User, "Test Message"); // Act - await manager.OnNewMessageAsync(newMessage); + await manager.OnNewMessageAsync("test-thread-id", newMessage); // Assert - extensionMock.Verify(x => x.OnNewMessageAsync(It.Is(m => m.Text == "Test Message" && m.Role == ChatRole.User), It.IsAny()), Times.Once); + extensionMock.Verify(x => x.OnNewMessageAsync("test-thread-id", It.Is(m => m.Text == "Test Message" && m.Role == ChatRole.User), It.IsAny()), Times.Once); } [Fact] @@ -48,15 +48,15 @@ public async Task OnAIInvocationShouldConvertMessagesInvokeRegisteredExtensionsA }; extensionMock - .Setup(x => x.OnAIInvocationAsync(It.IsAny>(), It.IsAny())) + .Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) .ReturnsAsync("Combined Context"); // Act - var result = await manager.OnAIInvocationAsync(messages); + var result = await manager.OnModelInvokeAsync(messages); // Assert Assert.Equal("Combined Context", result); - extensionMock.Verify(x => x.OnAIInvocationAsync(It.Is>(m => m.Count == 2), It.IsAny()), Times.Once); + extensionMock.Verify(x => x.OnModelInvokeAsync(It.Is>(m => m.Count == 2), It.IsAny()), Times.Once); } [Fact] diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs index 9036496b4df1..65149e23680b 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs @@ -115,10 +115,10 @@ public async Task OnNewMessageAsyncShouldCallOnNewMessageOnAllExtensions() manager.Add(mockExtension.Object); // Act - await manager.OnNewMessageAsync(message); + await manager.OnNewMessageAsync("test-thread-id", message); // Assert - mockExtension.Verify(x => x.OnNewMessageAsync(message, It.IsAny()), Times.Once); + mockExtension.Verify(x => x.OnNewMessageAsync("test-thread-id", message, It.IsAny()), Times.Once); } [Fact] @@ -128,9 +128,9 @@ public async Task OnAIInvocationAsyncShouldAggregateContextsFromAllExtensions() var manager = new ConversationStateExtensionsManager(); var mockExtension1 = new Mock(); var mockExtension2 = new Mock(); - mockExtension1.Setup(x => x.OnAIInvocationAsync(It.IsAny>(), It.IsAny())) + mockExtension1.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) .ReturnsAsync("Context1"); - mockExtension2.Setup(x => x.OnAIInvocationAsync(It.IsAny>(), It.IsAny())) + mockExtension2.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) .ReturnsAsync("Context2"); manager.Add(mockExtension1.Object); manager.Add(mockExtension2.Object); @@ -138,7 +138,7 @@ public async Task OnAIInvocationAsyncShouldAggregateContextsFromAllExtensions() var messages = new List(); // Act - var result = await manager.OnAIInvocationAsync(messages); + var result = await manager.OnModelInvokeAsync(messages); // Assert Assert.Equal("Context1\nContext2", result); diff --git a/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj b/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj index 8580c9a173ab..7bb2d937033f 100644 --- a/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj +++ b/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj @@ -6,7 +6,7 @@ net8.0 true false - $(NoWarn);CA2007,CA1861,IDE1006,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050,SKEXP0110,SKEXP0120 + $(NoWarn);CA2007,CA1861,IDE1006,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050,SKEXP0110,SKEXP0120,SKEXP0130 From f6c7aa0df23e0205b1d94660947b042a3ef00422 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Tue, 22 Apr 2025 14:13:44 +0100 Subject: [PATCH 29/46] Update AzureAIAgent to work with memory. --- .../src/Agents/AzureAI/Agents.AzureAI.csproj | 1 + dotnet/src/Agents/AzureAI/AzureAIAgent.cs | 32 ++++++++++-- .../AzureAIAgentWithMemoryTests.cs | 50 +++++++++++++++++++ .../AzureAIAgentFixture.cs | 2 + 4 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs diff --git a/dotnet/src/Agents/AzureAI/Agents.AzureAI.csproj b/dotnet/src/Agents/AzureAI/Agents.AzureAI.csproj index 2b0694f69986..a33dcd16812a 100644 --- a/dotnet/src/Agents/AzureAI/Agents.AzureAI.csproj +++ b/dotnet/src/Agents/AzureAI/Agents.AzureAI.csproj @@ -34,6 +34,7 @@ + diff --git a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs index 557299ae044e..5c916420de34 100644 --- a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs +++ b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs @@ -186,6 +186,18 @@ public async IAsyncEnumerable> InvokeAsync () => new AzureAIAgentThread(this.Client), cancellationToken).ConfigureAwait(false); + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await azureAIAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + azureAIAgentThread.StateExtensions.RegisterPlugins(kernel); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + + var extensionsContextOptions = options is null ? + new AzureAIAgentInvokeOptions() { AdditionalInstructions = extensionsContext } : + new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = extensionsContext }; + var invokeResults = ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), () => InternalInvokeAsync(), @@ -197,9 +209,9 @@ async IAsyncEnumerable InternalInvokeAsync() this, this.Client, azureAIAgentThread.Id!, - options?.ToAzureAIInvocationOptions(), + extensionsContextOptions?.ToAzureAIInvocationOptions(), this.Logger, - options?.Kernel ?? this.Kernel, + kernel, options?.KernelArguments, cancellationToken).ConfigureAwait(false)) { @@ -303,14 +315,26 @@ public async IAsyncEnumerable> In () => new AzureAIAgentThread(this.Client), cancellationToken).ConfigureAwait(false); + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await azureAIAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + azureAIAgentThread.StateExtensions.RegisterPlugins(kernel); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + + var extensionsContextOptions = options is null ? + new AzureAIAgentInvokeOptions() { AdditionalInstructions = extensionsContext } : + new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = extensionsContext }; + #pragma warning disable CS0618 // Type or member is obsolete // Invoke the Agent with the thread that we already added our message to. var newMessagesReceiver = new ChatHistory(); var invokeResults = this.InvokeStreamingAsync( azureAIAgentThread.Id!, - options?.ToAzureAIInvocationOptions(), + extensionsContextOptions.ToAzureAIInvocationOptions(), options?.KernelArguments, - options?.Kernel ?? this.Kernel, + kernel, newMessagesReceiver, cancellationToken); #pragma warning restore CS0618 // Type or member is obsolete diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs new file mode 100644 index 000000000000..32200a11f503 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.AzureAI; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Memory; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; + +public class AzureAIAgentWithMemoryTests() : AgentWithMemoryTests(() => new AzureAIAgentFixture()) +{ + [Fact(Skip = "For manual verification")] + public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() + { + // Arrange + var agent = this.Fixture.Agent; + + using var httpClient = new HttpClient(); + httpClient.BaseAddress = new Uri("https://api.mem0.ai"); + httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "m0-uWa1CXDyO9PpotOFMUfI9WzZOwAqJjZwH3GTKgqa"); + + var mem0Component = new Mem0MemoryComponent(httpClient, new() { UserId = "U1" }); + + var agentThread1 = new AzureAIAgentThread(this.Fixture.AgentsClient); + agentThread1.StateExtensions.Add(mem0Component); + + var agentThread2 = new AzureAIAgentThread(this.Fixture.AgentsClient); + agentThread2.StateExtensions.Add(mem0Component); + + // Act + var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); + var results1 = await asyncResults1.ToListAsync(); + + var asyncResults2 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "What is my name?."), agentThread2); + var results2 = await asyncResults2.ToListAsync(); + + // Assert + Assert.Contains("Caoimhe", results2.First().Message.Content); + + // Cleanup + await this.Fixture.DeleteThread(agentThread1); + await this.Fixture.DeleteThread(agentThread2); + } +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs index 769e3daec9d7..f2017815ffc7 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs @@ -30,6 +30,8 @@ public class AzureAIAgentFixture : AgentFixture private AzureAIAgentThread? _serviceFailingAgentThread; private AzureAIAgentThread? _createdServiceFailingAgentThread; + public AAIP.AgentsClient AgentsClient => this._agentsClient!; + public override Agent Agent => this._agent!; public override AgentThread AgentThread => this._thread!; From 72e6604ea4d7c7a33b5355434e0fd7ba5ce101d4 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 23 Apr 2025 11:48:35 +0100 Subject: [PATCH 30/46] Rename conversation state extension to part --- dotnet/src/Agents/Abstractions/AgentThread.cs | 14 +- dotnet/src/Agents/AzureAI/AzureAIAgent.cs | 8 +- dotnet/src/Agents/Core/ChatCompletionAgent.cs | 8 +- .../src/Agents/OpenAI/OpenAIAssistantAgent.cs | 8 +- .../Agents/UnitTests/Core/AgentThreadTests.cs | 50 ++--- .../AzureAIAgentWithMemoryTests.cs | 4 +- .../ChatCompletionAgentWithMemoryTests.cs | 22 +-- .../OpenAIAssistantAgentWithMemoryTests.cs.cs | 4 +- .../Memory/Memory/TextRag/TextRagComponent.cs | 2 +- .../Memory/Memory/UserFactsMemoryComponent.cs | 2 +- ...eExtension.cs => ConversationStatePart.cs} | 8 +- ...er.cs => ConversationStatePartsManager.cs} | 58 +++--- ...onversationStatePartsManagerExtensions.cs} | 18 +- ...rsationStateExtensionsManagerExtensions.cs | 23 --- ...ConversationStatePartsManagerExtensions.cs | 23 +++ .../Memory/Mem0/Mem0MemoryComponent.cs | 2 +- ...ConversationStateExtensionsManagerTests.cs | 176 ------------------ ...Tests.cs => ConversationStatePartTests.cs} | 28 +-- ...sationStatePartsManagerExtensionsTests.cs} | 34 ++-- .../ConversationStatePartsManagerTests.cs | 176 ++++++++++++++++++ 20 files changed, 334 insertions(+), 334 deletions(-) rename dotnet/src/SemanticKernel.Abstractions/Memory/{ConversationStateExtension.cs => ConversationStatePart.cs} (95%) rename dotnet/src/SemanticKernel.Abstractions/Memory/{ConversationStateExtensionsManager.cs => ConversationStatePartsManager.cs} (69%) rename dotnet/src/SemanticKernel.Abstractions/Memory/{ConversationStateExtensionsManagerExtensions.cs => ConversationStatePartsManagerExtensions.cs} (58%) delete mode 100644 dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs create mode 100644 dotnet/src/SemanticKernel.Core/Memory/ConversationStatePartsManagerExtensions.cs delete mode 100644 dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs rename dotnet/src/SemanticKernel.UnitTests/Memory/{ConversationStateExtensionTests.cs => ConversationStatePartTests.cs} (51%) rename dotnet/src/SemanticKernel.UnitTests/Memory/{ConversationStateExtensionsManagerExtensionsTests.cs => ConversationStatePartsManagerExtensionsTests.cs} (61%) create mode 100644 dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerTests.cs diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index fecc94c4f88f..01635ae0728a 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -28,10 +28,10 @@ public abstract class AgentThread public virtual bool IsDeleted { get; protected set; } = false; /// - /// Gets or sets the container for conversation state extension components that manages their lifecycle and interactions. + /// Gets or sets the container for conversation state part components that manages their lifecycle and interactions. /// [Experimental("SKEXP0110")] - public virtual ConversationStateExtensionsManager StateExtensions { get; init; } = new ConversationStateExtensionsManager(); + public virtual ConversationStatePartsManager StateParts { get; init; } = new ConversationStatePartsManager(); /// /// Called when the current conversion is temporarily suspended and any state should be saved. @@ -45,7 +45,7 @@ public abstract class AgentThread [Experimental("SKEXP0110")] public virtual Task OnSuspendAsync(CancellationToken cancellationToken = default) { - return this.StateExtensions.OnSuspendAsync(this.Id, cancellationToken); + return this.StateParts.OnSuspendAsync(this.Id, cancellationToken); } /// @@ -70,7 +70,7 @@ public virtual Task OnResumeAsync(CancellationToken cancellationToken = default) throw new InvalidOperationException("This thread cannot be resumed, since it has not been created."); } - return this.StateExtensions.OnResumeAsync(this.Id, cancellationToken); + return this.StateParts.OnResumeAsync(this.Id, cancellationToken); } /// @@ -94,7 +94,7 @@ protected internal virtual async Task CreateAsync(CancellationToken cancellation this.Id = await this.CreateInternalAsync(cancellationToken: cancellationToken).ConfigureAwait(false); #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - await this.StateExtensions.OnThreadCreatedAsync(this.Id!, cancellationToken).ConfigureAwait(false); + await this.StateParts.OnThreadCreatedAsync(this.Id!, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. } @@ -117,7 +117,7 @@ public virtual async Task DeleteAsync(CancellationToken cancellationToken = defa } #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - await this.StateExtensions.OnThreadDeleteAsync(this.Id!, cancellationToken).ConfigureAwait(false); + await this.StateParts.OnThreadDeleteAsync(this.Id!, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.DeleteInternalAsync(cancellationToken).ConfigureAwait(false); @@ -148,7 +148,7 @@ internal virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Can } #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - await this.StateExtensions.OnNewMessageAsync(this.Id, newMessage, cancellationToken).ConfigureAwait(false); + await this.StateParts.OnNewMessageAsync(this.Id, newMessage, cancellationToken).ConfigureAwait(false); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. await this.OnNewMessageInternalAsync(newMessage, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs index 5c916420de34..08ea51c7a616 100644 --- a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs +++ b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs @@ -190,8 +190,8 @@ public async IAsyncEnumerable> InvokeAsync // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await azureAIAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); - azureAIAgentThread.StateExtensions.RegisterPlugins(kernel); + var extensionsContext = await azureAIAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + azureAIAgentThread.StateParts.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContextOptions = options is null ? @@ -319,8 +319,8 @@ public async IAsyncEnumerable> In // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await azureAIAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); - azureAIAgentThread.StateExtensions.RegisterPlugins(kernel); + var extensionsContext = await azureAIAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + azureAIAgentThread.StateParts.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var extensionsContextOptions = options is null ? diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 7880a87f43ab..be5472c45429 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -78,8 +78,8 @@ public override async IAsyncEnumerable> In // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await chatHistoryAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); - chatHistoryAgentThread.StateExtensions.RegisterPlugins(kernel); + var extensionsContext = await chatHistoryAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + chatHistoryAgentThread.StateParts.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Invoke Chat Completion with the updated chat history. @@ -169,8 +169,8 @@ public override async IAsyncEnumerable> InvokeAsync // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await openAIAssistantAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); - openAIAssistantAgentThread.StateExtensions.RegisterPlugins(kernel); + var extensionsContext = await openAIAssistantAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + openAIAssistantAgentThread.StateParts.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. var invokeResults = ActivityExtensions.RunWithActivityAsync( @@ -560,8 +560,8 @@ public async IAsyncEnumerable> In // Get the conversation state extensions context contributions and register plugins from the extensions. #pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - var extensionsContext = await openAIAssistantAgentThread.StateExtensions.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); - openAIAssistantAgentThread.StateExtensions.RegisterPlugins(kernel); + var extensionsContext = await openAIAssistantAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + openAIAssistantAgentThread.StateParts.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Create options that use the RunCreationOptions from the options param if provided or diff --git a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs index 550c2ff3d475..c2810951ba31 100644 --- a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs @@ -174,94 +174,94 @@ public async Task OnResumeShouldThrowIfThreadDeletedAsync() /// /// Tests that the method - /// calls each registered extension in turn. + /// calls each registered state part in turn. /// [Fact] - public async Task OnSuspendShouldCallOnSuspendOnRegisteredExtensionsAsync() + public async Task OnSuspendShouldCallOnSuspendOnRegisteredPartsAsync() { // Arrange. var thread = new TestAgentThread(); - var mockExtension = new Mock(); - thread.StateExtensions.Add(mockExtension.Object); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); await thread.CreateAsync(); // Act. await thread.OnSuspendAsync(); // Assert. - mockExtension.Verify(x => x.OnSuspendAsync("test-thread-id", It.IsAny()), Times.Once); + mockPart.Verify(x => x.OnSuspendAsync("test-thread-id", It.IsAny()), Times.Once); } /// /// Tests that the method - /// calls each registered extension in turn. + /// calls each registered state part in turn. /// [Fact] - public async Task OnResumeShouldCallOnResumeOnRegisteredExtensionsAsync() + public async Task OnResumeShouldCallOnResumeOnRegisteredPartsAsync() { // Arrange. var thread = new TestAgentThread(); - var mockExtension = new Mock(); - thread.StateExtensions.Add(mockExtension.Object); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); await thread.CreateAsync(); // Act. await thread.OnResumeAsync(); // Assert. - mockExtension.Verify(x => x.OnResumeAsync("test-thread-id", It.IsAny()), Times.Once); + mockPart.Verify(x => x.OnResumeAsync("test-thread-id", It.IsAny()), Times.Once); } /// /// Tests that the method - /// calls each registered extension in turn. + /// calls each registered state parts in turn. /// [Fact] - public async Task CreateShouldCallOnThreadCreatedOnRegisteredExtensionsAsync() + public async Task CreateShouldCallOnThreadCreatedOnRegisteredPartsAsync() { // Arrange. var thread = new TestAgentThread(); - var mockExtension = new Mock(); - thread.StateExtensions.Add(mockExtension.Object); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); // Act. await thread.CreateAsync(); // Assert. - mockExtension.Verify(x => x.OnThreadCreatedAsync("test-thread-id", It.IsAny()), Times.Once); + mockPart.Verify(x => x.OnThreadCreatedAsync("test-thread-id", It.IsAny()), Times.Once); } /// /// Tests that the method - /// calls each registered extension in turn. + /// calls each registered state parts in turn. /// [Fact] - public async Task DeleteShouldCallOnThreadDeleteOnRegisteredExtensionsAsync() + public async Task DeleteShouldCallOnThreadDeleteOnRegisteredPartsAsync() { // Arrange. var thread = new TestAgentThread(); - var mockExtension = new Mock(); - thread.StateExtensions.Add(mockExtension.Object); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); await thread.CreateAsync(); // Act. await thread.DeleteAsync(); // Assert. - mockExtension.Verify(x => x.OnThreadDeleteAsync("test-thread-id", It.IsAny()), Times.Once); + mockPart.Verify(x => x.OnThreadDeleteAsync("test-thread-id", It.IsAny()), Times.Once); } /// /// Tests that the method - /// calls each registered extension in turn. + /// calls each registered state part in turn. /// [Fact] - public async Task OnNewMessageShouldCallOnNewMessageOnRegisteredExtensionsAsync() + public async Task OnNewMessageShouldCallOnNewMessageOnRegisteredPartsAsync() { // Arrange. var thread = new TestAgentThread(); - var mockExtension = new Mock(); - thread.StateExtensions.Add(mockExtension.Object); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); var message = new ChatMessageContent(AuthorRole.User, "Test Message."); await thread.CreateAsync(); @@ -270,7 +270,7 @@ public async Task OnNewMessageShouldCallOnNewMessageOnRegisteredExtensionsAsync( await thread.OnNewMessageAsync(message); // Assert. - mockExtension.Verify(x => x.OnNewMessageAsync("test-thread-id", It.Is(x => x.Text == "Test Message." && x.Role == ChatRole.User), It.IsAny()), Times.Once); + mockPart.Verify(x => x.OnNewMessageAsync("test-thread-id", It.Is(x => x.Text == "Test Message." && x.Role == ChatRole.User), It.IsAny()), Times.Once); } private sealed class TestAgentThread : AgentThread diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs index 32200a11f503..f4ed887da87f 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs @@ -28,10 +28,10 @@ public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() var mem0Component = new Mem0MemoryComponent(httpClient, new() { UserId = "U1" }); var agentThread1 = new AzureAIAgentThread(this.Fixture.AgentsClient); - agentThread1.StateExtensions.Add(mem0Component); + agentThread1.StateParts.Add(mem0Component); var agentThread2 = new AzureAIAgentThread(this.Fixture.AgentsClient); - agentThread2.StateExtensions.Add(mem0Component); + agentThread2.StateParts.Add(mem0Component); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs index ae18168aa22e..1d373b237e81 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs @@ -47,10 +47,10 @@ public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() var mem0Component = new Mem0MemoryComponent(httpClient, new() { UserId = "U1" }); var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateExtensions.Add(mem0Component); + agentThread1.StateParts.Add(mem0Component); var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateExtensions.Add(mem0Component); + agentThread2.StateParts.Add(mem0Component); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); @@ -75,10 +75,10 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateExtensions.Add(memoryComponent); + agentThread1.StateParts.Add(memoryComponent); var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateExtensions.Add(memoryComponent); + agentThread2.StateParts.Add(memoryComponent); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); @@ -109,14 +109,14 @@ public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserIn // Act - First invocation with first thread. var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateExtensions.Add(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + agentThread1.StateParts.Add(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); var results1 = await asyncResults1.ToListAsync(); // Act - Second invocation with second thread. var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateExtensions.Add(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); + agentThread2.StateParts.Add(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); var results2 = await asyncResults2.ToListAsync(); @@ -151,11 +151,11 @@ public virtual async Task CapturesMemoriesWhileUsingDIAsync() sp.GetRequiredService(), sp.GetRequiredService(), "Memories", "user/12345", 1536)); - builder.Services.AddTransient(); + builder.Services.AddTransient(); builder.Services.AddTransient((sp) => { var thread = new ChatHistoryAgentThread(); - thread.StateExtensions.AddFromServiceProvider(sp); + thread.StateParts.AddFromServiceProvider(sp); return thread; }); var host = builder.Build(); @@ -209,7 +209,7 @@ public virtual async Task RagComponentWithoutMatchesAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.StateExtensions.Add(ragComponent); + agentThread.StateParts.Add(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); @@ -242,7 +242,7 @@ public virtual async Task RagComponentWithMatchesAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.StateExtensions.Add(ragComponent); + agentThread.StateParts.Add(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); @@ -281,7 +281,7 @@ public virtual async Task RagComponentWithMatchesOnDemandAsync() // Act - Create a new agent thread and register the Rag component var agentThread = new ChatHistoryAgentThread(); - agentThread.StateExtensions.Add(ragComponent); + agentThread.StateParts.Add(ragComponent); // Act - Invoke the agent with a question var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread, new() { KernelArguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }) }); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs index 54ac5c508ee3..b3ece08315c5 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs @@ -20,10 +20,10 @@ public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); var agentThread1 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); - agentThread1.StateExtensions.Add(memoryComponent); + agentThread1.StateParts.Add(memoryComponent); var agentThread2 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); - agentThread2.StateExtensions.Add(memoryComponent); + agentThread2.StateParts.Add(memoryComponent); // Act var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs index 3ec32794e899..29f6a6175e46 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs +++ b/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs @@ -15,7 +15,7 @@ namespace Microsoft.SemanticKernel.Memory; /// /// A component that does a search based on any messages that the AI is invoked with and injects the results into the AI invocation context. /// -public class TextRagComponent : ConversationStateExtension +public class TextRagComponent : ConversationStatePart { private readonly ITextSearch _textSearch; diff --git a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs index 76b82aafb055..dc3180fcbe14 100644 --- a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -13,7 +13,7 @@ namespace Microsoft.SemanticKernel.Agents.Memory; /// A memory component that can retrieve, maintain and store user facts that /// are learned from the user's interactions with the agent. /// -public class UserFactsMemoryComponent : ConversationStateExtension +public class UserFactsMemoryComponent : ConversationStatePart { private readonly Kernel _kernel; private readonly TextMemoryStore _textMemoryStore; diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePart.cs similarity index 95% rename from dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs rename to dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePart.cs index e2764abd5264..eaa533a121ea 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtension.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePart.cs @@ -10,18 +10,18 @@ namespace Microsoft.SemanticKernel; /// -/// Base class for all conversation state extensions. +/// Base class for all conversation state parts. /// /// -/// A conversation state extension is a component that can be used to store additional state related +/// A conversation state part is a component that can be used to store additional state related /// to a conversation, listen to changes in the conversation state, and provide additional context to /// the AI model in use just before invocation. /// [Experimental("SKEXP0130")] -public abstract class ConversationStateExtension +public abstract class ConversationStatePart { /// - /// Gets the list of AI functions that this extension component exposes + /// Gets the list of AI functions that this component exposes /// and which should be used by the consuming AI when using this component. /// public virtual IReadOnlyCollection AIFunctions => Array.Empty(); diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManager.cs similarity index 69% rename from dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs rename to dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManager.cs index ebea6d97f094..f4c62b4b8443 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManager.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManager.cs @@ -12,39 +12,39 @@ namespace Microsoft.SemanticKernel; /// -/// A container class for objects that manages their lifecycle and interactions. +/// A container class for objects that manages their lifecycle and interactions. /// [Experimental("SKEXP0130")] -public sealed class ConversationStateExtensionsManager +public sealed class ConversationStatePartsManager { - private readonly List _extensions = new(); + private readonly List _parts = new(); private List? _currentAIFunctions = null; /// - /// Gets the list of registered conversation state extensions. + /// Gets the list of registered conversation state parts. /// - public IReadOnlyList Extensions => this._extensions; + public IReadOnlyList Parts => this._parts; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// - public ConversationStateExtensionsManager() + public ConversationStatePartsManager() { } /// - /// Initializes a new instance of the class with the specified conversation state extensions. + /// Initializes a new instance of the class with the specified conversation state parts. /// - /// The conversation state extensions to add to the manager. - public ConversationStateExtensionsManager(IEnumerable conversationtStateExtensions) + /// The conversation state parts to add to the manager. + public ConversationStatePartsManager(IEnumerable conversationtStateExtensions) { - this._extensions.AddRange(conversationtStateExtensions); + this._parts.AddRange(conversationtStateExtensions); } /// - /// Gets the list of AI functions that all contained extension component expose - /// and which should be used by the consuming AI when using these components. + /// Gets the list of AI functions that all contained parts expose + /// and which should be used by the consuming AI when using these parts. /// public IReadOnlyCollection AIFunctions { @@ -52,7 +52,7 @@ public IReadOnlyCollection AIFunctions { if (this._currentAIFunctions == null) { - this._currentAIFunctions = this.Extensions.SelectMany(ConversationStateExtensions => ConversationStateExtensions.AIFunctions).ToList(); + this._currentAIFunctions = this.Parts.SelectMany(conversationStateParts => conversationStateParts.AIFunctions).ToList(); } return this._currentAIFunctions; @@ -60,24 +60,24 @@ public IReadOnlyCollection AIFunctions } /// - /// Adds a new conversation state extension. + /// Adds a new conversation state part. /// - /// The conversation state extension to register. - public void Add(ConversationStateExtension conversationtStateExtension) + /// The conversation state part to register. + public void Add(ConversationStatePart conversationtStatePart) { - this._extensions.Add(conversationtStateExtension); + this._parts.Add(conversationtStatePart); this._currentAIFunctions = null; } /// - /// Adds all conversation state extensions registered on the provided dependency injection service provider. + /// Adds all conversation state parts registered on the provided dependency injection service provider. /// - /// The dependency injection service provider to read conversation state extensions from. + /// The dependency injection service provider to read conversation state parts from. public void AddFromServiceProvider(IServiceProvider serviceProvider) { - foreach (var extension in serviceProvider.GetServices()) + foreach (var part in serviceProvider.GetServices()) { - this.Add(extension); + this.Add(part); } this._currentAIFunctions = null; } @@ -90,7 +90,7 @@ public void AddFromServiceProvider(IServiceProvider serviceProvider) /// A task that represents the asynchronous operation. public async Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.Extensions.Select(x => x.OnThreadCreatedAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Parts.Select(x => x.OnThreadCreatedAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -101,7 +101,7 @@ public async Task OnThreadCreatedAsync(string? threadId, CancellationToken cance /// A task that represents the asynchronous operation. public async Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.Extensions.Select(x => x.OnThreadDeleteAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Parts.Select(x => x.OnThreadDeleteAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -113,7 +113,7 @@ public async Task OnThreadDeleteAsync(string threadId, CancellationToken cancell /// A task that represents the asynchronous operation. public async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.Extensions.Select(x => x.OnNewMessageAsync(threadId, newMessage, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Parts.Select(x => x.OnNewMessageAsync(threadId, newMessage, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -121,10 +121,10 @@ public async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, Ca /// /// The most recent messages that the Model/Agent/etc. is being invoked with. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. + /// A task that represents the asynchronous operation, containing the combined context from all conversation state parts. public async Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) { - var subContexts = await Task.WhenAll(this.Extensions.Select(x => x.OnModelInvokeAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); + var subContexts = await Task.WhenAll(this.Parts.Select(x => x.OnModelInvokeAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); return string.Join("\n", subContexts); } @@ -140,7 +140,7 @@ public async Task OnModelInvokeAsync(ICollection newMessage /// public async Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.Extensions.Select(x => x.OnSuspendAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Parts.Select(x => x.OnSuspendAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } /// @@ -155,6 +155,6 @@ public async Task OnSuspendAsync(string? threadId, CancellationToken cancellatio /// public async Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) { - await Task.WhenAll(this.Extensions.Select(x => x.OnResumeAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + await Task.WhenAll(this.Parts.Select(x => x.OnResumeAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManagerExtensions.cs similarity index 58% rename from dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs rename to dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManagerExtensions.cs index 1eff75a6317e..e64de1d0b180 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStateExtensionsManagerExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManagerExtensions.cs @@ -10,33 +10,33 @@ namespace Microsoft.SemanticKernel; /// -/// Extension methods for . +/// Extension methods for . /// [Experimental("SKEXP0130")] -public static class ConversationStateExtensionsManagerExtensions +public static class ConversationStatePartsManagerExtensions { /// /// This method is called when a new message has been contributed to the chat by any participant. /// - /// The conversation state manager to pass the new message to. + /// The conversation state manager to pass the new message to. /// The ID of the thread for the new message, if the thread has an ID. /// The new message. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. - public static Task OnNewMessageAsync(this ConversationStateExtensionsManager conversationStateExtensionsManager, string? threadId, ChatMessageContent newMessage, CancellationToken cancellationToken = default) + public static Task OnNewMessageAsync(this ConversationStatePartsManager conversationStatePartsManager, string? threadId, ChatMessageContent newMessage, CancellationToken cancellationToken = default) { - return conversationStateExtensionsManager.OnNewMessageAsync(threadId, ChatCompletionServiceExtensions.ToChatMessage(newMessage), cancellationToken); + return conversationStatePartsManager.OnNewMessageAsync(threadId, ChatCompletionServiceExtensions.ToChatMessage(newMessage), cancellationToken); } /// /// Called just before the Model/Agent/etc. is invoked /// - /// The conversation state manager to call. + /// The conversation state manager to call. /// The most recent messages that the Model/Agent/etc. is being invoked with. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation, containing the combined context from all conversation state extensions. - public static Task OnModelInvokeAsync(this ConversationStateExtensionsManager conversationStateExtensionsManager, ICollection newMessages, CancellationToken cancellationToken = default) + /// A task that represents the asynchronous operation, containing the combined context from all conversation state parts. + public static Task OnModelInvokeAsync(this ConversationStatePartsManager conversationStatePartsManager, ICollection newMessages, CancellationToken cancellationToken = default) { - return conversationStateExtensionsManager.OnModelInvokeAsync(newMessages.Select(ChatCompletionServiceExtensions.ToChatMessage).ToList(), cancellationToken); + return conversationStatePartsManager.OnModelInvokeAsync(newMessages.Select(ChatCompletionServiceExtensions.ToChatMessage).ToList(), cancellationToken); } } diff --git a/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs b/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs deleted file mode 100644 index ceb2e37f8a27..000000000000 --- a/dotnet/src/SemanticKernel.Core/Memory/ConversationStateExtensionsManagerExtensions.cs +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; -using System.Linq; - -namespace Microsoft.SemanticKernel; - -/// -/// Extension methods for . -/// -[Experimental("SKEXP0130")] -public static class ConversationStateExtensionsManagerExtensions -{ - /// - /// Registers plugins required by all conversation state extensions contained by this manager on the provided . - /// - /// The conversation state manager to get plugins from. - /// The kernel to register the plugins on. - public static void RegisterPlugins(this ConversationStateExtensionsManager conversationStateExtensionsManager, Kernel kernel) - { - kernel.Plugins.AddFromFunctions("Tools", conversationStateExtensionsManager.AIFunctions.Select(x => x.AsKernelFunction())); - } -} diff --git a/dotnet/src/SemanticKernel.Core/Memory/ConversationStatePartsManagerExtensions.cs b/dotnet/src/SemanticKernel.Core/Memory/ConversationStatePartsManagerExtensions.cs new file mode 100644 index 000000000000..888c56c758d9 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/ConversationStatePartsManagerExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for . +/// +[Experimental("SKEXP0130")] +public static class ConversationStatePartsManagerExtensions +{ + /// + /// Registers plugins required by all conversation state parts contained by this manager on the provided . + /// + /// The conversation state manager to get plugins from. + /// The kernel to register the plugins on. + public static void RegisterPlugins(this ConversationStatePartsManager conversationStatePartsManager, Kernel kernel) + { + kernel.Plugins.AddFromFunctions("Tools", conversationStatePartsManager.AIFunctions.Select(x => x.AsKernelFunction())); + } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs index b71f2bff813e..58fbdf19d73f 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -16,7 +16,7 @@ namespace Microsoft.SemanticKernel.Memory; /// information about the user. It is also able to retrieve this information and add it to the AI invocation context. /// [Experimental("SKEXP0130")] -public class Mem0MemoryComponent : ConversationStateExtension +public class Mem0MemoryComponent : ConversationStatePart { private readonly string? _applicationId; private readonly string? _agentId; diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs deleted file mode 100644 index 65149e23680b..000000000000 --- a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerTests.cs +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.SemanticKernel; -using Moq; -using Xunit; - -namespace SemanticKernel.UnitTests.Memory; - -/// -/// Contains tests for the class. -/// -public class ConversationStateExtensionsManagerTests -{ - [Fact] - public void ConstructorShouldInitializeEmptyExtensionsList() - { - // Act - var manager = new ConversationStateExtensionsManager(); - - // Assert - Assert.NotNull(manager.Extensions); - Assert.Empty(manager.Extensions); - } - - [Fact] - public void ConstructorShouldInitializeWithProvidedExtensions() - { - // Arrange - var mockExtension = new Mock(); - - // Act - var manager = new ConversationStateExtensionsManager(new[] { mockExtension.Object }); - - // Assert - Assert.Single(manager.Extensions); - Assert.Contains(mockExtension.Object, manager.Extensions); - } - - [Fact] - public void AddShouldRegisterNewExtension() - { - // Arrange - var manager = new ConversationStateExtensionsManager(); - var mockExtension = new Mock(); - - // Act - manager.Add(mockExtension.Object); - - // Assert - Assert.Single(manager.Extensions); - Assert.Contains(mockExtension.Object, manager.Extensions); - } - - [Fact] - public void AddFromServiceProviderShouldRegisterExtensionsFromServiceProvider() - { - // Arrange - var serviceCollection = new ServiceCollection(); - var mockExtension = new Mock(); - serviceCollection.AddSingleton(mockExtension.Object); - var serviceProvider = serviceCollection.BuildServiceProvider(); - - var manager = new ConversationStateExtensionsManager(); - - // Act - manager.AddFromServiceProvider(serviceProvider); - - // Assert - Assert.Single(manager.Extensions); - Assert.Contains(mockExtension.Object, manager.Extensions); - } - - [Fact] - public async Task OnThreadCreatedAsyncShouldCallOnThreadCreatedOnAllExtensions() - { - // Arrange - var manager = new ConversationStateExtensionsManager(); - var mockExtension = new Mock(); - manager.Add(mockExtension.Object); - - // Act - await manager.OnThreadCreatedAsync("test-thread-id"); - - // Assert - mockExtension.Verify(x => x.OnThreadCreatedAsync("test-thread-id", It.IsAny()), Times.Once); - } - - [Fact] - public async Task OnThreadDeleteAsyncShouldCallOnThreadDeleteOnAllExtensions() - { - // Arrange - var manager = new ConversationStateExtensionsManager(); - var mockExtension = new Mock(); - manager.Add(mockExtension.Object); - - // Act - await manager.OnThreadDeleteAsync("test-thread-id"); - - // Assert - mockExtension.Verify(x => x.OnThreadDeleteAsync("test-thread-id", It.IsAny()), Times.Once); - } - - [Fact] - public async Task OnNewMessageAsyncShouldCallOnNewMessageOnAllExtensions() - { - // Arrange - var manager = new ConversationStateExtensionsManager(); - var mockExtension = new Mock(); - var message = new ChatMessage(ChatRole.User, "Hello"); - manager.Add(mockExtension.Object); - - // Act - await manager.OnNewMessageAsync("test-thread-id", message); - - // Assert - mockExtension.Verify(x => x.OnNewMessageAsync("test-thread-id", message, It.IsAny()), Times.Once); - } - - [Fact] - public async Task OnAIInvocationAsyncShouldAggregateContextsFromAllExtensions() - { - // Arrange - var manager = new ConversationStateExtensionsManager(); - var mockExtension1 = new Mock(); - var mockExtension2 = new Mock(); - mockExtension1.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) - .ReturnsAsync("Context1"); - mockExtension2.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) - .ReturnsAsync("Context2"); - manager.Add(mockExtension1.Object); - manager.Add(mockExtension2.Object); - - var messages = new List(); - - // Act - var result = await manager.OnModelInvokeAsync(messages); - - // Assert - Assert.Equal("Context1\nContext2", result); - } - - [Fact] - public async Task OnSuspendAsyncShouldCallOnSuspendOnAllExtensions() - { - // Arrange - var manager = new ConversationStateExtensionsManager(); - var mockExtension = new Mock(); - manager.Add(mockExtension.Object); - - // Act - await manager.OnSuspendAsync("test-thread-id"); - - // Assert - mockExtension.Verify(x => x.OnSuspendAsync("test-thread-id", It.IsAny()), Times.Once); - } - - [Fact] - public async Task OnResumeAsyncShouldCallOnResumeOnAllExtensions() - { - // Arrange - var manager = new ConversationStateExtensionsManager(); - var mockExtension = new Mock(); - manager.Add(mockExtension.Object); - - // Act - await manager.OnResumeAsync("test-thread-id"); - - // Assert - mockExtension.Verify(x => x.OnResumeAsync("test-thread-id", It.IsAny()), Times.Once); - } -} diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartTests.cs similarity index 51% rename from dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs rename to dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartTests.cs index 62c82c1ff8eb..4ee007b2d82e 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartTests.cs @@ -10,18 +10,18 @@ namespace SemanticKernel.UnitTests.Memory; /// -/// Contains tests for the class. +/// Contains tests for the class. /// -public class ConversationStateExtensionTests +public class ConversationStatePartTests { [Fact] public void AIFunctionsBaseImplementationIsEmpty() { // Arrange. - var mockExtension = new Mock() { CallBase = true }; + var mockPart = new Mock() { CallBase = true }; // Act. - var functions = mockExtension.Object.AIFunctions; + var functions = mockPart.Object.AIFunctions; // Assert. Assert.NotNull(functions); @@ -32,50 +32,50 @@ public void AIFunctionsBaseImplementationIsEmpty() public async Task OnThreadCreatedBaseImplementationSucceeds() { // Arrange. - var mockExtension = new Mock() { CallBase = true }; + var mockPart = new Mock() { CallBase = true }; // Act & Assert. - await mockExtension.Object.OnThreadCreatedAsync("threadId", CancellationToken.None); + await mockPart.Object.OnThreadCreatedAsync("threadId", CancellationToken.None); } [Fact] public async Task OnNewMessageBaseImplementationSucceeds() { // Arrange. - var mockExtension = new Mock() { CallBase = true }; + var mockPart = new Mock() { CallBase = true }; var newMessage = new ChatMessage(ChatRole.User, "Hello"); // Act & Assert. - await mockExtension.Object.OnNewMessageAsync("threadId", newMessage, CancellationToken.None); + await mockPart.Object.OnNewMessageAsync("threadId", newMessage, CancellationToken.None); } [Fact] public async Task OnThreadDeleteBaseImplementationSucceeds() { // Arrange. - var mockExtension = new Mock() { CallBase = true }; + var mockPart = new Mock() { CallBase = true }; // Act & Assert. - await mockExtension.Object.OnThreadDeleteAsync("threadId", CancellationToken.None); + await mockPart.Object.OnThreadDeleteAsync("threadId", CancellationToken.None); } [Fact] public async Task OnSuspendBaseImplementationSucceeds() { // Arrange. - var mockExtension = new Mock() { CallBase = true }; + var mockPart = new Mock() { CallBase = true }; // Act & Assert. - await mockExtension.Object.OnSuspendAsync("threadId", CancellationToken.None); + await mockPart.Object.OnSuspendAsync("threadId", CancellationToken.None); } [Fact] public async Task OnResumeBaseImplementationSucceeds() { // Arrange. - var mockExtension = new Mock() { CallBase = true }; + var mockPart = new Mock() { CallBase = true }; // Act & Assert. - await mockExtension.Object.OnResumeAsync("threadId", CancellationToken.None); + await mockPart.Object.OnResumeAsync("threadId", CancellationToken.None); } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerExtensionsTests.cs similarity index 61% rename from dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs rename to dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerExtensionsTests.cs index e1bf25f94443..b074b44d3b09 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStateExtensionsManagerExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerExtensionsTests.cs @@ -12,17 +12,17 @@ namespace SemanticKernel.UnitTests.Memory; /// -/// Tests for the ConversationStateExtensionsManagerExtensions class. +/// Tests for the ConversationStatePartsManagerExtensions class. /// -public class ConversationStateExtensionsManagerExtensionsTests +public class ConversationStatePartsManagerExtensionsTests { [Fact] - public async Task OnNewMessageShouldConvertMessageAndInvokeRegisteredExtensionsAsync() + public async Task OnNewMessageShouldConvertMessageAndInvokeRegisteredPartsAsync() { // Arrange - var manager = new ConversationStateExtensionsManager(); - var extensionMock = new Mock(); - manager.Add(extensionMock.Object); + var manager = new ConversationStatePartsManager(); + var partMock = new Mock(); + manager.Add(partMock.Object); var newMessage = new ChatMessageContent(AuthorRole.User, "Test Message"); @@ -30,16 +30,16 @@ public async Task OnNewMessageShouldConvertMessageAndInvokeRegisteredExtensionsA await manager.OnNewMessageAsync("test-thread-id", newMessage); // Assert - extensionMock.Verify(x => x.OnNewMessageAsync("test-thread-id", It.Is(m => m.Text == "Test Message" && m.Role == ChatRole.User), It.IsAny()), Times.Once); + partMock.Verify(x => x.OnNewMessageAsync("test-thread-id", It.Is(m => m.Text == "Test Message" && m.Role == ChatRole.User), It.IsAny()), Times.Once); } [Fact] - public async Task OnAIInvocationShouldConvertMessagesInvokeRegisteredExtensionsAsync() + public async Task OnAIInvocationShouldConvertMessagesInvokeRegisteredPartsAsync() { // Arrange - var manager = new ConversationStateExtensionsManager(); - var extensionMock = new Mock(); - manager.Add(extensionMock.Object); + var manager = new ConversationStatePartsManager(); + var partMock = new Mock(); + manager.Add(partMock.Object); var messages = new List { @@ -47,7 +47,7 @@ public async Task OnAIInvocationShouldConvertMessagesInvokeRegisteredExtensionsA new(AuthorRole.Assistant, "Message 2") }; - extensionMock + partMock .Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) .ReturnsAsync("Combined Context"); @@ -56,7 +56,7 @@ public async Task OnAIInvocationShouldConvertMessagesInvokeRegisteredExtensionsA // Assert Assert.Equal("Combined Context", result); - extensionMock.Verify(x => x.OnModelInvokeAsync(It.Is>(m => m.Count == 2), It.IsAny()), Times.Once); + partMock.Verify(x => x.OnModelInvokeAsync(It.Is>(m => m.Count == 2), It.IsAny()), Times.Once); } [Fact] @@ -64,13 +64,13 @@ public void RegisterPluginsShouldConvertAIFunctionsAndRegisterAsPlugins() { // Arrange var kernel = new Kernel(); - var manager = new ConversationStateExtensionsManager(); - var extensionMock = new Mock(); + var manager = new ConversationStatePartsManager(); + var partMock = new Mock(); var aiFunctionMock = AIFunctionFactory.Create(() => "Hello", "TestFunction"); - extensionMock + partMock .Setup(x => x.AIFunctions) .Returns(new List { aiFunctionMock }); - manager.Add(extensionMock.Object); + manager.Add(partMock.Object); // Act manager.RegisterPlugins(kernel); diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerTests.cs new file mode 100644 index 000000000000..e76b13eda2e2 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerTests.cs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for the class. +/// +public class ConversationStatePartsManagerTests +{ + [Fact] + public void ConstructorShouldInitializeEmptyPartsList() + { + // Act + var manager = new ConversationStatePartsManager(); + + // Assert + Assert.NotNull(manager.Parts); + Assert.Empty(manager.Parts); + } + + [Fact] + public void ConstructorShouldInitializeWithProvidedParts() + { + // Arrange + var mockPart = new Mock(); + + // Act + var manager = new ConversationStatePartsManager(new[] { mockPart.Object }); + + // Assert + Assert.Single(manager.Parts); + Assert.Contains(mockPart.Object, manager.Parts); + } + + [Fact] + public void AddShouldRegisterNewPart() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + + // Act + manager.Add(mockPart.Object); + + // Assert + Assert.Single(manager.Parts); + Assert.Contains(mockPart.Object, manager.Parts); + } + + [Fact] + public void AddFromServiceProviderShouldRegisterPartsFromServiceProvider() + { + // Arrange + var serviceCollection = new ServiceCollection(); + var mockPart = new Mock(); + serviceCollection.AddSingleton(mockPart.Object); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var manager = new ConversationStatePartsManager(); + + // Act + manager.AddFromServiceProvider(serviceProvider); + + // Assert + Assert.Single(manager.Parts); + Assert.Contains(mockPart.Object, manager.Parts); + } + + [Fact] + public async Task OnThreadCreatedAsyncShouldCallOnThreadCreatedOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + manager.Add(mockPart.Object); + + // Act + await manager.OnThreadCreatedAsync("test-thread-id"); + + // Assert + mockPart.Verify(x => x.OnThreadCreatedAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnThreadDeleteAsyncShouldCallOnThreadDeleteOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + manager.Add(mockPart.Object); + + // Act + await manager.OnThreadDeleteAsync("test-thread-id"); + + // Assert + mockPart.Verify(x => x.OnThreadDeleteAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnNewMessageAsyncShouldCallOnNewMessageOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + var message = new ChatMessage(ChatRole.User, "Hello"); + manager.Add(mockPart.Object); + + // Act + await manager.OnNewMessageAsync("test-thread-id", message); + + // Assert + mockPart.Verify(x => x.OnNewMessageAsync("test-thread-id", message, It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnAIInvocationAsyncShouldAggregateContextsFromAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart1 = new Mock(); + var mockPart2 = new Mock(); + mockPart1.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync("Context1"); + mockPart2.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync("Context2"); + manager.Add(mockPart1.Object); + manager.Add(mockPart2.Object); + + var messages = new List(); + + // Act + var result = await manager.OnModelInvokeAsync(messages); + + // Assert + Assert.Equal("Context1\nContext2", result); + } + + [Fact] + public async Task OnSuspendAsyncShouldCallOnSuspendOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + manager.Add(mockPart.Object); + + // Act + await manager.OnSuspendAsync("test-thread-id"); + + // Assert + mockPart.Verify(x => x.OnSuspendAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnResumeAsyncShouldCallOnResumeOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + manager.Add(mockPart.Object); + + // Act + await manager.OnResumeAsync("test-thread-id"); + + // Assert + mockPart.Verify(x => x.OnResumeAsync("test-thread-id", It.IsAny()), Times.Once); + } +} From 49e8ec219cebcc80b4b917aa3ac2da5afed682fb Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:34:11 +0100 Subject: [PATCH 31/46] Update integration tests to be general purpose --- .../AgentFixture.cs | 2 + .../AgentWithMemoryTests.cs | 28 -- .../ChatCompletionAgentWithMemoryTests.cs | 336 ------------------ .../AgentWithStatePartTests.cs | 96 +++++ .../AzureAIAgentWithStatePartTests.cs} | 4 +- .../ChatCompletionAgentWithStatePartTests.cs | 7 + ...nAIAssistantAgentWithStatePartTests.cs.cs} | 4 +- .../AzureAIAgentFixture.cs | 5 + .../BedrockAgentFixture.cs | 5 + .../ChatCompletionAgentFixture.cs | 5 + .../OpenAIAssistantAgentFixture.cs | 5 + 11 files changed, 129 insertions(+), 368 deletions(-) delete mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AgentWithMemoryTests.cs delete mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs create mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs rename dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/{AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs => AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs} (91%) create mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/ChatCompletionAgentWithStatePartTests.cs rename dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/{AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs => AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs.cs} (88%) diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs index 8be11475493c..89ea56fa4e27 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs @@ -23,6 +23,8 @@ public abstract class AgentFixture : IAsyncLifetime public abstract AgentThread CreatedServiceFailingAgentThread { get; } + public abstract AgentThread GetNewThread(); + public abstract Task GetChatHistory(); public abstract Task DeleteThread(AgentThread thread); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AgentWithMemoryTests.cs deleted file mode 100644 index 6327a64ee307..000000000000 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AgentWithMemoryTests.cs +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Threading.Tasks; -using Xunit; - -namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; - -public abstract class AgentWithMemoryTests(Func createAgentFixture) : IAsyncLifetime - where TFixture : AgentFixture -{ -#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. - private TFixture _agentFixture; -#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. - - protected TFixture Fixture => this._agentFixture; - - public Task InitializeAsync() - { - this._agentFixture = createAgentFixture(); - return this._agentFixture.InitializeAsync(); - } - - public Task DisposeAsync() - { - return this._agentFixture.DisposeAsync(); - } -} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs deleted file mode 100644 index 1d373b237e81..000000000000 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/ChatCompletionAgentWithMemoryTests.cs +++ /dev/null @@ -1,336 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Net.Http; -using System.Net.Http.Headers; -using System.Threading.Tasks; -using Azure.Identity; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Agents; -using Microsoft.SemanticKernel.Agents.Memory; -using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Connectors.AzureOpenAI; -using Microsoft.SemanticKernel.Connectors.InMemory; -using Microsoft.SemanticKernel.Embeddings; -using Microsoft.SemanticKernel.Memory; -using Microsoft.SemanticKernel.Memory.TextRag; -using SemanticKernel.IntegrationTests.TestSettings; -using Xunit; - -namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; - -public class ChatCompletionAgentWithMemoryTests() : AgentWithMemoryTests(() => new ChatCompletionAgentFixture()) -{ - private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() - .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) - .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) - .AddEnvironmentVariables() - .AddUserSecrets() - .Build(); - - [Fact(Skip = "For manual verification")] - public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() - { - // Arrange - var agent = this.Fixture.Agent; - - using var httpClient = new HttpClient(); - httpClient.BaseAddress = new Uri("https://api.mem0.ai"); - httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", ""); - - var mem0Component = new Mem0MemoryComponent(httpClient, new() { UserId = "U1" }); - - var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateParts.Add(mem0Component); - - var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateParts.Add(mem0Component); - - // Act - var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); - var results1 = await asyncResults1.ToListAsync(); - - var asyncResults2 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "What is my name?."), agentThread2); - var results2 = await asyncResults2.ToListAsync(); - - // Assert - Assert.Contains("Caoimhe", results2.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread1); - await this.Fixture.DeleteThread(agentThread2); - } - - [Fact] - public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() - { - // Arrange - var agent = this.Fixture.Agent; - var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); - - var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateParts.Add(memoryComponent); - - var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateParts.Add(memoryComponent); - - // Act - var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); - var results1 = await asyncResults1.ToListAsync(); - - var asyncResults2 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "What is my name?."), agentThread2); - var results2 = await asyncResults2.ToListAsync(); - - // Assert - Assert.Contains("Caoimhe", results2.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread1); - await this.Fixture.DeleteThread(agentThread2); - } - - [Fact] - public virtual async Task MemoryComponentCapturesMemoriesInVectorStoreFromUserInputAsync() - { - // Arrange - var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); - - var vectorStore = new InMemoryVectorStore(); - var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); - using var textMemoryStore = new VectorDataTextMemoryStore(vectorStore, textEmbeddingService, "Memories", "user/12345", 1536); - - var agent = this.Fixture.Agent; - - // Act - First invocation with first thread. - var agentThread1 = new ChatHistoryAgentThread(); - agentThread1.StateParts.Add(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); - - var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); - var results1 = await asyncResults1.ToListAsync(); - - // Act - Second invocation with second thread. - var agentThread2 = new ChatHistoryAgentThread(); - agentThread2.StateParts.Add(new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore)); - - var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); - var results2 = await asyncResults2.ToListAsync(); - - // Assert - Assert.Contains("Caoimhe", results2.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread1); - await this.Fixture.DeleteThread(agentThread2); - } - - [Fact] - public virtual async Task CapturesMemoriesWhileUsingDIAsync() - { - var chatConfig = this._configuration.GetSection("AzureOpenAI").Get()!; - var embeddingConfig = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); - - // Arrange - Setup DI container. - var builder = Host.CreateEmptyApplicationBuilder(settings: null); - builder.Services.AddKernel(); - builder.Services.AddInMemoryVectorStore(); - builder.Services.AddAzureOpenAIChatCompletion( - deploymentName: chatConfig.ChatDeploymentName!, - endpoint: chatConfig.Endpoint, - credentials: new AzureCliCredential()); - builder.Services.AddAzureOpenAITextEmbeddingGeneration( - embeddingConfig!.EmbeddingModelId, - embeddingConfig.Endpoint, - new AzureCliCredential()); - builder.Services.AddKeyedTransient>("UserFactsStore", (sp, _) => new VectorDataTextMemoryStore( - sp.GetRequiredService(), - sp.GetRequiredService(), - "Memories", "user/12345", 1536)); - builder.Services.AddTransient(); - builder.Services.AddTransient((sp) => - { - var thread = new ChatHistoryAgentThread(); - thread.StateParts.AddFromServiceProvider(sp); - return thread; - }); - var host = builder.Build(); - - // Arrange - Create agent. - var agent = new ChatCompletionAgent() - { - Kernel = host.Services.GetRequiredService(), - Instructions = "You are a helpful assistant.", - }; - - // Act - First invocation - var agentThread1 = host.Services.GetRequiredService(); - - var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); - var results1 = await asyncResults1.ToListAsync(); - - // Act - Call suspend on the thread, so that all memory components attached to it, save their state. - await agentThread1.OnSuspendAsync(default); - - // Act - Second invocation - var agentThread2 = host.Services.GetRequiredService(); - - var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); - var results2 = await asyncResults2.ToListAsync(); - - // Assert - Assert.Contains("Caoimhe", results2.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread1); - await this.Fixture.DeleteThread(agentThread2); - } - - [Fact] - public virtual async Task RagComponentWithoutMatchesAsync() - { - // Arrange - Create Embedding Service - var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); - var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); - - // Arrange - Create Vector Store and Rag Store/Component - var vectorStore = new InMemoryVectorStore(); - using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "FinancialData", 1536, "group/g1"); - var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); - - // Arrange - Upsert documents into the Rag Store - await ragStore.UpsertDocumentsAsync(GetSampleDocuments()); - - var agent = this.Fixture.Agent; - - // Act - Create a new agent thread and register the Rag component - var agentThread = new ChatHistoryAgentThread(); - agentThread.StateParts.Add(ragComponent); - - // Act - Invoke the agent with a question - var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); - var results1 = await asyncResults1.ToListAsync(); - - // Assert - Check if the response does not contain the expected value from the database because - // we filtered by group/g1 which doesn't include the required document. - Assert.DoesNotContain("174", results1.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread); - } - - [Fact] - public virtual async Task RagComponentWithMatchesAsync() - { - // Arrange - Create Embedding Service - var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); - var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); - - // Arrange - Create Vector Store and Rag Store/Component - var vectorStore = new InMemoryVectorStore(); - using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "FinancialData", 1536, "group/g2"); - var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); - - // Arrange - Upsert documents into the Rag Store - await ragStore.UpsertDocumentsAsync(GetSampleDocuments()); - - var agent = this.Fixture.Agent; - - // Act - Create a new agent thread and register the Rag component - var agentThread = new ChatHistoryAgentThread(); - agentThread.StateParts.Add(ragComponent); - - // Act - Invoke the agent with a question - var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); - var results1 = await asyncResults1.ToListAsync(); - - // Assert - Check if the response contains the expected value from the database. - Assert.Contains("174", results1.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread); - } - - [Fact] - public virtual async Task RagComponentWithMatchesOnDemandAsync() - { - // Arrange - Create Embedding Service - var config = this._configuration.GetRequiredSection("AzureOpenAIEmbeddings").Get(); - var textEmbeddingService = new AzureOpenAITextEmbeddingGenerationService(config!.EmbeddingModelId, config.Endpoint, new AzureCliCredential()); - - // Arrange - Create Vector Store and Rag Store/Component - var vectorStore = new InMemoryVectorStore(); - using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "FinancialData", 1536, "group/g2"); - var ragComponent = new TextRagComponent( - ragStore, - new() - { - SearchTime = TextRagComponentOptions.TextRagSearchTime.ViaPlugin, - PluginSearchFunctionName = "SearchCorporateData", - PluginSearchFunctionDescription = "RAG Search over dataset containing financial data and company information about various companies." - }); - - // Arrange - Upsert documents into the Rag Store - await ragStore.UpsertDocumentsAsync(GetSampleDocuments()); - - var agent = this.Fixture.Agent; - - // Act - Create a new agent thread and register the Rag component - var agentThread = new ChatHistoryAgentThread(); - agentThread.StateParts.Add(ragComponent); - - // Act - Invoke the agent with a question - var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread, new() { KernelArguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }) }); - var results1 = await asyncResults1.ToListAsync(); - - // Assert - Check if the response contains the expected value from the database. - Assert.Contains("174", results1.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread); - } - - private static IEnumerable GetSampleDocuments() - { - yield return new TextRagDocument("The financial results of Contoso Corp for 2024 is as follows:\nIncome EUR 154 000 000\nExpenses EUR 142 000 000") - { - SourceName = "Contoso 2024 Financial Report", - SourceReference = "https://www.consoso.com/reports/2024.pdf", - Namespaces = ["group/g1"] - }; - yield return new TextRagDocument("The financial results of Contoso Corp for 2023 is as follows:\nIncome EUR 174 000 000\nExpenses EUR 152 000 000") - { - SourceName = "Contoso 2023 Financial Report", - SourceReference = "https://www.consoso.com/reports/2023.pdf", - Namespaces = ["group/g2"] - }; - yield return new TextRagDocument("The financial results of Contoso Corp for 2022 is as follows:\nIncome EUR 184 000 000\nExpenses EUR 162 000 000") - { - SourceName = "Contoso 2022 Financial Report", - SourceReference = "https://www.consoso.com/reports/2022.pdf", - Namespaces = ["group/g2"] - }; - yield return new TextRagDocument("The Contoso Corporation is a multinational business with its headquarters in Paris. The company is a manufacturing, sales, and support organization with more than 100,000 products.") - { - SourceName = "About Contoso", - SourceReference = "https://www.consoso.com/about-us", - Namespaces = ["group/g2"] - }; - yield return new TextRagDocument("The financial results of AdventureWorks for 2021 is as follows:\nIncome USD 223 000 000\nExpenses USD 210 000 000") - { - SourceName = "AdventureWorks 2021 Financial Report", - SourceReference = "https://www.adventure-works.com/reports/2021.pdf", - Namespaces = ["group/g1", "group/g2"] - }; - yield return new TextRagDocument("AdventureWorks is a large American business that specializaes in adventure parks and family entertainment.") - { - SourceName = "About AdventureWorks", - SourceReference = "https://www.adventure-works.com/about-us", - Namespaces = ["group/g1", "group/g2"] - }; - } -} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs new file mode 100644 index 000000000000..c2934ebfc9a0 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Moq; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public abstract class AgentWithStatePartTests(Func createAgentFixture) : IAsyncLifetime + where TFixture : AgentFixture +{ +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + private TFixture _agentFixture; +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + + protected TFixture Fixture => this._agentFixture; + + [Fact] + public virtual async Task StatePartReceivesMessagesFromAgentAsync() + { + // Arrange + var mockStatePart = new Mock() { CallBase = true }; + mockStatePart.Setup(x => x.OnNewMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())); + + var agent = this.Fixture.Agent; + + var agentThread = this.Fixture.GetNewThread(); + + try + { + agentThread.StateParts.Add(mockStatePart.Object); + + // Act + var inputMessage = "What is the capital of France?"; + var asyncResults1 = agent.InvokeAsync(inputMessage, agentThread); + var result = await asyncResults1.FirstAsync(); + + // Assert + Assert.Contains("Paris", result.Message.Content); + mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny(), It.Is(cm => cm.Text == inputMessage), It.IsAny()), Times.Once); + mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny(), It.Is(cm => cm.Text == result.Message.Content), It.IsAny()), Times.Once); + } + finally + { + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + } + + [Fact] + public virtual async Task StatePartPreInvokeStateIsUsedByAgentAsync() + { + // Arrange + var mockStatePart = new Mock() { CallBase = true }; + mockStatePart.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())).ReturnsAsync("User name is Caoimhe"); + + var agent = this.Fixture.Agent; + + var agentThread = this.Fixture.GetNewThread(); + + try + { + agentThread.StateParts.Add(mockStatePart.Object); + + // Act + var inputMessage = "What is my name?."; + var asyncResults1 = agent.InvokeAsync(inputMessage, agentThread); + var result = await asyncResults1.FirstAsync(); + + // Assert + Assert.Contains("Caoimhe", result.Message.Content); + } + finally + { + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + } + + public Task InitializeAsync() + { + this._agentFixture = createAgentFixture(); + return this._agentFixture.InitializeAsync(); + } + + public Task DisposeAsync() + { + return this._agentFixture.DisposeAsync(); + } +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs similarity index 91% rename from dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs rename to dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs index f4ed887da87f..32ba5519673c 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/AzureAIAgentWithMemoryTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs @@ -11,9 +11,9 @@ using Microsoft.SemanticKernel.Memory; using Xunit; -namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; -public class AzureAIAgentWithMemoryTests() : AgentWithMemoryTests(() => new AzureAIAgentFixture()) +public class AzureAIAgentWithStatePartTests() : AgentWithStatePartTests(() => new AzureAIAgentFixture()) { [Fact(Skip = "For manual verification")] public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/ChatCompletionAgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/ChatCompletionAgentWithStatePartTests.cs new file mode 100644 index 000000000000..838043376c9b --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/ChatCompletionAgentWithStatePartTests.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public class ChatCompletionAgentWithStatePartTests() : AgentWithStatePartTests(() => new ChatCompletionAgentFixture()) +{ +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs.cs similarity index 88% rename from dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs rename to dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs.cs index b3ece08315c5..cd683f597122 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithMemoryConformance/OpenAIAssistantAgentWithMemoryTests.cs.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs.cs @@ -8,9 +8,9 @@ using Microsoft.SemanticKernel.ChatCompletion; using Xunit; -namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithMemoryConformance; +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; -public class OpenAIAssistantAgentWithMemoryTests() : AgentWithMemoryTests(() => new OpenAIAssistantAgentFixture()) +public class OpenAIAssistantAgentWithMemoryTests() : AgentWithStatePartTests(() => new OpenAIAssistantAgentFixture()) { [Fact] public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs index f2017815ffc7..4ceb4d0605ef 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs @@ -42,6 +42,11 @@ public class AzureAIAgentFixture : AgentFixture public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!; + public override AgentThread GetNewThread() + { + return new AzureAIAgentThread(this._agentsClient!); + } + public override async Task GetChatHistory() { var chatHistory = new ChatHistory(); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs index 1dd85e085cce..957b8f76b283 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs @@ -47,6 +47,11 @@ internal sealed class BedrockAgentFixture : AgentFixture, IAsyncDisposable public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!; + public override AgentThread GetNewThread() + { + return new BedrockAgentThread(this._runtimeClient); + } + public override async Task DeleteThread(AgentThread thread) { await this._runtimeClient!.EndSessionAsync(new EndSessionRequest() { SessionIdentifier = thread.Id }); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs index 53734e986a90..ea1f9c3d2af4 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs @@ -36,6 +36,11 @@ public class ChatCompletionAgentFixture : AgentFixture public override AgentThread CreatedServiceFailingAgentThread => null!; + public override AgentThread GetNewThread() + { + return new ChatHistoryAgentThread(); + } + public override async Task GetChatHistory() { var chatHistory = new ChatHistory(); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs index 53d4cc61e972..2cf64795d79d 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs @@ -46,6 +46,11 @@ public class OpenAIAssistantAgentFixture : AgentFixture public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!; + public override AgentThread GetNewThread() + { + return new OpenAIAssistantAgentThread(this._assistantClient!); + } + public override async Task GetChatHistory() { var chatHistory = new ChatHistory(); From 46e9959c9e3bc94f96623a4e4eeb0bc5a842d7dc Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:19:18 +0100 Subject: [PATCH 32/46] use string.concat for instructions concatenation. --- dotnet/src/Agents/Core/ChatCompletionAgent.cs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index be5472c45429..af9a1f20a286 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -101,7 +101,9 @@ public override async IAsyncEnumerable> In }, options?.KernelArguments, kernel, - options?.AdditionalInstructions == null ? extensionsContext : options.AdditionalInstructions + Environment.NewLine + Environment.NewLine + extensionsContext, + options?.AdditionalInstructions == null ? + extensionsContext : + string.Concat(options.AdditionalInstructions, Environment.NewLine, Environment.NewLine, extensionsContext), cancellationToken); // Notify the thread of new messages and return them to the caller. @@ -193,7 +195,9 @@ public override async IAsyncEnumerable Date: Wed, 23 Apr 2025 17:02:08 +0100 Subject: [PATCH 33/46] Add mem0 tests and fix bugs. --- .../Memory/Mem0MemoryComponentTests.cs | 87 +++++++++++++++++++ .../TestSettings/Mem0Configuration.cs | 13 +++ dotnet/src/IntegrationTests/testsettings.json | 4 + .../Memory/Mem0/Mem0Client.cs | 4 +- .../Memory/Mem0/Mem0MemoryComponent.cs | 16 ++-- 5 files changed, 115 insertions(+), 9 deletions(-) create mode 100644 dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs create mode 100644 dotnet/src/IntegrationTests/TestSettings/Mem0Configuration.cs diff --git a/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs b/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs new file mode 100644 index 000000000000..9213769fb07b --- /dev/null +++ b/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.Memory; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Memory; + +/// +/// Contains tests for the class. +/// +public class Mem0MemoryComponentTests : IDisposable +{ + // If null, all tests will be enabled + private const string SkipReason = "Requires a Mem0 service configured"; + + private readonly Mem0MemoryComponent _sut; + private readonly HttpClient _httpClient; + private bool _disposedValue; + + public Mem0MemoryComponentTests() + { + IConfigurationRoot configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + var mem0Settings = configuration.GetRequiredSection("Mem0").Get()!; + + this._httpClient = new HttpClient(); + this._httpClient.BaseAddress = new Uri(mem0Settings.ServiceUri); + this._httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", mem0Settings.ApiKey); + this._sut = new Mem0MemoryComponent(this._httpClient, new() { ThreadId = "test-thread-id", UserId = "test-user-id", ScopeToThread = true }); + } + + [Fact(Skip = SkipReason)] + public async Task Mem0ComponentCanAddAndRetrieveMemoriesAsync() + { + // Arrange + var question = new ChatMessage(ChatRole.User, "What is my name?"); + var input = new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe."); + + await this._sut.ClearStoredUserFactsAsync(); + var answerBeforeAdding = await this._sut.OnModelInvokeAsync([question]); + Assert.DoesNotContain("Caoimhe", answerBeforeAdding); + + // Act + await this._sut.OnNewMessageAsync("test-thread-id", input); + + await this._sut.OnNewMessageAsync("test-thread-id", question); + var answerAfterAdding = await this._sut.OnModelInvokeAsync([question]); + + await this._sut.ClearStoredUserFactsAsync(); + var answerAfterClearing = await this._sut.OnModelInvokeAsync([question]); + + // Assert + Assert.Contains("Caoimhe", answerAfterAdding); + Assert.DoesNotContain("Caoimhe", answerAfterClearing); + } + + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._httpClient.Dispose(); + } + + this._disposedValue = true; + } + } + + public void Dispose() + { + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } +} diff --git a/dotnet/src/IntegrationTests/TestSettings/Mem0Configuration.cs b/dotnet/src/IntegrationTests/TestSettings/Mem0Configuration.cs new file mode 100644 index 000000000000..a89f68e45c0a --- /dev/null +++ b/dotnet/src/IntegrationTests/TestSettings/Mem0Configuration.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace SemanticKernel.IntegrationTests.TestSettings; + +[SuppressMessage("Performance", "CA1812:Internal class that is apparently never instantiated", + Justification = "Configuration classes are instantiated through IConfiguration.")] +internal sealed class Mem0Configuration +{ + public string ServiceUri { get; init; } = string.Empty; + public string ApiKey { get; init; } = string.Empty; +} diff --git a/dotnet/src/IntegrationTests/testsettings.json b/dotnet/src/IntegrationTests/testsettings.json index ef269839e1c3..287fe122e2ad 100644 --- a/dotnet/src/IntegrationTests/testsettings.json +++ b/dotnet/src/IntegrationTests/testsettings.json @@ -126,5 +126,9 @@ "BedrockAgent": { "AgentResourceRoleArn": "", "FoundationModel": "anthropic.claude-3-haiku-20240307-v1:0" + }, + "Mem0": { + "ServiceUri": "https://api.mem0.ai", + "ApiKey": "" } } \ No newline at end of file diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs index 096f293e409d..285b76a45c58 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs @@ -94,11 +94,11 @@ public async Task ClearMemoryAsync(string? applicationId, string? agentId, strin string[] paramNames = ["app_id", "agent_id", "run_id", "user_id"]; // Build query string. - var querystringParams = new string?[4] { applicationId, userId, agentId, threadId } + var querystringParams = new string?[4] { applicationId, agentId, threadId, userId } .Select((param, index) => string.IsNullOrWhiteSpace(param) ? null : $"{paramNames[index]}={param}") .Where(x => x != null); var queryString = string.Join("&", querystringParams); - var clearMemoryUrl = new Uri($"/v1/memories?{queryString}", UriKind.Relative); + var clearMemoryUrl = new Uri($"/v1/memories/?{queryString}", UriKind.Relative); // Delete. var responseMessage = await this._httpClient.DeleteAsync(clearMemoryUrl).ConfigureAwait(false); diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs index 58fbdf19d73f..62fc85239106 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics.CodeAnalysis; @@ -34,8 +35,8 @@ public class Mem0MemoryComponent : ConversationStatePart /// The HTTP client used for making requests. /// Options for configuring the component. /// - /// The base address of the required mem0 service and any authentication headers should be set on the - /// already when provided here. E.g.: + /// The base address of the required mem0 service, and any authentication headers, should be set on the + /// already, when passed as a parameter here. E.g.: /// /// using var httpClient = new HttpClient(); /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); @@ -72,6 +73,7 @@ public override Task OnThreadCreatedAsync(string? threadId, CancellationToken ca public override async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) { Verify.NotNull(newMessage); + this._threadId ??= threadId; if (newMessage.Role == ChatRole.User && !string.IsNullOrWhiteSpace(newMessage.Text)) { @@ -91,7 +93,7 @@ public override async Task OnModelInvokeAsync(ICollection n Verify.NotNull(newMessages); string inputText = string.Join( - "\n", + Environment.NewLine, newMessages. Where(m => m is not null && !string.IsNullOrWhiteSpace(m.Text)). Select(m => m.Text)); @@ -103,8 +105,8 @@ public override async Task OnModelInvokeAsync(ICollection n this._userId, inputText).ConfigureAwait(false); - var userInformation = string.Join("\n", memories); - return "The following list contains facts about the user:\n" + userInformation; + var userInformation = string.Join(Environment.NewLine, memories); + return string.Join(Environment.NewLine, "The following list contains facts about the user:", userInformation); } /// @@ -116,8 +118,8 @@ public async Task ClearStoredUserFactsAsync() { await this._mem0Client.ClearMemoryAsync( this._applicationId, - this._userId, this._agentId, - this._scopeToThread ? this._threadId : null).ConfigureAwait(false); + this._scopeToThread ? this._threadId : null, + this._userId).ConfigureAwait(false); } } From 41467eddf8917b8945b5112cc6411c67c96e4c75 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 23 Apr 2025 17:39:56 +0100 Subject: [PATCH 34/46] Add exclude from coverage for mem0 component --- dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs | 2 ++ .../src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs | 1 + .../Memory/Mem0/Mem0MemoryComponentOptions.cs | 1 + 3 files changed, 4 insertions(+) diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs index 285b76a45c58..1dfaa5a72d76 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Text; @@ -14,6 +15,7 @@ namespace Microsoft.SemanticKernel.Memory; /// /// Client for the Mem0 memory service. /// +[ExcludeFromCodeCoverage] // Tested via integration tests. internal sealed class Mem0Client { private static readonly Uri s_searchUri = new("/v1/memories/search/", UriKind.Relative); diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs index 62fc85239106..bb95d89648aa 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -17,6 +17,7 @@ namespace Microsoft.SemanticKernel.Memory; /// information about the user. It is also able to retrieve this information and add it to the AI invocation context. /// [Experimental("SKEXP0130")] +[ExcludeFromCodeCoverage] // Tested via integration tests. public class Mem0MemoryComponent : ConversationStatePart { private readonly string? _applicationId; diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs index 137857eb15aa..972db66c667d 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs @@ -8,6 +8,7 @@ namespace Microsoft.SemanticKernel.Memory; /// Options for the . /// [Experimental("SKEXP0130")] +[ExcludeFromCodeCoverage] // Tested via integration tests. public class Mem0MemoryComponentOptions { /// From 9a7990eaf72412c60f1481ffdf0770ad6721aa24 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:56:42 +0100 Subject: [PATCH 35/46] Add more tests for mem0 and improve thread handling. --- .../Memory/Mem0MemoryComponentTests.cs | 73 ++++++++++++++++--- .../Memory/Mem0/Mem0Client.cs | 8 +- .../Memory/Mem0/Mem0MemoryComponent.cs | 56 +++++++++++--- .../Memory/Mem0/Mem0MemoryComponentOptions.cs | 12 ++- 4 files changed, 122 insertions(+), 27 deletions(-) diff --git a/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs b/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs index 9213769fb07b..c1f4c90cdfdf 100644 --- a/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs +++ b/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs @@ -20,7 +20,6 @@ public class Mem0MemoryComponentTests : IDisposable // If null, all tests will be enabled private const string SkipReason = "Requires a Mem0 service configured"; - private readonly Mem0MemoryComponent _sut; private readonly HttpClient _httpClient; private bool _disposedValue; @@ -38,7 +37,6 @@ public Mem0MemoryComponentTests() this._httpClient = new HttpClient(); this._httpClient.BaseAddress = new Uri(mem0Settings.ServiceUri); this._httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", mem0Settings.ApiKey); - this._sut = new Mem0MemoryComponent(this._httpClient, new() { ThreadId = "test-thread-id", UserId = "test-user-id", ScopeToThread = true }); } [Fact(Skip = SkipReason)] @@ -48,24 +46,81 @@ public async Task Mem0ComponentCanAddAndRetrieveMemoriesAsync() var question = new ChatMessage(ChatRole.User, "What is my name?"); var input = new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe."); - await this._sut.ClearStoredUserFactsAsync(); - var answerBeforeAdding = await this._sut.OnModelInvokeAsync([question]); + var sut = new Mem0MemoryComponent(this._httpClient, new() { ThreadId = "test-thread-id", UserId = "test-user-id", ScopeToPerOperationThreadId = true }); + + await sut.ClearStoredUserFactsAsync(); + var answerBeforeAdding = await sut.OnModelInvokeAsync([question]); Assert.DoesNotContain("Caoimhe", answerBeforeAdding); // Act - await this._sut.OnNewMessageAsync("test-thread-id", input); + await sut.OnNewMessageAsync("test-thread-id", input); - await this._sut.OnNewMessageAsync("test-thread-id", question); - var answerAfterAdding = await this._sut.OnModelInvokeAsync([question]); + await sut.OnNewMessageAsync("test-thread-id", question); + var answerAfterAdding = await sut.OnModelInvokeAsync([question]); - await this._sut.ClearStoredUserFactsAsync(); - var answerAfterClearing = await this._sut.OnModelInvokeAsync([question]); + await sut.ClearStoredUserFactsAsync(); + var answerAfterClearing = await sut.OnModelInvokeAsync([question]); // Assert Assert.Contains("Caoimhe", answerAfterAdding); Assert.DoesNotContain("Caoimhe", answerAfterClearing); } + [Fact(Skip = SkipReason)] + public async Task Mem0ComponentDoesNotLeakMessagesAcrossScopesAsync() + { + // Arrange + var question = new ChatMessage(ChatRole.User, "What is your name?"); + var input = new ChatMessage(ChatRole.Assistant, "I'm an AI tutor with a personality. My name is Caoimhe."); + + var sut1 = new Mem0MemoryComponent(this._httpClient, new() { AgentId = "test-agent-id-1" }); + var sut2 = new Mem0MemoryComponent(this._httpClient, new() { AgentId = "test-agent-id-2" }); + + await sut1.ClearStoredUserFactsAsync(); + await sut2.ClearStoredUserFactsAsync(); + + var answerBeforeAdding1 = await sut1.OnModelInvokeAsync([question]); + var answerBeforeAdding2 = await sut2.OnModelInvokeAsync([question]); + Assert.DoesNotContain("Caoimhe", answerBeforeAdding1); + Assert.DoesNotContain("Caoimhe", answerBeforeAdding2); + + // Act + await sut1.OnNewMessageAsync("test-thread-id-1", input); + var answerAfterAdding = await sut1.OnModelInvokeAsync([question]); + + await sut2.OnNewMessageAsync("test-thread-id-2", question); + var answerAfterAddingOnOtherScope = await sut2.OnModelInvokeAsync([question]); + + // Assert + Assert.Contains("Caoimhe", answerAfterAdding); + Assert.DoesNotContain("Caoimhe", answerAfterAddingOnOtherScope); + + // Cleanup. + await sut1.ClearStoredUserFactsAsync(); + await sut2.ClearStoredUserFactsAsync(); + } + + [Fact(Skip = SkipReason)] + public async Task Mem0ComponentDoesNotWorkWithMultiplePerOperationThreadsAsync() + { + // Arrange + var input = new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe."); + + var sut = new Mem0MemoryComponent(this._httpClient, new() { UserId = "test-user-id", ScopeToPerOperationThreadId = true }); + + await sut.ClearStoredUserFactsAsync(); + + // Act & Assert + await sut.OnThreadCreatedAsync("test-thread-id-1"); + await Assert.ThrowsAsync(() => sut.OnThreadCreatedAsync("test-thread-id-2")); + + await sut.OnNewMessageAsync("test-thread-id-1", input); + await Assert.ThrowsAsync(() => sut.OnNewMessageAsync("test-thread-id-2", input)); + + // Cleanup + await sut.ClearStoredUserFactsAsync(); + } + protected virtual void Dispose(bool disposing) { if (!this._disposedValue) diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs index 1dfaa5a72d76..6b5d797eaf18 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs @@ -78,7 +78,7 @@ public async Task CreateMemoryAsync(string? applicationId, string? agentId, stri UserId = userId, Messages = new[] { - new CreateMemoryMemory + new CreateMemoryMessage { Content = messageContent, Role = messageRole @@ -118,10 +118,10 @@ internal sealed class CreateMemoryRequest [JsonPropertyName("user_id")] public string? UserId { get; set; } [JsonPropertyName("messages")] - public CreateMemoryMemory[] Messages { get; set; } = []; + public CreateMemoryMessage[] Messages { get; set; } = []; } - internal sealed class CreateMemoryMemory + internal sealed class CreateMemoryMessage { [JsonPropertyName("content")] public string Content { get; set; } = string.Empty; @@ -165,7 +165,7 @@ internal sealed class SearchResponseItem public string? AppId { get; set; } [JsonPropertyName("agent_id")] public string AgentId { get; set; } = string.Empty; - [JsonPropertyName("run_id")] + [JsonPropertyName("session_id")] public string RunId { get; set; } = string.Empty; } } diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs index bb95d89648aa..6b235d4a2c72 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -16,15 +16,35 @@ namespace Microsoft.SemanticKernel.Memory; /// A component that listens to messages added to the conversation thread, and automatically captures /// information about the user. It is also able to retrieve this information and add it to the AI invocation context. /// +/// +/// +/// Mem0 allows memories to be stored under one or more optional scopes: application, agent, thread, and user. +/// At least one scope must always be provided. +/// +/// +/// There are some special considerations when using thread as a scope. +/// A thread id may not be available at the time that this component is instantiated. +/// It is therefore possible to provide no thread id when instantiating this class and instead set +/// to . +/// The component will then capture a thread id when a thread is created or when messages are received +/// and use this thread id to scope the memories in mem0. +/// +/// +/// Note that this component will keep the current thread id in a private field for the duration of +/// the component's lifetime, and therefore using the component with multiple threads, with +/// set to is not supported. +/// +/// [Experimental("SKEXP0130")] [ExcludeFromCodeCoverage] // Tested via integration tests. -public class Mem0MemoryComponent : ConversationStatePart +public sealed class Mem0MemoryComponent : ConversationStatePart { private readonly string? _applicationId; private readonly string? _agentId; - private string? _threadId; + private readonly string? _threadId; + private string? _perOperationThreadId; private readonly string? _userId; - private readonly bool _scopeToThread; + private readonly bool _scopeToPerOperationThreadId; private readonly AIFunction[] _aIFunctions; @@ -53,7 +73,7 @@ public Mem0MemoryComponent(HttpClient httpClient, Mem0MemoryComponentOptions? op this._agentId = options?.AgentId; this._threadId = options?.ThreadId; this._userId = options?.UserId; - this._scopeToThread = options?.ScopeToThread ?? false; + this._scopeToPerOperationThreadId = options?.ScopeToPerOperationThreadId ?? false; this._aIFunctions = [AIFunctionFactory.Create(this.ClearStoredUserFactsAsync)]; @@ -66,7 +86,9 @@ public Mem0MemoryComponent(HttpClient httpClient, Mem0MemoryComponentOptions? op /// public override Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) { - this._threadId ??= threadId; + this.ValidatePerOperationThreadId(threadId); + + this._perOperationThreadId ??= threadId; return Task.CompletedTask; } @@ -74,14 +96,16 @@ public override Task OnThreadCreatedAsync(string? threadId, CancellationToken ca public override async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) { Verify.NotNull(newMessage); - this._threadId ??= threadId; + this.ValidatePerOperationThreadId(threadId); + + this._perOperationThreadId ??= threadId; - if (newMessage.Role == ChatRole.User && !string.IsNullOrWhiteSpace(newMessage.Text)) + if (!string.IsNullOrWhiteSpace(newMessage.Text)) { await this._mem0Client.CreateMemoryAsync( this._applicationId, this._agentId, - this._scopeToThread ? this._threadId : null, + this._scopeToPerOperationThreadId ? this._perOperationThreadId : this._threadId, this._userId, newMessage.Text, newMessage.Role.Value).ConfigureAwait(false); @@ -102,7 +126,7 @@ public override async Task OnModelInvokeAsync(ICollection n var memories = await this._mem0Client.SearchAsync( this._applicationId, this._agentId, - this._scopeToThread ? this._threadId : null, + this._scopeToPerOperationThreadId ? this._perOperationThreadId : this._threadId, this._userId, inputText).ConfigureAwait(false); @@ -120,7 +144,19 @@ public async Task ClearStoredUserFactsAsync() await this._mem0Client.ClearMemoryAsync( this._applicationId, this._agentId, - this._scopeToThread ? this._threadId : null, + this._scopeToPerOperationThreadId ? this._perOperationThreadId : this._threadId, this._userId).ConfigureAwait(false); } + + /// + /// Validate that we are not receiving a new thread id when the component has already received one before. + /// + /// The new thread id. + private void ValidatePerOperationThreadId(string? threadId) + { + if (this._scopeToPerOperationThreadId && !string.IsNullOrWhiteSpace(threadId) && this._perOperationThreadId != null && threadId != this._perOperationThreadId) + { + throw new InvalidOperationException($"The {nameof(Mem0MemoryComponent)} can only be used with one thread at a time when {nameof(Mem0MemoryComponentOptions.ScopeToPerOperationThreadId)} is set to true."); + } + } } diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs index 972db66c667d..4ef04f363adb 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Memory; /// [Experimental("SKEXP0130")] [ExcludeFromCodeCoverage] // Tested via integration tests. -public class Mem0MemoryComponentOptions +public sealed class Mem0MemoryComponentOptions { /// /// Gets or sets an optional ID for the application to scope memories to. @@ -44,10 +44,14 @@ public class Mem0MemoryComponentOptions public string? UserId { get; init; } /// - /// Gets or sets a value indicating whether the scope of the memories is limited to the current thread. + /// Gets or sets a value indicating whether memories should be scoped to the thread id provided on a per operation basis. /// /// - /// If false, will be ignored, and any thread ids passed into the methods of the will also be ignored. + /// This setting is useful if the thread id is not known when the is instantiated, but + /// per thread scoping is desired. + /// If , and is not set, there will be no per thread scoping. + /// if , and is set, will be used for scoping. + /// If , the thread id will be set to the thread id of the current operation, regardless of the value of . /// - public bool ScopeToThread { get; init; } = false; + public bool ScopeToPerOperationThreadId { get; init; } = false; } From 058a7dabbf787c3141797a76da22720588663b2a Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 24 Apr 2025 16:04:34 +0100 Subject: [PATCH 36/46] Improve AzureAI and OpenAIAssistant Agent tests --- .../Internal/AssistantRunOptionsFactory.cs | 5 ++- .../AssistantRunOptionsFactoryTests.cs | 33 ++++++++++++++ .../AzureAIAgentWithStatePartTests.cs | 43 ------------------- .../OpenAIAssistantAgentWithStatePartTests.cs | 7 +++ ...enAIAssistantAgentWithStatePartTests.cs.cs | 42 ------------------ 5 files changed, 44 insertions(+), 86 deletions(-) create mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs delete mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs.cs diff --git a/dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs b/dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs index 243778db83ac..1fc85f793a50 100644 --- a/dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs +++ b/dotnet/src/Agents/OpenAI/Internal/AssistantRunOptionsFactory.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using OpenAI.Assistants; @@ -12,7 +13,9 @@ internal static class AssistantRunOptionsFactory { public static RunCreationOptions GenerateOptions(RunCreationOptions? defaultOptions, string? agentInstructions, RunCreationOptions? invocationOptions, string? threadExtensionsContext) { - var additionalInstructions = (invocationOptions?.AdditionalInstructions ?? defaultOptions?.AdditionalInstructions) + threadExtensionsContext; + var additionalInstructions = string.Concat( + (invocationOptions?.AdditionalInstructions ?? defaultOptions?.AdditionalInstructions), + string.IsNullOrWhiteSpace(threadExtensionsContext) ? string.Empty : string.Concat(Environment.NewLine, Environment.NewLine, threadExtensionsContext)); RunCreationOptions runOptions = new() diff --git a/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs b/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs index 8777a29b6cfc..8e0cee8a382c 100644 --- a/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs +++ b/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents.OpenAI; using Microsoft.SemanticKernel.Agents.OpenAI.Internal; @@ -192,4 +193,36 @@ public void AssistantRunOptionsFactoryExecutionOptionsMaxTokensTest() Assert.Equal(1024, options.MaxInputTokenCount); Assert.Equal(4096, options.MaxOutputTokenCount); } + + /// + /// Verify run options generation with metadata. + /// + [Fact] + public void AssistantRunOptionsFactoryAdditionalInstructionsTest() + { + // Arrange + RunCreationOptions defaultOptions = + new() + { + ModelOverride = "gpt-anything", + Temperature = 0.5F, + MaxOutputTokenCount = 4096, + MaxInputTokenCount = 1024, + AdditionalInstructions = "DefaultInstructions" + }; + + RunCreationOptions invocationOptions = + new() + { + AdditionalInstructions = "OverrideInstructions", + }; + + // Act + RunCreationOptions optionsWithOverride = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions, threadExtensionsContext: "Context"); + RunCreationOptions optionsWithoutOverride = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null, threadExtensionsContext: "Context"); + + // Assert + Assert.Equal($"OverrideInstructions{Environment.NewLine}{Environment.NewLine}Context", optionsWithOverride.AdditionalInstructions); + Assert.Equal($"DefaultInstructions{Environment.NewLine}{Environment.NewLine}Context", optionsWithoutOverride.AdditionalInstructions); + } } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs index 32ba5519673c..59c3fbd0bf35 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs @@ -1,50 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Linq; -using System.Net.Http; -using System.Net.Http.Headers; -using System.Threading.Tasks; -using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Agents.AzureAI; -using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Memory; -using Xunit; - namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; public class AzureAIAgentWithStatePartTests() : AgentWithStatePartTests(() => new AzureAIAgentFixture()) { - [Fact(Skip = "For manual verification")] - public virtual async Task Mem0ComponentCapturesMemoriesFromUserInputAsync() - { - // Arrange - var agent = this.Fixture.Agent; - - using var httpClient = new HttpClient(); - httpClient.BaseAddress = new Uri("https://api.mem0.ai"); - httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "m0-uWa1CXDyO9PpotOFMUfI9WzZOwAqJjZwH3GTKgqa"); - - var mem0Component = new Mem0MemoryComponent(httpClient, new() { UserId = "U1" }); - - var agentThread1 = new AzureAIAgentThread(this.Fixture.AgentsClient); - agentThread1.StateParts.Add(mem0Component); - - var agentThread2 = new AzureAIAgentThread(this.Fixture.AgentsClient); - agentThread2.StateParts.Add(mem0Component); - - // Act - var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); - var results1 = await asyncResults1.ToListAsync(); - - var asyncResults2 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "What is my name?."), agentThread2); - var results2 = await asyncResults2.ToListAsync(); - - // Assert - Assert.Contains("Caoimhe", results2.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread1); - await this.Fixture.DeleteThread(agentThread2); - } } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs new file mode 100644 index 000000000000..62e6ab81309f --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public class OpenAIAssistantAgentWithStatePartTests() : AgentWithStatePartTests(() => new OpenAIAssistantAgentFixture()) +{ +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs.cs deleted file mode 100644 index cd683f597122..000000000000 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs.cs +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Linq; -using System.Threading.Tasks; -using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Agents.Memory; -using Microsoft.SemanticKernel.Agents.OpenAI; -using Microsoft.SemanticKernel.ChatCompletion; -using Xunit; - -namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; - -public class OpenAIAssistantAgentWithMemoryTests() : AgentWithStatePartTests(() => new OpenAIAssistantAgentFixture()) -{ - [Fact] - public virtual async Task MemoryComponentCapturesMemoriesFromUserInputAsync() - { - // Arrange - var agent = this.Fixture.Agent; - var memoryComponent = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel); - - var agentThread1 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); - agentThread1.StateParts.Add(memoryComponent); - - var agentThread2 = new OpenAIAssistantAgentThread(this.Fixture.AssistantClient); - agentThread2.StateParts.Add(memoryComponent); - - // Act - var asyncResults1 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hello, my name is Caoimhe."), agentThread1); - var results1 = await asyncResults1.ToListAsync(); - - var asyncResults2 = agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "What is my name?."), agentThread2); - var results2 = await asyncResults2.ToListAsync(); - - // Assert - Assert.Contains("Caoimhe", results2.First().Message.Content); - - // Cleanup - await this.Fixture.DeleteThread(agentThread1); - await this.Fixture.DeleteThread(agentThread2); - } -} From cd62b9f4a9b95edb29907fd5deba8cef7e8ad615 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 25 Apr 2025 12:03:24 +0100 Subject: [PATCH 37/46] Add mem0 unit tests and some small improvements. --- .../Memory/Mem0/Mem0Client.cs | 2 - .../Memory/Mem0/Mem0MemoryComponent.cs | 10 +- .../Memory/Mem0/Mem0MemoryComponentOptions.cs | 10 +- .../Memory/Mem0MemoryComponentTests.cs | 203 ++++++++++++++++++ 4 files changed, 220 insertions(+), 5 deletions(-) create mode 100644 dotnet/src/SemanticKernel.UnitTests/Memory/Mem0MemoryComponentTests.cs diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs index 6b5d797eaf18..12fbb1551dda 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Http; using System.Text; @@ -15,7 +14,6 @@ namespace Microsoft.SemanticKernel.Memory; /// /// Client for the Mem0 memory service. /// -[ExcludeFromCodeCoverage] // Tested via integration tests. internal sealed class Mem0Client { private static readonly Uri s_searchUri = new("/v1/memories/search/", UriKind.Relative); diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs index 6b235d4a2c72..60198313571f 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -36,7 +36,6 @@ namespace Microsoft.SemanticKernel.Memory; /// /// [Experimental("SKEXP0130")] -[ExcludeFromCodeCoverage] // Tested via integration tests. public sealed class Mem0MemoryComponent : ConversationStatePart { private readonly string? _applicationId; @@ -45,6 +44,7 @@ public sealed class Mem0MemoryComponent : ConversationStatePart private string? _perOperationThreadId; private readonly string? _userId; private readonly bool _scopeToPerOperationThreadId; + private readonly string _contextPrompt; private readonly AIFunction[] _aIFunctions; @@ -69,11 +69,17 @@ public Mem0MemoryComponent(HttpClient httpClient, Mem0MemoryComponentOptions? op { Verify.NotNull(httpClient); + if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsolutePath)) + { + throw new ArgumentException("The BaseAddress of the provided httpClient parameter must be set.", nameof(httpClient)); + } + this._applicationId = options?.ApplicationId; this._agentId = options?.AgentId; this._threadId = options?.ThreadId; this._userId = options?.UserId; this._scopeToPerOperationThreadId = options?.ScopeToPerOperationThreadId ?? false; + this._contextPrompt = options?.ContextPrompt ?? "Consider the following memories when answering user questions:"; this._aIFunctions = [AIFunctionFactory.Create(this.ClearStoredUserFactsAsync)]; @@ -131,7 +137,7 @@ public override async Task OnModelInvokeAsync(ICollection n inputText).ConfigureAwait(false); var userInformation = string.Join(Environment.NewLine, memories); - return string.Join(Environment.NewLine, "The following list contains facts about the user:", userInformation); + return string.Join(Environment.NewLine, this._contextPrompt, userInformation); } /// diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs index 4ef04f363adb..192446489475 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs @@ -8,7 +8,6 @@ namespace Microsoft.SemanticKernel.Memory; /// Options for the . /// [Experimental("SKEXP0130")] -[ExcludeFromCodeCoverage] // Tested via integration tests. public sealed class Mem0MemoryComponentOptions { /// @@ -54,4 +53,13 @@ public sealed class Mem0MemoryComponentOptions /// If , the thread id will be set to the thread id of the current operation, regardless of the value of . /// public bool ScopeToPerOperationThreadId { get; init; } = false; + + /// + /// When providing the memories found in Mem0 to the AI model on invocation, this string is prefixed + /// to those memories, in order to provide some context to the model. + /// + /// + /// Defaults to "Consider the following memories when answering user questions:" + /// + public string? ContextPrompt { get; init; } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/Mem0MemoryComponentTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/Mem0MemoryComponentTests.cs new file mode 100644 index 000000000000..3929d220aefd --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/Mem0MemoryComponentTests.cs @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft. All rights reserved. + +#pragma warning disable CA1054 // URI-like parameters should not be strings + +using System; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.Memory; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for the class. +/// +public class Mem0MemoryComponentTests : IDisposable +{ + private readonly HttpClient _httpClient; + private readonly Mock _mockMessageHandler; + private bool _disposedValue; + + public Mem0MemoryComponentTests() + { + this._mockMessageHandler = new Mock() { CallBase = true }; + this._httpClient = new HttpClient(this._mockMessageHandler.Object) + { + BaseAddress = new Uri("https://localhost/fakepath") + }; + } + + [Fact] + public void ValidatesHttpClientBaseAddress() + { + // Arrange + using var httpClientWithoutBaseAddress = new HttpClient(); + + // Act & Assert + var exception = Assert.Throws(() => + { + new Mem0MemoryComponent(httpClientWithoutBaseAddress); + }); + + Assert.Equal("The BaseAddress of the provided httpClient parameter must be set. (Parameter 'httpClient')", exception.Message); + } + + [Fact] + public void AIFunctionsAreSetCorrectly() + { + // Arrange + var sut = new Mem0MemoryComponent(this._httpClient, new() { ApplicationId = "test-app-id" }); + + // Act + var aiFunctions = sut.AIFunctions; + + // Assert + Assert.NotNull(aiFunctions); + Assert.Single(aiFunctions); + Assert.Equal("ClearStoredUserFacts", aiFunctions.First().Name); + } + + [Theory] + [InlineData(false, "test-thread-id")] + [InlineData(true, "test-thread-id-1")] + public async Task PostsMemoriesOnNewMessage(bool scopePerOperationThread, string expectedThreadId) + { + // Arrange + using var httpResponse = new HttpResponseMessage() { StatusCode = System.Net.HttpStatusCode.OK }; + this._mockMessageHandler + .Setup(x => x.MockableSendAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(httpResponse); + + var sut = new Mem0MemoryComponent(this._httpClient, new() { ApplicationId = "test-app-id", AgentId = "test-agent-id", ThreadId = "test-thread-id", UserId = "test-user-id", ScopeToPerOperationThreadId = scopePerOperationThread }); + + // Act + await sut.OnNewMessageAsync("test-thread-id-1", new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe.")); + + // Assert + var expectedPayload = $$""" + {"app_id":"test-app-id","agent_id":"test-agent-id","run_id":"{{expectedThreadId}}","user_id":"test-user-id","messages":[{"content":"Hello, my name is Caoimhe.","role":"user"}]} + """; + this._mockMessageHandler.Verify(x => x.MockableSendAsync(HttpMethod.Post, "https://localhost/v1/memories/", expectedPayload, It.IsAny()), Times.Once); + } + + [Theory] + [InlineData(false, "test-thread-id", null, "Consider the following memories when answering user questions:{0}Name is Caoimhe")] + [InlineData(true, "test-thread-id-1", "Custom Prompt:", "Custom Prompt:{0}Name is Caoimhe")] + public async Task SearchesForMemoriesOnModelInvoke(bool scopePerOperationThread, string expectedThreadId, string? customContextPrompt, string expectedAdditionalInstructions) + { + // Arrange + var expectedResponseString = """ + [{"id":"1","memory":"Name is Caoimhe","hash":"abc123","metadata":null,"score":0.9,"created_at":"2023-01-01T00:00:00Z","updated_at":null,"user_id":"test-user-id","app_id":null,"agent_id":"test-agent-id","session_id":"test-thread-id-1"}] + """; + using var httpResponse = new HttpResponseMessage() + { + StatusCode = System.Net.HttpStatusCode.OK, + Content = new StringContent(expectedResponseString, Encoding.UTF8, "application/json") + }; + + this._mockMessageHandler + .Setup(x => x.MockableSendAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(httpResponse); + + var sut = new Mem0MemoryComponent(this._httpClient, new() + { + ApplicationId = "test-app-id", + AgentId = "test-agent-id", + ThreadId = "test-thread-id", + UserId = "test-user-id", + ScopeToPerOperationThreadId = scopePerOperationThread, + ContextPrompt = customContextPrompt + }); + await sut.OnThreadCreatedAsync("test-thread-id-1"); + + // Act + var actual = await sut.OnModelInvokeAsync(new[] { new ChatMessage(ChatRole.User, "What is my name?") }); + + // Assert + var expectedPayload = $$""" + {"app_id":"test-app-id","agent_id":"test-agent-id","run_id":"{{expectedThreadId}}","user_id":"test-user-id","query":"What is my name?"} + """; + this._mockMessageHandler.Verify(x => x.MockableSendAsync(HttpMethod.Post, "https://localhost/v1/memories/search/", expectedPayload, It.IsAny()), Times.Once); + + Assert.Equal(string.Format(expectedAdditionalInstructions, Environment.NewLine), actual); + } + + [Theory] + [InlineData(false, "test-thread-id")] + [InlineData(true, "test-thread-id-1")] + public async Task ClearsStoredUserFacts(bool scopePerOperationThread, string expectedThreadId) + { + // Arrange + using var httpResponse = new HttpResponseMessage() { StatusCode = System.Net.HttpStatusCode.OK }; + this._mockMessageHandler + .Setup(x => x.MockableSendAsync(It.IsAny(), It.IsAny(), null, It.IsAny())) + .ReturnsAsync(httpResponse); + + var sut = new Mem0MemoryComponent(this._httpClient, new() { ApplicationId = "test-app-id", AgentId = "test-agent-id", ThreadId = "test-thread-id", UserId = "test-user-id", ScopeToPerOperationThreadId = scopePerOperationThread }); + await sut.OnThreadCreatedAsync("test-thread-id-1"); + + // Act + await sut.ClearStoredUserFactsAsync(); + + // Assert + var expectedUrl = $"https://localhost/v1/memories/?app_id=test-app-id&agent_id=test-agent-id&run_id={expectedThreadId}&user_id=test-user-id"; + this._mockMessageHandler.Verify(x => x.MockableSendAsync(HttpMethod.Delete, expectedUrl, null, It.IsAny()), Times.Once); + } + + [Fact] + public async Task ThrowsExceptionWhenThreadIdChangesAfterBeingSet() + { + // Arrange + var sut = new Mem0MemoryComponent(this._httpClient, new() { ScopeToPerOperationThreadId = true }); + + // Act + await sut.OnThreadCreatedAsync("initial-thread-id"); + + // Assert + var exception = await Assert.ThrowsAsync(async () => + { + await sut.OnThreadCreatedAsync("new-thread-id"); + }); + + Assert.Equal("The Mem0MemoryComponent can only be used with one thread at a time when ScopeToPerOperationThreadId is set to true.", exception.Message); + } + + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._httpClient.Dispose(); + } + + this._disposedValue = true; + } + } + + public void Dispose() + { + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + public class MockableMessageHandler : DelegatingHandler + { + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + string? contentString = request.Content is null ? null : await request.Content.ReadAsStringAsync(cancellationToken); + return await this.MockableSendAsync(request.Method, request.RequestUri?.AbsoluteUri, contentString, cancellationToken); + } + + public virtual Task MockableSendAsync(HttpMethod method, string? absoluteUri, string? content, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + } +} From 54f527f7101b51a2cf6aa15da5f6fc37e99854a8 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 25 Apr 2025 14:10:32 +0100 Subject: [PATCH 38/46] Move default context prompt to const variable. --- .../SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs index 60198313571f..c5010873bcbf 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -38,6 +38,8 @@ namespace Microsoft.SemanticKernel.Memory; [Experimental("SKEXP0130")] public sealed class Mem0MemoryComponent : ConversationStatePart { + private const string DefaultContextPrompt = "Consider the following memories when answering user questions:"; + private readonly string? _applicationId; private readonly string? _agentId; private readonly string? _threadId; @@ -79,7 +81,7 @@ public Mem0MemoryComponent(HttpClient httpClient, Mem0MemoryComponentOptions? op this._threadId = options?.ThreadId; this._userId = options?.UserId; this._scopeToPerOperationThreadId = options?.ScopeToPerOperationThreadId ?? false; - this._contextPrompt = options?.ContextPrompt ?? "Consider the following memories when answering user questions:"; + this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; this._aIFunctions = [AIFunctionFactory.Create(this.ClearStoredUserFactsAsync)]; From 79e30d5b48ad55fcf887e60a9c2ea82266edfaef Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 25 Apr 2025 14:32:27 +0100 Subject: [PATCH 39/46] Update Bedrock Agent to support conversation state and improve tests. --- dotnet/src/Agents/AzureAI/AzureAIAgent.cs | 25 ++++++-- .../src/Agents/Bedrock/Agents.Bedrock.csproj | 1 + dotnet/src/Agents/Bedrock/BedrockAgent.cs | 53 +++++++++++++-- .../BedrockAgentThreadTests.cs | 14 ++-- .../AgentWithStatePartTests.cs | 64 +++++++++++++++++++ .../BedrockAgentWithStatePartTests.cs | 35 ++++++++++ .../BedrockAgentFixture.cs | 2 +- .../BedrockAgentInvokeTests.cs | 18 +++--- 8 files changed, 188 insertions(+), 24 deletions(-) create mode 100644 dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/BedrockAgentWithStatePartTests.cs diff --git a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs index 08ea51c7a616..b3a0393eae90 100644 --- a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs +++ b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs @@ -194,9 +194,10 @@ public async IAsyncEnumerable> InvokeAsync azureAIAgentThread.StateParts.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext); var extensionsContextOptions = options is null ? - new AzureAIAgentInvokeOptions() { AdditionalInstructions = extensionsContext } : - new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = extensionsContext }; + new AzureAIAgentInvokeOptions() { AdditionalInstructions = mergedAdditionalInstructions } : + new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = mergedAdditionalInstructions }; var invokeResults = ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), @@ -323,9 +324,10 @@ public async IAsyncEnumerable> In azureAIAgentThread.StateParts.RegisterPlugins(kernel); #pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext); var extensionsContextOptions = options is null ? - new AzureAIAgentInvokeOptions() { AdditionalInstructions = extensionsContext } : - new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = extensionsContext }; + new AzureAIAgentInvokeOptions() { AdditionalInstructions = mergedAdditionalInstructions } : + new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = mergedAdditionalInstructions }; #pragma warning disable CS0618 // Type or member is obsolete // Invoke the Agent with the thread that we already added our message to. @@ -461,4 +463,19 @@ protected override async Task RestoreChannelAsync(string channelSt return new AzureAIChannel(this.Client, thread.Id); } + + private static string MergeAdditionalInstructions(string? optionsAdditionalInstructions, string extensionsContext) => + (optionsAdditionalInstructions, extensionsContext) switch + { + (string ai, string ec) when !string.IsNullOrWhiteSpace(ai) && !string.IsNullOrWhiteSpace(ec) => string.Concat( + ai, + Environment.NewLine, + Environment.NewLine, + ec), + (string ai, string ec) when string.IsNullOrWhiteSpace(ai) => ec, + (string ai, string ec) when string.IsNullOrWhiteSpace(ec) => ai, + (null, string ec) => ec, + (string ai, null) => ai, + _ => string.Empty + }; } diff --git a/dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj b/dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj index e17d43f63fcc..c3aa00b13330 100644 --- a/dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj +++ b/dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj @@ -33,6 +33,7 @@ + diff --git a/dotnet/src/Agents/Bedrock/BedrockAgent.cs b/dotnet/src/Agents/Bedrock/BedrockAgent.cs index 22d4187b0bd2..ae54bec77288 100644 --- a/dotnet/src/Agents/Bedrock/BedrockAgent.cs +++ b/dotnet/src/Agents/Bedrock/BedrockAgent.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Tasks; using Amazon.BedrockAgent; @@ -117,11 +118,18 @@ public override async IAsyncEnumerable> In () => new BedrockAgentThread(this.RuntimeClient), cancellationToken).ConfigureAwait(false); + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await bedrockThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + // Ensure that the last message provided is a user message string? message = this.ExtractUserMessage(messages.Last()); - // Build session state with conversation history if needed + // Build session state with conversation history and override instructions if needed SessionState sessionState = this.ExtractSessionState(messages); + var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext); + sessionState.PromptSessionAttributes = new() { ["AdditionalInstructions"] = mergedAdditionalInstructions }; // Configure the agent request with the provided options var invokeAgentRequest = this.ConfigureAgentRequest(options, () => @@ -346,11 +354,18 @@ public override async IAsyncEnumerable new BedrockAgentThread(this.RuntimeClient), cancellationToken).ConfigureAwait(false); + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await bedrockThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + // Ensure that the last message provided is a user message string? message = this.ExtractUserMessage(messages.Last()); - // Build session state with conversation history if needed + // Build session state with conversation history and override instructions if needed SessionState sessionState = this.ExtractSessionState(messages); + var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext); + sessionState.PromptSessionAttributes = new() { ["AdditionalInstructions"] = mergedAdditionalInstructions }; // Configure the agent request with the provided options var invokeAgentRequest = this.ConfigureAgentRequest(options, () => @@ -639,20 +654,35 @@ private IAsyncEnumerable InvokeStreamingInternalAsy async IAsyncEnumerable InvokeInternal() { + var combinedResponseMessageBuilder = new StringBuilder(); + StreamingChatMessageContent? lastMessage = null; + // The Bedrock agent service has the same API for both streaming and non-streaming responses. // We are invoking the same method as the non-streaming response with the streaming configuration set, // and converting the chat message content to streaming chat message content. await foreach (var chatMessageContent in this.InternalInvokeAsync(invokeAgentRequest, arguments, cancellationToken).ConfigureAwait(false)) { - await this.NotifyThreadOfNewMessage(thread, chatMessageContent, cancellationToken).ConfigureAwait(false); - yield return new StreamingChatMessageContent(chatMessageContent.Role, chatMessageContent.Content) + lastMessage = new StreamingChatMessageContent(chatMessageContent.Role, chatMessageContent.Content) { AuthorName = chatMessageContent.AuthorName, ModelId = chatMessageContent.ModelId, InnerContent = chatMessageContent.InnerContent, Metadata = chatMessageContent.Metadata, }; + yield return lastMessage; + + combinedResponseMessageBuilder.Append(chatMessageContent.Content); } + + // Build a combined message containing the text from all response parts + // to send to the thread. + var combinedMessage = new ChatMessageContent(AuthorRole.Assistant, combinedResponseMessageBuilder.ToString()) + { + AuthorName = lastMessage?.AuthorName, + ModelId = lastMessage?.ModelId, + Metadata = lastMessage?.Metadata, + }; + await this.NotifyThreadOfNewMessage(thread, combinedMessage, cancellationToken).ConfigureAwait(false); } } @@ -726,4 +756,19 @@ private Amazon.BedrockAgentRuntime.ConversationRole MapBedrockAgentUser(AuthorRo } #endregion + + private static string MergeAdditionalInstructions(string? optionsAdditionalInstructions, string extensionsContext) => + (optionsAdditionalInstructions, extensionsContext) switch + { + (string ai, string ec) when !string.IsNullOrWhiteSpace(ai) && !string.IsNullOrWhiteSpace(ec) => string.Concat( + ai, + Environment.NewLine, + Environment.NewLine, + ec), + (string ai, string ec) when string.IsNullOrWhiteSpace(ai) => ec, + (string ai, string ec) when string.IsNullOrWhiteSpace(ec) => ai, + (null, string ec) => ec, + (string ai, null) => ai, + _ => string.Empty + }; } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/BedrockAgentThreadTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/BedrockAgentThreadTests.cs index 93fff6c7f636..ce09c4a3a68d 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/BedrockAgentThreadTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/BedrockAgentThreadTests.cs @@ -7,38 +7,40 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Agen public class BedrockAgentThreadTests() : AgentThreadTests(() => new BedrockAgentFixture()) { - [Fact(Skip = "Manual verification only")] + private const string ManualVerificationSkipReason = "This test is for manual verification."; + + [Fact(Skip = ManualVerificationSkipReason)] public override Task OnNewMessageWithServiceFailureThrowsAgentOperationExceptionAsync() { // The Bedrock agent does not support writing to a thread with OnNewMessage. return Task.CompletedTask; } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task DeletingThreadTwiceDoesNotThrowAsync() { return base.DeletingThreadTwiceDoesNotThrowAsync(); } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task UsingThreadAfterDeleteThrowsAsync() { return base.UsingThreadAfterDeleteThrowsAsync(); } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task DeleteThreadBeforeCreateThrowsAsync() { return base.DeleteThreadBeforeCreateThrowsAsync(); } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task UsingThreadBeforeCreateCreatesAsync() { return base.UsingThreadBeforeCreateCreatesAsync(); } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync() { return base.DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync(); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs index c2934ebfc9a0..12e9011b4eb7 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs @@ -53,6 +53,39 @@ public virtual async Task StatePartReceivesMessagesFromAgentAsync() } } + [Fact] + public virtual async Task StatePartReceivesMessagesFromAgentWhenStreamingAsync() + { + // Arrange + var mockStatePart = new Mock() { CallBase = true }; + mockStatePart.Setup(x => x.OnNewMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())); + + var agent = this.Fixture.Agent; + + var agentThread = this.Fixture.GetNewThread(); + + try + { + agentThread.StateParts.Add(mockStatePart.Object); + + // Act + var inputMessage = "What is the capital of France?"; + var asyncResults1 = agent.InvokeStreamingAsync(inputMessage, agentThread); + var results = await asyncResults1.ToListAsync(); + + // Assert + var responseMessage = string.Concat(results.Select(x => x.Message.Content)); + Assert.Contains("Paris", responseMessage); + mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny(), It.Is(cm => cm.Text == inputMessage), It.IsAny()), Times.Once); + mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny(), It.Is(cm => cm.Text == responseMessage), It.IsAny()), Times.Once); + } + finally + { + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + } + [Fact] public virtual async Task StatePartPreInvokeStateIsUsedByAgentAsync() { @@ -83,6 +116,37 @@ public virtual async Task StatePartPreInvokeStateIsUsedByAgentAsync() } } + [Fact] + public virtual async Task StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync() + { + // Arrange + var mockStatePart = new Mock() { CallBase = true }; + mockStatePart.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())).ReturnsAsync("User name is Caoimhe"); + + var agent = this.Fixture.Agent; + + var agentThread = this.Fixture.GetNewThread(); + + try + { + agentThread.StateParts.Add(mockStatePart.Object); + + // Act + var inputMessage = "What is my name?."; + var asyncResults1 = agent.InvokeStreamingAsync(inputMessage, agentThread); + var results = await asyncResults1.ToListAsync(); + + // Assert + var responseMessage = string.Concat(results.Select(x => x.Message.Content)); + Assert.Contains("Caoimhe", responseMessage); + } + finally + { + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + } + public Task InitializeAsync() { this._agentFixture = createAgentFixture(); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/BedrockAgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/BedrockAgentWithStatePartTests.cs new file mode 100644 index 000000000000..05f16f5d7292 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/BedrockAgentWithStatePartTests.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public class BedrockAgentWithStatePartTests() : AgentWithStatePartTests(() => new BedrockAgentFixture()) +{ + private const string ManualVerificationSkipReason = "This test is for manual verification."; + + [Fact(Skip = ManualVerificationSkipReason)] + public override Task StatePartReceivesMessagesFromAgentAsync() + { + return base.StatePartReceivesMessagesFromAgentAsync(); + } + + [Fact(Skip = ManualVerificationSkipReason)] + public override Task StatePartReceivesMessagesFromAgentWhenStreamingAsync() + { + return base.StatePartReceivesMessagesFromAgentWhenStreamingAsync(); + } + + [Fact(Skip = ManualVerificationSkipReason)] + public override Task StatePartPreInvokeStateIsUsedByAgentAsync() + { + return base.StatePartPreInvokeStateIsUsedByAgentAsync(); + } + + [Fact(Skip = ManualVerificationSkipReason)] + public override Task StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync() + { + return base.StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync(); + } +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs index 957b8f76b283..d9ba2a333003 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs @@ -18,7 +18,7 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance; -internal sealed class BedrockAgentFixture : AgentFixture, IAsyncDisposable +public sealed class BedrockAgentFixture : AgentFixture, IAsyncDisposable { private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/InvokeConformance/BedrockAgentInvokeTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/InvokeConformance/BedrockAgentInvokeTests.cs index e9eff01241c6..84742ca908dc 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/InvokeConformance/BedrockAgentInvokeTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/InvokeConformance/BedrockAgentInvokeTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Linq; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -11,7 +10,9 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Invo public class BedrockAgentInvokeTests() : InvokeTests(() => new BedrockAgentFixture()) { - [Fact(Skip = "This test is for manual verification.")] + private const string ManualVerificationSkipReason = "This test is for manual verification."; + + [Fact(Skip = ManualVerificationSkipReason)] public override async Task ConversationMaintainsHistoryAsync() { var q1 = "What is the capital of France."; @@ -32,7 +33,7 @@ public override async Task ConversationMaintainsHistoryAsync() //Assert.Contains("Eiffel", result2.Message.Content); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = ManualVerificationSkipReason)] public override async Task InvokeReturnsResultAsync() { var agent = this.Fixture.Agent; @@ -48,7 +49,7 @@ public override async Task InvokeReturnsResultAsync() //Assert.Contains("Paris", firstResult.Message.Content); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = ManualVerificationSkipReason)] public override async Task InvokeWithoutThreadCreatesThreadAsync() { var agent = this.Fixture.Agent; @@ -66,20 +67,19 @@ public override async Task InvokeWithoutThreadCreatesThreadAsync() await this.Fixture.DeleteThread(firstResult.Thread); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "The BedrockAgent does not support invoking without a message.")] public override Task InvokeWithoutMessageCreatesThreadAsync() { - // The Bedrock agent does not support invoking without a message. - return Assert.ThrowsAsync(async () => await base.InvokeWithoutThreadCreatesThreadAsync()); + return base.InvokeWithoutMessageCreatesThreadAsync(); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "The BedrockAgent does not yet support plugins")] public override Task MultiStepInvokeWithPluginAndArgOverridesAsync() { return base.MultiStepInvokeWithPluginAndArgOverridesAsync(); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "The BedrockAgent does not yet support plugins")] public override Task InvokeWithPluginNotifiesForAllMessagesAsync() { return base.InvokeWithPluginNotifiesForAllMessagesAsync(); From 6f0c5f08945b8ca55e8a576cc41acee9283f60ce Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 25 Apr 2025 18:39:59 +0100 Subject: [PATCH 40/46] Move text rag to core and clean up. --- .../Memory/Memory/UserFactsMemoryComponent.cs | 4 +++ .../Memory}/TextMemoryStore.cs | 2 ++ .../Memory/OptionalTextMemoryStore.cs | 0 .../Memory/TextRag/TextRagComponent.cs | 31 ++++++++++++++----- .../Memory/TextRag/TextRagComponentOptions.cs | 20 ++++++++++++ .../Memory/TextRag/TextRagDocument.cs | 2 ++ .../Memory/TextRag/TextRagStore.cs | 12 ++++--- .../Memory/VectorDataTextMemoryStore.cs | 2 ++ 8 files changed, 61 insertions(+), 12 deletions(-) rename dotnet/src/{Memory/Memory.Abstractions => SemanticKernel.Abstractions/Memory}/TextMemoryStore.cs (97%) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/OptionalTextMemoryStore.cs (100%) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/TextRag/TextRagComponent.cs (70%) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/TextRag/TextRagComponentOptions.cs (70%) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/TextRag/TextRagDocument.cs (97%) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/TextRag/TextRagStore.cs (96%) rename dotnet/src/{Memory => SemanticKernel.Core}/Memory/VectorDataTextMemoryStore.cs (99%) diff --git a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs index dc3180fcbe14..2c77ef77e3f5 100644 --- a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +#if DISABLED + using System.Collections.Generic; using System.ComponentModel; using System.Threading; @@ -167,3 +169,5 @@ private async Task ExtractAndSaveMemoriesAsync(string inputText, CancellationTok await this._textMemoryStore.SaveMemoryAsync(this.UserFactsDocumentName, this._userFacts, cancellationToken).ConfigureAwait(false); } } + +#endif diff --git a/dotnet/src/Memory/Memory.Abstractions/TextMemoryStore.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/TextMemoryStore.cs similarity index 97% rename from dotnet/src/Memory/Memory.Abstractions/TextMemoryStore.cs rename to dotnet/src/SemanticKernel.Abstractions/Memory/TextMemoryStore.cs index ded74e1c0398..599ec7dd1eae 100644 --- a/dotnet/src/Memory/Memory.Abstractions/TextMemoryStore.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/TextMemoryStore.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; @@ -9,6 +10,7 @@ namespace Microsoft.SemanticKernel.Memory; /// /// Abstract base class for storing and retrieving text based memories. /// +[Experimental("SKEXP0001")] public abstract class TextMemoryStore { /// diff --git a/dotnet/src/Memory/Memory/OptionalTextMemoryStore.cs b/dotnet/src/SemanticKernel.Core/Memory/OptionalTextMemoryStore.cs similarity index 100% rename from dotnet/src/Memory/Memory/OptionalTextMemoryStore.cs rename to dotnet/src/SemanticKernel.Core/Memory/OptionalTextMemoryStore.cs diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs similarity index 70% rename from dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs rename to dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs index 29f6a6175e46..c52490d280a4 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs @@ -2,9 +2,11 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -13,10 +15,16 @@ namespace Microsoft.SemanticKernel.Memory; /// -/// A component that does a search based on any messages that the AI is invoked with and injects the results into the AI invocation context. +/// A component that does a search based on any messages that the AI model is invoked with and injects the results into the AI model invocation context. /// +[Experimental("SKEXP0130")] public class TextRagComponent : ConversationStatePart { + private const string DefaultPluginSearchFunctionName = "Search"; + private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question."; + private const string DefaultContextPrompt = "Consider the following source information when responding to the user:"; + private const string DefaultIncludeCitationsPrompt = "Include citations to the relevant information where it is referenced in the response."; + private readonly ITextSearch _textSearch; private readonly AIFunction[] _aIFunctions; @@ -25,7 +33,7 @@ public class TextRagComponent : ConversationStatePart /// Initializes a new instance of the class. /// /// The text search component to retrieve results from. - /// + /// Options that configure the behavior of the component. /// public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions? options = default) { @@ -38,8 +46,8 @@ public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions? options [ AIFunctionFactory.Create( this.SearchAsync, - name: this.Options.PluginSearchFunctionName ?? "Search", - description: this.Options.PluginSearchFunctionDescription ?? "Allows searching for additional information to help answer the user question.") + name: this.Options.PluginSearchFunctionName ?? DefaultPluginSearchFunctionName, + description: this.Options.PluginSearchFunctionDescription ?? DefaultPluginSearchFunctionDescription) ]; } @@ -81,7 +89,7 @@ public override async Task OnModelInvokeAsync(ICollection n // Format the results showing the content with source link and name for each result. var sb = new StringBuilder(); - sb.AppendLine("Please consider the following source information when responding to the user:"); + sb.AppendLine(this.Options.ContextPrompt ?? DefaultContextPrompt); await foreach (var result in searchResults.Results.ConfigureAwait(false)) { sb.AppendLine($" Source Document Name: {result.Name}"); @@ -90,7 +98,7 @@ public override async Task OnModelInvokeAsync(ICollection n sb.AppendLine(" -----------------"); } - sb.AppendLine("Include citations to the relevant information where it is referenced in the response."); + sb.AppendLine(this.Options.InclueCitationsPrompt ?? DefaultIncludeCitationsPrompt); sb.AppendLine("-------------------"); return sb.ToString(); @@ -109,6 +117,15 @@ public async Task SearchAsync(string userQuestion, CancellationToken can var results = await searchResults.Results.ToListAsync(cancellationToken).ConfigureAwait(false); - return JsonSerializer.Serialize(results); + return JsonSerializer.Serialize(results, TextRagSourceGenerationContext.Default.ListTextSearchResult); } } + +[JsonSourceGenerationOptions(JsonSerializerDefaults.General, + UseStringEnumConverter = false, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = false)] +[JsonSerializable(typeof(List))] +internal partial class TextRagSourceGenerationContext : JsonSerializerContext +{ +} diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs similarity index 70% rename from dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs rename to dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs index 275074cd60b1..7a0f85f8be28 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagComponentOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs @@ -1,12 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics.CodeAnalysis; namespace Microsoft.SemanticKernel.Memory; /// /// Contains options for the . /// +[Experimental("SKEXP0130")] public class TextRagComponentOptions { private int _top = 3; @@ -47,6 +49,24 @@ public int Top /// public string? PluginSearchFunctionDescription { get; init; } + /// + /// When providing the text chunks to the AI model on invocation, this string is prefixed + /// to those chunks, in order to provide some context to the model. + /// + /// + /// Defaults to "Consider the following source information when responding to the user::" + /// + public string? ContextPrompt { get; init; } + + /// + /// When providing the text chunks to the AI model on invocation, this string is postfixed + /// to those chunks, in order to instruct the model to include citations. + /// + /// + /// Defaults to "Include citations to the relevant information where it is referenced in the response.:" + /// + public string? InclueCitationsPrompt { get; init; } + /// /// The time at which the text search is performed. /// diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagDocument.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs similarity index 97% rename from dotnet/src/Memory/Memory/TextRag/TextRagDocument.cs rename to dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs index 7e6146b769a8..fa1ab41f06ee 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagDocument.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs @@ -1,12 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; namespace Microsoft.SemanticKernel.Memory.TextRag; /// /// Represents a document that can be used for Retrieval Augmented Generation (RAG). /// +[Experimental("SKEXP0130")] public class TextRagDocument { /// diff --git a/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs similarity index 96% rename from dotnet/src/Memory/Memory/TextRag/TextRagStore.cs rename to dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs index f713658b492f..1a641b88b2a4 100644 --- a/dotnet/src/Memory/Memory/TextRag/TextRagStore.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; using System.Threading; @@ -17,6 +18,7 @@ namespace Microsoft.SemanticKernel.Memory; /// A class that allows for easy storage and retrieval of documents in a Vector Store for Retrieval Augmented Generation (RAG). /// /// The key type to use with the vector store. +[Experimental("SKEXP0130")] public class TextRagStore : ITextSearch, IDisposable where TKey : notnull { @@ -110,7 +112,7 @@ public async Task> SearchAsync(string query, TextSea { var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); - return new(searchResult.Results.Select(x => x.Record.Text ?? string.Empty)); + return new(searchResult.Results.SelectAsync(x => x.Record.Text ?? string.Empty, cancellationToken)); } /// @@ -118,20 +120,20 @@ public async Task> GetTextSearchResultsAsy { var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); - var results = searchResult.Results.Select(x => new TextSearchResult(x.Record.Text ?? string.Empty) { Name = x.Record.SourceName, Link = x.Record.SourceReference }); - return new(searchResult.Results.Select(x => + var results = searchResult.Results.SelectAsync(x => new TextSearchResult(x.Record.Text ?? string.Empty) { Name = x.Record.SourceName, Link = x.Record.SourceReference }, cancellationToken); + return new(searchResult.Results.SelectAsync(x => new TextSearchResult(x.Record.Text ?? string.Empty) { Name = x.Record.SourceName, Link = x.Record.SourceReference - })); + }, cancellationToken)); } /// public async Task> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); - return new(searchResult.Results.Cast()); + return new(searchResult.Results.SelectAsync(x => (object)x, cancellationToken)); } /// diff --git a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs b/dotnet/src/SemanticKernel.Core/Memory/VectorDataTextMemoryStore.cs similarity index 99% rename from dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs rename to dotnet/src/SemanticKernel.Core/Memory/VectorDataTextMemoryStore.cs index fa98021dbba9..c3681a1c7745 100644 --- a/dotnet/src/Memory/Memory/VectorDataTextMemoryStore.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/VectorDataTextMemoryStore.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; @@ -17,6 +18,7 @@ namespace Microsoft.SemanticKernel.Memory; /// Class to store and retrieve text-based memories to and from a vector store. /// /// The key type to use with the vector store. +[Experimental("SKEXP0001")] public class VectorDataTextMemoryStore : TextMemoryStore, IDisposable where TKey : notnull { From 3fdf5e2ee5fe09633c47219300f9fc9dd56d8a16 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 28 Apr 2025 15:07:35 +0100 Subject: [PATCH 41/46] Add unit tests for TextRagComponent and rename a few settings --- .../Memory/TextRag/TextRagComponent.cs | 49 +++-- .../Memory/TextRag/TextRagComponentOptions.cs | 30 +-- .../Memory/TextRag/TextRagStore.cs | 7 +- .../Memory/TextRagComponentTests.cs | 175 ++++++++++++++++++ 4 files changed, 228 insertions(+), 33 deletions(-) create mode 100644 dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs index c52490d280a4..241e13419a42 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs @@ -46,8 +46,8 @@ public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions? options [ AIFunctionFactory.Create( this.SearchAsync, - name: this.Options.PluginSearchFunctionName ?? DefaultPluginSearchFunctionName, - description: this.Options.PluginSearchFunctionDescription ?? DefaultPluginSearchFunctionDescription) + name: this.Options.PluginFunctionName ?? DefaultPluginSearchFunctionName, + description: this.Options.PluginFunctionDescription ?? DefaultPluginSearchFunctionDescription) ]; } @@ -61,7 +61,7 @@ public override IReadOnlyCollection AIFunctions { get { - if (this.Options.SearchTime != TextRagComponentOptions.TextRagSearchTime.ViaPlugin) + if (this.Options.SearchTime != TextRagComponentOptions.RagBehavior.ViaPlugin) { return Array.Empty(); } @@ -73,7 +73,7 @@ public override IReadOnlyCollection AIFunctions /// public override async Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) { - if (this.Options.SearchTime != TextRagComponentOptions.TextRagSearchTime.BeforeAIInvoke) + if (this.Options.SearchTime != TextRagComponentOptions.RagBehavior.BeforeAIInvoke) { return string.Empty; } @@ -87,28 +87,16 @@ public override async Task OnModelInvokeAsync(ICollection n new() { Top = this.Options.Top }, cancellationToken: cancellationToken).ConfigureAwait(false); - // Format the results showing the content with source link and name for each result. - var sb = new StringBuilder(); - sb.AppendLine(this.Options.ContextPrompt ?? DefaultContextPrompt); - await foreach (var result in searchResults.Results.ConfigureAwait(false)) - { - sb.AppendLine($" Source Document Name: {result.Name}"); - sb.AppendLine($" Source Document Link: {result.Link}"); - sb.AppendLine($" Source Document Contents: {result.Value}"); - sb.AppendLine(" -----------------"); - } - - sb.AppendLine(this.Options.InclueCitationsPrompt ?? DefaultIncludeCitationsPrompt); - sb.AppendLine("-------------------"); + var results = await searchResults.Results.ToListAsync(cancellationToken).ConfigureAwait(false); - return sb.ToString(); + return this.FormatResults(results); } /// /// Plugin method to search the database on demand. /// [KernelFunction] - public async Task SearchAsync(string userQuestion, CancellationToken cancellationToken = default) + internal async Task SearchAsync(string userQuestion, CancellationToken cancellationToken = default) { var searchResults = await this._textSearch.GetTextSearchResultsAsync( userQuestion, @@ -117,7 +105,28 @@ public async Task SearchAsync(string userQuestion, CancellationToken can var results = await searchResults.Results.ToListAsync(cancellationToken).ConfigureAwait(false); - return JsonSerializer.Serialize(results, TextRagSourceGenerationContext.Default.ListTextSearchResult); + return this.FormatResults(results); + } + + /// + /// Format the results showing the content with source link and name for each result. + /// + /// The results to format. + /// The formatted results. + private string FormatResults(List results) + { + var sb = new StringBuilder(); + sb.AppendLine(this.Options.ContextPrompt ?? DefaultContextPrompt); + foreach (var result in results) + { + sb.AppendLine($" Source Document Name: {result.Name}"); + sb.AppendLine($" Source Document Link: {result.Link}"); + sb.AppendLine($" Source Document Contents: {result.Value}"); + sb.AppendLine(" -----------------"); + } + sb.AppendLine(this.Options.IncludeCitationsPrompt ?? DefaultIncludeCitationsPrompt); + sb.AppendLine("-------------------"); + return sb.ToString(); } } diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs index 7a0f85f8be28..639c2651a1ab 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs @@ -35,26 +35,32 @@ public int Top /// /// Gets or sets the time at which the text search is performed. /// - public TextRagSearchTime SearchTime { get; init; } = TextRagSearchTime.BeforeAIInvoke; + public RagBehavior SearchTime { get; init; } = RagBehavior.BeforeAIInvoke; /// /// Gets or sets the name of the plugin method that will be made available for searching - /// if the option is set to . + /// if the option is set to . /// - public string? PluginSearchFunctionName { get; init; } + /// + /// Defaults to "Search" if not set. + /// + public string? PluginFunctionName { get; init; } /// /// Gets or sets the description of the plugin method that will be made available for searching - /// if the option is set to . + /// if the option is set to . /// - public string? PluginSearchFunctionDescription { get; init; } + /// + /// Defaults to "Allows searching for additional information to help answer the user question." if not set. + /// + public string? PluginFunctionDescription { get; init; } /// /// When providing the text chunks to the AI model on invocation, this string is prefixed /// to those chunks, in order to provide some context to the model. /// /// - /// Defaults to "Consider the following source information when responding to the user::" + /// Defaults to "Consider the following source information when responding to the user:" /// public string? ContextPrompt { get; init; } @@ -65,21 +71,21 @@ public int Top /// /// Defaults to "Include citations to the relevant information where it is referenced in the response.:" /// - public string? InclueCitationsPrompt { get; init; } + public string? IncludeCitationsPrompt { get; init; } /// - /// The time at which the text search is performed. + /// Choices for controlling the behavior of the . /// - public enum TextRagSearchTime + public enum RagBehavior { /// - /// A search is performed each time that the AI is invoked just before the AI is invoked - /// and the results are provided to the AI via the invocation context. + /// A search is performed each time that the model/agent is invoked just before invocation + /// and the results are provided to the model/agent via the invocation context. /// BeforeAIInvoke, /// - /// A search may be performed by the AI on demand via a plugin. + /// A search may be performed by the model/agent on demand via function calling. /// ViaPlugin } diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs index 1a641b88b2a4..c187cb899a36 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs @@ -41,7 +41,12 @@ public class TextRagStore : ITextSearch, IDisposable /// The number of dimensions to use for the memory embeddings. /// An optional namespace to filter search results to. /// Thrown if the key type provided is not supported. - public TextRagStore(IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService, string collectionName, int vectorDimensions, string? searchNamespace) + public TextRagStore( + IVectorStore vectorStore, + ITextEmbeddingGenerationService textEmbeddingGenerationService, + string collectionName, + int vectorDimensions, + string? searchNamespace) { Verify.NotNull(vectorStore); Verify.NotNull(textEmbeddingGenerationService); diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs new file mode 100644 index 000000000000..33fb49683aff --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.Data; +using Microsoft.SemanticKernel.Memory; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for +/// +public class TextRagComponentTests +{ + [Theory] + [InlineData(null, null, "Consider the following source information when responding to the user:", "Include citations to the relevant information where it is referenced in the response.")] + [InlineData("Custom context prompt", "Custom citations prompt", "Custom context prompt", "Custom citations prompt")] + public async Task OnModelInvokeShouldIncludeSearchResultsInOutputAsync( + string? overrideContextPrompt, + string? overrideCitationsPrompt, + string expectedContextPrompt, + string expectedCitationsPrompt) + { + // Arrange + var mockTextSearch = new Mock(); + var searchResults = new Mock>(); + var mockEnumerator = new Mock>(); + + // Mock search results + var results = new List + { + new("Content of Doc1") { Name = "Doc1", Link = "http://example.com/doc1" }, + new("Content of Doc2") { Name = "Doc2", Link = "http://example.com/doc2" } + }; + + mockEnumerator.SetupSequence(e => e.MoveNextAsync()) + .ReturnsAsync(true) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockEnumerator.SetupSequence(e => e.Current) + .Returns(results[0]) + .Returns(results[1]); + + searchResults.Setup(r => r.GetAsyncEnumerator(It.IsAny())) + .Returns(mockEnumerator.Object); + + mockTextSearch.Setup(ts => ts.GetTextSearchResultsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new KernelSearchResults(searchResults.Object)); + + var options = new TextRagComponentOptions + { + SearchTime = TextRagComponentOptions.RagBehavior.BeforeAIInvoke, + Top = 2, + ContextPrompt = overrideContextPrompt, + IncludeCitationsPrompt = overrideCitationsPrompt + }; + + var component = new TextRagComponent(mockTextSearch.Object, options); + + // Act + var result = await component.OnModelInvokeAsync([new ChatMessage(ChatRole.User, "Sample user question?")], CancellationToken.None); + + // Assert + Assert.Contains(expectedContextPrompt, result); + Assert.Contains("Source Document Name: Doc1", result); + Assert.Contains("Source Document Link: http://example.com/doc1", result); + Assert.Contains("Source Document Contents: Content of Doc1", result); + Assert.Contains("Source Document Name: Doc2", result); + Assert.Contains("Source Document Link: http://example.com/doc2", result); + Assert.Contains("Source Document Contents: Content of Doc2", result); + Assert.Contains(expectedCitationsPrompt, result); + } + + [Theory] + [InlineData(null, null, "Search", "Allows searching for additional information to help answer the user question.")] + [InlineData("CustomSearch", "CustomDescription", "CustomSearch", "CustomDescription")] + public void AIFunctionsShouldBeRegisteredCorrectly( + string? overridePluginFunctionName, + string? overridePluginFunctionDescription, + string expectedPluginFunctionName, + string expectedPluginFunctionDescription) + { + // Arrange + var mockTextSearch = new Mock(); + var options = new TextRagComponentOptions + { + SearchTime = TextRagComponentOptions.RagBehavior.ViaPlugin, + PluginFunctionName = overridePluginFunctionName, + PluginFunctionDescription = overridePluginFunctionDescription + }; + + var component = new TextRagComponent(mockTextSearch.Object, options); + + // Act + var aiFunctions = component.AIFunctions; + + // Assert + Assert.NotNull(aiFunctions); + Assert.Single(aiFunctions); + var aiFunction = aiFunctions.First(); + Assert.Equal(expectedPluginFunctionName, aiFunction.Name); + Assert.Equal(expectedPluginFunctionDescription, aiFunction.Description); + } + + [Theory] + [InlineData(null, null, "Consider the following source information when responding to the user:", "Include citations to the relevant information where it is referenced in the response.")] + [InlineData("Custom context prompt", "Custom citations prompt", "Custom context prompt", "Custom citations prompt")] + public async Task SearchAsyncShouldIncludeSearchResultsInOutputAsync( + string? overrideContextPrompt, + string? overrideCitationsPrompt, + string expectedContextPrompt, + string expectedCitationsPrompt) + + { + // Arrange + var mockTextSearch = new Mock(); + var searchResults = new Mock>(); + var mockEnumerator = new Mock>(); + + // Mock search results + var results = new List + { + new("Content of Doc1") { Name = "Doc1", Link = "http://example.com/doc1" }, + new("Content of Doc2") { Name = "Doc2", Link = "http://example.com/doc2" } + }; + + mockEnumerator.SetupSequence(e => e.MoveNextAsync()) + .ReturnsAsync(true) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockEnumerator.SetupSequence(e => e.Current) + .Returns(results[0]) + .Returns(results[1]); + + searchResults.Setup(r => r.GetAsyncEnumerator(It.IsAny())) + .Returns(mockEnumerator.Object); + + mockTextSearch.Setup(ts => ts.GetTextSearchResultsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new KernelSearchResults(searchResults.Object)); + + var options = new TextRagComponentOptions + { + ContextPrompt = overrideContextPrompt, + IncludeCitationsPrompt = overrideCitationsPrompt + }; + + var component = new TextRagComponent(mockTextSearch.Object, options); + + // Act + var result = await component.SearchAsync("Sample user question?", CancellationToken.None); + + // Assert + Assert.Contains(expectedContextPrompt, result); + Assert.Contains("Source Document Name: Doc1", result); + Assert.Contains("Source Document Link: http://example.com/doc1", result); + Assert.Contains("Source Document Contents: Content of Doc1", result); + Assert.Contains("Source Document Name: Doc2", result); + Assert.Contains("Source Document Link: http://example.com/doc2", result); + Assert.Contains("Source Document Contents: Content of Doc2", result); + Assert.Contains(expectedCitationsPrompt, result); + } +} From 2c50f7e55c4a83cfd93681e92545ba041b3e6958 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Tue, 29 Apr 2025 22:00:21 +0100 Subject: [PATCH 42/46] Add TextRagStore improvements. --- .../Memory/TextRag/TextRagDocument.cs | 25 +++------ .../Memory/TextRag/TextRagStore.cs | 55 ++++++++++++------- .../Memory/TextRag/TextRagStoreOptions.cs | 33 +++++++++++ 3 files changed, 76 insertions(+), 37 deletions(-) create mode 100644 dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs index fa1ab41f06ee..144bf4bc5fc3 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs @@ -11,17 +11,6 @@ namespace Microsoft.SemanticKernel.Memory.TextRag; [Experimental("SKEXP0130")] public class TextRagDocument { - /// - /// Initializes a new instance of the class. - /// - /// The text content. - public TextRagDocument(string text) - { - Verify.NotNullOrWhiteSpace(text); - - this.Text = text; - } - /// /// Gets or sets an optional list of namespaces that the document should belong to. /// @@ -30,6 +19,11 @@ public TextRagDocument(string text) /// public List Namespaces { get; set; } = []; + /// + /// Gets or sets the content as text. + /// + public string? Text { get; set; } + /// /// Gets or sets an optional source ID for the document. /// @@ -41,11 +35,6 @@ public TextRagDocument(string text) /// public string? SourceId { get; set; } - /// - /// Gets or sets the content as text. - /// - public string Text { get; set; } - /// /// Gets or sets an optional name for the source document. /// @@ -56,11 +45,11 @@ public TextRagDocument(string text) public string? SourceName { get; set; } /// - /// Gets or sets an optional reference back to the source of the document. + /// Gets or sets an optional link back to the source of the document. /// /// /// This can be used to provide citation links when the document is referenced as /// part of a response to a query. /// - public string? SourceReference { get; set; } + public string? SourceLink { get; set; } } diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs index c187cb899a36..86dad15d5ebc 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs @@ -25,7 +25,7 @@ public class TextRagStore : ITextSearch, IDisposable private readonly IVectorStore _vectorStore; private readonly ITextEmbeddingGenerationService _textEmbeddingGenerationService; private readonly int _vectorDimensions; - private readonly string? _searchNamespace; + private readonly TextRagStoreOptions _options; private readonly Lazy>> _vectorStoreRecordCollection; private readonly SemaphoreSlim _collectionInitializationLock = new(1, 1); @@ -39,14 +39,14 @@ public class TextRagStore : ITextSearch, IDisposable /// The service to use for generating embeddings for the memories. /// The name of the collection in the vector store to store and read the memories from. /// The number of dimensions to use for the memory embeddings. - /// An optional namespace to filter search results to. + /// Options to configure the behavior of this class. /// Thrown if the key type provided is not supported. public TextRagStore( IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService, string collectionName, int vectorDimensions, - string? searchNamespace) + TextRagStoreOptions? options) { Verify.NotNull(vectorStore); Verify.NotNull(textEmbeddingGenerationService); @@ -56,13 +56,18 @@ public TextRagStore( this._vectorStore = vectorStore; this._textEmbeddingGenerationService = textEmbeddingGenerationService; this._vectorDimensions = vectorDimensions; - this._searchNamespace = searchNamespace; + this._options = options ?? new TextRagStoreOptions(); if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(Guid)) { throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}'"); } + if (typeof(TKey) != typeof(string) && this._options.UseSourceIdAsPrimaryKey is true) + { + throw new NotSupportedException($"The {nameof(TextRagStoreOptions.UseSourceIdAsPrimaryKey)} option can only be used when the key type is 'string'."); + } + VectorStoreRecordDefinition ragDocumentDefinition = new() { Properties = new List() @@ -89,11 +94,16 @@ public TextRagStore( /// A task that completes when the documents have been upserted. public async Task UpsertDocumentsAsync(IEnumerable documents, CancellationToken cancellationToken = default) { - var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); var storageDocumentsTasks = documents.Select(async document => { - var key = GenerateUniqueKey(document.SourceId); + if (string.IsNullOrWhiteSpace(document.Text) && string.IsNullOrWhiteSpace(document.SourceId) && string.IsNullOrWhiteSpace(document.SourceLink)) + { + throw new ArgumentException($"Either the document {nameof(TextRagDocument.Text)}, {nameof(TextRagDocument.SourceId)} or {nameof(TextRagDocument.SourceLink)} properties must be set.", nameof(document)); + } + + var key = GenerateUniqueKey(this._options.UseSourceIdAsPrimaryKey ?? false ? document.SourceId : null); var textEmbedding = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(document.Text).ConfigureAwait(false); return new TextRagStorageDocument @@ -103,7 +113,7 @@ public async Task UpsertDocumentsAsync(IEnumerable documents, C SourceId = document.SourceId, Text = document.Text, SourceName = document.SourceName, - SourceReference = document.SourceReference, + SourceLink = document.SourceLink, TextEmbedding = textEmbedding }; }); @@ -115,6 +125,13 @@ public async Task UpsertDocumentsAsync(IEnumerable documents, C /// public async Task> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { + var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); + var textSearch = new VectorStoreTextSearch>( + vectorStoreRecordCollection, + this._textEmbeddingGenerationService, + r => r is TextRagStorageDocument doc ? doc.Text ?? string.Empty : string.Empty, + r => r is TextRagStorageDocument doc ? new TextSearchResult(doc.Text ?? string.Empty) { Name = doc.SourceName, Link = doc.SourceLink } : new TextSearchResult(string.Empty)); + var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); return new(searchResult.Results.SelectAsync(x => x.Record.Text ?? string.Empty, cancellationToken)); @@ -125,12 +142,12 @@ public async Task> GetTextSearchResultsAsy { var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); - var results = searchResult.Results.SelectAsync(x => new TextSearchResult(x.Record.Text ?? string.Empty) { Name = x.Record.SourceName, Link = x.Record.SourceReference }, cancellationToken); + var results = searchResult.Results.SelectAsync(x => new TextSearchResult(x.Record.Text ?? string.Empty) { Name = x.Record.SourceName, Link = x.Record.SourceLink }, cancellationToken); return new(searchResult.Results.SelectAsync(x => new TextSearchResult(x.Record.Text ?? string.Empty) { Name = x.Record.SourceName, - Link = x.Record.SourceReference + Link = x.Record.SourceLink }, cancellationToken)); } @@ -150,10 +167,10 @@ public async Task> GetSearchResultsAsync(string quer /// The search results. private async Task>> SearchInternalAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { - var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); // Optional filter to limit the search to a specific namespace. - Expression, bool>>? filter = string.IsNullOrWhiteSpace(this._searchNamespace) ? null : x => x.Namespaces.Contains(this._searchNamespace); + Expression, bool>>? filter = string.IsNullOrWhiteSpace(this._options.SearchNamespace) ? null : x => x.Namespaces.Contains(this._options.SearchNamespace); var vector = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(query, cancellationToken: cancellationToken).ConfigureAwait(false); var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( @@ -173,7 +190,7 @@ private async Task>> SearchInte /// /// The to monitor for cancellation requests. The default is . /// The created collection. - private async Task>> EnsureCollectionCreatedAsync(CancellationToken cancellationToken) + private async Task>> EnsureCollectionExistsAsync(CancellationToken cancellationToken) { var vectorStoreRecordCollection = this._vectorStoreRecordCollection.Value; @@ -266,6 +283,11 @@ private sealed class TextRagStorageDocument /// public List Namespaces { get; set; } = []; + /// + /// Gets or sets the content as text. + /// + public string? Text { get; set; } + /// /// Gets or sets an optional source ID for the document. /// @@ -277,11 +299,6 @@ private sealed class TextRagStorageDocument /// public string? SourceId { get; set; } - /// - /// Gets or sets the content as text. - /// - public string? Text { get; set; } - /// /// Gets or sets an optional name for the source document. /// @@ -292,13 +309,13 @@ private sealed class TextRagStorageDocument public string? SourceName { get; set; } /// - /// Gets or sets an optional reference back to the source of the document. + /// Gets or sets an optional link back to the source of the document. /// /// /// This can be used to provide citation links when the document is referenced as /// part of a response to a query. /// - public string? SourceReference { get; set; } + public string? SourceLink { get; set; } /// /// Gets or sets the embedding for the text content. diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs new file mode 100644 index 000000000000..298430ee4054 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Contains options for the . +/// +public class TextRagStoreOptions +{ + /// + /// Gets or sets an optional namespace to pre-filter the possible + /// records with when doing a vector search. + /// + public string? SearchNamespace { get; init; } + + /// + /// Gets or sets a value indicating whether to use the source ID as the primary key for records. + /// + /// + /// + /// Using the source ID as the primary key allows for easy updates from the source for any changed + /// records, since those records can just be upserted again, and will overwrite the previous version + /// of the same record. + /// + /// + /// This setting can only be used when the chosen key type is a string. + /// + /// + /// + /// Defaults to false if not set. + /// + public bool? UseSourceIdAsPrimaryKey { get; init; } +} From 34bbed2e58a49add79c247910e5a39964e2c0b77 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 30 Apr 2025 11:25:44 +0100 Subject: [PATCH 43/46] Simplify context formatting and alow formatting override. --- .../Memory/TextRag/TextRagComponent.cs | 25 ++++-- .../Memory/TextRag/TextRagComponentOptions.cs | 29 ++++++- .../Memory/TextRagComponentTests.cs | 84 +++++++++++++++---- 3 files changed, 116 insertions(+), 22 deletions(-) diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs index 241e13419a42..558cc26cff3c 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs @@ -22,7 +22,7 @@ public class TextRagComponent : ConversationStatePart { private const string DefaultPluginSearchFunctionName = "Search"; private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question."; - private const string DefaultContextPrompt = "Consider the following source information when responding to the user:"; + private const string DefaultContextPrompt = "Consider the following information when responding to the user:"; private const string DefaultIncludeCitationsPrompt = "Include citations to the relevant information where it is referenced in the response."; private readonly ITextSearch _textSearch; @@ -115,17 +115,28 @@ internal async Task SearchAsync(string userQuestion, CancellationToken c /// The formatted results. private string FormatResults(List results) { + if (this.Options.ContextFormatter is not null) + { + return this.Options.ContextFormatter(results); + } + + if (results.Count == 0) + { + return string.Empty; + } + var sb = new StringBuilder(); sb.AppendLine(this.Options.ContextPrompt ?? DefaultContextPrompt); - foreach (var result in results) + for (int i = 0; i < results.Count; i++) { - sb.AppendLine($" Source Document Name: {result.Name}"); - sb.AppendLine($" Source Document Link: {result.Link}"); - sb.AppendLine($" Source Document Contents: {result.Value}"); - sb.AppendLine(" -----------------"); + var result = results[i]; + sb.AppendLine($"Item {i + 1}:"); + sb.AppendLine($"Name: {result.Name}"); + sb.AppendLine($"Link: {result.Link}"); + sb.AppendLine($"Contents: {result.Value}"); } sb.AppendLine(this.Options.IncludeCitationsPrompt ?? DefaultIncludeCitationsPrompt); - sb.AppendLine("-------------------"); + sb.AppendLine(); return sb.ToString(); } } diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs index 639c2651a1ab..c60f8c323a57 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using Microsoft.SemanticKernel.Data; namespace Microsoft.SemanticKernel.Memory; @@ -60,7 +62,7 @@ public int Top /// to those chunks, in order to provide some context to the model. /// /// - /// Defaults to "Consider the following source information when responding to the user:" + /// Defaults to "Consider the following information when responding to the user:" /// public string? ContextPrompt { get; init; } @@ -73,6 +75,24 @@ public int Top /// public string? IncludeCitationsPrompt { get; init; } + /// + /// Optional delegate to override the default context creation implementation. + /// + /// + /// + /// If provided, this delegate will be used to do the following: + /// 1. Create the output context provided by the when invoking the AI model. + /// 2. Create the response text when invoking the component via a plugin. + /// + /// + /// Note that the delegate should include the context prompt and the + /// include citations prompt if they are required in the output. + /// The and settings + /// will not be used if providing this delegate. + /// + /// + public ContextFormatterType? ContextFormatter { get; init; } + /// /// Choices for controlling the behavior of the . /// @@ -89,4 +109,11 @@ public enum RagBehavior /// ViaPlugin } + + /// + /// Delegate type for formatting the output context for the component. + /// + /// The results returned by the text search. + /// The formatted context. + public delegate string ContextFormatterType(List results); } diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs index 33fb49683aff..7f7e56f73369 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs @@ -18,7 +18,7 @@ namespace SemanticKernel.UnitTests.Memory; public class TextRagComponentTests { [Theory] - [InlineData(null, null, "Consider the following source information when responding to the user:", "Include citations to the relevant information where it is referenced in the response.")] + [InlineData(null, null, "Consider the following information when responding to the user:", "Include citations to the relevant information where it is referenced in the response.")] [InlineData("Custom context prompt", "Custom citations prompt", "Custom context prompt", "Custom citations prompt")] public async Task OnModelInvokeShouldIncludeSearchResultsInOutputAsync( string? overrideContextPrompt, @@ -71,12 +71,14 @@ public async Task OnModelInvokeShouldIncludeSearchResultsInOutputAsync( // Assert Assert.Contains(expectedContextPrompt, result); - Assert.Contains("Source Document Name: Doc1", result); - Assert.Contains("Source Document Link: http://example.com/doc1", result); - Assert.Contains("Source Document Contents: Content of Doc1", result); - Assert.Contains("Source Document Name: Doc2", result); - Assert.Contains("Source Document Link: http://example.com/doc2", result); - Assert.Contains("Source Document Contents: Content of Doc2", result); + Assert.Contains("Item 1:", result); + Assert.Contains("Name: Doc1", result); + Assert.Contains("Link: http://example.com/doc1", result); + Assert.Contains("Contents: Content of Doc1", result); + Assert.Contains("Item 2:", result); + Assert.Contains("Name: Doc2", result); + Assert.Contains("Link: http://example.com/doc2", result); + Assert.Contains("Contents: Content of Doc2", result); Assert.Contains(expectedCitationsPrompt, result); } @@ -112,7 +114,7 @@ public void AIFunctionsShouldBeRegisteredCorrectly( } [Theory] - [InlineData(null, null, "Consider the following source information when responding to the user:", "Include citations to the relevant information where it is referenced in the response.")] + [InlineData(null, null, "Consider the following information when responding to the user:", "Include citations to the relevant information where it is referenced in the response.")] [InlineData("Custom context prompt", "Custom citations prompt", "Custom context prompt", "Custom citations prompt")] public async Task SearchAsyncShouldIncludeSearchResultsInOutputAsync( string? overrideContextPrompt, @@ -164,12 +166,66 @@ public async Task SearchAsyncShouldIncludeSearchResultsInOutputAsync( // Assert Assert.Contains(expectedContextPrompt, result); - Assert.Contains("Source Document Name: Doc1", result); - Assert.Contains("Source Document Link: http://example.com/doc1", result); - Assert.Contains("Source Document Contents: Content of Doc1", result); - Assert.Contains("Source Document Name: Doc2", result); - Assert.Contains("Source Document Link: http://example.com/doc2", result); - Assert.Contains("Source Document Contents: Content of Doc2", result); + Assert.Contains("Item 1:", result); + Assert.Contains("Name: Doc1", result); + Assert.Contains("Link: http://example.com/doc1", result); + Assert.Contains("Contents: Content of Doc1", result); + Assert.Contains("Item 2:", result); + Assert.Contains("Name: Doc2", result); + Assert.Contains("Link: http://example.com/doc2", result); + Assert.Contains("Contents: Content of Doc2", result); Assert.Contains(expectedCitationsPrompt, result); } + + [Fact] + public async Task OnModelInvokeShouldUseOverrideContextFormatterIfProvidedAsync() + { + // Arrange + var mockTextSearch = new Mock(); + var searchResults = new Mock>(); + var mockEnumerator = new Mock>(); + + // Mock search results + var results = new List + { + new("Content of Doc1") { Name = "Doc1", Link = "http://example.com/doc1" }, + new("Content of Doc2") { Name = "Doc2", Link = "http://example.com/doc2" } + }; + + mockEnumerator.SetupSequence(e => e.MoveNextAsync()) + .ReturnsAsync(true) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockEnumerator.SetupSequence(e => e.Current) + .Returns(results[0]) + .Returns(results[1]); + + searchResults.Setup(r => r.GetAsyncEnumerator(It.IsAny())) + .Returns(mockEnumerator.Object); + + mockTextSearch.Setup(ts => ts.GetTextSearchResultsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new KernelSearchResults(searchResults.Object)); + + var customFormatter = new TextRagComponentOptions.ContextFormatterType(results => + $"Custom formatted context with {results.Count} results."); + + var options = new TextRagComponentOptions + { + SearchTime = TextRagComponentOptions.RagBehavior.BeforeAIInvoke, + Top = 2, + ContextFormatter = customFormatter + }; + + var component = new TextRagComponent(mockTextSearch.Object, options); + + // Act + var result = await component.OnModelInvokeAsync([new ChatMessage(ChatRole.User, "Sample user question?")], CancellationToken.None); + + // Assert + Assert.Equal("Custom formatted context with 2 results.", result); + } } From e89ee4fde13e09acc8dc4862b7f4eb27e76c0c1f Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 30 Apr 2025 14:14:05 +0100 Subject: [PATCH 44/46] Seal TextRagComponent and TextRagStore --- .../SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs | 2 +- .../Memory/TextRag/TextRagComponentOptions.cs | 2 +- .../src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs | 2 +- dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs | 4 ++-- .../SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs index 558cc26cff3c..6564cca67df0 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs @@ -18,7 +18,7 @@ namespace Microsoft.SemanticKernel.Memory; /// A component that does a search based on any messages that the AI model is invoked with and injects the results into the AI model invocation context. /// [Experimental("SKEXP0130")] -public class TextRagComponent : ConversationStatePart +public sealed class TextRagComponent : ConversationStatePart { private const string DefaultPluginSearchFunctionName = "Search"; private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question."; diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs index c60f8c323a57..1bb97e4ac729 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Memory; /// Contains options for the . /// [Experimental("SKEXP0130")] -public class TextRagComponentOptions +public sealed class TextRagComponentOptions { private int _top = 3; diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs index 144bf4bc5fc3..e3b221fc74d5 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Memory.TextRag; /// Represents a document that can be used for Retrieval Augmented Generation (RAG). /// [Experimental("SKEXP0130")] -public class TextRagDocument +public sealed class TextRagDocument { /// /// Gets or sets an optional list of namespaces that the document should belong to. diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs index 86dad15d5ebc..2c3efc32fab8 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs @@ -19,7 +19,7 @@ namespace Microsoft.SemanticKernel.Memory; /// /// The key type to use with the vector store. [Experimental("SKEXP0130")] -public class TextRagStore : ITextSearch, IDisposable +public sealed class TextRagStore : ITextSearch, IDisposable where TKey : notnull { private readonly IVectorStore _vectorStore; @@ -243,7 +243,7 @@ _ when typeof(TDocumentKey) == typeof(Guid) => (TDocumentKey)(object)Guid.NewGui }; /// - protected virtual void Dispose(bool disposing) + private void Dispose(bool disposing) { if (!this._disposedValue) { diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs index 298430ee4054..d47e7d90c15d 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs @@ -5,7 +5,7 @@ namespace Microsoft.SemanticKernel.Memory; /// /// Contains options for the . /// -public class TextRagStoreOptions +public sealed class TextRagStoreOptions { /// /// Gets or sets an optional namespace to pre-filter the possible From 64618aa2359263cc298af2d4b4c8fbe4a2dca7fe Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 30 Apr 2025 19:29:14 +0100 Subject: [PATCH 45/46] Add more tests and options for TextRagStore --- .../Memory/TextRag/TextRagStore.cs | 29 ++- .../TextRag/TextRagStoreUpsertOptions.cs | 17 ++ .../SemanticKernel.Core.csproj | 1 + .../Memory/TextRagStoreTests.cs | 187 ++++++++++++++++++ 4 files changed, 226 insertions(+), 8 deletions(-) create mode 100644 dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreUpsertOptions.cs create mode 100644 dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs index 2c3efc32fab8..4a308e83b9d5 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs @@ -46,7 +46,7 @@ public TextRagStore( ITextEmbeddingGenerationService textEmbeddingGenerationService, string collectionName, int vectorDimensions, - TextRagStoreOptions? options) + TextRagStoreOptions? options = default) { Verify.NotNull(vectorStore); Verify.NotNull(textEmbeddingGenerationService); @@ -90,31 +90,44 @@ public TextRagStore( /// Upserts a batch of documents into the vector store. /// /// The documents to upload. + /// Optional options to control the upsert behavior. /// The to monitor for cancellation requests. The default is . /// A task that completes when the documents have been upserted. - public async Task UpsertDocumentsAsync(IEnumerable documents, CancellationToken cancellationToken = default) + public async Task UpsertDocumentsAsync(IEnumerable documents, TextRagStoreUpsertOptions? options = null, CancellationToken cancellationToken = default) { + Verify.NotNull(documents); + var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); var storageDocumentsTasks = documents.Select(async document => { - if (string.IsNullOrWhiteSpace(document.Text) && string.IsNullOrWhiteSpace(document.SourceId) && string.IsNullOrWhiteSpace(document.SourceLink)) + if (document is null) + { + throw new ArgumentNullException(nameof(documents), "One of the provided documents is null."); + } + + if (string.IsNullOrWhiteSpace(document.Text)) + { + throw new ArgumentException($"The {nameof(TextRagDocument.Text)} property must be set.", nameof(document)); + } + + if (options?.PersistSourceText is false && string.IsNullOrWhiteSpace(document.SourceId) && string.IsNullOrWhiteSpace(document.SourceLink)) { - throw new ArgumentException($"Either the document {nameof(TextRagDocument.Text)}, {nameof(TextRagDocument.SourceId)} or {nameof(TextRagDocument.SourceLink)} properties must be set.", nameof(document)); + throw new ArgumentException($"Either the {nameof(TextRagDocument.SourceId)} or {nameof(TextRagDocument.SourceLink)} properties must be set when the {nameof(TextRagStoreUpsertOptions.PersistSourceText)} setting is false.", nameof(document)); } var key = GenerateUniqueKey(this._options.UseSourceIdAsPrimaryKey ?? false ? document.SourceId : null); - var textEmbedding = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(document.Text).ConfigureAwait(false); + var textEmbeddings = await this._textEmbeddingGenerationService.GenerateEmbeddingsAsync([document.Text!]).ConfigureAwait(false); return new TextRagStorageDocument { Key = key, Namespaces = document.Namespaces, SourceId = document.SourceId, - Text = document.Text, + Text = options?.PersistSourceText is false ? null : document.Text, SourceName = document.SourceName, SourceLink = document.SourceLink, - TextEmbedding = textEmbedding + TextEmbedding = textEmbeddings.Single() }; }); @@ -268,7 +281,7 @@ public void Dispose() /// The data model to use for storing RAG documents in the vector store. /// /// The type of the key to use, since different databases require/support different keys. - private sealed class TextRagStorageDocument + internal sealed class TextRagStorageDocument { /// /// Gets or sets a unique identifier for the memory document. diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreUpsertOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreUpsertOptions.cs new file mode 100644 index 000000000000..83b20179c064 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreUpsertOptions.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Memory.TextRag; + +/// +/// Contains options for . +/// +public sealed class TextRagStoreUpsertOptions +{ + /// + /// Gets or sets a value indicating whether the source text should be persisted in the database. + /// + /// + /// Defaults to if not set. + /// + public bool? PersistSourceText { get; init; } +} diff --git a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj index 6539a03d547a..c4aa91cc2359 100644 --- a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj +++ b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj @@ -46,6 +46,7 @@ + diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs new file mode 100644 index 000000000000..2ba1f9dd9465 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs @@ -0,0 +1,187 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticKernel.Memory.TextRag; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +public class TextRagStoreTests +{ + private readonly Mock _vectorStoreMock; + private readonly Mock _embeddingServiceMock; + private readonly Mock.TextRagStorageDocument>> _recordCollectionMock; + + public TextRagStoreTests() + { + this._vectorStoreMock = new Mock(); + this._recordCollectionMock = new Mock.TextRagStorageDocument>>(); + this._embeddingServiceMock = new Mock(); + + this._vectorStoreMock + .Setup(v => v.GetCollection.TextRagStorageDocument>("testCollection", It.IsAny())) + .Returns(this._recordCollectionMock.Object); + + this._embeddingServiceMock + .Setup(e => e.GenerateEmbeddingsAsync(It.IsAny>(), null, It.IsAny())) + .ReturnsAsync(new[] { new ReadOnlyMemory(new float[128]) }); + } + + [Fact] + public async Task UpsertDocumentsAsyncThrowsWhenDocumentsAreNull() + { + // Arrange + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + // Act & Assert + await Assert.ThrowsAsync(() => store.UpsertDocumentsAsync(null!)); + } + + [Theory] + [InlineData(null)] + [InlineData(" ")] + public async Task UpsertDocumentsAsyncThrowsDocumentTextIsNullOrWhiteSpace(string? text) + { + // Arrange + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + var documents = new List + { + new() { Text = text } + }; + + // Act & Assert + await Assert.ThrowsAsync(() => store.UpsertDocumentsAsync(documents)); + } + + [Fact] + public async Task UpsertDocumentsAsyncCreatesCollectionGeneratesVectorAndUpsertsDocument() + { + // Arrange + this._recordCollectionMock + .Setup(r => r.UpsertBatchAsync(It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .Returns(AsyncEnumerable.Empty()); + + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + var documents = new List + { + new() { Text = "Sample text", Namespaces = ["ns1"], SourceId = "sid", SourceLink = "sl", SourceName = "sn" } + }; + + // Act + await store.UpsertDocumentsAsync(documents); + + // Assert + this._recordCollectionMock.Verify(r => r.CreateCollectionIfNotExistsAsync(It.IsAny()), Times.Once); + this._embeddingServiceMock.Verify(e => e.GenerateEmbeddingsAsync(It.Is>(texts => texts.Count == 1 && texts[0] == "Sample text"), null, It.IsAny()), Times.Once); + this._recordCollectionMock.Verify(r => r.UpsertBatchAsync( + It.Is.TextRagStorageDocument>>(doc => + doc.Count() == 1 && + doc.First().Text == "Sample text" && + doc.First().Namespaces.Count == 1 && + doc.First().Namespaces[0] == "ns1" && + doc.First().SourceId == "sid" && + doc.First().SourceLink == "sl" && + doc.First().SourceName == "sn" && + doc.First().TextEmbedding.Length == 128), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task UpsertDocumentsAsyncCreatesCollectionGeneratesVectorAndUpsertsDocumentWithSourceIdAsId() + { + // Arrange + this._recordCollectionMock + .Setup(r => r.UpsertBatchAsync(It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .Returns(AsyncEnumerable.Empty()); + + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128, new() { UseSourceIdAsPrimaryKey = true }); + + var documents = new List + { + new() { Text = "Sample text", Namespaces = ["ns1"], SourceId = "sid", SourceLink = "sl", SourceName = "sn" } + }; + + // Act + await store.UpsertDocumentsAsync(documents); + + // Assert + this._recordCollectionMock.Verify(r => r.UpsertBatchAsync( + It.Is.TextRagStorageDocument>>(doc => + doc.Count() == 1 && + doc.First().Key == "sid" && + doc.First().Text == "Sample text" && + doc.First().Namespaces.Count == 1 && + doc.First().Namespaces[0] == "ns1" && + doc.First().SourceId == "sid" && + doc.First().SourceLink == "sl" && + doc.First().SourceName == "sn" && + doc.First().TextEmbedding.Length == 128), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task UpsertDocumentsAsyncCreatesCollectionGeneratesVectorAndUpsertsDocumentWithoutSourceText() + { + // Arrange + this._recordCollectionMock + .Setup(r => r.UpsertBatchAsync(It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .Returns(AsyncEnumerable.Empty()); + + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + var documents = new List + { + new() { Text = "Sample text", Namespaces = ["ns1"], SourceId = "sid", SourceLink = "sl", SourceName = "sn" } + }; + + // Act + await store.UpsertDocumentsAsync(documents, new() { PersistSourceText = false }); + + // Assert + this._recordCollectionMock.Verify(r => r.UpsertBatchAsync( + It.Is.TextRagStorageDocument>>(doc => + doc.Count() == 1 && + doc.First().Text == null && + doc.First().Namespaces.Count == 1 && + doc.First().Namespaces[0] == "ns1" && + doc.First().SourceId == "sid" && + doc.First().SourceLink == "sl" && + doc.First().SourceName == "sn" && + doc.First().TextEmbedding.Length == 128), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task SearchAsyncReturnsSearchResults() + { + // Arrange + var mockResults = new List.TextRagStorageDocument>> + { + new(new TextRagStore.TextRagStorageDocument { Text = "Sample text" }, 0.9f) + }; + + this._recordCollectionMock + .Setup(r => r.VectorizedSearchAsync(It.IsAny>(), It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .ReturnsAsync(new VectorSearchResults.TextRagStorageDocument>(mockResults.ToAsyncEnumerable())); + + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128, null); + + // Act + var actualResults = await store.SearchAsync("query"); + + // Assert + var actualResultsList = await actualResults.Results.ToListAsync(); + Assert.Single(actualResultsList); + Assert.Equal("Sample text", actualResultsList[0]); + } +} From ad3100b1395ce7b431d63c9dd72c4e30c674838a Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 1 May 2025 21:32:41 +0100 Subject: [PATCH 46/46] Add support for text hydration --- .../Memory/TextRag/TextRagStore.cs | 92 ++++++++++++++----- .../Memory/TextRag/TextRagStoreOptions.cs | 17 ++++ .../Memory/TextRagStoreTests.cs | 45 ++++++++- 3 files changed, 128 insertions(+), 26 deletions(-) diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs index 4a308e83b9d5..143af21f34a5 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs @@ -33,7 +33,7 @@ public sealed class TextRagStore : ITextSearch, IDisposable private bool _disposedValue; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The vector store to store and read the memories from. /// The service to use for generating embeddings for the memories. @@ -48,26 +48,29 @@ public TextRagStore( int vectorDimensions, TextRagStoreOptions? options = default) { + // Verify Verify.NotNull(vectorStore); Verify.NotNull(textEmbeddingGenerationService); Verify.NotNullOrWhiteSpace(collectionName); Verify.True(vectorDimensions > 0, "Vector dimensions must be greater than 0"); - this._vectorStore = vectorStore; - this._textEmbeddingGenerationService = textEmbeddingGenerationService; - this._vectorDimensions = vectorDimensions; - this._options = options ?? new TextRagStoreOptions(); - if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(Guid)) { throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}'"); } - if (typeof(TKey) != typeof(string) && this._options.UseSourceIdAsPrimaryKey is true) + if (typeof(TKey) != typeof(string) && options?.UseSourceIdAsPrimaryKey is true) { throw new NotSupportedException($"The {nameof(TextRagStoreOptions.UseSourceIdAsPrimaryKey)} option can only be used when the key type is 'string'."); } + // Assign + this._vectorStore = vectorStore; + this._textEmbeddingGenerationService = textEmbeddingGenerationService; + this._vectorDimensions = vectorDimensions; + this._options = options ?? new TextRagStoreOptions(); + + // Create a definition so that we can use the dimensions provided at runtime. VectorStoreRecordDefinition ragDocumentDefinition = new() { Properties = new List() @@ -106,11 +109,13 @@ public async Task UpsertDocumentsAsync(IEnumerable documents, T throw new ArgumentNullException(nameof(documents), "One of the provided documents is null."); } + // Without text we cannot generate a vector. if (string.IsNullOrWhiteSpace(document.Text)) { throw new ArgumentException($"The {nameof(TextRagDocument.Text)} property must be set.", nameof(document)); } + // If we aren't persisting the text, we need a source id or link to refer back to the original document. if (options?.PersistSourceText is false && string.IsNullOrWhiteSpace(document.SourceId) && string.IsNullOrWhiteSpace(document.SourceLink)) { throw new ArgumentException($"Either the {nameof(TextRagDocument.SourceId)} or {nameof(TextRagDocument.SourceLink)} properties must be set when the {nameof(TextRagStoreUpsertOptions.PersistSourceText)} setting is false.", nameof(document)); @@ -138,16 +143,9 @@ public async Task UpsertDocumentsAsync(IEnumerable documents, T /// public async Task> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { - var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); - var textSearch = new VectorStoreTextSearch>( - vectorStoreRecordCollection, - this._textEmbeddingGenerationService, - r => r is TextRagStorageDocument doc ? doc.Text ?? string.Empty : string.Empty, - r => r is TextRagStorageDocument doc ? new TextSearchResult(doc.Text ?? string.Empty) { Name = doc.SourceName, Link = doc.SourceLink } : new TextSearchResult(string.Empty)); - var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); - return new(searchResult.Results.SelectAsync(x => x.Record.Text ?? string.Empty, cancellationToken)); + return new(searchResult.Select(x => x.Text ?? string.Empty).ToAsyncEnumerable()); } /// @@ -155,36 +153,37 @@ public async Task> GetTextSearchResultsAsy { var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); - var results = searchResult.Results.SelectAsync(x => new TextSearchResult(x.Record.Text ?? string.Empty) { Name = x.Record.SourceName, Link = x.Record.SourceLink }, cancellationToken); - return new(searchResult.Results.SelectAsync(x => - new TextSearchResult(x.Record.Text ?? string.Empty) + var results = searchResult.Select(x => new TextSearchResult(x.Text ?? string.Empty) { Name = x.SourceName, Link = x.SourceLink }); + return new(searchResult.Select(x => + new TextSearchResult(x.Text ?? string.Empty) { - Name = x.Record.SourceName, - Link = x.Record.SourceLink - }, cancellationToken)); + Name = x.SourceName, + Link = x.SourceLink + }).ToAsyncEnumerable()); } /// public async Task> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); - return new(searchResult.Results.SelectAsync(x => (object)x, cancellationToken)); + return new(searchResult.Select(x => (object)x).ToAsyncEnumerable()); } /// - /// Internal search implementation. + /// Internal search implementation with hydration of id / link only storage. /// /// The text query to find similar documents to. /// Search options. /// The to monitor for cancellation requests. The default is . /// The search results. - private async Task>> SearchInternalAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + private async Task>> SearchInternalAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) { var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); // Optional filter to limit the search to a specific namespace. Expression, bool>>? filter = string.IsNullOrWhiteSpace(this._options.SearchNamespace) ? null : x => x.Namespaces.Contains(this._options.SearchNamespace); + // Generate the vector for the query and search. var vector = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(query, cancellationToken: cancellationToken).ConfigureAwait(false); var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( vector, @@ -195,7 +194,50 @@ private async Task>> SearchInte }, cancellationToken: cancellationToken).ConfigureAwait(false); - return searchResult; + // Retrieve the documents from the search results. + var retrievedDocs = await searchResult + .Results + .SelectAsync(x => x.Record, cancellationToken) + .ToListAsync(cancellationToken) + .ConfigureAwait(false); + + // Find any source ids and links for which the text needs to be retrieved. + var sourceIdsToRetrieve = retrievedDocs + .Where(x => string.IsNullOrWhiteSpace(x.Text)) + .Select(x => (x.SourceId, x.SourceLink)) + .ToList(); + + if (sourceIdsToRetrieve.Count > 0) + { + if (this._options.SourceRetrievalCallback is null) + { + throw new InvalidOperationException($"The {nameof(TextRagStoreOptions.SourceRetrievalCallback)} option must be set if retrieving documents without stored text."); + } + + var retrievedText = await this._options.SourceRetrievalCallback(sourceIdsToRetrieve).ConfigureAwait(false); + + if (retrievedText is null) + { + throw new InvalidOperationException($"The {nameof(TextRagStoreOptions.SourceRetrievalCallback)} must return a non-null value."); + } + + // Update the retrieved documents with the retrieved text. + retrievedDocs = retrievedDocs.GroupJoin( + retrievedText, + retrievedDocs => (retrievedDocs.SourceId, retrievedDocs.SourceLink), + retrievedText => (retrievedText.sourceId, retrievedText.sourceLink), + (retrievedDoc, retrievedText) => (retrievedDoc, retrievedText)) + .SelectMany( + joinedSet => joinedSet.retrievedText.DefaultIfEmpty(), + (combined, retrievedText) => + { + combined.retrievedDoc.Text = retrievedText.text ?? combined.retrievedDoc.Text; + return combined.retrievedDoc; + }) + .ToList(); + } + + return retrievedDocs; } /// diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs index d47e7d90c15d..6c2ecf3254a6 100644 --- a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs @@ -1,5 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; +using System.Threading.Tasks; + namespace Microsoft.SemanticKernel.Memory; /// @@ -30,4 +33,18 @@ public sealed class TextRagStoreOptions /// Defaults to false if not set. /// public bool? UseSourceIdAsPrimaryKey { get; init; } + + /// + /// Gets or sets an optional callback to load the source text from the source id or source link + /// if the source text is not persisted in the database. + /// + public SourceRetriever? SourceRetrievalCallback { get; init; } + + /// + /// Delegate type for loading the source text from the source id or source link + /// if the source text is not persisted in the database. + /// + /// The ids and links of the text to load. + /// The source text with the source id or source link. + public delegate Task> SourceRetriever(List<(string? sourceId, string? sourceLink)> sourceIds); } diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs index 2ba1f9dd9465..8d91b09b3deb 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs @@ -174,7 +174,7 @@ public async Task SearchAsyncReturnsSearchResults() .Setup(r => r.VectorizedSearchAsync(It.IsAny>(), It.IsAny.TextRagStorageDocument>>(), It.IsAny())) .ReturnsAsync(new VectorSearchResults.TextRagStorageDocument>(mockResults.ToAsyncEnumerable())); - using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128, null); + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); // Act var actualResults = await store.SearchAsync("query"); @@ -184,4 +184,47 @@ public async Task SearchAsyncReturnsSearchResults() Assert.Single(actualResultsList); Assert.Equal("Sample text", actualResultsList[0]); } + + [Fact] + public async Task SearchAsyncWithHydrationCallsCallbackAndReturnsSearchResults() + { + // Arrange + var mockResults = new List.TextRagStorageDocument>> + { + new(new TextRagStore.TextRagStorageDocument { SourceId = "sid1", SourceLink = "sl1", Text = "Sample text 1" }, 0.9f), + new(new TextRagStore.TextRagStorageDocument { SourceId = "sid2", SourceLink = "sl2" }, 0.9f), + new(new TextRagStore.TextRagStorageDocument { SourceId = "sid3", SourceLink = "sl3", Text = "Sample text 3" }, 0.9f), + }; + + this._recordCollectionMock + .Setup(r => r.VectorizedSearchAsync(It.IsAny>(), It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .ReturnsAsync(new VectorSearchResults.TextRagStorageDocument>(mockResults.ToAsyncEnumerable())); + + using var store = new TextRagStore( + this._vectorStoreMock.Object, + this._embeddingServiceMock.Object, + "testCollection", + 128, + new() + { + SourceRetrievalCallback = sourceIds => + { + Assert.Single(sourceIds); + Assert.Equal("sid2", sourceIds[0].sourceId); + Assert.Equal("sl2", sourceIds[0].sourceLink); + + return Task.FromResult>([("sid2", "sl2", "Sample text 2")]); + } + }); + + // Act + var actualResults = await store.SearchAsync("query"); + + // Assert + var actualResultsList = await actualResults.Results.ToListAsync(); + Assert.Equal(3, actualResultsList.Count); + Assert.Equal("Sample text 1", actualResultsList[0]); + Assert.Equal("Sample text 2", actualResultsList[1]); + Assert.Equal("Sample text 3", actualResultsList[2]); + } }