diff --git a/docs/decisions/00NN-agents-with-memory.md b/docs/decisions/00NN-agents-with-memory.md new file mode 100644 index 000000000000..25adcc43fcd1 --- /dev/null +++ b/docs/decisions/00NN-agents-with-memory.md @@ -0,0 +1,206 @@ +--- +# These are optional elements. Feel free to remove any of them. +status: accepted +contact: westey-m +date: 2025-04-17 +deciders: westey-m, markwallace-microsoft, alliscode, TaoChenOSU, moonbox3, crickman +consulted: westey-m, markwallace-microsoft, alliscode, TaoChenOSU, moonbox3, crickman +informed: westey-m, markwallace-microsoft, alliscode, TaoChenOSU, moonbox3, crickman +--- + +# Agents with Memory + +## What do we mean by Memory? + +By memory we mean the capability to remember information and skills that are learned during +a conversation and re-use those later in the same conversation or later in a subsequent conversation. + +## Context and Problem Statement + +Today we support multiple agent types with different characteristics: + +1. In process vs remote. +1. Remote agents that store and maintain conversation state in the service vs those that require the caller to provide conversation state on each invocation. + +We need to support advanced memory capabilities across this range of agent types. + +### Memory Scope + +Another aspect of memory that is important to consider is the scope of different memory types. +Most agent implementations have instructions and skills but the agent is not tied to a single conversation. +On each invocation of the agent, the agent is told which conversation to participate in, during that invocation. + +Memories about a user or about a conversation with a user is therefore extracted from one of these conversation and recalled +during the same or another conversation with the same user. +These memories will typically contain information that the user would not like to share with other users of the system. + +Other types of memories also exist which are not tied to a specific user or conversation. +E.g. an Agent may learn how to do something and be able to do that in many conversations with different users. +With these type of memories there is of cousrse risk in leaking personal information between different users which is important to guard against. + +### Packaging memory capabilities + +All of the above memory types can be supported for any agent by attaching software components to conversation threads. +This is achieved via a simple mechanism of: + +1. Inspecting and using messages as they are passed to and from the agent. +1. Passing additional context to the agent per invocation. + +With our current `AgentThread` implementation, when an agent is invoked, all input and output messages are already passed to the `AgentThread` +and can be made available to any components attached to the `AgentThread`. +Where agents are remote/external and manage conversation state in the service, passing the messages to the `AgentThread` may not have any +affect on the thread in the service. This is OK, since the service will have already updated the thread during the remote invocation. +It does however, still allow us to subscribe to messages in any attached components. + +For the second requirement of getting additional context per invocation, the agent may ask the thread passed to it, to in turn ask +each of the components attached to it, to provide context to pass to the Agent. +This enables the component to provide memories that it contains to the Agent as needed. + +Different memory capabilities can be built using separate components. Each component would have the following characteristics: + +1. May store some context that can be provided to the agent per invocation. +1. May inspect messages from the conversation to learn from the conversation and build its context. +1. May register plugins to allow the agent to directly store, retrieve, update or clear memories. + +### Suspend / Resume + +Building a service to host an agent comes with challenges. +It's hard to build a stateful service, but service consumers expect an experience that looks stateful from the outside. +E.g. on each invocation, the user expects that the service can continue a conversation they are having. + +This means that where the the service is exposing a local agent with local conversation state management (e.g. via `ChatHistory`) +that conversation state needs to be loaded and persisted for each invocation of the service. + +It also means that any memory components that may have some in-memory state will need to be loaded and persisted too. + +For cases like this, the `OnSuspend` and `OnResume` methods allow notification of the components that they need to save or reload their state. +It is up to each of these components to decide how and where to save state to or load state from. + +## Proposed interface for Memory Components + +The types of events that Memory Components require are not unique to memory, and can be used to package up other capabilities too. +The suggestion is therefore to create a more generally named type that can be used for other scenarios as well and can even +be used for non-agent scenarios too. + +This type should live in the `Microsoft.SemanticKernel.Abstractions` nuget, since these components can be used by systems other than just agents. + +```csharp +namespace Microsoft.SemanticKernel; + +public abstract class ConversationStateExtension +{ + public virtual IReadOnlyCollection AIFunctions => Array.Empty(); + + public virtual Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default); + public virtual Task OnThreadDeleteAsync(string? threadId, CancellationToken cancellationToken = default); + + // OnThreadCheckpointAsync not included in initial release, maybe in future. + public virtual Task OnThreadCheckpointAsync(string? threadId, CancellationToken cancellationToken = default); + + public virtual Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default); + public abstract Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default); + + public virtual Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default); + public virtual Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default); +} +``` + +> TODO: Decide about the correct namespace for `ConversationStateExtension` + +## Managing multiple components + +To manage multiple components I propose that we have a `ConversationStateExtensionsManager`. +This class allows registering components and delegating new message notifications, ai invocation calls, etc. to the contained components. + +## Integrating with agents + +I propose to add a `ConversationStateExtensionsManager` to the `AgentThread` class, allowing us to attach components to any `AgentThread`. + +When an `Agent` is invoked, we will call `OnModelInvokeAsync` on each component via the `ConversationStateExtensionsManager` to get +a combined set of context to pass to the agent for this invocation. This will be internal to the `Agent` class and transparent to the user. + +```csharp +var additionalInstructions = await currentAgentThread.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); +``` + +## Usage examples + +### Multiple threads using the same memory component + +```csharp +// Create a vector store for storing memories. +var vectorStore = new InMemoryVectorStore(); +// Create a memory store that is tired to a "Memories" collection in the vector store and stores memories under the "user/12345" namespace. +using var textMemoryStore = new VectorDataTextMemoryStore(vectorStore, textEmbeddingService, "Memories", "user/12345", 1536); + +// Create a memory component to will pull user facts from the conversation, store them in the vector store +// and pass them to the agent as additional instructions. +var userFacts = new UserFactsMemoryComponent(this.Fixture.Agent.Kernel, textMemoryStore); + +// Create a thread and attach a Memory Component. +var agentThread1 = new ChatHistoryAgentThread(); +agentThread1.ThreadExtensionsManager.Add(userFacts); +var asyncResults1 = agent.InvokeAsync("Hello, my name is Caoimhe.", agentThread1); + +// Create a second thread and attach a Memory Component. +var agentThread2 = new ChatHistoryAgentThread(); +agentThread2.ThreadExtensionsManager.Add(userFacts); +var asyncResults2 = agent.InvokeAsync("What is my name?.", agentThread2); +// Expected response contains Caoimhe. +``` + +### Using a RAG component + +```csharp +// Create Vector Store and Rag Store/Component +var vectorStore = new InMemoryVectorStore(); +using var ragStore = new TextRagStore(vectorStore, textEmbeddingService, "Memories", 1536, "group/g2"); +var ragComponent = new TextRagComponent(ragStore, new TextRagComponentOptions()); + +// Upsert docs into vector store. +await ragStore.UpsertDocumentsAsync( +[ + new TextRagDocument("The financial results of Contoso Corp for 2023 is as follows:\nIncome EUR 174 000 000\nExpenses EUR 152 000 000") + { + SourceName = "Contoso 2023 Financial Report", + SourceReference = "https://www.consoso.com/reports/2023.pdf", + Namespaces = ["group/g2"] + } +]); + +// Create a new agent thread and register the Rag component +var agentThread = new ChatHistoryAgentThread(); +agentThread.ThreadExtensionsManager.RegisterThreadExtension(ragComponent); + +// Inovke the agent. +var asyncResults1 = agent.InvokeAsync("What was the income of Contoso for 2023", agentThread); +// Expected response contains the 174M income from the document. +``` + +## Decisions to make + +### Extension base class name + +1. ConversationStateExtension + 1. Long +1. MemoryComponent + 1. Too specific + +Chose ConversationStateExtension. + +### Location for abstractions + +1. Microsoft.SemanticKernel. +1. Microsoft.SemanticKernel.Memory. +1. Microsoft.SemanticKernel.Memory. (in separate nuget) + +Chose Microsoft.SemanticKernel.. + +### Location for memory components + +1. A nuget for each component +1. Microsoft.SemanticKernel.Core nuget +1. Microsoft.SemanticKernel.Memory nuget +1. Microsoft.SemanticKernel.ConversationStateExtensions nuget + +Chose Microsoft.SemanticKernel.Core nuget \ No newline at end of file diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index acf61a114486..a2e1b9925d06 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -524,6 +524,11 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MCPClient", "samples\Demos\ {12C7E0C7-A7DF-3BC3-0D4B-1A706BCE6981} = {12C7E0C7-A7DF-3BC3-0D4B-1A706BCE6981} EndProjectSection EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "memory", "memory", "{4B850B93-46D6-4F25-9DB1-90D1E6E4AB70}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Memory.Abstractions", "src\Memory\Memory.Abstractions\Memory.Abstractions.csproj", "{F5124057-1DA1-4799-9357-D9A635047678}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Memory", "src\Memory\Memory\Memory.csproj", "{A0538079-AB6F-4C7D-9138-A15258583F80}" Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ProcessWithCloudEvents", "ProcessWithCloudEvents", "{7C092DD9-9985-4D18-A817-15317D984149}" ProjectSection(SolutionItems) = preProject samples\Demos\ProcessWithCloudEvents\README.md = samples\Demos\ProcessWithCloudEvents\README.md @@ -1460,6 +1465,18 @@ Global {B06770D5-2F3E-4271-9F6B-3AA9E716176F}.Publish|Any CPU.Build.0 = Release|Any CPU {B06770D5-2F3E-4271-9F6B-3AA9E716176F}.Release|Any CPU.ActiveCfg = Release|Any CPU {B06770D5-2F3E-4271-9F6B-3AA9E716176F}.Release|Any CPU.Build.0 = Release|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Publish|Any CPU.Build.0 = Release|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F5124057-1DA1-4799-9357-D9A635047678}.Release|Any CPU.Build.0 = Release|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Publish|Any CPU.Build.0 = Release|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A0538079-AB6F-4C7D-9138-A15258583F80}.Release|Any CPU.Build.0 = Release|Any CPU {31F6608A-FD36-F529-A5FC-C954A0B5E29E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {31F6608A-FD36-F529-A5FC-C954A0B5E29E}.Debug|Any CPU.Build.0 = Debug|Any CPU {31F6608A-FD36-F529-A5FC-C954A0B5E29E}.Publish|Any CPU.ActiveCfg = Release|Any CPU @@ -1703,6 +1720,9 @@ Global {879545ED-D429-49B1-96F1-2EC55FFED31D} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {12C7E0C7-A7DF-3BC3-0D4B-1A706BCE6981} = {879545ED-D429-49B1-96F1-2EC55FFED31D} {B06770D5-2F3E-4271-9F6B-3AA9E716176F} = {879545ED-D429-49B1-96F1-2EC55FFED31D} + {4B850B93-46D6-4F25-9DB1-90D1E6E4AB70} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} + {F5124057-1DA1-4799-9357-D9A635047678} = {4B850B93-46D6-4F25-9DB1-90D1E6E4AB70} + {A0538079-AB6F-4C7D-9138-A15258583F80} = {4B850B93-46D6-4F25-9DB1-90D1E6E4AB70} {7C092DD9-9985-4D18-A817-15317D984149} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {31F6608A-FD36-F529-A5FC-C954A0B5E29E} = {7C092DD9-9985-4D18-A817-15317D984149} {08D84994-794A-760F-95FD-4EFA8998A16D} = {7C092DD9-9985-4D18-A817-15317D984149} diff --git a/dotnet/docs/EXPERIMENTS.md b/dotnet/docs/EXPERIMENTS.md index 99fd9b56afb4..8dbfa4d746b7 100644 --- a/dotnet/docs/EXPERIMENTS.md +++ b/dotnet/docs/EXPERIMENTS.md @@ -25,6 +25,7 @@ You can use the following diagnostic IDs to ignore warnings or errors for a part | SKEXP0100 | Advanced Semantic Kernel features | | SKEXP0110 | Semantic Kernel Agents | | SKEXP0120 | Native-AOT | +| SKEXP0130 | Conversation State | ## Experimental Features Tracking diff --git a/dotnet/samples/Demos/ModelContextProtocolPlugin/ModelContextProtocolPlugin.csproj b/dotnet/samples/Demos/ModelContextProtocolPlugin/ModelContextProtocolPlugin.csproj index 8b872e1db766..9d65e0b8a19d 100644 --- a/dotnet/samples/Demos/ModelContextProtocolPlugin/ModelContextProtocolPlugin.csproj +++ b/dotnet/samples/Demos/ModelContextProtocolPlugin/ModelContextProtocolPlugin.csproj @@ -16,6 +16,7 @@ + diff --git a/dotnet/src/Agents/Abstractions/AgentThread.cs b/dotnet/src/Agents/Abstractions/AgentThread.cs index 74477d556340..01635ae0728a 100644 --- a/dotnet/src/Agents/Abstractions/AgentThread.cs +++ b/dotnet/src/Agents/Abstractions/AgentThread.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; @@ -26,6 +27,52 @@ public abstract class AgentThread /// public virtual bool IsDeleted { get; protected set; } = false; + /// + /// Gets or sets the container for conversation state part components that manages their lifecycle and interactions. + /// + [Experimental("SKEXP0110")] + public virtual ConversationStatePartsManager StateParts { get; init; } = new ConversationStatePartsManager(); + + /// + /// Called when the current conversion is temporarily suspended and any state should be saved. + /// + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the end of each service call. + /// In a client application, this might be when the user closes the chat window or the application. + /// + [Experimental("SKEXP0110")] + public virtual Task OnSuspendAsync(CancellationToken cancellationToken = default) + { + return this.StateParts.OnSuspendAsync(this.Id, cancellationToken); + } + + /// + /// Called when the current conversion is resumed and any state should be restored. + /// + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the start of each service call where a previous conversation is being continued. + /// In a client application, this might be when the user re-opens the chat window to resume a conversation after having previously closed it. + /// + [Experimental("SKEXP0110")] + public virtual Task OnResumeAsync(CancellationToken cancellationToken = default) + { + if (this.IsDeleted) + { + throw new InvalidOperationException("This thread has been deleted and cannot be used anymore."); + } + + if (this.Id is null) + { + throw new InvalidOperationException("This thread cannot be resumed, since it has not been created."); + } + + return this.StateParts.OnResumeAsync(this.Id, cancellationToken); + } + /// /// Creates the thread and returns the thread id. /// @@ -45,6 +92,10 @@ protected internal virtual async Task CreateAsync(CancellationToken cancellation } this.Id = await this.CreateInternalAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + await this.StateParts.OnThreadCreatedAsync(this.Id!, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. } /// @@ -65,6 +116,10 @@ public virtual async Task DeleteAsync(CancellationToken cancellationToken = defa throw new InvalidOperationException("This thread cannot be deleted, since it has not been created."); } +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + await this.StateParts.OnThreadDeleteAsync(this.Id!, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + await this.DeleteInternalAsync(cancellationToken).ConfigureAwait(false); this.IsDeleted = true; @@ -92,6 +147,10 @@ internal virtual async Task OnNewMessageAsync(ChatMessageContent newMessage, Can await this.CreateAsync(cancellationToken).ConfigureAwait(false); } +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + await this.StateParts.OnNewMessageAsync(this.Id, newMessage, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + await this.OnNewMessageInternalAsync(newMessage, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Agents/AzureAI/Agents.AzureAI.csproj b/dotnet/src/Agents/AzureAI/Agents.AzureAI.csproj index 2b0694f69986..a33dcd16812a 100644 --- a/dotnet/src/Agents/AzureAI/Agents.AzureAI.csproj +++ b/dotnet/src/Agents/AzureAI/Agents.AzureAI.csproj @@ -34,6 +34,7 @@ + diff --git a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs index 440dd0c74b65..84b362ea00b2 100644 --- a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs +++ b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs @@ -186,6 +186,19 @@ public async IAsyncEnumerable> InvokeAsync () => new AzureAIAgentThread(this.Client), cancellationToken).ConfigureAwait(false); + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await azureAIAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + azureAIAgentThread.StateParts.RegisterPlugins(kernel); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + + var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext); + var extensionsContextOptions = options is null ? + new AzureAIAgentInvokeOptions() { AdditionalInstructions = mergedAdditionalInstructions } : + new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = mergedAdditionalInstructions }; + var invokeResults = ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), () => InternalInvokeAsync(), @@ -197,9 +210,9 @@ async IAsyncEnumerable InternalInvokeAsync() this, this.Client, azureAIAgentThread.Id!, - options?.ToAzureAIInvocationOptions(), + extensionsContextOptions?.ToAzureAIInvocationOptions(), this.Logger, - options?.Kernel ?? this.Kernel, + kernel, options?.KernelArguments, cancellationToken).ConfigureAwait(false)) { @@ -303,14 +316,27 @@ public async IAsyncEnumerable> In () => new AzureAIAgentThread(this.Client), cancellationToken).ConfigureAwait(false); + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await azureAIAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + azureAIAgentThread.StateParts.RegisterPlugins(kernel); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + + var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext); + var extensionsContextOptions = options is null ? + new AzureAIAgentInvokeOptions() { AdditionalInstructions = mergedAdditionalInstructions } : + new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = mergedAdditionalInstructions }; + #pragma warning disable CS0618 // Type or member is obsolete // Invoke the Agent with the thread that we already added our message to. var newMessagesReceiver = new ChatHistory(); var invokeResults = this.InvokeStreamingAsync( azureAIAgentThread.Id!, - options?.ToAzureAIInvocationOptions(), + extensionsContextOptions.ToAzureAIInvocationOptions(), options?.KernelArguments, - options?.Kernel ?? this.Kernel, + kernel, newMessagesReceiver, cancellationToken); #pragma warning restore CS0618 // Type or member is obsolete @@ -437,4 +463,19 @@ protected override async Task RestoreChannelAsync(string channelSt return new AzureAIChannel(this.Client, thread.Id); } + + private static string MergeAdditionalInstructions(string? optionsAdditionalInstructions, string extensionsContext) => + (optionsAdditionalInstructions, extensionsContext) switch + { + (string ai, string ec) when !string.IsNullOrWhiteSpace(ai) && !string.IsNullOrWhiteSpace(ec) => string.Concat( + ai, + Environment.NewLine, + Environment.NewLine, + ec), + (string ai, string ec) when string.IsNullOrWhiteSpace(ai) => ec, + (string ai, string ec) when string.IsNullOrWhiteSpace(ec) => ai, + (null, string ec) => ec, + (string ai, null) => ai, + _ => string.Empty + }; } diff --git a/dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj b/dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj index e17d43f63fcc..c3aa00b13330 100644 --- a/dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj +++ b/dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj @@ -33,6 +33,7 @@ + diff --git a/dotnet/src/Agents/Bedrock/BedrockAgent.cs b/dotnet/src/Agents/Bedrock/BedrockAgent.cs index 22d4187b0bd2..ae54bec77288 100644 --- a/dotnet/src/Agents/Bedrock/BedrockAgent.cs +++ b/dotnet/src/Agents/Bedrock/BedrockAgent.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Tasks; using Amazon.BedrockAgent; @@ -117,11 +118,18 @@ public override async IAsyncEnumerable> In () => new BedrockAgentThread(this.RuntimeClient), cancellationToken).ConfigureAwait(false); + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await bedrockThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + // Ensure that the last message provided is a user message string? message = this.ExtractUserMessage(messages.Last()); - // Build session state with conversation history if needed + // Build session state with conversation history and override instructions if needed SessionState sessionState = this.ExtractSessionState(messages); + var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext); + sessionState.PromptSessionAttributes = new() { ["AdditionalInstructions"] = mergedAdditionalInstructions }; // Configure the agent request with the provided options var invokeAgentRequest = this.ConfigureAgentRequest(options, () => @@ -346,11 +354,18 @@ public override async IAsyncEnumerable new BedrockAgentThread(this.RuntimeClient), cancellationToken).ConfigureAwait(false); + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await bedrockThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + // Ensure that the last message provided is a user message string? message = this.ExtractUserMessage(messages.Last()); - // Build session state with conversation history if needed + // Build session state with conversation history and override instructions if needed SessionState sessionState = this.ExtractSessionState(messages); + var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext); + sessionState.PromptSessionAttributes = new() { ["AdditionalInstructions"] = mergedAdditionalInstructions }; // Configure the agent request with the provided options var invokeAgentRequest = this.ConfigureAgentRequest(options, () => @@ -639,20 +654,35 @@ private IAsyncEnumerable InvokeStreamingInternalAsy async IAsyncEnumerable InvokeInternal() { + var combinedResponseMessageBuilder = new StringBuilder(); + StreamingChatMessageContent? lastMessage = null; + // The Bedrock agent service has the same API for both streaming and non-streaming responses. // We are invoking the same method as the non-streaming response with the streaming configuration set, // and converting the chat message content to streaming chat message content. await foreach (var chatMessageContent in this.InternalInvokeAsync(invokeAgentRequest, arguments, cancellationToken).ConfigureAwait(false)) { - await this.NotifyThreadOfNewMessage(thread, chatMessageContent, cancellationToken).ConfigureAwait(false); - yield return new StreamingChatMessageContent(chatMessageContent.Role, chatMessageContent.Content) + lastMessage = new StreamingChatMessageContent(chatMessageContent.Role, chatMessageContent.Content) { AuthorName = chatMessageContent.AuthorName, ModelId = chatMessageContent.ModelId, InnerContent = chatMessageContent.InnerContent, Metadata = chatMessageContent.Metadata, }; + yield return lastMessage; + + combinedResponseMessageBuilder.Append(chatMessageContent.Content); } + + // Build a combined message containing the text from all response parts + // to send to the thread. + var combinedMessage = new ChatMessageContent(AuthorRole.Assistant, combinedResponseMessageBuilder.ToString()) + { + AuthorName = lastMessage?.AuthorName, + ModelId = lastMessage?.ModelId, + Metadata = lastMessage?.Metadata, + }; + await this.NotifyThreadOfNewMessage(thread, combinedMessage, cancellationToken).ConfigureAwait(false); } } @@ -726,4 +756,19 @@ private Amazon.BedrockAgentRuntime.ConversationRole MapBedrockAgentUser(AuthorRo } #endregion + + private static string MergeAdditionalInstructions(string? optionsAdditionalInstructions, string extensionsContext) => + (optionsAdditionalInstructions, extensionsContext) switch + { + (string ai, string ec) when !string.IsNullOrWhiteSpace(ai) && !string.IsNullOrWhiteSpace(ec) => string.Concat( + ai, + Environment.NewLine, + Environment.NewLine, + ec), + (string ai, string ec) when string.IsNullOrWhiteSpace(ai) => ec, + (string ai, string ec) when string.IsNullOrWhiteSpace(ec) => ai, + (null, string ec) => ec, + (string ai, null) => ai, + _ => string.Empty + }; } diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 0fbcc3a8a198..af9a1f20a286 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -74,6 +74,14 @@ public override async IAsyncEnumerable> In () => new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await chatHistoryAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + chatHistoryAgentThread.StateParts.RegisterPlugins(kernel); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + // Invoke Chat Completion with the updated chat history. var chatHistory = new ChatHistory(); await foreach (var existingMessage in chatHistoryAgentThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false)) @@ -92,8 +100,10 @@ public override async IAsyncEnumerable> In } }, options?.KernelArguments, - options?.Kernel, - options?.AdditionalInstructions, + kernel, + options?.AdditionalInstructions == null ? + extensionsContext : + string.Concat(options.AdditionalInstructions, Environment.NewLine, Environment.NewLine, extensionsContext), cancellationToken); // Notify the thread of new messages and return them to the caller. @@ -157,6 +167,14 @@ public override async IAsyncEnumerable new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await chatHistoryAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + chatHistoryAgentThread.StateParts.RegisterPlugins(kernel); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + // Invoke Chat Completion with the updated chat history. var chatHistory = new ChatHistory(); await foreach (var existingMessage in chatHistoryAgentThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false)) @@ -176,8 +194,10 @@ public override async IAsyncEnumerable internal static class AssistantRunOptionsFactory { - public static RunCreationOptions GenerateOptions(RunCreationOptions? defaultOptions, string? agentInstructions, RunCreationOptions? invocationOptions) + public static RunCreationOptions GenerateOptions(RunCreationOptions? defaultOptions, string? agentInstructions, RunCreationOptions? invocationOptions, string? threadExtensionsContext) { + var additionalInstructions = string.Concat( + (invocationOptions?.AdditionalInstructions ?? defaultOptions?.AdditionalInstructions), + string.IsNullOrWhiteSpace(threadExtensionsContext) ? string.Empty : string.Concat(Environment.NewLine, Environment.NewLine, threadExtensionsContext)); + RunCreationOptions runOptions = new() { - AdditionalInstructions = invocationOptions?.AdditionalInstructions ?? defaultOptions?.AdditionalInstructions, + AdditionalInstructions = additionalInstructions, InstructionsOverride = invocationOptions?.InstructionsOverride ?? agentInstructions, MaxOutputTokenCount = invocationOptions?.MaxOutputTokenCount ?? defaultOptions?.MaxOutputTokenCount, MaxInputTokenCount = invocationOptions?.MaxInputTokenCount ?? defaultOptions?.MaxInputTokenCount, diff --git a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs index e1cb991b643e..c4a603b38922 100644 --- a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs +++ b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs @@ -105,6 +105,7 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist /// The assistant client /// The thread identifier /// Options to utilize for the invocation + /// Additional context from thead extensions to pass to the invoke method. /// The logger to utilize (might be agent or channel scoped) /// The plugins and other state. /// Optional arguments to pass to the agents's invocation, including any . @@ -115,6 +116,7 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist AssistantClient client, string threadId, RunCreationOptions? invocationOptions, + string? threadExtensionsContext, ILogger logger, Kernel kernel, KernelArguments? arguments, @@ -133,7 +135,7 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist string? instructions = await agent.GetInstructionsAsync(kernel, arguments, cancellationToken).ConfigureAwait(false); - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(agent.RunOptions, instructions, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(agent.RunOptions, instructions, invocationOptions, threadExtensionsContext); options.ToolsOverride.AddRange(tools); @@ -335,6 +337,7 @@ async Task PollRunStatusAsync() /// The thread identifier /// The receiver for the completed messages generated /// Options to utilize for the invocation + /// Additional context from thead extensions to pass to the invoke method. /// The logger to utilize (might be agent or channel scoped) /// The plugins and other state. /// Optional arguments to pass to the agents's invocation, including any . @@ -350,6 +353,7 @@ public static async IAsyncEnumerable InvokeStreamin string threadId, IList? messages, RunCreationOptions? invocationOptions, + string? threadExtensionsContext, ILogger logger, Kernel kernel, KernelArguments? arguments, @@ -361,7 +365,7 @@ public static async IAsyncEnumerable InvokeStreamin string? instructions = await agent.GetInstructionsAsync(kernel, arguments, cancellationToken).ConfigureAwait(false); - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(agent.RunOptions, instructions, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(agent.RunOptions, instructions, invocationOptions, threadExtensionsContext); options.ToolsOverride.AddRange(tools); diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs index f14503c484c9..aa0f46c0629b 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs @@ -411,6 +411,14 @@ public async IAsyncEnumerable> InvokeAsync AdditionalInstructions = options?.AdditionalInstructions, }); + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await openAIAssistantAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + openAIAssistantAgentThread.StateParts.RegisterPlugins(kernel); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var invokeResults = ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), () => InternalInvokeAsync(), @@ -423,8 +431,9 @@ async IAsyncEnumerable InternalInvokeAsync() this.Client, openAIAssistantAgentThread.Id!, internalOptions, + extensionsContext, this.Logger, - options?.Kernel ?? this.Kernel, + kernel, options?.KernelArguments, cancellationToken).ConfigureAwait(false)) { @@ -496,7 +505,7 @@ public IAsyncEnumerable InvokeAsync( async IAsyncEnumerable InternalInvokeAsync() { kernel ??= this.Kernel; - await foreach ((bool isVisible, ChatMessageContent message) in AssistantThreadActions.InvokeAsync(this, this.Client, threadId, options, this.Logger, kernel, arguments, cancellationToken).ConfigureAwait(false)) + await foreach ((bool isVisible, ChatMessageContent message) in AssistantThreadActions.InvokeAsync(this, this.Client, threadId, options, null, this.Logger, kernel, arguments, cancellationToken).ConfigureAwait(false)) { if (isVisible) { @@ -547,6 +556,14 @@ public async IAsyncEnumerable> In () => new OpenAIAssistantAgentThread(this.Client), cancellationToken).ConfigureAwait(false); + var kernel = (options?.Kernel ?? this.Kernel).Clone(); + + // Get the conversation state extensions context contributions and register plugins from the extensions. +#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + var extensionsContext = await openAIAssistantAgentThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false); + openAIAssistantAgentThread.StateParts.RegisterPlugins(kernel); +#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + // Create options that use the RunCreationOptions from the options param if provided or // falls back to creating a new RunCreationOptions if additional instructions is provided // separately. @@ -555,17 +572,28 @@ public async IAsyncEnumerable> In AdditionalInstructions = options?.AdditionalInstructions, }); -#pragma warning disable CS0618 // Type or member is obsolete - // Invoke the Agent with the thread that we already added our message to. +#pragma warning disable SKEXP0001 // ModelDiagnostics is marked experimental. var newMessagesReceiver = new ChatHistory(); - var invokeResults = this.InvokeStreamingAsync( - openAIAssistantAgentThread.Id!, - internalOptions, - options?.KernelArguments, - options?.Kernel ?? this.Kernel, - newMessagesReceiver, + var invokeResults = ActivityExtensions.RunWithActivityAsync( + () => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description), + () => InternalInvokeStreamingAsync(), cancellationToken); -#pragma warning restore CS0618 // Type or member is obsolete +#pragma warning restore SKEXP0001 // ModelDiagnostics is marked experimental. + + IAsyncEnumerable InternalInvokeStreamingAsync() + { + return AssistantThreadActions.InvokeStreamingAsync( + this, + this.Client, + openAIAssistantAgentThread.Id!, + newMessagesReceiver, + internalOptions, + extensionsContext, + this.Logger, + kernel, + options?.KernelArguments, + cancellationToken); + } // Return the chunks to the caller. await foreach (var result in invokeResults.ConfigureAwait(false)) @@ -638,7 +666,7 @@ public IAsyncEnumerable InvokeStreamingAsync( IAsyncEnumerable InternalInvokeStreamingAsync() { kernel ??= this.Kernel; - return AssistantThreadActions.InvokeStreamingAsync(this, this.Client, threadId, messages, options, this.Logger, kernel, arguments, cancellationToken); + return AssistantThreadActions.InvokeStreamingAsync(this, this.Client, threadId, messages, options, null, this.Logger, kernel, arguments, cancellationToken); } } diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs index 39534df768da..ad2331a2fb6e 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs @@ -36,7 +36,7 @@ protected override async Task ReceiveAsync(IEnumerable histo { return ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(agent.Id, agent.GetDisplayName(), agent.Description), - () => AssistantThreadActions.InvokeAsync(agent, this._client, this._threadId, invocationOptions: null, this.Logger, agent.Kernel, agent.Arguments, cancellationToken), + () => AssistantThreadActions.InvokeAsync(agent, this._client, this._threadId, invocationOptions: null, threadExtensionsContext: null, this.Logger, agent.Kernel, agent.Arguments, cancellationToken), cancellationToken); } @@ -45,7 +45,7 @@ protected override IAsyncEnumerable InvokeStreaming { return ActivityExtensions.RunWithActivityAsync( () => ModelDiagnostics.StartAgentInvocationActivity(agent.Id, agent.GetDisplayName(), agent.Description), - () => AssistantThreadActions.InvokeStreamingAsync(agent, this._client, this._threadId, messages, invocationOptions: null, this.Logger, agent.Kernel, agent.Arguments, cancellationToken), + () => AssistantThreadActions.InvokeStreamingAsync(agent, this._client, this._threadId, messages, invocationOptions: null, threadExtensionsContext: null, this.Logger, agent.Kernel, agent.Arguments, cancellationToken), cancellationToken); } diff --git a/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj b/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj index a0222fac89cf..c8550ebc178c 100644 --- a/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj +++ b/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj @@ -8,7 +8,7 @@ true false 12 - $(NoWarn);CA2007,CA1812,CA1861,CA1063,CS0618,VSTHRD111,SKEXP0001,SKEXP0050,SKEXP0110;OPENAI001 + $(NoWarn);CA2007,CA1812,CA1861,CA1063,CS0618,VSTHRD111,SKEXP0001,SKEXP0050,SKEXP0110;SKEXP0130;OPENAI001 diff --git a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs index c8e0c1884a87..c2810951ba31 100644 --- a/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/AgentThreadTests.cs @@ -3,8 +3,11 @@ using System; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.ChatCompletion; +using Moq; using Xunit; namespace SemanticKernel.Agents.UnitTests.Core; @@ -141,6 +144,135 @@ public async Task OnNewMessageShouldThrowIfThreadDeletedAsync() Assert.Equal(0, thread.OnNewMessageInternalAsyncCount); } + /// + /// Tests that the method throws an InvalidOperationException if the thread is not yet created. + /// + [Fact] + public async Task OnResumeShouldThrowIfThreadNotCreatedAsync() + { + // Arrange + var thread = new TestAgentThread(); + + // Act & Assert + await Assert.ThrowsAsync(() => thread.OnResumeAsync()); + } + + /// + /// Tests that the method throws an InvalidOperationException if the thread is deleted. + /// + [Fact] + public async Task OnResumeShouldThrowIfThreadDeletedAsync() + { + // Arrange + var thread = new TestAgentThread(); + await thread.CreateAsync(); + await thread.DeleteAsync(); + + // Act & Assert + await Assert.ThrowsAsync(() => thread.OnResumeAsync()); + } + + /// + /// Tests that the method + /// calls each registered state part in turn. + /// + [Fact] + public async Task OnSuspendShouldCallOnSuspendOnRegisteredPartsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); + await thread.CreateAsync(); + + // Act. + await thread.OnSuspendAsync(); + + // Assert. + mockPart.Verify(x => x.OnSuspendAsync("test-thread-id", It.IsAny()), Times.Once); + } + + /// + /// Tests that the method + /// calls each registered state part in turn. + /// + [Fact] + public async Task OnResumeShouldCallOnResumeOnRegisteredPartsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); + await thread.CreateAsync(); + + // Act. + await thread.OnResumeAsync(); + + // Assert. + mockPart.Verify(x => x.OnResumeAsync("test-thread-id", It.IsAny()), Times.Once); + } + + /// + /// Tests that the method + /// calls each registered state parts in turn. + /// + [Fact] + public async Task CreateShouldCallOnThreadCreatedOnRegisteredPartsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); + + // Act. + await thread.CreateAsync(); + + // Assert. + mockPart.Verify(x => x.OnThreadCreatedAsync("test-thread-id", It.IsAny()), Times.Once); + } + + /// + /// Tests that the method + /// calls each registered state parts in turn. + /// + [Fact] + public async Task DeleteShouldCallOnThreadDeleteOnRegisteredPartsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); + await thread.CreateAsync(); + + // Act. + await thread.DeleteAsync(); + + // Assert. + mockPart.Verify(x => x.OnThreadDeleteAsync("test-thread-id", It.IsAny()), Times.Once); + } + + /// + /// Tests that the method + /// calls each registered state part in turn. + /// + [Fact] + public async Task OnNewMessageShouldCallOnNewMessageOnRegisteredPartsAsync() + { + // Arrange. + var thread = new TestAgentThread(); + var mockPart = new Mock(); + thread.StateParts.Add(mockPart.Object); + var message = new ChatMessageContent(AuthorRole.User, "Test Message."); + + await thread.CreateAsync(); + + // Act. + await thread.OnNewMessageAsync(message); + + // Assert. + mockPart.Verify(x => x.OnNewMessageAsync("test-thread-id", It.Is(x => x.Text == "Test Message." && x.Role == ChatRole.User), It.IsAny()), Times.Once); + } + private sealed class TestAgentThread : AgentThread { public int CreateInternalAsyncCount { get; private set; } diff --git a/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs b/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs index dfca85afc0f2..8e0cee8a382c 100644 --- a/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs +++ b/dotnet/src/Agents/UnitTests/OpenAI/Internal/AssistantRunOptionsFactoryTests.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents.OpenAI; using Microsoft.SemanticKernel.Agents.OpenAI.Internal; @@ -29,7 +30,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsNullTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null, threadExtensionsContext: null); // Assert Assert.NotNull(options); @@ -62,7 +63,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsEquivalentTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, "test", invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, "test", invocationOptions, threadExtensionsContext: null); // Assert Assert.NotNull(options); @@ -97,7 +98,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsOverrideTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions, threadExtensionsContext: null); // Assert Assert.NotNull(options); @@ -134,7 +135,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsMetadataTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions, threadExtensionsContext: null); // Assert Assert.Equal(2, options.Metadata.Count); @@ -163,7 +164,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsMessagesTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions, threadExtensionsContext: null); // Assert Assert.Single(options.AdditionalMessages); @@ -186,10 +187,42 @@ public void AssistantRunOptionsFactoryExecutionOptionsMaxTokensTest() }; // Act - RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null); + RunCreationOptions options = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null, threadExtensionsContext: null); // Assert Assert.Equal(1024, options.MaxInputTokenCount); Assert.Equal(4096, options.MaxOutputTokenCount); } + + /// + /// Verify run options generation with metadata. + /// + [Fact] + public void AssistantRunOptionsFactoryAdditionalInstructionsTest() + { + // Arrange + RunCreationOptions defaultOptions = + new() + { + ModelOverride = "gpt-anything", + Temperature = 0.5F, + MaxOutputTokenCount = 4096, + MaxInputTokenCount = 1024, + AdditionalInstructions = "DefaultInstructions" + }; + + RunCreationOptions invocationOptions = + new() + { + AdditionalInstructions = "OverrideInstructions", + }; + + // Act + RunCreationOptions optionsWithOverride = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, invocationOptions, threadExtensionsContext: "Context"); + RunCreationOptions optionsWithoutOverride = AssistantRunOptionsFactory.GenerateOptions(defaultOptions, null, null, threadExtensionsContext: "Context"); + + // Assert + Assert.Equal($"OverrideInstructions{Environment.NewLine}{Environment.NewLine}Context", optionsWithOverride.AdditionalInstructions); + Assert.Equal($"DefaultInstructions{Environment.NewLine}{Environment.NewLine}Context", optionsWithoutOverride.AdditionalInstructions); + } } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs index 8be11475493c..89ea56fa4e27 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentFixture.cs @@ -23,6 +23,8 @@ public abstract class AgentFixture : IAsyncLifetime public abstract AgentThread CreatedServiceFailingAgentThread { get; } + public abstract AgentThread GetNewThread(); + public abstract Task GetChatHistory(); public abstract Task DeleteThread(AgentThread thread); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/BedrockAgentThreadTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/BedrockAgentThreadTests.cs index 93fff6c7f636..ce09c4a3a68d 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/BedrockAgentThreadTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentThreadConformance/BedrockAgentThreadTests.cs @@ -7,38 +7,40 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Agen public class BedrockAgentThreadTests() : AgentThreadTests(() => new BedrockAgentFixture()) { - [Fact(Skip = "Manual verification only")] + private const string ManualVerificationSkipReason = "This test is for manual verification."; + + [Fact(Skip = ManualVerificationSkipReason)] public override Task OnNewMessageWithServiceFailureThrowsAgentOperationExceptionAsync() { // The Bedrock agent does not support writing to a thread with OnNewMessage. return Task.CompletedTask; } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task DeletingThreadTwiceDoesNotThrowAsync() { return base.DeletingThreadTwiceDoesNotThrowAsync(); } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task UsingThreadAfterDeleteThrowsAsync() { return base.UsingThreadAfterDeleteThrowsAsync(); } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task DeleteThreadBeforeCreateThrowsAsync() { return base.DeleteThreadBeforeCreateThrowsAsync(); } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task UsingThreadBeforeCreateCreatesAsync() { return base.UsingThreadBeforeCreateCreatesAsync(); } - [Fact(Skip = "Manual verification only")] + [Fact(Skip = ManualVerificationSkipReason)] public override Task DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync() { return base.DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync(); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs new file mode 100644 index 000000000000..12e9011b4eb7 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AgentWithStatePartTests.cs @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Moq; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public abstract class AgentWithStatePartTests(Func createAgentFixture) : IAsyncLifetime + where TFixture : AgentFixture +{ +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + private TFixture _agentFixture; +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + + protected TFixture Fixture => this._agentFixture; + + [Fact] + public virtual async Task StatePartReceivesMessagesFromAgentAsync() + { + // Arrange + var mockStatePart = new Mock() { CallBase = true }; + mockStatePart.Setup(x => x.OnNewMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())); + + var agent = this.Fixture.Agent; + + var agentThread = this.Fixture.GetNewThread(); + + try + { + agentThread.StateParts.Add(mockStatePart.Object); + + // Act + var inputMessage = "What is the capital of France?"; + var asyncResults1 = agent.InvokeAsync(inputMessage, agentThread); + var result = await asyncResults1.FirstAsync(); + + // Assert + Assert.Contains("Paris", result.Message.Content); + mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny(), It.Is(cm => cm.Text == inputMessage), It.IsAny()), Times.Once); + mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny(), It.Is(cm => cm.Text == result.Message.Content), It.IsAny()), Times.Once); + } + finally + { + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + } + + [Fact] + public virtual async Task StatePartReceivesMessagesFromAgentWhenStreamingAsync() + { + // Arrange + var mockStatePart = new Mock() { CallBase = true }; + mockStatePart.Setup(x => x.OnNewMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())); + + var agent = this.Fixture.Agent; + + var agentThread = this.Fixture.GetNewThread(); + + try + { + agentThread.StateParts.Add(mockStatePart.Object); + + // Act + var inputMessage = "What is the capital of France?"; + var asyncResults1 = agent.InvokeStreamingAsync(inputMessage, agentThread); + var results = await asyncResults1.ToListAsync(); + + // Assert + var responseMessage = string.Concat(results.Select(x => x.Message.Content)); + Assert.Contains("Paris", responseMessage); + mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny(), It.Is(cm => cm.Text == inputMessage), It.IsAny()), Times.Once); + mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny(), It.Is(cm => cm.Text == responseMessage), It.IsAny()), Times.Once); + } + finally + { + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + } + + [Fact] + public virtual async Task StatePartPreInvokeStateIsUsedByAgentAsync() + { + // Arrange + var mockStatePart = new Mock() { CallBase = true }; + mockStatePart.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())).ReturnsAsync("User name is Caoimhe"); + + var agent = this.Fixture.Agent; + + var agentThread = this.Fixture.GetNewThread(); + + try + { + agentThread.StateParts.Add(mockStatePart.Object); + + // Act + var inputMessage = "What is my name?."; + var asyncResults1 = agent.InvokeAsync(inputMessage, agentThread); + var result = await asyncResults1.FirstAsync(); + + // Assert + Assert.Contains("Caoimhe", result.Message.Content); + } + finally + { + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + } + + [Fact] + public virtual async Task StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync() + { + // Arrange + var mockStatePart = new Mock() { CallBase = true }; + mockStatePart.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())).ReturnsAsync("User name is Caoimhe"); + + var agent = this.Fixture.Agent; + + var agentThread = this.Fixture.GetNewThread(); + + try + { + agentThread.StateParts.Add(mockStatePart.Object); + + // Act + var inputMessage = "What is my name?."; + var asyncResults1 = agent.InvokeStreamingAsync(inputMessage, agentThread); + var results = await asyncResults1.ToListAsync(); + + // Assert + var responseMessage = string.Concat(results.Select(x => x.Message.Content)); + Assert.Contains("Caoimhe", responseMessage); + } + finally + { + // Cleanup + await this.Fixture.DeleteThread(agentThread); + } + } + + public Task InitializeAsync() + { + this._agentFixture = createAgentFixture(); + return this._agentFixture.InitializeAsync(); + } + + public Task DisposeAsync() + { + return this._agentFixture.DisposeAsync(); + } +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs new file mode 100644 index 000000000000..59c3fbd0bf35 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/AzureAIAgentWithStatePartTests.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public class AzureAIAgentWithStatePartTests() : AgentWithStatePartTests(() => new AzureAIAgentFixture()) +{ +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/BedrockAgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/BedrockAgentWithStatePartTests.cs new file mode 100644 index 000000000000..05f16f5d7292 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/BedrockAgentWithStatePartTests.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public class BedrockAgentWithStatePartTests() : AgentWithStatePartTests(() => new BedrockAgentFixture()) +{ + private const string ManualVerificationSkipReason = "This test is for manual verification."; + + [Fact(Skip = ManualVerificationSkipReason)] + public override Task StatePartReceivesMessagesFromAgentAsync() + { + return base.StatePartReceivesMessagesFromAgentAsync(); + } + + [Fact(Skip = ManualVerificationSkipReason)] + public override Task StatePartReceivesMessagesFromAgentWhenStreamingAsync() + { + return base.StatePartReceivesMessagesFromAgentWhenStreamingAsync(); + } + + [Fact(Skip = ManualVerificationSkipReason)] + public override Task StatePartPreInvokeStateIsUsedByAgentAsync() + { + return base.StatePartPreInvokeStateIsUsedByAgentAsync(); + } + + [Fact(Skip = ManualVerificationSkipReason)] + public override Task StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync() + { + return base.StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync(); + } +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/ChatCompletionAgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/ChatCompletionAgentWithStatePartTests.cs new file mode 100644 index 000000000000..838043376c9b --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/ChatCompletionAgentWithStatePartTests.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public class ChatCompletionAgentWithStatePartTests() : AgentWithStatePartTests(() => new ChatCompletionAgentFixture()) +{ +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs new file mode 100644 index 000000000000..62e6ab81309f --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AgentWithStatePartConformance/OpenAIAssistantAgentWithStatePartTests.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance; + +public class OpenAIAssistantAgentWithStatePartTests() : AgentWithStatePartTests(() => new OpenAIAssistantAgentFixture()) +{ +} diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs index 769e3daec9d7..4ceb4d0605ef 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/AzureAIAgentFixture.cs @@ -30,6 +30,8 @@ public class AzureAIAgentFixture : AgentFixture private AzureAIAgentThread? _serviceFailingAgentThread; private AzureAIAgentThread? _createdServiceFailingAgentThread; + public AAIP.AgentsClient AgentsClient => this._agentsClient!; + public override Agent Agent => this._agent!; public override AgentThread AgentThread => this._thread!; @@ -40,6 +42,11 @@ public class AzureAIAgentFixture : AgentFixture public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!; + public override AgentThread GetNewThread() + { + return new AzureAIAgentThread(this._agentsClient!); + } + public override async Task GetChatHistory() { var chatHistory = new ChatHistory(); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs index 1dd85e085cce..d9ba2a333003 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/BedrockAgentFixture.cs @@ -18,7 +18,7 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance; -internal sealed class BedrockAgentFixture : AgentFixture, IAsyncDisposable +public sealed class BedrockAgentFixture : AgentFixture, IAsyncDisposable { private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) @@ -47,6 +47,11 @@ internal sealed class BedrockAgentFixture : AgentFixture, IAsyncDisposable public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!; + public override AgentThread GetNewThread() + { + return new BedrockAgentThread(this._runtimeClient); + } + public override async Task DeleteThread(AgentThread thread) { await this._runtimeClient!.EndSessionAsync(new EndSessionRequest() { SessionIdentifier = thread.Id }); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs index c7fa8dbcede3..ea1f9c3d2af4 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/ChatCompletionAgentFixture.cs @@ -36,6 +36,11 @@ public class ChatCompletionAgentFixture : AgentFixture public override AgentThread CreatedServiceFailingAgentThread => null!; + public override AgentThread GetNewThread() + { + return new ChatHistoryAgentThread(); + } + public override async Task GetChatHistory() { var chatHistory = new ChatHistory(); @@ -65,6 +70,10 @@ public async override Task InitializeAsync() deploymentName: configuration.ChatDeploymentName!, endpoint: configuration.Endpoint, credentials: new AzureCliCredential()); + kernelBuilder.AddAzureOpenAITextEmbeddingGeneration( + deploymentName: configuration.EmbeddingModelId!, + endpoint: configuration.Endpoint, + credential: new AzureCliCredential()); Kernel kernel = kernelBuilder.Build(); this._agent = new ChatCompletionAgent() diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/InvokeConformance/BedrockAgentInvokeTests.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/InvokeConformance/BedrockAgentInvokeTests.cs index e9eff01241c6..84742ca908dc 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/InvokeConformance/BedrockAgentInvokeTests.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/InvokeConformance/BedrockAgentInvokeTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Linq; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -11,7 +10,9 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Invo public class BedrockAgentInvokeTests() : InvokeTests(() => new BedrockAgentFixture()) { - [Fact(Skip = "This test is for manual verification.")] + private const string ManualVerificationSkipReason = "This test is for manual verification."; + + [Fact(Skip = ManualVerificationSkipReason)] public override async Task ConversationMaintainsHistoryAsync() { var q1 = "What is the capital of France."; @@ -32,7 +33,7 @@ public override async Task ConversationMaintainsHistoryAsync() //Assert.Contains("Eiffel", result2.Message.Content); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = ManualVerificationSkipReason)] public override async Task InvokeReturnsResultAsync() { var agent = this.Fixture.Agent; @@ -48,7 +49,7 @@ public override async Task InvokeReturnsResultAsync() //Assert.Contains("Paris", firstResult.Message.Content); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = ManualVerificationSkipReason)] public override async Task InvokeWithoutThreadCreatesThreadAsync() { var agent = this.Fixture.Agent; @@ -66,20 +67,19 @@ public override async Task InvokeWithoutThreadCreatesThreadAsync() await this.Fixture.DeleteThread(firstResult.Thread); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "The BedrockAgent does not support invoking without a message.")] public override Task InvokeWithoutMessageCreatesThreadAsync() { - // The Bedrock agent does not support invoking without a message. - return Assert.ThrowsAsync(async () => await base.InvokeWithoutThreadCreatesThreadAsync()); + return base.InvokeWithoutMessageCreatesThreadAsync(); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "The BedrockAgent does not yet support plugins")] public override Task MultiStepInvokeWithPluginAndArgOverridesAsync() { return base.MultiStepInvokeWithPluginAndArgOverridesAsync(); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "The BedrockAgent does not yet support plugins")] public override Task InvokeWithPluginNotifiesForAllMessagesAsync() { return base.InvokeWithPluginNotifiesForAllMessagesAsync(); diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs index 4c512e559379..2cf64795d79d 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIAssistantAgentFixture.cs @@ -34,6 +34,8 @@ public class OpenAIAssistantAgentFixture : AgentFixture private OpenAIAssistantAgentThread? _serviceFailingAgentThread; private OpenAIAssistantAgentThread? _createdServiceFailingAgentThread; + public AssistantClient AssistantClient => this._assistantClient!; + public override Agent Agent => this._agent!; public override AgentThread AgentThread => this._thread!; @@ -44,6 +46,11 @@ public class OpenAIAssistantAgentFixture : AgentFixture public override AgentThread CreatedServiceFailingAgentThread => this._createdServiceFailingAgentThread!; + public override AgentThread GetNewThread() + { + return new OpenAIAssistantAgentThread(this._assistantClient!); + } + public override async Task GetChatHistory() { var chatHistory = new ChatHistory(); @@ -85,6 +92,7 @@ public override Task DeleteThread(AgentThread thread) public override async Task InitializeAsync() { + AzureOpenAIConfiguration openAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get()!; AzureAIConfiguration configuration = this._configuration.GetSection("AzureAI").Get()!; var client = OpenAIAssistantAgent.CreateAzureOpenAIClient(new AzureCliCredential(), new Uri(configuration.Endpoint)); this._assistantClient = client.GetAssistantClient(); @@ -96,6 +104,14 @@ await this._assistantClient.CreateAssistantAsync( instructions: "You are a helpful assistant."); var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddAzureOpenAIChatCompletion( + deploymentName: openAIConfiguration.ChatDeploymentName!, + endpoint: openAIConfiguration.Endpoint, + credentials: new AzureCliCredential()); + kernelBuilder.AddAzureOpenAITextEmbeddingGeneration( + deploymentName: openAIConfiguration.EmbeddingModelId!, + endpoint: openAIConfiguration.Endpoint, + credential: new AzureCliCredential()); Kernel kernel = kernelBuilder.Build(); this._agent = new OpenAIAssistantAgent(this._assistant, this._assistantClient) { Kernel = kernel }; diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 284f7d46a978..8a37e317cbe2 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -5,7 +5,7 @@ net8.0 true false - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110,OPENAI001 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110,SKEXP0130,OPENAI001 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 @@ -48,6 +48,7 @@ + @@ -97,6 +98,7 @@ + diff --git a/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs b/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs new file mode 100644 index 000000000000..c1f4c90cdfdf --- /dev/null +++ b/dotnet/src/IntegrationTests/Memory/Mem0MemoryComponentTests.cs @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.Memory; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Memory; + +/// +/// Contains tests for the class. +/// +public class Mem0MemoryComponentTests : IDisposable +{ + // If null, all tests will be enabled + private const string SkipReason = "Requires a Mem0 service configured"; + + private readonly HttpClient _httpClient; + private bool _disposedValue; + + public Mem0MemoryComponentTests() + { + IConfigurationRoot configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + var mem0Settings = configuration.GetRequiredSection("Mem0").Get()!; + + this._httpClient = new HttpClient(); + this._httpClient.BaseAddress = new Uri(mem0Settings.ServiceUri); + this._httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", mem0Settings.ApiKey); + } + + [Fact(Skip = SkipReason)] + public async Task Mem0ComponentCanAddAndRetrieveMemoriesAsync() + { + // Arrange + var question = new ChatMessage(ChatRole.User, "What is my name?"); + var input = new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe."); + + var sut = new Mem0MemoryComponent(this._httpClient, new() { ThreadId = "test-thread-id", UserId = "test-user-id", ScopeToPerOperationThreadId = true }); + + await sut.ClearStoredUserFactsAsync(); + var answerBeforeAdding = await sut.OnModelInvokeAsync([question]); + Assert.DoesNotContain("Caoimhe", answerBeforeAdding); + + // Act + await sut.OnNewMessageAsync("test-thread-id", input); + + await sut.OnNewMessageAsync("test-thread-id", question); + var answerAfterAdding = await sut.OnModelInvokeAsync([question]); + + await sut.ClearStoredUserFactsAsync(); + var answerAfterClearing = await sut.OnModelInvokeAsync([question]); + + // Assert + Assert.Contains("Caoimhe", answerAfterAdding); + Assert.DoesNotContain("Caoimhe", answerAfterClearing); + } + + [Fact(Skip = SkipReason)] + public async Task Mem0ComponentDoesNotLeakMessagesAcrossScopesAsync() + { + // Arrange + var question = new ChatMessage(ChatRole.User, "What is your name?"); + var input = new ChatMessage(ChatRole.Assistant, "I'm an AI tutor with a personality. My name is Caoimhe."); + + var sut1 = new Mem0MemoryComponent(this._httpClient, new() { AgentId = "test-agent-id-1" }); + var sut2 = new Mem0MemoryComponent(this._httpClient, new() { AgentId = "test-agent-id-2" }); + + await sut1.ClearStoredUserFactsAsync(); + await sut2.ClearStoredUserFactsAsync(); + + var answerBeforeAdding1 = await sut1.OnModelInvokeAsync([question]); + var answerBeforeAdding2 = await sut2.OnModelInvokeAsync([question]); + Assert.DoesNotContain("Caoimhe", answerBeforeAdding1); + Assert.DoesNotContain("Caoimhe", answerBeforeAdding2); + + // Act + await sut1.OnNewMessageAsync("test-thread-id-1", input); + var answerAfterAdding = await sut1.OnModelInvokeAsync([question]); + + await sut2.OnNewMessageAsync("test-thread-id-2", question); + var answerAfterAddingOnOtherScope = await sut2.OnModelInvokeAsync([question]); + + // Assert + Assert.Contains("Caoimhe", answerAfterAdding); + Assert.DoesNotContain("Caoimhe", answerAfterAddingOnOtherScope); + + // Cleanup. + await sut1.ClearStoredUserFactsAsync(); + await sut2.ClearStoredUserFactsAsync(); + } + + [Fact(Skip = SkipReason)] + public async Task Mem0ComponentDoesNotWorkWithMultiplePerOperationThreadsAsync() + { + // Arrange + var input = new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe."); + + var sut = new Mem0MemoryComponent(this._httpClient, new() { UserId = "test-user-id", ScopeToPerOperationThreadId = true }); + + await sut.ClearStoredUserFactsAsync(); + + // Act & Assert + await sut.OnThreadCreatedAsync("test-thread-id-1"); + await Assert.ThrowsAsync(() => sut.OnThreadCreatedAsync("test-thread-id-2")); + + await sut.OnNewMessageAsync("test-thread-id-1", input); + await Assert.ThrowsAsync(() => sut.OnNewMessageAsync("test-thread-id-2", input)); + + // Cleanup + await sut.ClearStoredUserFactsAsync(); + } + + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._httpClient.Dispose(); + } + + this._disposedValue = true; + } + } + + public void Dispose() + { + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } +} diff --git a/dotnet/src/IntegrationTests/TestSettings/Mem0Configuration.cs b/dotnet/src/IntegrationTests/TestSettings/Mem0Configuration.cs new file mode 100644 index 000000000000..a89f68e45c0a --- /dev/null +++ b/dotnet/src/IntegrationTests/TestSettings/Mem0Configuration.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace SemanticKernel.IntegrationTests.TestSettings; + +[SuppressMessage("Performance", "CA1812:Internal class that is apparently never instantiated", + Justification = "Configuration classes are instantiated through IConfiguration.")] +internal sealed class Mem0Configuration +{ + public string ServiceUri { get; init; } = string.Empty; + public string ApiKey { get; init; } = string.Empty; +} diff --git a/dotnet/src/IntegrationTests/testsettings.json b/dotnet/src/IntegrationTests/testsettings.json index ef269839e1c3..287fe122e2ad 100644 --- a/dotnet/src/IntegrationTests/testsettings.json +++ b/dotnet/src/IntegrationTests/testsettings.json @@ -126,5 +126,9 @@ "BedrockAgent": { "AgentResourceRoleArn": "", "FoundationModel": "anthropic.claude-3-haiku-20240307-v1:0" + }, + "Mem0": { + "ServiceUri": "https://api.mem0.ai", + "ApiKey": "" } } \ No newline at end of file diff --git a/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj b/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj new file mode 100644 index 000000000000..fb48dbd9236b --- /dev/null +++ b/dotnet/src/Memory/Memory.Abstractions/Memory.Abstractions.csproj @@ -0,0 +1,31 @@ + + + + Microsoft.SemanticKernel.Memory.Abstractions + Microsoft.SemanticKernel.Memory + net8.0;netstandard2.0 + false + alpha + + + + + + + Semantic Kernel - Memory Abstractions + Semantic Kernel interfaces and abstractions for capturing, storing and retrieving memories. + + + + rc + + + + + + + + + + + diff --git a/dotnet/src/Memory/Memory.Abstractions/Properties/AssemblyInfo.cs b/dotnet/src/Memory/Memory.Abstractions/Properties/AssemblyInfo.cs new file mode 100644 index 000000000000..418ffa6d4b58 --- /dev/null +++ b/dotnet/src/Memory/Memory.Abstractions/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +// This assembly is currently experimental. +[assembly: Experimental("SKEXP0130")] diff --git a/dotnet/src/Memory/Memory/Memory.csproj b/dotnet/src/Memory/Memory/Memory.csproj new file mode 100644 index 000000000000..4caeec7c0426 --- /dev/null +++ b/dotnet/src/Memory/Memory/Memory.csproj @@ -0,0 +1,39 @@ + + + + Microsoft.SemanticKernel.Memory.Core + Microsoft.SemanticKernel.Memory + net8.0;netstandard2.0 + false + alpha + $(NoWarn);NU5104 + + + + + + + Semantic Kernel - Memory Core + Semantic Kernel implementations for capturing, storing and retrieving memories. + + + + rc + + + + + + + + + + + + + + + + + + diff --git a/dotnet/src/Memory/Memory/Properties/AssemblyInfo.cs b/dotnet/src/Memory/Memory/Properties/AssemblyInfo.cs new file mode 100644 index 000000000000..418ffa6d4b58 --- /dev/null +++ b/dotnet/src/Memory/Memory/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +// This assembly is currently experimental. +[assembly: Experimental("SKEXP0130")] diff --git a/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs new file mode 100644 index 000000000000..2c77ef77e3f5 --- /dev/null +++ b/dotnet/src/Memory/Memory/UserFactsMemoryComponent.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if DISABLED + +using System.Collections.Generic; +using System.ComponentModel; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.Memory; + +namespace Microsoft.SemanticKernel.Agents.Memory; + +/// +/// A memory component that can retrieve, maintain and store user facts that +/// are learned from the user's interactions with the agent. +/// +public class UserFactsMemoryComponent : ConversationStatePart +{ + private readonly Kernel _kernel; + private readonly TextMemoryStore _textMemoryStore; + private string _userFacts = string.Empty; + private bool _contextLoaded = false; + + private readonly AIFunction[] _aIFunctions; + + /// + /// Initializes a new instance of the class. + /// + /// A kernel to use for making chat completion calls. + /// The memory store to retrieve and save memories from and to. + public UserFactsMemoryComponent(Kernel kernel, TextMemoryStore textMemoryStore) + { + this._kernel = kernel; + this._textMemoryStore = textMemoryStore; + + this._aIFunctions = [AIFunctionFactory.Create(this.ClearUserFactsAsync)]; + } + + /// + /// Initializes a new instance of the class. + /// + /// A kernel to use for making chat completion calls. + /// The service key that the for user facts is registered under in DI. + public UserFactsMemoryComponent(Kernel kernel, string? userFactsStoreName = "UserFactsStore") + { + this._kernel = kernel; + this._textMemoryStore = new OptionalTextMemoryStore(kernel, userFactsStoreName); + + this._aIFunctions = [AIFunctionFactory.Create(this.ClearUserFactsAsync)]; + } + + /// + public override IReadOnlyCollection AIFunctions => this._aIFunctions; + + /// + /// Gets or sets the name of the document to use for storing user preferfactsences. + /// + public string UserFactsDocumentName { get; init; } = "UserFacts"; + + /// + /// Gets or sets the prompt template to use for extracting user facts and merging them with existing facts. + /// + public string MaintenancePromptTemplate { get; init; } = + """ + You are an expert in extracting facts about a user from text and combining these facts with existing facts to output a new list of facts. + Facts are short statements that each contain a single piece of information. + Facts should always be about the user and should always be in the present tense. + Facts should focus on the user's long term preferences and characteristics, not on their short term actions. + + Here are 5 few shot examples: + + EXAMPLES START + + Input text: My name is John. I love dogs and cats, but unfortunately I am allergic to cats. I'm not alergic to dogs though. I have a dog called Fido. + Input facts: User name is John. User is alergic to cats. + Output: User name is John. User loves dogs. User loves cats. User is alergic to cats. User is not alergic to dogs. User has a dog. User dog's name is Fido. + + Input text: My name is Mary. I like active holidays. I enjoy cycling and hiking. + Input facts: User name is Mary. User dislikes cycling. + Output: User name is Mary. User likes cycling. User likes hiking. User likes active holidays. + + Input text: How do I calculate the area of a circle? + Input facts: + Output: + + Input text: What is today's date? + Input facts: User name is Peter. + Output: User name is Peter. + + EXAMPLES END + + Return output for the following inputs like shown in the examples above: + + Input text: {{$inputText}} + Input facts: {{$existingFacts}} + """; + + /// + public override async Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) + { + if (!this._contextLoaded) + { + this._userFacts = string.Empty; + + var memoryText = await this._textMemoryStore.GetMemoryAsync(this.UserFactsDocumentName, cancellationToken).ConfigureAwait(false); + if (memoryText is not null) + { + this._userFacts = memoryText; + } + + this._contextLoaded = true; + } + } + + /// + public override async Task OnThreadDeleteAsync(string? threadId, CancellationToken cancellationToken = default) + { + await this._textMemoryStore.SaveMemoryAsync(this.UserFactsDocumentName, this._userFacts, cancellationToken).ConfigureAwait(false); + } + + /// + public override async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) + { + if (newMessage.Role == ChatRole.User && !string.IsNullOrWhiteSpace(newMessage.Text)) + { + // Don't wait for task to complete. Just run in the background. + await this.ExtractAndSaveMemoriesAsync(newMessage.Text, cancellationToken).ConfigureAwait(false); + } + } + + /// + public override Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) + { + return Task.FromResult("The following list contains facts about the user:\n" + this._userFacts); + } + + /// + public override Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) + { + return this.OnThreadCreatedAsync(threadId, cancellationToken); + } + + /// + public override Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) + { + return this.OnThreadDeleteAsync(threadId, cancellationToken); + } + + /// + /// Plugin method to clear user facts stored in memory. + /// + [Description("Deletes any user facts that are stored acros multiple conversations.")] + public async Task ClearUserFactsAsync(CancellationToken cancellationToken = default) + { + this._userFacts = string.Empty; + await this._textMemoryStore.SaveMemoryAsync(this.UserFactsDocumentName, this._userFacts, cancellationToken).ConfigureAwait(false); + } + + private async Task ExtractAndSaveMemoriesAsync(string inputText, CancellationToken cancellationToken = default) + { + var result = await this._kernel.InvokePromptAsync( + this.MaintenancePromptTemplate, + new KernelArguments() { ["inputText"] = inputText, ["existingFacts"] = this._userFacts }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + this._userFacts = result.ToString(); + + await this._textMemoryStore.SaveMemoryAsync(this.UserFactsDocumentName, this._userFacts, cancellationToken).ConfigureAwait(false); + } +} + +#endif diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePart.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePart.cs new file mode 100644 index 000000000000..eaa533a121ea --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePart.cs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +/// +/// Base class for all conversation state parts. +/// +/// +/// A conversation state part is a component that can be used to store additional state related +/// to a conversation, listen to changes in the conversation state, and provide additional context to +/// the AI model in use just before invocation. +/// +[Experimental("SKEXP0130")] +public abstract class ConversationStatePart +{ + /// + /// Gets the list of AI functions that this component exposes + /// and which should be used by the consuming AI when using this component. + /// + public virtual IReadOnlyCollection AIFunctions => Array.Empty(); + + /// + /// Called just after a new thread is created. + /// + /// + /// Implementers can use this method to do any operations required at the creation of a new thread. + /// For example, checking long term storage for any data that is relevant to the current session based on the input text. + /// + /// The ID of the new thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been loaded. + public virtual Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + /// + /// This method is called when a new message has been contributed to the chat by any participant. + /// + /// + /// Inheritors can use this method to update their context based on the new message. + /// + /// The ID of the thread for the new message, if the thread has an ID. + /// The new message. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been updated. + public virtual Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + /// + /// Called just before a thread is deleted. + /// + /// + /// Implementers can use this method to do any operations required before a thread is deleted. + /// For example, storing the context to long term storage. + /// + /// The ID of the thread that will be deleted, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been saved. + public virtual Task OnThreadDeleteAsync(string? threadId, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + /// + /// Called just before the Model/Agent/etc. is invoked + /// Implementers can load any additional context required at this time, + /// but they should also return any context that should be passed to the Model/Agent/etc. + /// + /// The most recent messages that the Model/Agent/etc. is being invoked with. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been rendered and returned. + public abstract Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default); + + /// + /// Called when the current conversion is temporarily suspended and any state should be saved. + /// + /// The ID of the current thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the end of each service call. + /// In a client application, this might be when the user closes the chat window or the application. + /// + public virtual Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + /// + /// Called when the current conversion is resumed and any state should be restored. + /// + /// The ID of the current thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the start of each service call where a previous conversation is being continued. + /// In a client application, this might be when the user re-opens the chat window to resume a conversation after having previously closed it. + /// + public virtual Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManager.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManager.cs new file mode 100644 index 000000000000..f4c62b4b8443 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManager.cs @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.SemanticKernel; + +/// +/// A container class for objects that manages their lifecycle and interactions. +/// +[Experimental("SKEXP0130")] +public sealed class ConversationStatePartsManager +{ + private readonly List _parts = new(); + + private List? _currentAIFunctions = null; + + /// + /// Gets the list of registered conversation state parts. + /// + public IReadOnlyList Parts => this._parts; + + /// + /// Initializes a new instance of the class. + /// + public ConversationStatePartsManager() + { + } + + /// + /// Initializes a new instance of the class with the specified conversation state parts. + /// + /// The conversation state parts to add to the manager. + public ConversationStatePartsManager(IEnumerable conversationtStateExtensions) + { + this._parts.AddRange(conversationtStateExtensions); + } + + /// + /// Gets the list of AI functions that all contained parts expose + /// and which should be used by the consuming AI when using these parts. + /// + public IReadOnlyCollection AIFunctions + { + get + { + if (this._currentAIFunctions == null) + { + this._currentAIFunctions = this.Parts.SelectMany(conversationStateParts => conversationStateParts.AIFunctions).ToList(); + } + + return this._currentAIFunctions; + } + } + + /// + /// Adds a new conversation state part. + /// + /// The conversation state part to register. + public void Add(ConversationStatePart conversationtStatePart) + { + this._parts.Add(conversationtStatePart); + this._currentAIFunctions = null; + } + + /// + /// Adds all conversation state parts registered on the provided dependency injection service provider. + /// + /// The dependency injection service provider to read conversation state parts from. + public void AddFromServiceProvider(IServiceProvider serviceProvider) + { + foreach (var part in serviceProvider.GetServices()) + { + this.Add(part); + } + this._currentAIFunctions = null; + } + + /// + /// Called when a new thread is created. + /// + /// The ID of the new thread. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public async Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.Parts.Select(x => x.OnThreadCreatedAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + } + + /// + /// Called just before a thread is deleted. + /// + /// The id of the thread that will be deleted. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public async Task OnThreadDeleteAsync(string threadId, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.Parts.Select(x => x.OnThreadDeleteAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + } + + /// + /// This method is called when a new message has been contributed to the chat by any participant. + /// + /// The ID of the thread for the new message, if the thread has an ID. + /// The new message. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.Parts.Select(x => x.OnNewMessageAsync(threadId, newMessage, cancellationToken)).ToList()).ConfigureAwait(false); + } + + /// + /// Called just before the Model/Agent/etc. is invoked + /// + /// The most recent messages that the Model/Agent/etc. is being invoked with. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation, containing the combined context from all conversation state parts. + public async Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) + { + var subContexts = await Task.WhenAll(this.Parts.Select(x => x.OnModelInvokeAsync(newMessages, cancellationToken)).ToList()).ConfigureAwait(false); + return string.Join("\n", subContexts); + } + + /// + /// Called when the current conversion is temporarily suspended and any state should be saved. + /// + /// The ID of the current thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the end of each service call. + /// In a client application, this might be when the user closes the chat window or the application. + /// + public async Task OnSuspendAsync(string? threadId, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.Parts.Select(x => x.OnSuspendAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + } + + /// + /// Called when the current conversion is resumed and any state should be restored. + /// + /// The ID of the current thread, if the thread has an ID. + /// The to monitor for cancellation requests. The default is . + /// An async task. + /// + /// In a service that hosts an agent, that is invoked via calls to the service, this might be at the start of each service call where a previous conversation is being continued. + /// In a client application, this might be when the user re-opens the chat window to resume a conversation after having previously closed it. + /// + public async Task OnResumeAsync(string? threadId, CancellationToken cancellationToken = default) + { + await Task.WhenAll(this.Parts.Select(x => x.OnResumeAsync(threadId, cancellationToken)).ToList()).ConfigureAwait(false); + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManagerExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManagerExtensions.cs new file mode 100644 index 000000000000..e64de1d0b180 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ConversationStatePartsManagerExtensions.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for . +/// +[Experimental("SKEXP0130")] +public static class ConversationStatePartsManagerExtensions +{ + /// + /// This method is called when a new message has been contributed to the chat by any participant. + /// + /// The conversation state manager to pass the new message to. + /// The ID of the thread for the new message, if the thread has an ID. + /// The new message. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public static Task OnNewMessageAsync(this ConversationStatePartsManager conversationStatePartsManager, string? threadId, ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + return conversationStatePartsManager.OnNewMessageAsync(threadId, ChatCompletionServiceExtensions.ToChatMessage(newMessage), cancellationToken); + } + + /// + /// Called just before the Model/Agent/etc. is invoked + /// + /// The conversation state manager to call. + /// The most recent messages that the Model/Agent/etc. is being invoked with. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation, containing the combined context from all conversation state parts. + public static Task OnModelInvokeAsync(this ConversationStatePartsManager conversationStatePartsManager, ICollection newMessages, CancellationToken cancellationToken = default) + { + return conversationStatePartsManager.OnModelInvokeAsync(newMessages.Select(ChatCompletionServiceExtensions.ToChatMessage).ToList(), cancellationToken); + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/TextMemoryStore.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/TextMemoryStore.cs new file mode 100644 index 000000000000..599ec7dd1eae --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/TextMemoryStore.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Abstract base class for storing and retrieving text based memories. +/// +[Experimental("SKEXP0001")] +public abstract class TextMemoryStore +{ + /// + /// Retrieves a memory asynchronously by its document name. + /// + /// The name of the document to retrieve the memory for. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the memory text if found, otherwise null. + public abstract Task GetMemoryAsync(string documentName, CancellationToken cancellationToken = default); + + /// + /// Searches for memories that are similar to the given text asynchronously. + /// + /// The text to search for similar memories. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous enumerable of similar memory texts. + public abstract IAsyncEnumerable SimilaritySearch(string query, CancellationToken cancellationToken = default); + + /// + /// Saves a memory asynchronously with the specified document name. + /// + /// The name of the document to save the memory to. + /// The memory text to save. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous save operation. + public abstract Task SaveMemoryAsync(string documentName, string memoryText, CancellationToken cancellationToken = default); + + /// + /// Saves a memory asynchronously with no document name. + /// + /// The memory text to save. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous save operation. + public abstract Task SaveMemoryAsync(string memoryText, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/ConversationStatePartsManagerExtensions.cs b/dotnet/src/SemanticKernel.Core/Memory/ConversationStatePartsManagerExtensions.cs new file mode 100644 index 000000000000..888c56c758d9 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/ConversationStatePartsManagerExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for . +/// +[Experimental("SKEXP0130")] +public static class ConversationStatePartsManagerExtensions +{ + /// + /// Registers plugins required by all conversation state parts contained by this manager on the provided . + /// + /// The conversation state manager to get plugins from. + /// The kernel to register the plugins on. + public static void RegisterPlugins(this ConversationStatePartsManager conversationStatePartsManager, Kernel kernel) + { + kernel.Plugins.AddFromFunctions("Tools", conversationStatePartsManager.AIFunctions.Select(x => x.AsKernelFunction())); + } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs new file mode 100644 index 000000000000..12fbb1551dda --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0Client.cs @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Client for the Mem0 memory service. +/// +internal sealed class Mem0Client +{ + private static readonly Uri s_searchUri = new("/v1/memories/search/", UriKind.Relative); + private static readonly Uri s_createMemoryUri = new("/v1/memories/", UriKind.Relative); + + private readonly HttpClient _httpClient; + + public Mem0Client(HttpClient httpClient) + { + Verify.NotNull(httpClient); + + this._httpClient = httpClient; + } + + public async Task> SearchAsync(string? applicationId, string? agentId, string? threadId, string? userId, string? inputText) + { + if (string.IsNullOrWhiteSpace(applicationId) + && string.IsNullOrWhiteSpace(agentId) + && string.IsNullOrWhiteSpace(threadId) + && string.IsNullOrWhiteSpace(userId)) + { + throw new InvalidOperationException("At least one of applicationId, agentId, threadId, or userId must be provided."); + } + + var searchRequest = new SearchRequest + { + AppId = applicationId, + AgentId = agentId, + RunId = threadId, + UserId = userId, + Query = inputText ?? string.Empty + }; + + // Search. + using var content = new StringContent(JsonSerializer.Serialize(searchRequest, Mem0SourceGenerationContext.Default.SearchRequest), Encoding.UTF8, "application/json"); + var responseMessage = await this._httpClient.PostAsync(s_searchUri, content).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + + // Process response. + var response = await responseMessage.Content.ReadAsStringAsync().ConfigureAwait(false); + var searchResponseItems = JsonSerializer.Deserialize(response, Mem0SourceGenerationContext.Default.SearchResponseItemArray); + return searchResponseItems?.Select(item => item.Memory) ?? []; + } + + public async Task CreateMemoryAsync(string? applicationId, string? agentId, string? threadId, string? userId, string messageContent, string messageRole) + { + if (string.IsNullOrWhiteSpace(applicationId) + && string.IsNullOrWhiteSpace(agentId) + && string.IsNullOrWhiteSpace(threadId) + && string.IsNullOrWhiteSpace(userId)) + { + throw new InvalidOperationException("At least one of applicationId, agentId, threadId, or userId must be provided."); + } + + var createMemoryRequest = new CreateMemoryRequest() + { + AppId = applicationId, + AgentId = agentId, + RunId = threadId, + UserId = userId, + Messages = new[] + { + new CreateMemoryMessage + { + Content = messageContent, + Role = messageRole + } + } + }; + + using var content = new StringContent(JsonSerializer.Serialize(createMemoryRequest, Mem0SourceGenerationContext.Default.CreateMemoryRequest), Encoding.UTF8, "application/json"); + var responseMessage = await this._httpClient.PostAsync(s_createMemoryUri, content).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + } + + public async Task ClearMemoryAsync(string? applicationId, string? agentId, string? threadId, string? userId) + { + string[] paramNames = ["app_id", "agent_id", "run_id", "user_id"]; + + // Build query string. + var querystringParams = new string?[4] { applicationId, agentId, threadId, userId } + .Select((param, index) => string.IsNullOrWhiteSpace(param) ? null : $"{paramNames[index]}={param}") + .Where(x => x != null); + var queryString = string.Join("&", querystringParams); + var clearMemoryUrl = new Uri($"/v1/memories/?{queryString}", UriKind.Relative); + + // Delete. + var responseMessage = await this._httpClient.DeleteAsync(clearMemoryUrl).ConfigureAwait(false); + responseMessage.EnsureSuccessStatusCode(); + } + + internal sealed class CreateMemoryRequest + { + [JsonPropertyName("app_id")] + public string? AppId { get; set; } + [JsonPropertyName("agent_id")] + public string? AgentId { get; set; } + [JsonPropertyName("run_id")] + public string? RunId { get; set; } + [JsonPropertyName("user_id")] + public string? UserId { get; set; } + [JsonPropertyName("messages")] + public CreateMemoryMessage[] Messages { get; set; } = []; + } + + internal sealed class CreateMemoryMessage + { + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; + [JsonPropertyName("role")] + public string Role { get; set; } = string.Empty; + } + + internal sealed class SearchRequest + { + [JsonPropertyName("app_id")] + public string? AppId { get; set; } + [JsonPropertyName("agent_id")] + public string? AgentId { get; set; } = null; + [JsonPropertyName("run_id")] + public string? RunId { get; set; } = null; + [JsonPropertyName("user_id")] + public string? UserId { get; set; } = null; + [JsonPropertyName("query")] + public string Query { get; set; } = string.Empty; + } + + internal sealed class SearchResponseItem + { + [JsonPropertyName("id")] + public string Id { get; set; } = string.Empty; + [JsonPropertyName("memory")] + public string Memory { get; set; } = string.Empty; + [JsonPropertyName("hash")] + public string Hash { get; set; } = string.Empty; + [JsonPropertyName("metadata")] + public object? Metadata { get; set; } + [JsonPropertyName("score")] + public double Score { get; set; } + [JsonPropertyName("created_at")] + public DateTime CreatedAt { get; set; } + [JsonPropertyName("updated_at")] + public DateTime? UpdatedAt { get; set; } + [JsonPropertyName("user_id")] + public string UserId { get; set; } = string.Empty; + [JsonPropertyName("app_id")] + public string? AppId { get; set; } + [JsonPropertyName("agent_id")] + public string AgentId { get; set; } = string.Empty; + [JsonPropertyName("session_id")] + public string RunId { get; set; } = string.Empty; + } +} + +[JsonSourceGenerationOptions(JsonSerializerDefaults.General, + UseStringEnumConverter = false, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = false)] +[JsonSerializable(typeof(Mem0Client.CreateMemoryRequest))] +[JsonSerializable(typeof(Mem0Client.SearchRequest))] +[JsonSerializable(typeof(Mem0Client.SearchResponseItem[]))] +internal partial class Mem0SourceGenerationContext : JsonSerializerContext +{ +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs new file mode 100644 index 000000000000..c5010873bcbf --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponent.cs @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// A component that listens to messages added to the conversation thread, and automatically captures +/// information about the user. It is also able to retrieve this information and add it to the AI invocation context. +/// +/// +/// +/// Mem0 allows memories to be stored under one or more optional scopes: application, agent, thread, and user. +/// At least one scope must always be provided. +/// +/// +/// There are some special considerations when using thread as a scope. +/// A thread id may not be available at the time that this component is instantiated. +/// It is therefore possible to provide no thread id when instantiating this class and instead set +/// to . +/// The component will then capture a thread id when a thread is created or when messages are received +/// and use this thread id to scope the memories in mem0. +/// +/// +/// Note that this component will keep the current thread id in a private field for the duration of +/// the component's lifetime, and therefore using the component with multiple threads, with +/// set to is not supported. +/// +/// +[Experimental("SKEXP0130")] +public sealed class Mem0MemoryComponent : ConversationStatePart +{ + private const string DefaultContextPrompt = "Consider the following memories when answering user questions:"; + + private readonly string? _applicationId; + private readonly string? _agentId; + private readonly string? _threadId; + private string? _perOperationThreadId; + private readonly string? _userId; + private readonly bool _scopeToPerOperationThreadId; + private readonly string _contextPrompt; + + private readonly AIFunction[] _aIFunctions; + + private readonly Mem0Client _mem0Client; + + /// + /// Initializes a new instance of the class. + /// + /// The HTTP client used for making requests. + /// Options for configuring the component. + /// + /// The base address of the required mem0 service, and any authentication headers, should be set on the + /// already, when passed as a parameter here. E.g.: + /// + /// using var httpClient = new HttpClient(); + /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); + /// httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "<Your APIKey>"); + /// new Mem0Client(httpClient); + /// + /// + public Mem0MemoryComponent(HttpClient httpClient, Mem0MemoryComponentOptions? options = default) + { + Verify.NotNull(httpClient); + + if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsolutePath)) + { + throw new ArgumentException("The BaseAddress of the provided httpClient parameter must be set.", nameof(httpClient)); + } + + this._applicationId = options?.ApplicationId; + this._agentId = options?.AgentId; + this._threadId = options?.ThreadId; + this._userId = options?.UserId; + this._scopeToPerOperationThreadId = options?.ScopeToPerOperationThreadId ?? false; + this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; + + this._aIFunctions = [AIFunctionFactory.Create(this.ClearStoredUserFactsAsync)]; + + this._mem0Client = new(httpClient); + } + + /// + public override IReadOnlyCollection AIFunctions => this._aIFunctions; + + /// + public override Task OnThreadCreatedAsync(string? threadId, CancellationToken cancellationToken = default) + { + this.ValidatePerOperationThreadId(threadId); + + this._perOperationThreadId ??= threadId; + return Task.CompletedTask; + } + + /// + public override async Task OnNewMessageAsync(string? threadId, ChatMessage newMessage, CancellationToken cancellationToken = default) + { + Verify.NotNull(newMessage); + this.ValidatePerOperationThreadId(threadId); + + this._perOperationThreadId ??= threadId; + + if (!string.IsNullOrWhiteSpace(newMessage.Text)) + { + await this._mem0Client.CreateMemoryAsync( + this._applicationId, + this._agentId, + this._scopeToPerOperationThreadId ? this._perOperationThreadId : this._threadId, + this._userId, + newMessage.Text, + newMessage.Role.Value).ConfigureAwait(false); + } + } + + /// + public override async Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) + { + Verify.NotNull(newMessages); + + string inputText = string.Join( + Environment.NewLine, + newMessages. + Where(m => m is not null && !string.IsNullOrWhiteSpace(m.Text)). + Select(m => m.Text)); + + var memories = await this._mem0Client.SearchAsync( + this._applicationId, + this._agentId, + this._scopeToPerOperationThreadId ? this._perOperationThreadId : this._threadId, + this._userId, + inputText).ConfigureAwait(false); + + var userInformation = string.Join(Environment.NewLine, memories); + return string.Join(Environment.NewLine, this._contextPrompt, userInformation); + } + + /// + /// Plugin method to clear user preferences stored in memory for the current agent/thread/user. + /// + /// A task that completes when the memory is cleared. + [Description("Deletes any user facts that are stored across multiple conversations.")] + public async Task ClearStoredUserFactsAsync() + { + await this._mem0Client.ClearMemoryAsync( + this._applicationId, + this._agentId, + this._scopeToPerOperationThreadId ? this._perOperationThreadId : this._threadId, + this._userId).ConfigureAwait(false); + } + + /// + /// Validate that we are not receiving a new thread id when the component has already received one before. + /// + /// The new thread id. + private void ValidatePerOperationThreadId(string? threadId) + { + if (this._scopeToPerOperationThreadId && !string.IsNullOrWhiteSpace(threadId) && this._perOperationThreadId != null && threadId != this._perOperationThreadId) + { + throw new InvalidOperationException($"The {nameof(Mem0MemoryComponent)} can only be used with one thread at a time when {nameof(Mem0MemoryComponentOptions.ScopeToPerOperationThreadId)} is set to true."); + } + } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs new file mode 100644 index 000000000000..192446489475 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/Mem0/Mem0MemoryComponentOptions.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Options for the . +/// +[Experimental("SKEXP0130")] +public sealed class Mem0MemoryComponentOptions +{ + /// + /// Gets or sets an optional ID for the application to scope memories to. + /// + /// + /// If not set, the scope of the memories will span all applications. + /// + public string? ApplicationId { get; init; } + + /// + /// Gets or sets an optional ID for the agent to scope memories to. + /// + /// + /// If not set, the scope of the memories will span all agents. + /// + public string? AgentId { get; init; } + + /// + /// Gets or sets an optional ID for the thread to scope memories to. + /// + /// + /// This value will be overridden by any thread id provided to the methods of the . + /// + public string? ThreadId { get; init; } + + /// + /// Gets or sets an optional ID for the user to scope memories to. + /// + /// + /// If not set, the scope of the memories will span all users. + /// + public string? UserId { get; init; } + + /// + /// Gets or sets a value indicating whether memories should be scoped to the thread id provided on a per operation basis. + /// + /// + /// This setting is useful if the thread id is not known when the is instantiated, but + /// per thread scoping is desired. + /// If , and is not set, there will be no per thread scoping. + /// if , and is set, will be used for scoping. + /// If , the thread id will be set to the thread id of the current operation, regardless of the value of . + /// + public bool ScopeToPerOperationThreadId { get; init; } = false; + + /// + /// When providing the memories found in Mem0 to the AI model on invocation, this string is prefixed + /// to those memories, in order to provide some context to the model. + /// + /// + /// Defaults to "Consider the following memories when answering user questions:" + /// + public string? ContextPrompt { get; init; } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/OptionalTextMemoryStore.cs b/dotnet/src/SemanticKernel.Core/Memory/OptionalTextMemoryStore.cs new file mode 100644 index 000000000000..af7797776681 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/OptionalTextMemoryStore.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Memory; + +namespace Microsoft.SemanticKernel.Agents.Memory; + +/// +/// Helper class to get a from the DI container if a name is provided +/// and if the memory store is registered. +/// Class implements no-op methods if no name is provided or the store is not registered. +/// +internal sealed class OptionalTextMemoryStore : TextMemoryStore +{ + private readonly TextMemoryStore? _textMemoryStore; + + /// + /// Initializes a new instance of the class. + /// + /// The kernel to try and get the named store from. + /// The name of the store to get from the DI container. + public OptionalTextMemoryStore(Kernel kernel, string? storeName) + { + if (storeName is not null) + { + this._textMemoryStore = kernel.Services.GetKeyedService(storeName); + } + } + + /// + public override Task GetMemoryAsync(string documentName, CancellationToken cancellationToken = default) + { + if (this._textMemoryStore is not null) + { + return this._textMemoryStore.GetMemoryAsync(documentName, cancellationToken); + } + + return Task.FromResult(null); + } + + /// + public override Task SaveMemoryAsync(string documentName, string memoryText, CancellationToken cancellationToken = default) + { + if (this._textMemoryStore is not null) + { + return this._textMemoryStore.SaveMemoryAsync(documentName, memoryText, cancellationToken); + } + + return Task.CompletedTask; + } + + /// + public override Task SaveMemoryAsync(string memoryText, CancellationToken cancellationToken = default) + { + if (this._textMemoryStore is not null) + { + return this._textMemoryStore.SaveMemoryAsync(memoryText, cancellationToken); + } + + return Task.CompletedTask; + } + + /// + public override IAsyncEnumerable SimilaritySearch(string query, CancellationToken cancellationToken = default) + { + if (this._textMemoryStore is not null) + { + return this._textMemoryStore.SimilaritySearch(query, cancellationToken); + } + + return AsyncEnumerable.Empty(); + } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs new file mode 100644 index 000000000000..6564cca67df0 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponent.cs @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// A component that does a search based on any messages that the AI model is invoked with and injects the results into the AI model invocation context. +/// +[Experimental("SKEXP0130")] +public sealed class TextRagComponent : ConversationStatePart +{ + private const string DefaultPluginSearchFunctionName = "Search"; + private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question."; + private const string DefaultContextPrompt = "Consider the following information when responding to the user:"; + private const string DefaultIncludeCitationsPrompt = "Include citations to the relevant information where it is referenced in the response."; + + private readonly ITextSearch _textSearch; + + private readonly AIFunction[] _aIFunctions; + + /// + /// Initializes a new instance of the class. + /// + /// The text search component to retrieve results from. + /// Options that configure the behavior of the component. + /// + public TextRagComponent(ITextSearch textSearch, TextRagComponentOptions? options = default) + { + Verify.NotNull(textSearch); + + this._textSearch = textSearch; + this.Options = options ?? new(); + + this._aIFunctions = + [ + AIFunctionFactory.Create( + this.SearchAsync, + name: this.Options.PluginFunctionName ?? DefaultPluginSearchFunctionName, + description: this.Options.PluginFunctionDescription ?? DefaultPluginSearchFunctionDescription) + ]; + } + + /// + /// Gets the options that have been configured for this component. + /// + public TextRagComponentOptions Options { get; } + + /// + public override IReadOnlyCollection AIFunctions + { + get + { + if (this.Options.SearchTime != TextRagComponentOptions.RagBehavior.ViaPlugin) + { + return Array.Empty(); + } + + return this._aIFunctions; + } + } + + /// + public override async Task OnModelInvokeAsync(ICollection newMessages, CancellationToken cancellationToken = default) + { + if (this.Options.SearchTime != TextRagComponentOptions.RagBehavior.BeforeAIInvoke) + { + return string.Empty; + } + + Verify.NotNull(newMessages); + + string input = string.Join("\n", newMessages.Where(m => m is not null).Select(m => m.Text)); + + var searchResults = await this._textSearch.GetTextSearchResultsAsync( + input, + new() { Top = this.Options.Top }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + var results = await searchResults.Results.ToListAsync(cancellationToken).ConfigureAwait(false); + + return this.FormatResults(results); + } + + /// + /// Plugin method to search the database on demand. + /// + [KernelFunction] + internal async Task SearchAsync(string userQuestion, CancellationToken cancellationToken = default) + { + var searchResults = await this._textSearch.GetTextSearchResultsAsync( + userQuestion, + new() { Top = this.Options.Top }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + var results = await searchResults.Results.ToListAsync(cancellationToken).ConfigureAwait(false); + + return this.FormatResults(results); + } + + /// + /// Format the results showing the content with source link and name for each result. + /// + /// The results to format. + /// The formatted results. + private string FormatResults(List results) + { + if (this.Options.ContextFormatter is not null) + { + return this.Options.ContextFormatter(results); + } + + if (results.Count == 0) + { + return string.Empty; + } + + var sb = new StringBuilder(); + sb.AppendLine(this.Options.ContextPrompt ?? DefaultContextPrompt); + for (int i = 0; i < results.Count; i++) + { + var result = results[i]; + sb.AppendLine($"Item {i + 1}:"); + sb.AppendLine($"Name: {result.Name}"); + sb.AppendLine($"Link: {result.Link}"); + sb.AppendLine($"Contents: {result.Value}"); + } + sb.AppendLine(this.Options.IncludeCitationsPrompt ?? DefaultIncludeCitationsPrompt); + sb.AppendLine(); + return sb.ToString(); + } +} + +[JsonSourceGenerationOptions(JsonSerializerDefaults.General, + UseStringEnumConverter = false, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = false)] +[JsonSerializable(typeof(List))] +internal partial class TextRagSourceGenerationContext : JsonSerializerContext +{ +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs new file mode 100644 index 000000000000..1bb97e4ac729 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagComponentOptions.cs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Contains options for the . +/// +[Experimental("SKEXP0130")] +public sealed class TextRagComponentOptions +{ + private int _top = 3; + + /// + /// Maximum number of results to return from the similarity search. + /// + /// The value must be greater than 0. + /// The default value is 3 if not set. + public int Top + { + get => this._top; + init + { + if (value < 1) + { + throw new ArgumentOutOfRangeException(nameof(value), "Top must be greater than 0."); + } + + this._top = value; + } + } + + /// + /// Gets or sets the time at which the text search is performed. + /// + public RagBehavior SearchTime { get; init; } = RagBehavior.BeforeAIInvoke; + + /// + /// Gets or sets the name of the plugin method that will be made available for searching + /// if the option is set to . + /// + /// + /// Defaults to "Search" if not set. + /// + public string? PluginFunctionName { get; init; } + + /// + /// Gets or sets the description of the plugin method that will be made available for searching + /// if the option is set to . + /// + /// + /// Defaults to "Allows searching for additional information to help answer the user question." if not set. + /// + public string? PluginFunctionDescription { get; init; } + + /// + /// When providing the text chunks to the AI model on invocation, this string is prefixed + /// to those chunks, in order to provide some context to the model. + /// + /// + /// Defaults to "Consider the following information when responding to the user:" + /// + public string? ContextPrompt { get; init; } + + /// + /// When providing the text chunks to the AI model on invocation, this string is postfixed + /// to those chunks, in order to instruct the model to include citations. + /// + /// + /// Defaults to "Include citations to the relevant information where it is referenced in the response.:" + /// + public string? IncludeCitationsPrompt { get; init; } + + /// + /// Optional delegate to override the default context creation implementation. + /// + /// + /// + /// If provided, this delegate will be used to do the following: + /// 1. Create the output context provided by the when invoking the AI model. + /// 2. Create the response text when invoking the component via a plugin. + /// + /// + /// Note that the delegate should include the context prompt and the + /// include citations prompt if they are required in the output. + /// The and settings + /// will not be used if providing this delegate. + /// + /// + public ContextFormatterType? ContextFormatter { get; init; } + + /// + /// Choices for controlling the behavior of the . + /// + public enum RagBehavior + { + /// + /// A search is performed each time that the model/agent is invoked just before invocation + /// and the results are provided to the model/agent via the invocation context. + /// + BeforeAIInvoke, + + /// + /// A search may be performed by the model/agent on demand via function calling. + /// + ViaPlugin + } + + /// + /// Delegate type for formatting the output context for the component. + /// + /// The results returned by the text search. + /// The formatted context. + public delegate string ContextFormatterType(List results); +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs new file mode 100644 index 000000000000..e3b221fc74d5 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagDocument.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Memory.TextRag; + +/// +/// Represents a document that can be used for Retrieval Augmented Generation (RAG). +/// +[Experimental("SKEXP0130")] +public sealed class TextRagDocument +{ + /// + /// Gets or sets an optional list of namespaces that the document should belong to. + /// + /// + /// A namespace is a logical grouping of documents, e.g. may include a group id to scope the document to a specific group of users. + /// + public List Namespaces { get; set; } = []; + + /// + /// Gets or sets the content as text. + /// + public string? Text { get; set; } + + /// + /// Gets or sets an optional source ID for the document. + /// + /// + /// This ID should be unique within the collection that the document is stored in, and can + /// be used to map back to the source artifact for this document. + /// If updates need to be made later or the source document was deleted and this document + /// also needs to be deleted, this id can be used to find the document again. + /// + public string? SourceId { get; set; } + + /// + /// Gets or sets an optional name for the source document. + /// + /// + /// This can be used to provide display names for citation links when the document is referenced as + /// part of a response to a query. + /// + public string? SourceName { get; set; } + + /// + /// Gets or sets an optional link back to the source of the document. + /// + /// + /// This can be used to provide citation links when the document is referenced as + /// part of a response to a query. + /// + public string? SourceLink { get; set; } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs new file mode 100644 index 000000000000..143af21f34a5 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStore.cs @@ -0,0 +1,380 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Data; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Memory.TextRag; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// A class that allows for easy storage and retrieval of documents in a Vector Store for Retrieval Augmented Generation (RAG). +/// +/// The key type to use with the vector store. +[Experimental("SKEXP0130")] +public sealed class TextRagStore : ITextSearch, IDisposable + where TKey : notnull +{ + private readonly IVectorStore _vectorStore; + private readonly ITextEmbeddingGenerationService _textEmbeddingGenerationService; + private readonly int _vectorDimensions; + private readonly TextRagStoreOptions _options; + + private readonly Lazy>> _vectorStoreRecordCollection; + private readonly SemaphoreSlim _collectionInitializationLock = new(1, 1); + private bool _collectionInitialized = false; + private bool _disposedValue; + + /// + /// Initializes a new instance of the class. + /// + /// The vector store to store and read the memories from. + /// The service to use for generating embeddings for the memories. + /// The name of the collection in the vector store to store and read the memories from. + /// The number of dimensions to use for the memory embeddings. + /// Options to configure the behavior of this class. + /// Thrown if the key type provided is not supported. + public TextRagStore( + IVectorStore vectorStore, + ITextEmbeddingGenerationService textEmbeddingGenerationService, + string collectionName, + int vectorDimensions, + TextRagStoreOptions? options = default) + { + // Verify + Verify.NotNull(vectorStore); + Verify.NotNull(textEmbeddingGenerationService); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.True(vectorDimensions > 0, "Vector dimensions must be greater than 0"); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(Guid)) + { + throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}'"); + } + + if (typeof(TKey) != typeof(string) && options?.UseSourceIdAsPrimaryKey is true) + { + throw new NotSupportedException($"The {nameof(TextRagStoreOptions.UseSourceIdAsPrimaryKey)} option can only be used when the key type is 'string'."); + } + + // Assign + this._vectorStore = vectorStore; + this._textEmbeddingGenerationService = textEmbeddingGenerationService; + this._vectorDimensions = vectorDimensions; + this._options = options ?? new TextRagStoreOptions(); + + // Create a definition so that we can use the dimensions provided at runtime. + VectorStoreRecordDefinition ragDocumentDefinition = new() + { + Properties = new List() + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("Namespaces", typeof(List)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("SourceId", typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Text", typeof(string)), + new VectorStoreRecordDataProperty("SourceName", typeof(string)), + new VectorStoreRecordDataProperty("SourceReference", typeof(string)), + new VectorStoreRecordVectorProperty("TextEmbedding", typeof(ReadOnlyMemory)) { Dimensions = vectorDimensions }, + } + }; + + this._vectorStoreRecordCollection = new Lazy>>(() => + this._vectorStore.GetCollection>(collectionName, ragDocumentDefinition)); + } + + /// + /// Upserts a batch of documents into the vector store. + /// + /// The documents to upload. + /// Optional options to control the upsert behavior. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the documents have been upserted. + public async Task UpsertDocumentsAsync(IEnumerable documents, TextRagStoreUpsertOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(documents); + + var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); + + var storageDocumentsTasks = documents.Select(async document => + { + if (document is null) + { + throw new ArgumentNullException(nameof(documents), "One of the provided documents is null."); + } + + // Without text we cannot generate a vector. + if (string.IsNullOrWhiteSpace(document.Text)) + { + throw new ArgumentException($"The {nameof(TextRagDocument.Text)} property must be set.", nameof(document)); + } + + // If we aren't persisting the text, we need a source id or link to refer back to the original document. + if (options?.PersistSourceText is false && string.IsNullOrWhiteSpace(document.SourceId) && string.IsNullOrWhiteSpace(document.SourceLink)) + { + throw new ArgumentException($"Either the {nameof(TextRagDocument.SourceId)} or {nameof(TextRagDocument.SourceLink)} properties must be set when the {nameof(TextRagStoreUpsertOptions.PersistSourceText)} setting is false.", nameof(document)); + } + + var key = GenerateUniqueKey(this._options.UseSourceIdAsPrimaryKey ?? false ? document.SourceId : null); + var textEmbeddings = await this._textEmbeddingGenerationService.GenerateEmbeddingsAsync([document.Text!]).ConfigureAwait(false); + + return new TextRagStorageDocument + { + Key = key, + Namespaces = document.Namespaces, + SourceId = document.SourceId, + Text = options?.PersistSourceText is false ? null : document.Text, + SourceName = document.SourceName, + SourceLink = document.SourceLink, + TextEmbedding = textEmbeddings.Single() + }; + }); + + var storageDocuments = await Task.WhenAll(storageDocumentsTasks).ConfigureAwait(false); + await vectorStoreRecordCollection.UpsertBatchAsync(storageDocuments, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + } + + /// + public async Task> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + + return new(searchResult.Select(x => x.Text ?? string.Empty).ToAsyncEnumerable()); + } + + /// + public async Task> GetTextSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + + var results = searchResult.Select(x => new TextSearchResult(x.Text ?? string.Empty) { Name = x.SourceName, Link = x.SourceLink }); + return new(searchResult.Select(x => + new TextSearchResult(x.Text ?? string.Empty) + { + Name = x.SourceName, + Link = x.SourceLink + }).ToAsyncEnumerable()); + } + + /// + public async Task> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + var searchResult = await this.SearchInternalAsync(query, searchOptions, cancellationToken).ConfigureAwait(false); + return new(searchResult.Select(x => (object)x).ToAsyncEnumerable()); + } + + /// + /// Internal search implementation with hydration of id / link only storage. + /// + /// The text query to find similar documents to. + /// Search options. + /// The to monitor for cancellation requests. The default is . + /// The search results. + private async Task>> SearchInternalAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); + + // Optional filter to limit the search to a specific namespace. + Expression, bool>>? filter = string.IsNullOrWhiteSpace(this._options.SearchNamespace) ? null : x => x.Namespaces.Contains(this._options.SearchNamespace); + + // Generate the vector for the query and search. + var vector = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(query, cancellationToken: cancellationToken).ConfigureAwait(false); + var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( + vector, + options: new() + { + Top = searchOptions?.Top ?? 3, + Filter = filter, + }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + // Retrieve the documents from the search results. + var retrievedDocs = await searchResult + .Results + .SelectAsync(x => x.Record, cancellationToken) + .ToListAsync(cancellationToken) + .ConfigureAwait(false); + + // Find any source ids and links for which the text needs to be retrieved. + var sourceIdsToRetrieve = retrievedDocs + .Where(x => string.IsNullOrWhiteSpace(x.Text)) + .Select(x => (x.SourceId, x.SourceLink)) + .ToList(); + + if (sourceIdsToRetrieve.Count > 0) + { + if (this._options.SourceRetrievalCallback is null) + { + throw new InvalidOperationException($"The {nameof(TextRagStoreOptions.SourceRetrievalCallback)} option must be set if retrieving documents without stored text."); + } + + var retrievedText = await this._options.SourceRetrievalCallback(sourceIdsToRetrieve).ConfigureAwait(false); + + if (retrievedText is null) + { + throw new InvalidOperationException($"The {nameof(TextRagStoreOptions.SourceRetrievalCallback)} must return a non-null value."); + } + + // Update the retrieved documents with the retrieved text. + retrievedDocs = retrievedDocs.GroupJoin( + retrievedText, + retrievedDocs => (retrievedDocs.SourceId, retrievedDocs.SourceLink), + retrievedText => (retrievedText.sourceId, retrievedText.sourceLink), + (retrievedDoc, retrievedText) => (retrievedDoc, retrievedText)) + .SelectMany( + joinedSet => joinedSet.retrievedText.DefaultIfEmpty(), + (combined, retrievedText) => + { + combined.retrievedDoc.Text = retrievedText.text ?? combined.retrievedDoc.Text; + return combined.retrievedDoc; + }) + .ToList(); + } + + return retrievedDocs; + } + + /// + /// Thread safe method to get the collection and ensure that it is created at least once. + /// + /// The to monitor for cancellation requests. The default is . + /// The created collection. + private async Task>> EnsureCollectionExistsAsync(CancellationToken cancellationToken) + { + var vectorStoreRecordCollection = this._vectorStoreRecordCollection.Value; + + // Return immediately if the collection is already created, no need to do any locking in this case. + if (this._collectionInitialized) + { + return vectorStoreRecordCollection; + } + + // Wait on a lock to ensure that only one thread can create the collection. + await this._collectionInitializationLock.WaitAsync(cancellationToken).ConfigureAwait(false); + + // If multiple threads waited on the lock, and the first already created the collection, + // we can return immediately without doing any work in subsequent threads. + if (this._collectionInitialized) + { + this._collectionInitializationLock.Release(); + return vectorStoreRecordCollection; + } + + // Only the winning thread should reach this point and create the collection. + try + { + await vectorStoreRecordCollection.CreateCollectionIfNotExistsAsync(cancellationToken).ConfigureAwait(false); + this._collectionInitialized = true; + } + finally + { + this._collectionInitializationLock.Release(); + } + + return vectorStoreRecordCollection; + } + + /// + /// Generates a unique key for the RAG document. + /// + /// Source id of the source document for this RAG document. + /// The type of the key to use, since different databases require/support different keys. + /// A new unique key. + /// Thrown if the requested key type is not supported. + private static TDocumentKey GenerateUniqueKey(string? sourceId) + => typeof(TDocumentKey) switch + { + _ when typeof(TDocumentKey) == typeof(string) && !string.IsNullOrWhiteSpace(sourceId) => (TDocumentKey)(object)sourceId!, + _ when typeof(TDocumentKey) == typeof(string) => (TDocumentKey)(object)Guid.NewGuid().ToString(), + _ when typeof(TDocumentKey) == typeof(Guid) => (TDocumentKey)(object)Guid.NewGuid(), + + _ => throw new NotSupportedException($"Unsupported key of type '{typeof(TDocumentKey).Name}'") + }; + + /// + private void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._collectionInitializationLock.Dispose(); + } + + this._disposedValue = true; + } + } + + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + /// + /// The data model to use for storing RAG documents in the vector store. + /// + /// The type of the key to use, since different databases require/support different keys. + internal sealed class TextRagStorageDocument + { + /// + /// Gets or sets a unique identifier for the memory document. + /// + public TDocumentKey Key { get; set; } = default!; + + /// + /// Gets or sets an optional list of namespaces that the document should belong to. + /// + /// + /// A namespace is a logical grouping of documents, e.g. may include a group id to scope the document to a specific group of users. + /// + public List Namespaces { get; set; } = []; + + /// + /// Gets or sets the content as text. + /// + public string? Text { get; set; } + + /// + /// Gets or sets an optional source ID for the document. + /// + /// + /// This ID should be unique within the collection that the document is stored in, and can + /// be used to map back to the source artifact for this document. + /// If updates need to be made later or the source document was deleted and this document + /// also needs to be deleted, this id can be used to find the document again. + /// + public string? SourceId { get; set; } + + /// + /// Gets or sets an optional name for the source document. + /// + /// + /// This can be used to provide display names for citation links when the document is referenced as + /// part of a response to a query. + /// + public string? SourceName { get; set; } + + /// + /// Gets or sets an optional link back to the source of the document. + /// + /// + /// This can be used to provide citation links when the document is referenced as + /// part of a response to a query. + /// + public string? SourceLink { get; set; } + + /// + /// Gets or sets the embedding for the text content. + /// + public ReadOnlyMemory TextEmbedding { get; set; } + } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs new file mode 100644 index 000000000000..6c2ecf3254a6 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreOptions.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Memory; + +/// +/// Contains options for the . +/// +public sealed class TextRagStoreOptions +{ + /// + /// Gets or sets an optional namespace to pre-filter the possible + /// records with when doing a vector search. + /// + public string? SearchNamespace { get; init; } + + /// + /// Gets or sets a value indicating whether to use the source ID as the primary key for records. + /// + /// + /// + /// Using the source ID as the primary key allows for easy updates from the source for any changed + /// records, since those records can just be upserted again, and will overwrite the previous version + /// of the same record. + /// + /// + /// This setting can only be used when the chosen key type is a string. + /// + /// + /// + /// Defaults to false if not set. + /// + public bool? UseSourceIdAsPrimaryKey { get; init; } + + /// + /// Gets or sets an optional callback to load the source text from the source id or source link + /// if the source text is not persisted in the database. + /// + public SourceRetriever? SourceRetrievalCallback { get; init; } + + /// + /// Delegate type for loading the source text from the source id or source link + /// if the source text is not persisted in the database. + /// + /// The ids and links of the text to load. + /// The source text with the source id or source link. + public delegate Task> SourceRetriever(List<(string? sourceId, string? sourceLink)> sourceIds); +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreUpsertOptions.cs b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreUpsertOptions.cs new file mode 100644 index 000000000000..83b20179c064 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/TextRag/TextRagStoreUpsertOptions.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Memory.TextRag; + +/// +/// Contains options for . +/// +public sealed class TextRagStoreUpsertOptions +{ + /// + /// Gets or sets a value indicating whether the source text should be persisted in the database. + /// + /// + /// Defaults to if not set. + /// + public bool? PersistSourceText { get; init; } +} diff --git a/dotnet/src/SemanticKernel.Core/Memory/VectorDataTextMemoryStore.cs b/dotnet/src/SemanticKernel.Core/Memory/VectorDataTextMemoryStore.cs new file mode 100644 index 000000000000..c3681a1c7745 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Memory/VectorDataTextMemoryStore.cs @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Embeddings; + +namespace Microsoft.SemanticKernel.Memory; + +#pragma warning disable SKEXP0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +/// +/// Class to store and retrieve text-based memories to and from a vector store. +/// +/// The key type to use with the vector store. +[Experimental("SKEXP0001")] +public class VectorDataTextMemoryStore : TextMemoryStore, IDisposable + where TKey : notnull +{ + private readonly IVectorStore _vectorStore; + private readonly ITextEmbeddingGenerationService _textEmbeddingGenerationService; + private readonly string _storageNamespace; + private readonly int _vectorDimensions; + private readonly Lazy>> _vectorStoreRecordCollection; + private readonly SemaphoreSlim _collectionInitializationLock = new(1, 1); + private bool _collectionInitialized = false; + private bool _disposedValue; + + /// + /// Initializes a new instance of the class. + /// + /// The vector store to store and read the memories from. + /// The service to use for generating embeddings for the memories. + /// The name of the collection in the vector store to store and read the memories from. + /// The namespace to scope memories to within the collection. + /// The number of dimensions to use for the memory embeddings. + /// Thrown if the key type provided is not supported. + public VectorDataTextMemoryStore(IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService, string collectionName, string storageNamespace, int vectorDimensions) + { + Verify.NotNull(vectorStore); + Verify.NotNull(textEmbeddingGenerationService); + Verify.NotNullOrWhiteSpace(collectionName); + Verify.NotNullOrWhiteSpace(storageNamespace); + Verify.True(vectorDimensions > 0, "Vector dimensions must be greater than 0"); + + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(Guid)) + { + throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}'"); + } + + VectorStoreRecordDefinition memoryDocumentDefinition = new() + { + Properties = new List() + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("Namespace", typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Name", typeof(string)), + new VectorStoreRecordDataProperty("Category", typeof(string)), + new VectorStoreRecordDataProperty("MemoryText", typeof(string)), + new VectorStoreRecordVectorProperty("MemoryTextEmbedding", typeof(ReadOnlyMemory)) { Dimensions = vectorDimensions }, + } + }; + + this._vectorStore = vectorStore; + this._textEmbeddingGenerationService = textEmbeddingGenerationService; + this._storageNamespace = storageNamespace; + this._vectorDimensions = vectorDimensions; + this._vectorStoreRecordCollection = new Lazy>>(() => + this._vectorStore.GetCollection>(collectionName, memoryDocumentDefinition)); + } + + /// + public override async Task GetMemoryAsync(string documentName, CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + + // If the database supports string keys, get using the namespace + document name. + if (typeof(TKey) == typeof(string)) + { + var namespaceKey = $"{this._storageNamespace}:{documentName}"; + + var record = await vectorStoreRecordCollection.GetAsync((TKey)(object)namespaceKey, cancellationToken: cancellationToken).ConfigureAwait(false); + return record?.MemoryText; + } + + // Otherwise do a search with a filter on the document name and namespace. + ReadOnlyMemory vector = new(new float[this._vectorDimensions]); + var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( + vector, + options: new() + { + Top = 1, + Filter = x => x.Name == documentName && x.Namespace == this._storageNamespace, + }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + var results = await searchResult.Results.ToListAsync(cancellationToken).ConfigureAwait(false); + + if (results.Count == 0) + { + return null; + } + + return results[0].Record.MemoryText; + } + + /// + public override async IAsyncEnumerable SimilaritySearch(string query, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + + var vector = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(query, cancellationToken: cancellationToken).ConfigureAwait(false); + var searchResult = await vectorStoreRecordCollection.VectorizedSearchAsync( + vector, + options: new() + { + Top = 3, + Filter = x => x.Namespace == this._storageNamespace, + }, + cancellationToken: cancellationToken).ConfigureAwait(false); + + await foreach (var result in searchResult.Results.ConfigureAwait(false)) + { + yield return result.Record.MemoryText; + } + } + + /// + public override async Task SaveMemoryAsync(string documentName, string memoryText, CancellationToken cancellationToken = default) + { + var vectorStoreRecordCollection = await this.EnsureCollectionCreatedAsync(cancellationToken).ConfigureAwait(false); + + var vector = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync( + string.IsNullOrWhiteSpace(memoryText) ? "Empty" : memoryText, + cancellationToken: cancellationToken).ConfigureAwait(false); + + var memoryDocument = new MemoryDocument + { + Key = GenerateUniqueKey(this._storageNamespace, documentName), + Namespace = this._storageNamespace, + Name = documentName, + MemoryText = memoryText, + MemoryTextEmbedding = vector, + }; + + await vectorStoreRecordCollection.UpsertAsync(memoryDocument, cancellationToken: cancellationToken).ConfigureAwait(false); + } + + /// + public override Task SaveMemoryAsync(string memoryText, CancellationToken cancellationToken = default) + { + return this.SaveMemoryAsync(null!, memoryText, cancellationToken); + } + + /// + /// Thread safe method to get the collection and ensure that it is created at least once. + /// + /// The to monitor for cancellation requests. The default is . + /// The created collection. + private async Task>> EnsureCollectionCreatedAsync(CancellationToken cancellationToken) + { + var vectorStoreRecordCollection = this._vectorStoreRecordCollection.Value; + + // Return immediately if the collection is already created, no need to do any locking in this case. + if (this._collectionInitialized) + { + return vectorStoreRecordCollection; + } + + // Wait on a lock to ensure that only one thread can create the collection. + await this._collectionInitializationLock.WaitAsync(cancellationToken).ConfigureAwait(false); + + // If multiple threads waited on the lock, and the first already created the collection, + // we can return immediately without doing any work in subsequent threads. + if (this._collectionInitialized) + { + this._collectionInitializationLock.Release(); + return vectorStoreRecordCollection; + } + + // Only the winning thread should reach this point and create the collection. + try + { + await vectorStoreRecordCollection.CreateCollectionIfNotExistsAsync(cancellationToken).ConfigureAwait(false); + this._collectionInitialized = true; + } + finally + { + this._collectionInitializationLock.Release(); + } + + return vectorStoreRecordCollection; + } + + /// + /// Generates a unique key for the memory document. + /// + /// Storage namespace to use for string keys. + /// An optional document name to use for the key if the database supports string keys. + /// The type of the key to use, since different databases require/support different keys. + /// A new unique key. + /// Thrown if the requested key type is not supported. + private static TDocumentKey GenerateUniqueKey(string storageNamespace, string? documentName) + => typeof(TDocumentKey) switch + { + _ when typeof(TDocumentKey) == typeof(string) && documentName is not null => (TDocumentKey)(object)$"{storageNamespace}:{documentName}", + _ when typeof(TDocumentKey) == typeof(string) => (TDocumentKey)(object)Guid.NewGuid().ToString(), + _ when typeof(TDocumentKey) == typeof(Guid) => (TDocumentKey)(object)Guid.NewGuid(), + + _ => throw new NotSupportedException($"Unsupported key of type '{typeof(TDocumentKey).Name}'") + }; + + /// + /// The data model to use for storing memory documents in the vector store. + /// + /// The type of the key to use, since different databases require/support different keys. + private sealed class MemoryDocument + { + /// + /// Gets or sets a unique identifier for the memory document. + /// + public TDocumentKey Key { get; set; } = default!; + + /// + /// Gets or sets the namespace for the memory document. + /// + /// + /// A namespace is a logical grouping of memory documents, e.g. may include a user id to scope the memory to a specific user. + /// + public string Namespace { get; set; } = string.Empty; + + /// + /// Gets or sets an optional name for the memory document. + /// + public string Name { get; set; } = string.Empty; + + /// + /// Gets or sets an optional category for the memory document. + /// + public string Category { get; set; } = string.Empty; + + /// + /// Gets or sets the actual memory content as text. + /// + public string MemoryText { get; set; } = string.Empty; + + public ReadOnlyMemory MemoryTextEmbedding { get; set; } + } + + /// + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._collectionInitializationLock.Dispose(); + } + + this._disposedValue = true; + } + } + + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } +} diff --git a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj index 2a5d5d03d961..c4aa91cc2359 100644 --- a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj +++ b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj @@ -6,7 +6,7 @@ Microsoft.SemanticKernel net8.0;netstandard2.0 true - $(NoWarn);SKEXP0001,SKEXP0120 + $(NoWarn);NU5104;SKEXP0001,SKEXP0120 true true @@ -34,6 +34,7 @@ + @@ -45,6 +46,7 @@ + diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartTests.cs new file mode 100644 index 000000000000..4ee007b2d82e --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartTests.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for the class. +/// +public class ConversationStatePartTests +{ + [Fact] + public void AIFunctionsBaseImplementationIsEmpty() + { + // Arrange. + var mockPart = new Mock() { CallBase = true }; + + // Act. + var functions = mockPart.Object.AIFunctions; + + // Assert. + Assert.NotNull(functions); + Assert.Empty(functions); + } + + [Fact] + public async Task OnThreadCreatedBaseImplementationSucceeds() + { + // Arrange. + var mockPart = new Mock() { CallBase = true }; + + // Act & Assert. + await mockPart.Object.OnThreadCreatedAsync("threadId", CancellationToken.None); + } + + [Fact] + public async Task OnNewMessageBaseImplementationSucceeds() + { + // Arrange. + var mockPart = new Mock() { CallBase = true }; + var newMessage = new ChatMessage(ChatRole.User, "Hello"); + + // Act & Assert. + await mockPart.Object.OnNewMessageAsync("threadId", newMessage, CancellationToken.None); + } + + [Fact] + public async Task OnThreadDeleteBaseImplementationSucceeds() + { + // Arrange. + var mockPart = new Mock() { CallBase = true }; + + // Act & Assert. + await mockPart.Object.OnThreadDeleteAsync("threadId", CancellationToken.None); + } + + [Fact] + public async Task OnSuspendBaseImplementationSucceeds() + { + // Arrange. + var mockPart = new Mock() { CallBase = true }; + + // Act & Assert. + await mockPart.Object.OnSuspendAsync("threadId", CancellationToken.None); + } + + [Fact] + public async Task OnResumeBaseImplementationSucceeds() + { + // Arrange. + var mockPart = new Mock() { CallBase = true }; + + // Act & Assert. + await mockPart.Object.OnResumeAsync("threadId", CancellationToken.None); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerExtensionsTests.cs new file mode 100644 index 000000000000..b074b44d3b09 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerExtensionsTests.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Tests for the ConversationStatePartsManagerExtensions class. +/// +public class ConversationStatePartsManagerExtensionsTests +{ + [Fact] + public async Task OnNewMessageShouldConvertMessageAndInvokeRegisteredPartsAsync() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var partMock = new Mock(); + manager.Add(partMock.Object); + + var newMessage = new ChatMessageContent(AuthorRole.User, "Test Message"); + + // Act + await manager.OnNewMessageAsync("test-thread-id", newMessage); + + // Assert + partMock.Verify(x => x.OnNewMessageAsync("test-thread-id", It.Is(m => m.Text == "Test Message" && m.Role == ChatRole.User), It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnAIInvocationShouldConvertMessagesInvokeRegisteredPartsAsync() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var partMock = new Mock(); + manager.Add(partMock.Object); + + var messages = new List + { + new(AuthorRole.User, "Message 1"), + new(AuthorRole.Assistant, "Message 2") + }; + + partMock + .Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync("Combined Context"); + + // Act + var result = await manager.OnModelInvokeAsync(messages); + + // Assert + Assert.Equal("Combined Context", result); + partMock.Verify(x => x.OnModelInvokeAsync(It.Is>(m => m.Count == 2), It.IsAny()), Times.Once); + } + + [Fact] + public void RegisterPluginsShouldConvertAIFunctionsAndRegisterAsPlugins() + { + // Arrange + var kernel = new Kernel(); + var manager = new ConversationStatePartsManager(); + var partMock = new Mock(); + var aiFunctionMock = AIFunctionFactory.Create(() => "Hello", "TestFunction"); + partMock + .Setup(x => x.AIFunctions) + .Returns(new List { aiFunctionMock }); + manager.Add(partMock.Object); + + // Act + manager.RegisterPlugins(kernel); + + // Assert + var registeredFunction = kernel.Plugins.GetFunction("Tools", aiFunctionMock.Name); + Assert.NotNull(registeredFunction); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerTests.cs new file mode 100644 index 000000000000..e76b13eda2e2 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/ConversationStatePartsManagerTests.cs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for the class. +/// +public class ConversationStatePartsManagerTests +{ + [Fact] + public void ConstructorShouldInitializeEmptyPartsList() + { + // Act + var manager = new ConversationStatePartsManager(); + + // Assert + Assert.NotNull(manager.Parts); + Assert.Empty(manager.Parts); + } + + [Fact] + public void ConstructorShouldInitializeWithProvidedParts() + { + // Arrange + var mockPart = new Mock(); + + // Act + var manager = new ConversationStatePartsManager(new[] { mockPart.Object }); + + // Assert + Assert.Single(manager.Parts); + Assert.Contains(mockPart.Object, manager.Parts); + } + + [Fact] + public void AddShouldRegisterNewPart() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + + // Act + manager.Add(mockPart.Object); + + // Assert + Assert.Single(manager.Parts); + Assert.Contains(mockPart.Object, manager.Parts); + } + + [Fact] + public void AddFromServiceProviderShouldRegisterPartsFromServiceProvider() + { + // Arrange + var serviceCollection = new ServiceCollection(); + var mockPart = new Mock(); + serviceCollection.AddSingleton(mockPart.Object); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var manager = new ConversationStatePartsManager(); + + // Act + manager.AddFromServiceProvider(serviceProvider); + + // Assert + Assert.Single(manager.Parts); + Assert.Contains(mockPart.Object, manager.Parts); + } + + [Fact] + public async Task OnThreadCreatedAsyncShouldCallOnThreadCreatedOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + manager.Add(mockPart.Object); + + // Act + await manager.OnThreadCreatedAsync("test-thread-id"); + + // Assert + mockPart.Verify(x => x.OnThreadCreatedAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnThreadDeleteAsyncShouldCallOnThreadDeleteOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + manager.Add(mockPart.Object); + + // Act + await manager.OnThreadDeleteAsync("test-thread-id"); + + // Assert + mockPart.Verify(x => x.OnThreadDeleteAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnNewMessageAsyncShouldCallOnNewMessageOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + var message = new ChatMessage(ChatRole.User, "Hello"); + manager.Add(mockPart.Object); + + // Act + await manager.OnNewMessageAsync("test-thread-id", message); + + // Assert + mockPart.Verify(x => x.OnNewMessageAsync("test-thread-id", message, It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnAIInvocationAsyncShouldAggregateContextsFromAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart1 = new Mock(); + var mockPart2 = new Mock(); + mockPart1.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync("Context1"); + mockPart2.Setup(x => x.OnModelInvokeAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync("Context2"); + manager.Add(mockPart1.Object); + manager.Add(mockPart2.Object); + + var messages = new List(); + + // Act + var result = await manager.OnModelInvokeAsync(messages); + + // Assert + Assert.Equal("Context1\nContext2", result); + } + + [Fact] + public async Task OnSuspendAsyncShouldCallOnSuspendOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + manager.Add(mockPart.Object); + + // Act + await manager.OnSuspendAsync("test-thread-id"); + + // Assert + mockPart.Verify(x => x.OnSuspendAsync("test-thread-id", It.IsAny()), Times.Once); + } + + [Fact] + public async Task OnResumeAsyncShouldCallOnResumeOnAllParts() + { + // Arrange + var manager = new ConversationStatePartsManager(); + var mockPart = new Mock(); + manager.Add(mockPart.Object); + + // Act + await manager.OnResumeAsync("test-thread-id"); + + // Assert + mockPart.Verify(x => x.OnResumeAsync("test-thread-id", It.IsAny()), Times.Once); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/Mem0MemoryComponentTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/Mem0MemoryComponentTests.cs new file mode 100644 index 000000000000..3929d220aefd --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/Mem0MemoryComponentTests.cs @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft. All rights reserved. + +#pragma warning disable CA1054 // URI-like parameters should not be strings + +using System; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.Memory; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for the class. +/// +public class Mem0MemoryComponentTests : IDisposable +{ + private readonly HttpClient _httpClient; + private readonly Mock _mockMessageHandler; + private bool _disposedValue; + + public Mem0MemoryComponentTests() + { + this._mockMessageHandler = new Mock() { CallBase = true }; + this._httpClient = new HttpClient(this._mockMessageHandler.Object) + { + BaseAddress = new Uri("https://localhost/fakepath") + }; + } + + [Fact] + public void ValidatesHttpClientBaseAddress() + { + // Arrange + using var httpClientWithoutBaseAddress = new HttpClient(); + + // Act & Assert + var exception = Assert.Throws(() => + { + new Mem0MemoryComponent(httpClientWithoutBaseAddress); + }); + + Assert.Equal("The BaseAddress of the provided httpClient parameter must be set. (Parameter 'httpClient')", exception.Message); + } + + [Fact] + public void AIFunctionsAreSetCorrectly() + { + // Arrange + var sut = new Mem0MemoryComponent(this._httpClient, new() { ApplicationId = "test-app-id" }); + + // Act + var aiFunctions = sut.AIFunctions; + + // Assert + Assert.NotNull(aiFunctions); + Assert.Single(aiFunctions); + Assert.Equal("ClearStoredUserFacts", aiFunctions.First().Name); + } + + [Theory] + [InlineData(false, "test-thread-id")] + [InlineData(true, "test-thread-id-1")] + public async Task PostsMemoriesOnNewMessage(bool scopePerOperationThread, string expectedThreadId) + { + // Arrange + using var httpResponse = new HttpResponseMessage() { StatusCode = System.Net.HttpStatusCode.OK }; + this._mockMessageHandler + .Setup(x => x.MockableSendAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(httpResponse); + + var sut = new Mem0MemoryComponent(this._httpClient, new() { ApplicationId = "test-app-id", AgentId = "test-agent-id", ThreadId = "test-thread-id", UserId = "test-user-id", ScopeToPerOperationThreadId = scopePerOperationThread }); + + // Act + await sut.OnNewMessageAsync("test-thread-id-1", new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe.")); + + // Assert + var expectedPayload = $$""" + {"app_id":"test-app-id","agent_id":"test-agent-id","run_id":"{{expectedThreadId}}","user_id":"test-user-id","messages":[{"content":"Hello, my name is Caoimhe.","role":"user"}]} + """; + this._mockMessageHandler.Verify(x => x.MockableSendAsync(HttpMethod.Post, "https://localhost/v1/memories/", expectedPayload, It.IsAny()), Times.Once); + } + + [Theory] + [InlineData(false, "test-thread-id", null, "Consider the following memories when answering user questions:{0}Name is Caoimhe")] + [InlineData(true, "test-thread-id-1", "Custom Prompt:", "Custom Prompt:{0}Name is Caoimhe")] + public async Task SearchesForMemoriesOnModelInvoke(bool scopePerOperationThread, string expectedThreadId, string? customContextPrompt, string expectedAdditionalInstructions) + { + // Arrange + var expectedResponseString = """ + [{"id":"1","memory":"Name is Caoimhe","hash":"abc123","metadata":null,"score":0.9,"created_at":"2023-01-01T00:00:00Z","updated_at":null,"user_id":"test-user-id","app_id":null,"agent_id":"test-agent-id","session_id":"test-thread-id-1"}] + """; + using var httpResponse = new HttpResponseMessage() + { + StatusCode = System.Net.HttpStatusCode.OK, + Content = new StringContent(expectedResponseString, Encoding.UTF8, "application/json") + }; + + this._mockMessageHandler + .Setup(x => x.MockableSendAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(httpResponse); + + var sut = new Mem0MemoryComponent(this._httpClient, new() + { + ApplicationId = "test-app-id", + AgentId = "test-agent-id", + ThreadId = "test-thread-id", + UserId = "test-user-id", + ScopeToPerOperationThreadId = scopePerOperationThread, + ContextPrompt = customContextPrompt + }); + await sut.OnThreadCreatedAsync("test-thread-id-1"); + + // Act + var actual = await sut.OnModelInvokeAsync(new[] { new ChatMessage(ChatRole.User, "What is my name?") }); + + // Assert + var expectedPayload = $$""" + {"app_id":"test-app-id","agent_id":"test-agent-id","run_id":"{{expectedThreadId}}","user_id":"test-user-id","query":"What is my name?"} + """; + this._mockMessageHandler.Verify(x => x.MockableSendAsync(HttpMethod.Post, "https://localhost/v1/memories/search/", expectedPayload, It.IsAny()), Times.Once); + + Assert.Equal(string.Format(expectedAdditionalInstructions, Environment.NewLine), actual); + } + + [Theory] + [InlineData(false, "test-thread-id")] + [InlineData(true, "test-thread-id-1")] + public async Task ClearsStoredUserFacts(bool scopePerOperationThread, string expectedThreadId) + { + // Arrange + using var httpResponse = new HttpResponseMessage() { StatusCode = System.Net.HttpStatusCode.OK }; + this._mockMessageHandler + .Setup(x => x.MockableSendAsync(It.IsAny(), It.IsAny(), null, It.IsAny())) + .ReturnsAsync(httpResponse); + + var sut = new Mem0MemoryComponent(this._httpClient, new() { ApplicationId = "test-app-id", AgentId = "test-agent-id", ThreadId = "test-thread-id", UserId = "test-user-id", ScopeToPerOperationThreadId = scopePerOperationThread }); + await sut.OnThreadCreatedAsync("test-thread-id-1"); + + // Act + await sut.ClearStoredUserFactsAsync(); + + // Assert + var expectedUrl = $"https://localhost/v1/memories/?app_id=test-app-id&agent_id=test-agent-id&run_id={expectedThreadId}&user_id=test-user-id"; + this._mockMessageHandler.Verify(x => x.MockableSendAsync(HttpMethod.Delete, expectedUrl, null, It.IsAny()), Times.Once); + } + + [Fact] + public async Task ThrowsExceptionWhenThreadIdChangesAfterBeingSet() + { + // Arrange + var sut = new Mem0MemoryComponent(this._httpClient, new() { ScopeToPerOperationThreadId = true }); + + // Act + await sut.OnThreadCreatedAsync("initial-thread-id"); + + // Assert + var exception = await Assert.ThrowsAsync(async () => + { + await sut.OnThreadCreatedAsync("new-thread-id"); + }); + + Assert.Equal("The Mem0MemoryComponent can only be used with one thread at a time when ScopeToPerOperationThreadId is set to true.", exception.Message); + } + + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._httpClient.Dispose(); + } + + this._disposedValue = true; + } + } + + public void Dispose() + { + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + public class MockableMessageHandler : DelegatingHandler + { + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + string? contentString = request.Content is null ? null : await request.Content.ReadAsStringAsync(cancellationToken); + return await this.MockableSendAsync(request.Method, request.RequestUri?.AbsoluteUri, contentString, cancellationToken); + } + + public virtual Task MockableSendAsync(HttpMethod method, string? absoluteUri, string? content, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs new file mode 100644 index 000000000000..7f7e56f73369 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagComponentTests.cs @@ -0,0 +1,231 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.Data; +using Microsoft.SemanticKernel.Memory; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +/// +/// Contains tests for +/// +public class TextRagComponentTests +{ + [Theory] + [InlineData(null, null, "Consider the following information when responding to the user:", "Include citations to the relevant information where it is referenced in the response.")] + [InlineData("Custom context prompt", "Custom citations prompt", "Custom context prompt", "Custom citations prompt")] + public async Task OnModelInvokeShouldIncludeSearchResultsInOutputAsync( + string? overrideContextPrompt, + string? overrideCitationsPrompt, + string expectedContextPrompt, + string expectedCitationsPrompt) + { + // Arrange + var mockTextSearch = new Mock(); + var searchResults = new Mock>(); + var mockEnumerator = new Mock>(); + + // Mock search results + var results = new List + { + new("Content of Doc1") { Name = "Doc1", Link = "http://example.com/doc1" }, + new("Content of Doc2") { Name = "Doc2", Link = "http://example.com/doc2" } + }; + + mockEnumerator.SetupSequence(e => e.MoveNextAsync()) + .ReturnsAsync(true) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockEnumerator.SetupSequence(e => e.Current) + .Returns(results[0]) + .Returns(results[1]); + + searchResults.Setup(r => r.GetAsyncEnumerator(It.IsAny())) + .Returns(mockEnumerator.Object); + + mockTextSearch.Setup(ts => ts.GetTextSearchResultsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new KernelSearchResults(searchResults.Object)); + + var options = new TextRagComponentOptions + { + SearchTime = TextRagComponentOptions.RagBehavior.BeforeAIInvoke, + Top = 2, + ContextPrompt = overrideContextPrompt, + IncludeCitationsPrompt = overrideCitationsPrompt + }; + + var component = new TextRagComponent(mockTextSearch.Object, options); + + // Act + var result = await component.OnModelInvokeAsync([new ChatMessage(ChatRole.User, "Sample user question?")], CancellationToken.None); + + // Assert + Assert.Contains(expectedContextPrompt, result); + Assert.Contains("Item 1:", result); + Assert.Contains("Name: Doc1", result); + Assert.Contains("Link: http://example.com/doc1", result); + Assert.Contains("Contents: Content of Doc1", result); + Assert.Contains("Item 2:", result); + Assert.Contains("Name: Doc2", result); + Assert.Contains("Link: http://example.com/doc2", result); + Assert.Contains("Contents: Content of Doc2", result); + Assert.Contains(expectedCitationsPrompt, result); + } + + [Theory] + [InlineData(null, null, "Search", "Allows searching for additional information to help answer the user question.")] + [InlineData("CustomSearch", "CustomDescription", "CustomSearch", "CustomDescription")] + public void AIFunctionsShouldBeRegisteredCorrectly( + string? overridePluginFunctionName, + string? overridePluginFunctionDescription, + string expectedPluginFunctionName, + string expectedPluginFunctionDescription) + { + // Arrange + var mockTextSearch = new Mock(); + var options = new TextRagComponentOptions + { + SearchTime = TextRagComponentOptions.RagBehavior.ViaPlugin, + PluginFunctionName = overridePluginFunctionName, + PluginFunctionDescription = overridePluginFunctionDescription + }; + + var component = new TextRagComponent(mockTextSearch.Object, options); + + // Act + var aiFunctions = component.AIFunctions; + + // Assert + Assert.NotNull(aiFunctions); + Assert.Single(aiFunctions); + var aiFunction = aiFunctions.First(); + Assert.Equal(expectedPluginFunctionName, aiFunction.Name); + Assert.Equal(expectedPluginFunctionDescription, aiFunction.Description); + } + + [Theory] + [InlineData(null, null, "Consider the following information when responding to the user:", "Include citations to the relevant information where it is referenced in the response.")] + [InlineData("Custom context prompt", "Custom citations prompt", "Custom context prompt", "Custom citations prompt")] + public async Task SearchAsyncShouldIncludeSearchResultsInOutputAsync( + string? overrideContextPrompt, + string? overrideCitationsPrompt, + string expectedContextPrompt, + string expectedCitationsPrompt) + + { + // Arrange + var mockTextSearch = new Mock(); + var searchResults = new Mock>(); + var mockEnumerator = new Mock>(); + + // Mock search results + var results = new List + { + new("Content of Doc1") { Name = "Doc1", Link = "http://example.com/doc1" }, + new("Content of Doc2") { Name = "Doc2", Link = "http://example.com/doc2" } + }; + + mockEnumerator.SetupSequence(e => e.MoveNextAsync()) + .ReturnsAsync(true) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockEnumerator.SetupSequence(e => e.Current) + .Returns(results[0]) + .Returns(results[1]); + + searchResults.Setup(r => r.GetAsyncEnumerator(It.IsAny())) + .Returns(mockEnumerator.Object); + + mockTextSearch.Setup(ts => ts.GetTextSearchResultsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new KernelSearchResults(searchResults.Object)); + + var options = new TextRagComponentOptions + { + ContextPrompt = overrideContextPrompt, + IncludeCitationsPrompt = overrideCitationsPrompt + }; + + var component = new TextRagComponent(mockTextSearch.Object, options); + + // Act + var result = await component.SearchAsync("Sample user question?", CancellationToken.None); + + // Assert + Assert.Contains(expectedContextPrompt, result); + Assert.Contains("Item 1:", result); + Assert.Contains("Name: Doc1", result); + Assert.Contains("Link: http://example.com/doc1", result); + Assert.Contains("Contents: Content of Doc1", result); + Assert.Contains("Item 2:", result); + Assert.Contains("Name: Doc2", result); + Assert.Contains("Link: http://example.com/doc2", result); + Assert.Contains("Contents: Content of Doc2", result); + Assert.Contains(expectedCitationsPrompt, result); + } + + [Fact] + public async Task OnModelInvokeShouldUseOverrideContextFormatterIfProvidedAsync() + { + // Arrange + var mockTextSearch = new Mock(); + var searchResults = new Mock>(); + var mockEnumerator = new Mock>(); + + // Mock search results + var results = new List + { + new("Content of Doc1") { Name = "Doc1", Link = "http://example.com/doc1" }, + new("Content of Doc2") { Name = "Doc2", Link = "http://example.com/doc2" } + }; + + mockEnumerator.SetupSequence(e => e.MoveNextAsync()) + .ReturnsAsync(true) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockEnumerator.SetupSequence(e => e.Current) + .Returns(results[0]) + .Returns(results[1]); + + searchResults.Setup(r => r.GetAsyncEnumerator(It.IsAny())) + .Returns(mockEnumerator.Object); + + mockTextSearch.Setup(ts => ts.GetTextSearchResultsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new KernelSearchResults(searchResults.Object)); + + var customFormatter = new TextRagComponentOptions.ContextFormatterType(results => + $"Custom formatted context with {results.Count} results."); + + var options = new TextRagComponentOptions + { + SearchTime = TextRagComponentOptions.RagBehavior.BeforeAIInvoke, + Top = 2, + ContextFormatter = customFormatter + }; + + var component = new TextRagComponent(mockTextSearch.Object, options); + + // Act + var result = await component.OnModelInvokeAsync([new ChatMessage(ChatRole.User, "Sample user question?")], CancellationToken.None); + + // Assert + Assert.Equal("Custom formatted context with 2 results.", result); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs new file mode 100644 index 000000000000..8d91b09b3deb --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Memory/TextRagStoreTests.cs @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticKernel.Memory.TextRag; +using Moq; +using Xunit; + +namespace SemanticKernel.UnitTests.Memory; + +public class TextRagStoreTests +{ + private readonly Mock _vectorStoreMock; + private readonly Mock _embeddingServiceMock; + private readonly Mock.TextRagStorageDocument>> _recordCollectionMock; + + public TextRagStoreTests() + { + this._vectorStoreMock = new Mock(); + this._recordCollectionMock = new Mock.TextRagStorageDocument>>(); + this._embeddingServiceMock = new Mock(); + + this._vectorStoreMock + .Setup(v => v.GetCollection.TextRagStorageDocument>("testCollection", It.IsAny())) + .Returns(this._recordCollectionMock.Object); + + this._embeddingServiceMock + .Setup(e => e.GenerateEmbeddingsAsync(It.IsAny>(), null, It.IsAny())) + .ReturnsAsync(new[] { new ReadOnlyMemory(new float[128]) }); + } + + [Fact] + public async Task UpsertDocumentsAsyncThrowsWhenDocumentsAreNull() + { + // Arrange + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + // Act & Assert + await Assert.ThrowsAsync(() => store.UpsertDocumentsAsync(null!)); + } + + [Theory] + [InlineData(null)] + [InlineData(" ")] + public async Task UpsertDocumentsAsyncThrowsDocumentTextIsNullOrWhiteSpace(string? text) + { + // Arrange + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + var documents = new List + { + new() { Text = text } + }; + + // Act & Assert + await Assert.ThrowsAsync(() => store.UpsertDocumentsAsync(documents)); + } + + [Fact] + public async Task UpsertDocumentsAsyncCreatesCollectionGeneratesVectorAndUpsertsDocument() + { + // Arrange + this._recordCollectionMock + .Setup(r => r.UpsertBatchAsync(It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .Returns(AsyncEnumerable.Empty()); + + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + var documents = new List + { + new() { Text = "Sample text", Namespaces = ["ns1"], SourceId = "sid", SourceLink = "sl", SourceName = "sn" } + }; + + // Act + await store.UpsertDocumentsAsync(documents); + + // Assert + this._recordCollectionMock.Verify(r => r.CreateCollectionIfNotExistsAsync(It.IsAny()), Times.Once); + this._embeddingServiceMock.Verify(e => e.GenerateEmbeddingsAsync(It.Is>(texts => texts.Count == 1 && texts[0] == "Sample text"), null, It.IsAny()), Times.Once); + this._recordCollectionMock.Verify(r => r.UpsertBatchAsync( + It.Is.TextRagStorageDocument>>(doc => + doc.Count() == 1 && + doc.First().Text == "Sample text" && + doc.First().Namespaces.Count == 1 && + doc.First().Namespaces[0] == "ns1" && + doc.First().SourceId == "sid" && + doc.First().SourceLink == "sl" && + doc.First().SourceName == "sn" && + doc.First().TextEmbedding.Length == 128), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task UpsertDocumentsAsyncCreatesCollectionGeneratesVectorAndUpsertsDocumentWithSourceIdAsId() + { + // Arrange + this._recordCollectionMock + .Setup(r => r.UpsertBatchAsync(It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .Returns(AsyncEnumerable.Empty()); + + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128, new() { UseSourceIdAsPrimaryKey = true }); + + var documents = new List + { + new() { Text = "Sample text", Namespaces = ["ns1"], SourceId = "sid", SourceLink = "sl", SourceName = "sn" } + }; + + // Act + await store.UpsertDocumentsAsync(documents); + + // Assert + this._recordCollectionMock.Verify(r => r.UpsertBatchAsync( + It.Is.TextRagStorageDocument>>(doc => + doc.Count() == 1 && + doc.First().Key == "sid" && + doc.First().Text == "Sample text" && + doc.First().Namespaces.Count == 1 && + doc.First().Namespaces[0] == "ns1" && + doc.First().SourceId == "sid" && + doc.First().SourceLink == "sl" && + doc.First().SourceName == "sn" && + doc.First().TextEmbedding.Length == 128), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task UpsertDocumentsAsyncCreatesCollectionGeneratesVectorAndUpsertsDocumentWithoutSourceText() + { + // Arrange + this._recordCollectionMock + .Setup(r => r.UpsertBatchAsync(It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .Returns(AsyncEnumerable.Empty()); + + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + var documents = new List + { + new() { Text = "Sample text", Namespaces = ["ns1"], SourceId = "sid", SourceLink = "sl", SourceName = "sn" } + }; + + // Act + await store.UpsertDocumentsAsync(documents, new() { PersistSourceText = false }); + + // Assert + this._recordCollectionMock.Verify(r => r.UpsertBatchAsync( + It.Is.TextRagStorageDocument>>(doc => + doc.Count() == 1 && + doc.First().Text == null && + doc.First().Namespaces.Count == 1 && + doc.First().Namespaces[0] == "ns1" && + doc.First().SourceId == "sid" && + doc.First().SourceLink == "sl" && + doc.First().SourceName == "sn" && + doc.First().TextEmbedding.Length == 128), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task SearchAsyncReturnsSearchResults() + { + // Arrange + var mockResults = new List.TextRagStorageDocument>> + { + new(new TextRagStore.TextRagStorageDocument { Text = "Sample text" }, 0.9f) + }; + + this._recordCollectionMock + .Setup(r => r.VectorizedSearchAsync(It.IsAny>(), It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .ReturnsAsync(new VectorSearchResults.TextRagStorageDocument>(mockResults.ToAsyncEnumerable())); + + using var store = new TextRagStore(this._vectorStoreMock.Object, this._embeddingServiceMock.Object, "testCollection", 128); + + // Act + var actualResults = await store.SearchAsync("query"); + + // Assert + var actualResultsList = await actualResults.Results.ToListAsync(); + Assert.Single(actualResultsList); + Assert.Equal("Sample text", actualResultsList[0]); + } + + [Fact] + public async Task SearchAsyncWithHydrationCallsCallbackAndReturnsSearchResults() + { + // Arrange + var mockResults = new List.TextRagStorageDocument>> + { + new(new TextRagStore.TextRagStorageDocument { SourceId = "sid1", SourceLink = "sl1", Text = "Sample text 1" }, 0.9f), + new(new TextRagStore.TextRagStorageDocument { SourceId = "sid2", SourceLink = "sl2" }, 0.9f), + new(new TextRagStore.TextRagStorageDocument { SourceId = "sid3", SourceLink = "sl3", Text = "Sample text 3" }, 0.9f), + }; + + this._recordCollectionMock + .Setup(r => r.VectorizedSearchAsync(It.IsAny>(), It.IsAny.TextRagStorageDocument>>(), It.IsAny())) + .ReturnsAsync(new VectorSearchResults.TextRagStorageDocument>(mockResults.ToAsyncEnumerable())); + + using var store = new TextRagStore( + this._vectorStoreMock.Object, + this._embeddingServiceMock.Object, + "testCollection", + 128, + new() + { + SourceRetrievalCallback = sourceIds => + { + Assert.Single(sourceIds); + Assert.Equal("sid2", sourceIds[0].sourceId); + Assert.Equal("sl2", sourceIds[0].sourceLink); + + return Task.FromResult>([("sid2", "sl2", "Sample text 2")]); + } + }); + + // Act + var actualResults = await store.SearchAsync("query"); + + // Assert + var actualResultsList = await actualResults.Results.ToListAsync(); + Assert.Equal(3, actualResultsList.Count); + Assert.Equal("Sample text 1", actualResultsList[0]); + Assert.Equal("Sample text 2", actualResultsList[1]); + Assert.Equal("Sample text 3", actualResultsList[2]); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj b/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj index 8580c9a173ab..7bb2d937033f 100644 --- a/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj +++ b/dotnet/src/SemanticKernel.UnitTests/SemanticKernel.UnitTests.csproj @@ -6,7 +6,7 @@ net8.0 true false - $(NoWarn);CA2007,CA1861,IDE1006,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050,SKEXP0110,SKEXP0120 + $(NoWarn);CA2007,CA1861,IDE1006,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0050,SKEXP0110,SKEXP0120,SKEXP0130